import casperfpga
import numpy as np
import matplotlib.pyplot as plt

"""
  Use `rfdc.set_fine_mixer_freq` to change the NCO frequency for both ADC_TILE and DAC_TILE.
  See the doc string in caserfpga's rfdc.py file for more information for input parameters
"""
def set_dac_nco(fpga, nco_freq):
  rfdc = fpga.adcs.rfdc
  for i in range(0,4):
    for j in range(0,4):
      rfdc.set_fine_mixer_freq(i, j, rfdc.DAC_TILE, nco_freq)

def set_adc_nco(fpga, nco_freq):
  rfdc = fpga.adcs.rfdc
  for i in range(0,4):
    rfdc.set_fine_mixer_freq(3, i, rfdc.ADC_TILE, nco_freq)

""" The following helper functions are only to facilitate the visualization in nco.py """

def set_dac_src(fpga, mode=0):
  """ 0: CW, 1: pulsed complex constant, 2: OFF """
  waveform_registers = ['s00_src_sel', 's02_src_sel', 's10_src_sel', 's12_src_sel']
  for r in waveform_registers:
    fpga.registers[r].write(reg=mode)

def arm_snapshots(fpga):
  for ss in fpga.snapshots:
    ss.arm()

def trigger_snapshot(fpga):
  arm_snapshots(fpga)
  fpga.registers.snapshot_ctrl.write(reg=1)
  fpga.registers.snapshot_ctrl.write(reg=0);

def extract_snapshot(fpga):
  calibration_snapshots = ['m20_ss', 'm22_ss', 'm30_ss', 'm32_ss']

  data_field_name = 'd'
  num_ss = len(calibration_snapshots)
  wordsize_bytes = 16
  num_samples = 16384

  # four calibration snapshots, each 4096 deep
  X = np.zeros((num_ss,num_samples), dtype=np.complex128)

  for i,css in enumerate(calibration_snapshots):
    ss = fpga.snapshots[css]

    # extract raw data and extract format string of bytes formatted words
    raw = ss.read(arm=False)['data'];
    blob = b"".join(x.to_bytes(wordsize_bytes, 'little', signed=False) for x in raw[data_field_name])

    # group real and imaginary values
    x = np.frombuffer(blob, dtype=np.int32);

    # view the data as complex-valued
    x_cmplx = x.view(np.int16).astype(np.float64).view(np.complex128);

    X[i,:] = x_cmplx

  return X

def get_adc_samples(fpga):
  trigger_snapshot(fpga)
  return extract_snapshot(fpga)

def compute_PSD(x, nfft=512):
  xmat = x.reshape(x.shape[0],x.shape[1]//nfft, nfft)
  Xfft = np.fft.fft(xmat, nfft, axis=2)
  Xpsd = np.mean(np.real(Xfft * np.conj(Xfft)), axis=1)
  return Xpsd

def plot_PSD(x, nco_freq=0, nfft=512):
  Xpsd = compute_PSD(x)
  bins = np.arange(-nfft//2, nfft//2)
  fadc = 2000
  D = 2
  fs = fadc/D
  df = fs/nfft
  faxis = df*bins - nco_freq
  plt.plot(faxis, 10*np.log10(np.fft.fftshift(Xpsd.T)))
  plt.xlabel('Frequency (MHz)') 
  plt.ylabel('Power (arb. dB)')
  plt.grid()
  plt.show()

