import os
import sys
import time
import re
import psutil
import numpy as np
import sympy
import sympy.physics.quantum
import CompactMatrix as cm

pi = sympy.pi
I = sympy.I
exp = sympy.exp
sqrt = sympy.sqrt
abs = sympy.Abs
#log = sympy.log
real = sympy.re
imag = sympy.im

#eye = sympy.eye
#zeros = sympy.zeros
#ones = sympy.ones
#kron = sympy.physics.quantum.TensorProduct
#Matrix = sympy.Matrix
eye = cm.eye
zeros = cm.zeros
#kron = cm.kron
#dot = cm.dot
Matrix = cm.Matrix

start_time = time.time()
last_time = start_time

def process_memory():
  process = psutil.Process(os.getpid())
  mem_info = process.memory_info()
  return mem_info.rss

def profile(func):
  def wrapper(*args, **kwargs):
    mem_before = process_memory()
    result = func(*args, **kwargs)
    mem_after = process_memory()
    print("{}:consumed memory: {:,}".format(
      func.__name__,
      mem_before, mem_after, mem_after - mem_before))
    return result
  return wrapper

class QuantumRegister:
  def __init__(self, num_qubits):
    self.num_qubits = num_qubits
    self.state = zeros(2**num_qubits, 1)
    self.state.state[0][0] = 1 # Initialize to |0>

  def apply_gate(self, gate, target_qubits=None, control_qubits=None):
    if target_qubits is None:
      raise ValueError("Target qubit must be specified for uncontrolled gate.")
    if isinstance(target_qubits, int):
      target_qubits = [target_qubits]
    target_qubits = sorted(target_qubits)
    a = [0] * self.num_qubits
    for q in target_qubits:
      a[q] = 1
    #print(target_qubits, a, control_qubits)
    #print("gate", gate)
    #if control_qubits is None or control_qubits == []:
      # Uncontrolled gate
      #if isinstance(target_qubits, int):
      #  target_qubits = [target_qubits]
      #target_qubits = sorted(target_qubits)
    gate_matrix = eye(1)
    #print("kron start")
    for i in range(self.num_qubits):
      if a[i]:
        gate_matrix = gate.kron(gate_matrix)
      else:
        gate_matrix = eye(2).kron(gate_matrix)
      #print(f"round {i} done")
    #print("kron done")
      #gate_matrix_b = kron(eye(int(2**(self.num_qubits - target_qubits[0] - 1))), kron(gate, eye(2**target_qubits[0])))
      #print(gate_matrix)
      #print(gate_matrix_b)
    if control_qubits != None and control_qubits != []:
      # Controlled gate
      if isinstance(control_qubits, int):
        control_qubits = [control_qubits]
      #gate_matrix = kron(eye(int(2**(self.num_qubits - target_qubits[0] - 1))), kron(gate, eye(2**target_qubits[0])))
      #print_gate("init gate_matrix", gate_matrix)
      controlled_gate = zeros(2**self.num_qubits, 2**self.num_qubits)
      #print("begin build gate")
      for i in range(2**(self.num_qubits)):
        #print("q1 ", end='')
        if all((i // 2**q) % 2 == 1 for q in control_qubits):
          for j in gate_matrix.state[i]:
            controlled_gate.state[i][j] = gate_matrix.state[i][j]
        else:
          controlled_gate.state[i][i] = 1
      #print("gate complete")
      gate_matrix = controlled_gate #@ gate_matrix
      #print_gate("controlled_gate", controlled_gate)
    #print_gate("gate_matrix", gate_matrix.get_matrix())
    self.state = gate_matrix.dot(self.state)
    #print(self.state.state)
    #print_gate("state", self.state.get_matrix())

  def measure(self):
    probabilities = np.abs(self.state.get_matrix()) ** 2
    #print(probabilities)
    probabilities = np.array(probabilities, dtype=np.float64).flatten()
    #probabilities /= sum(probabilities)
    outcome = np.random.choice(range(len(probabilities)), p=probabilities)
    self.state = zeros(*self.state.shape)
    self.state.state[0][outcome] = 1
    #print(self.state)
    return bin(outcome)[2:].zfill(self.num_qubits)

  def measure_qubit(self, target_qubit):
    num_states = 2 ** self.num_qubits #len(self.state)
    probabilities = np.abs(self.state.get_matrix())**2

    # Calculate the range of states that correspond to the target qubit being 0 or 1
    states_0 = [i for i in range(num_states) if (i >> target_qubit) % 2 == 0]
    states_1 = [i for i in range(num_states) if (i >> target_qubit) % 2 == 1]

    # Calculate the total probabilities of measuring 0 or 1 for the target qubit
    prob_0 = sum(probabilities[state] for state in states_0)[0]
    prob_1 = sum(probabilities[state] for state in states_1)[0]

    # Randomly choose the measurement outcome based on the probabilities
    outcome = np.random.choice([0, 1], p=[prob_0, prob_1])

    # Update the state vector based on the measurement outcome
    if outcome == 0:
      for state in states_1:
        self.state.state[0][state] = 0
      normalization_factor = sqrt(prob_0)
    else:
      for state in states_0:
        self.state.state[0][state] = 0
      normalization_factor = sqrt(prob_1)

    # Normalize the state vector
    #self.state /= normalization_factor

    return outcome

  def set_qubit(self, target_qubit, state):
    #state = state.strip('|><')
    if state != '0' or state != '1' or state != '+' or state != '-' or state != 'i' or state != '-i':
      raise ValueError(f"state {state} not recognized")
    print(f"set qubit {target_qubit} to state |{state}> (not implemented)")

  def list_states(self, percentages = True):
    num_states = len(self.state)
    states_with_probabilities = {}
    for i in range(num_states):
      # Convert the state index to binary representation
      binary_state = bin(i)[2:].zfill(self.num_qubits)
      # Calculate the probability of the current state
      probability = self.state[i]
      #print(i, binary_state, probability)
      # Store the state and its probability in the dictionary
      if probability != 0 or real(probability) > 1e-10 or real(probability) < -1e-10 or imag(probability) > 1e-10 or imag(probability) < -1e-10:
        if percentages:
          probability = abs(probability)**2
        states_with_probabilities[binary_state] = probability
    return states_with_probabilities

  def probabilities(self, percentages = True):
    num_states = len(self.state)
    if percentages:
      probabilities = zeros(self.num_qubits, 2)  # Initialize probabilities array
    else:
      probabilities = zeros(self.num_qubits, 2)
    for i in range(num_states):
      # Convert the state index to binary representation
      binary_state = bin(i)[2:].zfill(self.num_qubits)
      # Calculate the probability of the current state
      probability = self.state[i]
      if percentages:
        probability = abs(probability)**2
      #probability = self.state[i]
      # Update the probabilities array for each qubit
      for j in range(self.num_qubits):
        qubit_state = int(binary_state[j])
        probabilities[j, qubit_state] += probability
    return probabilities

def print_gate(name, gate):
  print(name, ':')
  #print(' |00>      |01>      |10>      |11>       ')
  for ga in gate:
    print('[', end='')
    for i, ket in enumerate(ga):
      if i > 0:
        print(', ', end='')
      if abs(real(ket)) > 0.01 or abs(imag(ket)) > 0.01:
        print('\033[31m', end='')
      print("{0:+2.0f}{1:+2.0f}".format(real(ket), imag(ket)), end='')
      if abs(real(ket)) > 0.01 or abs(imag(ket)) > 0.01:
        print('\033[0m', end='')
    print(']')

# Pauli-X gate (bit-flip gate)
X = Matrix([[0, 1], 
              [1, 0]])

# Pauli-Y gate
Y = Matrix([[0, -I], 
              [I, 0]])

# Pauli-Z gate
Z = Matrix([[1,0], 
              [0,-1]])

# Hadamard gate
H = Matrix([[sqrt(2) / 2,sqrt(2) / 2], 
              [sqrt(2) / 2,-sqrt(2) / 2]])

# Phase gate (S, P)
S = Matrix([[1, 0],
              [0, I]])

SD = Matrix([[1, 0],
               [0, -I]])

# pi / 8 gate (T)
T = Matrix([[1, 0],
              [0, exp(I * pi / 4)]])

TD = Matrix([[1, 0],
               [0, exp(-I * pi / 4)]])

SWAP = Matrix([[1, 0, 0, 0],
                 [0, 0, 1, 0],
                 [0, 1, 0, 0],
                 [0, 0, 0, 1]])

# Identity gate
EYE = Matrix([[1, 0],
              [0, 1]])

GATES = {
  'X': X,
  'Y': Y,
  'Z': Z,
  'H': H,
  'S': S,
  'SD': SD,
  'T': T,
  'TD': TD,
  'EYE': EYE,
#  'SWAP': SWAP,
}

GATES = cm.GATES

def quantum_parse(quantum_code):
  instructions = quantum_code.strip().split("\n")
  parsed_instructions = []
  for instruction in instructions:
    instruction = instruction.strip()
    if instruction != "" and not instruction.startswith('#'):
      parsed_instruction = instruction.split(" ")
      parsed_instructions.append(parsed_instruction)
  return parsed_instructions

def quantum_interpreter(quantum_code, qreg=None):
  global last_time
  parsed_instructions = quantum_parse(quantum_code)
  if qreg is None:
    num_qubits = int(parsed_instructions[0][1])
    qreg = QuantumRegister(num_qubits)
    print(f"initialized qreg with {num_qubits} qubits")
  parsed_instructions = parsed_instructions[1:]
  for instruction in parsed_instructions:
    instruction[0] = instruction[0].upper()
    print(f"{time.time() - start_time:8.3f}s:", *instruction)
    cmd = re.findall(r'(C*)(\d*)([A-Za-z_]+)(\d*)', instruction[0])
    if cmd == []:
      cmd = ('', '', '', '')
    else:
      cmd = cmd[0]
    if cmd[2] in GATES.keys():
      control_qubits = []
      target_qubits = []
      if cmd[1] != '':
        num_control_qubits = int(cmd[1])
      else:
        num_control_qubits = len(cmd[0])
      for q in range(num_control_qubits):
        control_qubits.append(int(instruction[q + 1]))
      if cmd[3] != '':
        num_target_qubits = int(cmd[3])
      else:
        num_target_qubits = 1
      for q in range(num_target_qubits):
        target_qubits.append(int(instruction[q + num_control_qubits + 1]))
      #target_qubit = int(instruction[num_control_qubits + 1])
      qreg.apply_gate(GATES[cmd[2]], target_qubits, control_qubits)
    elif instruction[0] == 'QREG':
      num_qubits = int(instruction[1])
      qreg = QuantumRegister(num_qubits)
      print(f"initialized qreg with {num_qubits} qubits")
    elif instruction[0] == 'SET':
      q = int(instruction[1])
      s = instruction[2]
      qreg.set_qubit(q, s)
    elif instruction[0] == 'Q':
      filename = instruction[1].lower()
      if not os.path.isfile(filename):
        filename += ".q"
      if os.path.isfile(filename):
        with open(filename) as file:
          code = file.read()
          code = code.format(*instruction[2:])
        qreg = quantum_interpreter(code, qreg)
    #elif instruction[0] == 'PEEK':
    #  probabilities = qreg.probabilities()
    #  states = qreg.list_states()
    #  print("qubit probabilities:", *probabilities)
    #  print("qreg probabilities:", states)
    #elif instruction[0] == 'PEEK_Q':
    #  probabilities = qreg.probabilities(percentages=False)
    #  probabilities_p = qreg.probabilities()
    #  print("qubit probabilities:")
    #  for q, q_p in zip(probabilities, probabilities_p):
    #    print("[{0:+.2f}{1:+.2f}j, {2:+.2f}{3:+.2f}j] ({4:3.0f}%, {5:3.0f}%)".format(real(q[0]), imag(q[0]), real(q[1]), imag(q[1]), q_p[0] * 100, q_p[1] * 100))
    #elif instruction[0] == 'PEEK_S':
    #  states = qreg.list_states(percentages=False)
    #  states_p = qreg.list_states()
    #  print("qreg probabilities:")
    #  for s in states:
    #    print("|{0}>: [{1:+.2f}{2:+.2f}] ({3:6.2f}%)".format(s, real(states[s]), imag(states[s]), states_p[s] * 100))
    elif instruction[0] == 'STATE':
      print(qreg.state)
      #for i in range(2**qreg.num_qubits):
      #  print('|{:08b}>'.format(i), end='')
      #print()
      #print('[', end='')
      #for i, s in enumerate(qreg.state):
      #  #print(s)
      #  if i > 0:
      #    print(', ', end='') 
      #  print("{0:+4.1f}{1:+4.1f}".format(real(s), imag(s)), end='')
      #print(']')
    elif instruction[0] == 'MEASURE':
      result = qreg.measure()
      print("measured state: |{}>".format(result))
    elif instruction[0] == 'MEASURE_Q':
      result = qreg.measure_qubit(int(instruction[1]))
      print("measured state: |{}>".format(result))
    elif instruction[0] == 'TIME':
      this_time = time.time()
      print("time: ", this_time - last_time)
      last_time = this_time
    elif instruction[0] == 'PRINT':
      print(*instruction[1:])
    elif instruction[0] == 'CONSOLE':
      while True:
        print(f'{time.time() - start_time:8.3f}s: Q{qreg.num_qubits}> ', end='')
        instruction = input().upper()
        if instruction.startswith('BREAK'):
          break
        elif instruction.startswith('EXIT'):
          exit()
        qreg = quantum_interpreter(f'QREG 8\n{instruction}', qreg)
    elif os.path.isfile(instruction[0].lower() + ".q"):
      with open(instruction[0].lower() + ".q") as file:
        code = file.read()
        code = code.format(*instruction[1:])
      qreg = quantum_interpreter(code, qreg)
    else:
      print("No command", instruction[0])
  return qreg

quantum_code = "QREG 4\nCONSOLE\nMEASURE"
if len(sys.argv) > 1:
  filename = sys.argv[1]
  if os.path.isfile(filename):
    with open(filename) as file:
      quantum_code = file.read()

result = quantum_interpreter(quantum_code)
