import os
import sys
import time
import re
import math
import psutil
import numpy as np
#import sympy

#pi = sympy.pi
#e = sympy.e

start_time = time.time()

def process_memory():
  process = psutil.Process(os.getpid())
  mem_info = process.memory_info()
  return mem_info.rss

def profile(func):
  def wrapper(*args, **kwargs):
    mem_before = process_memory()
    result = func(*args, **kwargs)
    mem_after = process_memory()
    print("{}:consumed memory: {:,}".format(
      func.__name__,
      mem_before, mem_after, mem_after - mem_before))
    return result
  return wrapper

class QuantumRegister:
  def __init__(self, num_qubits):
    self.num_qubits = num_qubits
    self.state = np.zeros(2**num_qubits, dtype=np.complex128)
    self.state[0] = 1.0 # Initialize to |0>

  def apply_gate(self, gate, target_qubit=None, control_qubits=None):
    #print_gate("gate", gate)
    if control_qubits is None:
      # Uncontrolled gate
      if target_qubit is None:
        raise ValueError("Target qubit must be specified for uncontrolled gate.")
      #if isinstance(target_qubits, int):
      #  target_qubits = [target_qubits]
      #target_qubits = sorted(target_qubits)
      gate_matrix = np.kron(np.eye(int(2**(self.num_qubits - target_qubit - 1)), dtype=np.complex128), np.kron(gate, np.eye(2**target_qubit, dtype=np.complex128)))
    else:
      # Controlled gate
      if target_qubit is None or control_qubits is None:
        raise ValueError("Both target and control qubits must be specified for controlled gates.")
      if isinstance(control_qubits, int):
        control_qubits = [control_qubits]
      gate_matrix = np.kron(np.eye(int(2**(self.num_qubits - target_qubit - 1)), dtype=np.complex128), np.kron(gate, np.eye(2**target_qubit, dtype=np.complex128)))
      #print_gate("init gate_matrix", gate_matrix)
      controlled_gate = np.zeros((2**self.num_qubits, 2**self.num_qubits), dtype=np.complex128)
      for i in range(2**(self.num_qubits)):
        if all((i // 2**q) % 2 == 1 for q in control_qubits):
          for j, a in enumerate(gate_matrix[i,:]):
            controlled_gate[i,j] = a
        else:
          controlled_gate[i, i] = 1
      gate_matrix = controlled_gate #@ gate_matrix
      #print_gate("controlled_gate", controlled_gate)
    #print_gate("gate_matrix", gate_matrix)
    self.state = gate_matrix @ self.state
  def apple_gate(self, gate, target_qubit=None, control_qubit=None):
    print_gate("gate", gate)
    if gate.shape[0] == 2: #single qubit gate
      if target_qubit is None:
        raise ValueError("Target qubit must be specified for single qubit gate.")
      gate_matrix = np.kron(np.eye(int(2**(self.num_qubits - target_qubit - 1)), dtype=np.complex128), np.kron(gate, np.eye(2**target_qubit, dtype=np.complex128)))
    else: #two+ qubit gate
      if target_qubit is None or control_qubit is None:
        raise ValueError("Both target and control qubits must be specified for two qubit gate.")
      # Apply controlled gate
      #control_qubit, target_qubit = min(control_qubit, target_qubit), max(control_qubit, target_qubit)
      #print("control: ", control_qubit)
      #print("target: ", target_qubit)
      gate_matrix = np.eye(2**(self.num_qubits - 2), dtype=np.complex128)
      #gate_matrix = np.kron(gate, gate_matrix)
      #gate_matrix = np.kron(gate_matrix, gate)
      #gate_matrix = np.kron(gate, gate_matrix) if target_qubit > 0 else np.kron(gate_matrix, gate)
      gate_matrix = np.kron(gate_matrix, gate) if target_qubit == self.num_qubits - 1 else np.kron(gate, gate_matrix)
      #gate_matrix = np.kron(np.eye(int(2**(self.num_qubits - target_qubit - 1)), dtype=np.complex128), np.kron(gate[2:,2:], np.eye(2**target_qubit, dtype=np.complex128)))
      print_gate("init gate_matrix", gate_matrix)
      controlled_gate = np.zeros((2**self.num_qubits, 2**self.num_qubits), dtype=np.complex128)
      for i in range(2**(self.num_qubits)):
        if (i // 2**control_qubit) % 2 == 1:
          controlled_gate[i, i ^ (1 << target_qubit)] = 1
        else:
          controlled_gate[i, i] = 1
      gate_matrix = controlled_gate @ gate_matrix
      #print_gate("controlled_gate", controlled_gate)
    if self.num_qubits < 5:
      print_gate("gate_matrix", gate_matrix)
    self.state = gate_matrix @ self.state
    #self.state /= np.linalg.norm(self.state)
  def appled_gate(self, gate, target_qubits, control_qubits=None):
    # Ensure target_qubits and gate dimensions match
    if isinstance(target_qubits, int):
        target_qubits = [target_qubits]

    if control_qubits is None:
      if gate.shape[0] != 2**len(target_qubits):
          raise ValueError("Gate dimensions do not match the number of target qubits.")

    # Ensure control_qubits and gate dimensions match
    if control_qubits is not None:
        if isinstance(control_qubits, int):
            control_qubits = [control_qubits]
        if gate.shape[0] != 2**len(control_qubits) + 2**len(target_qubits):
            raise ValueError("Gate dimensions do not match the number of control qubits.")

    # Apply gate directly to target qubits if no control qubits are specified
    if control_qubits is None:
        # Construct the gate matrix for the target qubits
        gate_matrix = gate
        for target_qubit in reversed(target_qubits):
            gate_matrix = np.kron(gate_matrix, np.eye(2**(target_qubit), dtype=np.complex128))
            gate_matrix = np.kron(np.eye(int(2**(self.num_qubits - target_qubit - 1)), dtype=np.complex128), gate_matrix)

    # Apply gate with control qubits
    else:
        # Ensure control_qubits are sorted in ascending order
        control_qubits = sorted(control_qubits)

        # Construct the controlled gate matrix
        control_mask = sum(1 << qubit for qubit in control_qubits)
        control_states = [(state, state ^ control_mask) for state in range(2**self.num_qubits)]
        controlled_gate_matrix = np.zeros((2**self.num_qubits, 2**self.num_qubits), dtype=np.complex128)
        for control_state, target_state in control_states:
            controlled_gate_matrix[control_state, target_state] = 1

        # Construct the gate matrix for the target qubits
        gate_matrix = gate
        for target_qubit in reversed(target_qubits):
            gate_matrix = np.kron(gate_matrix, np.eye(2**(target_qubit), dtype=np.complex128))
            gate_matrix = np.kron(np.eye(2**(self.num_qubits - target_qubit - 2), dtype=np.complex128), gate_matrix)

        # Apply the controlled gate
        print(gate_matrix.shape, controlled_gate_matrix.shape, self.num_qubits)
        gate_matrix = controlled_gate_matrix.dot(gate_matrix)

    # Apply the gate to the state vector
    self.state = gate_matrix.dot(self.state)

    # Normalize the state vector
    self.state /= np.linalg.norm(self.state)
  def measure(self):
    probabilities = np.abs(self.state)**2
    probabilities /= np.sum(probabilities)
    outcome = np.random.choice(range(len(probabilities)), p=probabilities)
    self.state = np.zeros_like(self.state)
    self.state[outcome] = 1.0
    #print(self.state)
    return bin(outcome)[2:].zfill(self.num_qubits)
  def measure_qubit(self, target_qubit):
    num_states = len(self.state)
    probabilities = np.abs(self.state)**2

    # Calculate the range of states that correspond to the target qubit being 0 or 1
    states_0 = [i for i in range(num_states) if (i >> target_qubit) % 2 == 0]
    states_1 = [i for i in range(num_states) if (i >> target_qubit) % 2 == 1]

    # Calculate the total probabilities of measuring 0 or 1 for the target qubit
    prob_0 = sum(probabilities[state] for state in states_0)
    prob_1 = sum(probabilities[state] for state in states_1)

    # Randomly choose the measurement outcome based on the probabilities
    outcome = np.random.choice([0, 1], p=[prob_0, prob_1])

    # Update the state vector based on the measurement outcome
    if outcome == 0:
      for state in states_1:
        self.state[state] = 0
      normalization_factor = np.sqrt(prob_0)
    else:
      for state in states_0:
        self.state[state] = 0
      normalization_factor = np.sqrt(prob_1)

    # Normalize the state vector
    self.state /= normalization_factor

    return outcome
  def list_states(self, percentages = True):
    num_states = len(self.state)
    states_with_probabilities = {}
    for i in range(num_states):
      # Convert the state index to binary representation
      binary_state = bin(i)[2:].zfill(self.num_qubits)
      # Calculate the probability of the current state
      probability = self.state[i]
      #probability = self.state[i]
      # Store the state and its probability in the dictionary
      if probability.real > 1e-10 or probability.real < -1e-10 or probability.imag > 1e-10 or probability.imag < -1e-10:
        if percentages:
          probability = abs(probability)**2
        states_with_probabilities[binary_state] = probability
    return states_with_probabilities
  def probabilities(self, percentages = True):
    num_states = len(self.state)
    if percentages:
      probabilities = np.zeros((self.num_qubits, 2))  # Initialize probabilities array
    else:
      probabilities = np.zeros((self.num_qubits, 2), np.complex128)
    for i in range(num_states):
      # Convert the state index to binary representation
      binary_state = bin(i)[2:].zfill(self.num_qubits)
      # Calculate the probability of the current state
      probability = self.state[i]
      if percentages:
        probability = abs(probability)**2
      #probability = self.state[i]
      # Update the probabilities array for each qubit
      for j in range(self.num_qubits):
        qubit_state = int(binary_state[j])
        probabilities[j, qubit_state] += probability
    return probabilities

def print_gate(name, gate):
  print(name, ':')
  #print(' |00>      |01>      |10>      |11>       ')
  for ga in gate:
    print('[', end='')
    for i, ket in enumerate(ga):
      if i > 0:
        print(', ', end='')
      if abs(ket.real) > 0.01 or abs(ket.imag) > 0.01:
        print('\033[31m', end='')
      print("{0.real:+2.0f}{0.imag:+2.0f}".format(ket), end='')
      if abs(ket.real) > 0.01 or abs(ket.imag) > 0.01:
        print('\033[0m', end='')
    print(']')

# Identity gate
I = np.array([[1, 0],
              [0, 1]])

# Pauli-X gate (bit-flip gate)
X = np.array([[0,1], 
              [1,0]], dtype=np.complex128)

# Pauli-Y gate
Y = np.array([[0,-1j], 
              [1j,0]], dtype=np.complex128)

# Pauli-Z gate
Z = np.array([[1,0], 
              [0,-1]], dtype=np.complex128)

# Hadamard gate
H = np.array([[math.sqrt(0.5),math.sqrt(0.5)], 
              [math.sqrt(0.5),-math.sqrt(0.5)]], dtype=np.complex128)

# Phase gate (S, P)
S = np.array([[1, 0],
              [0, 1j]], dtype=np.complex128)

SD = np.array([[1, 0],
               [0, -1j]], dtype=np.complex128)

# pi / 8 gate (T)
T = np.array([[1, 0],
              [0, np.exp(1j * np.pi / 4)]], dtype=np.complex128)

TD = np.array([[1, 0],
               [0, np.exp(-1j * np.pi / 4)]], dtype=np.complex128)

SWAP = np.array([[1, 0, 0, 0],
                 [0, 0, 1, 0],
                 [0, 1, 0, 0],
                 [0, 0, 0, 1]], dtype=np.complex128)

GATES = {
  'X': X,
  'Y': Y,
  'Z': Z,
  'H': H,
  'S': S,
  'SD': SD,
  'T': T,
  'TD': TD,
#  'SWAP': SWAP,
}

def build_controlled_gate(gate, num_control_gates):
  # Calculate the size of the controlled gate matrix
  num_qubits = num_control_gates + int(math.log(gate.shape[0], 2))
  gate_size = 2 ** num_qubits
  controlled_gate = np.eye(gate_size, dtype=np.complex128)
  # Calculate the position to insert the gate
  target_position = 2 ** num_qubits - gate.shape[0]
  # Insert the gate into the controlled gate matrix
  controlled_gate[target_position:target_position + gate.shape[0], target_position:target_position + gate.shape[0]] = gate
  return controlled_gate

def quantum_parse(quantum_code):
  instructions = quantum_code.strip().split("\n")
  parsed_instructions = []
  for instruction in instructions:
    instruction = instruction.strip()
    if instruction != "" and not instruction.startswith('#'):
      parsed_instruction = instruction.split(" ")
      parsed_instructions.append(parsed_instruction)
  return parsed_instructions

#@profile
def quantum_interpreter(quantum_code, qreg=None):
  parsed_instructions = quantum_parse(quantum_code)
  if qreg is None:
    num_qubits = int(parsed_instructions[0][1])
    qreg = QuantumRegister(num_qubits)
  parsed_instructions = parsed_instructions[1:]
  for instruction in parsed_instructions:
    #print(*instruction)
    cmd = re.findall(r'(C*)(\d*)(\w+)', instruction[0])[0]
    if cmd[2] in GATES.keys():
      control_qubits = []
      if cmd[1] != '':
        num_control_qubits = int(cmd[1])
      else:
        num_control_qubits = len(cmd[0])
      for q in range(num_control_qubits):
        control_qubits.append(int(instruction[q + 1]))
      target_qubit = int(instruction[num_control_qubits + 1])
      qreg.apply_gate(GATES[cmd[2]], target_qubit, control_qubits)      
    elif instruction[0] == 'Q':
      filename = instruction[1].lower()
      if not os.path.isfile(filename):
        filename += ".q"
      if os.path.isfile(filename):
        with open(filename) as file:
          code = file.read()
          code = code.format(*instruction[2:])
        qreg = quantum_interpreter(code, qreg)
    elif instruction[0] == 'PEEK':
      probabilities = qreg.probabilities()
      states = qreg.list_states()
      print("qubit probabilities:", *probabilities)
      print("qreg probabilities:", states)
    elif instruction[0] == 'PEEK_Q':
      probabilities = qreg.probabilities(percentages=False)
      probabilities_p = qreg.probabilities()
      print("qubit probabilities:")
      for q, q_p in zip(probabilities, probabilities_p):
        print("[{0.real:+.2f}{0.imag:+.2f}j, {1.real:+.2f}{1.imag:+.2f}j] ({2:3.0f}%, {3:3.0f}%)".format(q[0], q[1], q_p[0] * 100, q_p[1] * 100))
    elif instruction[0] == 'PEEK_S':
      states = qreg.list_states(percentages=False)
      states_p = qreg.list_states()
      print("qreg probabilities:")
      for s in states:
        print("|{0}>: [{1.real:+.2f}{1.imag:+.2f}] ({2:6.2f}%)".format(s, states[s], states_p[s] * 100))
    elif instruction[0] == 'STATE':
      for i in range(2**qreg.num_qubits):
        print('|{:08b}>'.format(i), end='')
      print()
      print('[', end='')
      for i, s in enumerate(qreg.state):
        #print(s)
        if i > 0:
          print(', ', end='') 
        print("{0.real:+4.1f}{0.imag:+4.1f}".format(s), end='')
      print(']')
    elif instruction[0] == 'MEASURE':
      #probabilities = qreg.probabilities()
      #states = qreg.list_states()
      result = qreg.measure()
      #print("qubit probabilities:", *probabilities)
      #print("qreg probabilities:", states)
      print("measured state: |{}>".format(result))
    elif instruction[0] == 'MEASURE_Q':
      result = qreg.measure_qubit(int(instruction[1]))
      print("measured state: |{}>".format(result))
    elif instruction[0] == 'TIME':
      print("time: ", time.time() - start_time)
    elif instruction[0] == 'PRINT':
      print(*instruction[1:])
    elif os.path.isfile(instruction[0].lower() + ".q"):
      with open(instruction[0].lower() + ".q") as file:
        code = file.read()
        code = code.format(*instruction[1:])
      qreg = quantum_interpreter(code, qreg)
    else:
      print("No command", instruction[0])
  return qreg

filename = sys.argv[1]
if os.path.isfile(filename):
  with open(filename) as file:
    quantum_code = file.read()
else:
  quantum_code = """
QREG 2
H 0
CX 0 1
MEASURE
"""

result = quantum_interpreter(quantum_code)
