package ru.autosome.di.ytilib;

import ru.autosome.assist.AMatrix;

import java.util.List;

public class WPCM extends AMatrix {

  public WPCM(int length, Sequence seq, int pos) {
    N = 1.0;
    matrix = new double[Din.dins][length];
    for (int i = 0; i < length; i++) {
      matrix[seq.direct()[pos+i]][i] = N;
    }
    this.length = length;
  }

  public WPCM(int length, List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences, List<Integer> sequenceOrder, int sequenceCount) {

    this.length = length;

    double totalWeight = 0;
    for (int i = 0; i < sequenceCount; i++) {
      totalWeight += sequences[sequenceOrder.get(i)].weight;
    }
    this.N = totalWeight;

    matrix = new double[Din.dins][];
    for (int i = 0; i < Din.dins; i++) {
      matrix[i] = new double[length];
    }

    for (int i = 0; i < sequenceCount; i++) {

      int k = sequenceOrder.get(i);

      Sequence sequence = sequences[k];
      double hitWeight = sequence.weight / (prihits[k].size() + revhits[k].size());

      for (Integer prihit : prihits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.direct()[prihit + j]][j] += hitWeight;
        }
      }
      for (Integer revhit : revhits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.revcomp()[revhit + j]][j] += hitWeight;
        }
      }
    }

    processNs();
  }

  public WPCM(double n, double[][] matrix) {
    this.length = matrix[0].length;
    N = n;
    this.matrix = matrix;
  }

 public WPCM(double n, double[][] matrix, boolean processIUPAC) {
    this(n, matrix);
    if (processIUPAC) processNs();
 }

  public WPCM(WPCM base) {
    matrix = new double[base.matrix.length][];
    for (int i = 0; i < matrix.length; i++) {
      matrix[i] = base.matrix[i].clone();
    }
    this.length = base.length;
    this.N = base.N;
  }

  public static void iupacomprobs(double[] probs) {

    // N
    probs[(byte)Din.NN.ordinal()] = 0.0625;

    fillProb(Din.AN, probs, Din.AA, Din.AC, Din.AG, Din.AT);
    fillProb(Din.CN, probs, Din.CA, Din.CC, Din.CG, Din.CT);
    fillProb(Din.GN, probs, Din.GA, Din.GC, Din.GG, Din.GT);
    fillProb(Din.TN, probs, Din.TA, Din.TC, Din.TG, Din.TT);

    fillProb(Din.NA, probs, Din.AA, Din.CA, Din.GA, Din.TA);
    fillProb(Din.NC, probs, Din.AC, Din.CC, Din.GC, Din.TC);
    fillProb(Din.NG, probs, Din.AG, Din.CG, Din.GG, Din.TG);
    fillProb(Din.NT, probs, Din.AT, Din.CT, Din.GT, Din.TT);

  }

  private static void fillProb(Din what, double[] probs, Din w1, Din w2, Din w3, Din w4) {
    probs[(byte)what.ordinal()] = ( probs[(byte)w1.ordinal()] + probs[(byte)w2.ordinal()] + probs[(byte)w3.ordinal()] + probs[(byte)w4.ordinal()]) / 4.0;
  }

  public WPCM(Integer length, List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences, boolean discardWeights) {
    this.length = length;
    N = sequences.length;

    matrix = new double[Din.dins][];
    for (int i = 0; i < Din.dins; i++) {
      matrix[i] = new double[length];
    }

    for (int k = 0; k < sequences.length; k++) {

      Sequence sequence = sequences[k];
      double total_weight = ( discardWeights ? 1.0 : sequence.weight ) / (prihits[k].size() + revhits[k].size());

      for (Integer prihit : prihits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.direct()[prihit + j]][j] += total_weight;
        }
      }
      for (Integer revhit : revhits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.revcomp()[revhit + j]][j] += total_weight;
        }
      }
    }

    processNs();
  }

  public WPCM(int length, List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences) {
    this.length = length;
    N = sequences.length;

    matrix = new double[Din.dins][];
    for (int i = 0; i < Din.dins; i++) {
      matrix[i] = new double[length];
    }

    for (int k = 0; k < sequences.length; k++) {

      Sequence sequence = sequences[k];
      double total_weight = sequence.weight / (prihits[k].size() + revhits[k].size());

      for (Integer prihit : prihits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.direct()[prihit + j]][j] += total_weight;
        }
      }
      for (Integer revhit : revhits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.revcomp()[revhit + j]][j] += total_weight;
        }
      }
    }

    processNs();
  }

  public WPCM(byte[][] alignment) {
    N = alignment.length;
    this.length = alignment[0].length;
    matrix = new double[Din.dins][];
    for (int i = 0; i < Din.dins; i++) {
      matrix[i] = new double[length];
    }
    // j - pos, i - letter
    for (byte[] word : alignment) {
      for (int j = 0; j < length; j++) {
        matrix[word[j]][j] += 1;
      }
    }

    processNs();
  }

  public WPCM toPWM(double[] background) {
    double pseudocount = N == 1 ? Math.log(2) : Math.log(N);

    for (int j = 0; j < length; j++)
      for (int i = 0; i < Din.dins; i++)
        matrix[i][j] = Math.log((matrix[i][j] + (background[i] * pseudocount)) / ((N + pseudocount) * background[i]));
        //matrix[i][j] = Math.log((matrix[i][j] + (uniformBackground[i] * pseudocount)) / ((N + pseudocount) * uniformBackground[i]));    
    return this;
  }

  public void rebuild(List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences) {

    for (int i = 0; i < Din.dins; i++) {
      java.util.Arrays.fill(matrix[i], 0);
    }

    for (int k = 0; k < sequences.length; k++) {
      Sequence sequence = sequences[k];
      sequence.rebuild(this, prihits[k], revhits[k]);
    }

    processNs();
  }

  private void processNs() {

   /* System.out.println("==");
    for (int j = 0; j < length; j++) {
      double sum = 0;
      for (byte b = (byte)Din.AA.ordinal(); b <= (byte)Din.NN.ordinal(); b++)
        sum += matrix[b][j];
      System.out.println(sum);
    }*/

    // remake letter matrix counting N-s
    for (int j = 0; j < length; j++) {

      // N-s
      double NA = matrix[(byte)Din.NA.ordinal()][j] / 4.0;
      matrix[(byte)Din.AA.ordinal()][j] += NA;
      matrix[(byte)Din.CA.ordinal()][j] += NA;
      matrix[(byte)Din.GA.ordinal()][j] += NA;
      matrix[(byte)Din.TA.ordinal()][j] += NA;

      double NC = matrix[(byte)Din.NC.ordinal()][j] / 4.0;
      matrix[(byte)Din.AC.ordinal()][j] += NC;
      matrix[(byte)Din.CC.ordinal()][j] += NC;
      matrix[(byte)Din.GC.ordinal()][j] += NC;
      matrix[(byte)Din.TC.ordinal()][j] += NC;

      double NG = matrix[(byte)Din.NG.ordinal()][j] / 4.0;
      matrix[(byte)Din.AG.ordinal()][j] += NG;
      matrix[(byte)Din.CG.ordinal()][j] += NG;
      matrix[(byte)Din.GG.ordinal()][j] += NG;
      matrix[(byte)Din.TG.ordinal()][j] += NG;

      double NT = matrix[(byte)Din.NT.ordinal()][j] / 4.0;
      matrix[(byte)Din.AT.ordinal()][j] += NT;
      matrix[(byte)Din.CT.ordinal()][j] += NT;
      matrix[(byte)Din.GT.ordinal()][j] += NT;
      matrix[(byte)Din.TT.ordinal()][j] += NT;

      // other side

      double AN = matrix[(byte)Din.AN.ordinal()][j] / 4.0;
      matrix[(byte)Din.AA.ordinal()][j] += AN;
      matrix[(byte)Din.AC.ordinal()][j] += AN;
      matrix[(byte)Din.AG.ordinal()][j] += AN;
      matrix[(byte)Din.AT.ordinal()][j] += AN;

      double CN = matrix[(byte)Din.CN.ordinal()][j] / 4.0;
      matrix[(byte)Din.CA.ordinal()][j] += CN;
      matrix[(byte)Din.CC.ordinal()][j] += CN;
      matrix[(byte)Din.CG.ordinal()][j] += CN;
      matrix[(byte)Din.CT.ordinal()][j] += CN;

      double GN = matrix[(byte)Din.GN.ordinal()][j] / 4.0;
      matrix[(byte)Din.GA.ordinal()][j] += GN;
      matrix[(byte)Din.GC.ordinal()][j] += GN;
      matrix[(byte)Din.GG.ordinal()][j] += GN;
      matrix[(byte)Din.GT.ordinal()][j] += GN;

      double TN = matrix[(byte)Din.TN.ordinal()][j] / 4.0;
      matrix[(byte)Din.TA.ordinal()][j] += TN;
      matrix[(byte)Din.TC.ordinal()][j] += TN;
      matrix[(byte)Din.TG.ordinal()][j] += TN;
      matrix[(byte)Din.TT.ordinal()][j] += TN;

      // NNs

      double NN = matrix[(byte)Din.NN.ordinal()][j] / 16.0;
      for (byte b = (byte)Din.AA.ordinal(); b <= (byte)Din.TT.ordinal(); b++)
        matrix[b][j] += NN;

    }

    // remake N-s counting matrix letters
    for (int j = 0; j < length; j++) {

      // N-s
      matrix[(byte)Din.NN.ordinal()][j] = N / 16;

      fillMatrix(j, Din.AN, matrix, Din.AA, Din.AC, Din.AG, Din.AT);
      fillMatrix(j, Din.CN, matrix, Din.CA, Din.CC, Din.CG, Din.CT);
      fillMatrix(j, Din.GN, matrix, Din.GA, Din.GC, Din.GG, Din.GT);
      fillMatrix(j, Din.TN, matrix, Din.TA, Din.TC, Din.TG, Din.TT);

      fillMatrix(j, Din.NA, matrix, Din.AA, Din.CA, Din.GA, Din.TA);
      fillMatrix(j, Din.NC, matrix, Din.AC, Din.CC, Din.GC, Din.TC);
      fillMatrix(j, Din.NG, matrix, Din.AG, Din.CG, Din.GG, Din.TG);
      fillMatrix(j, Din.NT, matrix, Din.AT, Din.CT, Din.GT, Din.TT);

    }

    // debug
    /*System.out.println("--");
    for (int j = 0; j < length; j++) {
      double sum = 0;
      for (byte b = (byte)Din.AA.ordinal(); b <= (byte)Din.TT.ordinal(); b++)
        sum += matrix[b][j];
      System.out.println(sum);
    }*/
    
  }

  private void fillMatrix(int j, Din what, double[][] matrix, Din w1, Din w2, Din w3, Din w4) {
    matrix[(byte)what.ordinal()][j] = ( matrix[(byte)w1.ordinal()][j] + matrix[(byte)w2.ordinal()][j] + matrix[(byte)w3.ordinal()][j] + matrix[(byte)w4.ordinal()][j]) / 4.0;
  }

  public double kdidic(double[] background, int j) {

    double kdidic = logFact(N);
    for (int i = 0; i <= 15; i++) {
      kdidic -= logFact(matrix[i][j]);
      kdidic += matrix[i][j] * Math.log(background[i]);
    }

    kdidic = -kdidic / N;
    return kdidic / KDIDIC_MAX;
  }

  public double kdidic(double[] background) {

    double kdidic = 0.0;
    for (int j = 0; j < length; j++) {

      kdidic += kdidic(background, j);
    }
    return kdidic / length;
  }

  /*public double kdidic(double[] background, int start, int end) {
    double kdidic = 0.0;
    for (int j = start; j <= end; j++) {
      kdidic += kdidic(background, j);
    }
    return kdidic;
  }

  public double thresholdLC0() {
    double r = logFact(N);
    double o = N / 24.0;
    double[] vs = {o, o, o, o, o, o, o, o, 2*o, 2*o, 2*o, 2*o, 2*o, 2*o, 2*o, 2*o};

    for (double v: vs) {
      r -= logFact(v);
      r += v * Math.log(0.0625);
    }
    return -r / N;
  }*/

  public double thresholdLC() {
    double r = logFact(N);
    double o = N / 14.0;
    double[] vs = {o, o, o, o, o, o, o, o, o, o, o, o, o, o, 0.0, 0.0};

    for (double v: vs) {
      r -= logFact(v);
      r += v * Math.log(0.0625);
    }
    double threshold = -r / N;

    return threshold / KDIDIC_MAX;
  }

 

  public WPCM toPWM(double pseudocount, double[] background) {
    for (int j = 0; j < length; j++) {
      for (int i = 0; i < Din.dins; i++)
        matrix[i][j] = Math.log((matrix[i][j] + (background[i] * pseudocount)) / ((N + pseudocount) * background[i]));
    }
    return this;
  }

  public static final double KDIDIC_MAX = -Math.log(0.0625); // 100% conservative column

  @Override
  public AMatrix makePWM(double[] background) {
    return new WPCM(this).toPWM(background);
  }
}
