package ru.autosome.ytilib;

import ru.autosome.assist.AMatrix;
import ru.autosome.di.ytilib.Din;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class WPCM extends AMatrix {

  public WPCM(int length, List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences, List<Integer> sequenceOrder, int sequenceCount) {

    this.length = length;

    double totalWeight = 0;
    for (int i = 0; i < sequenceCount; i++) {
      totalWeight += sequences[sequenceOrder.get(i)].weight;
    }
    this.N = totalWeight;

    matrix = new double[Sequence.IUPACOUNT][];
    for (int i = 0; i < Sequence.IUPACOUNT; i++) {
      matrix[i] = new double[length];
    }

    for (int i = 0; i < sequenceCount; i++) {

      int k = sequenceOrder.get(i);

      Sequence sequence = sequences[k];
      double hitWeight = sequence.weight / (prihits[k].size() + revhits[k].size());

      for (Integer prihit : prihits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.direct()[prihit + j]][j] += hitWeight;
        }
      }
      for (Integer revhit : revhits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.revcomp()[revhit + j]][j] += hitWeight;
        }
      }
    }

    processIUPAC();
  }

  public WPCM(int length, Sequence sequence, int shift) {
    this.length = length;

    double totalWeight = 1.0;
    matrix = new double[15][length];
    for (int i = 0; i < length; i++) {
      matrix[sequence.direct()[shift + i]][i] = totalWeight;
    }
    N = totalWeight;
    processIUPAC();
  }

  public WPCM(ru.autosome.di.ytilib.WPCM dim) {
    // NOTE: be careful, use for WPCMs only, not for PWMs
    this.length = dim.length() + 1;

    matrix = new double[15][length];
    double[][] dimatrix = dim.getMatrix();
    for (int i = 0; i < dim.length(); i++) {
      // A
      matrix[0][i] = dimatrix[(byte) Din.AA.ordinal()][i] + dimatrix[(byte)Din.AC.ordinal()][i] + dimatrix[(byte)Din.AG.ordinal()][i] + dimatrix[(byte)Din.AT.ordinal()][i];
      // C
      matrix[1][i] = dimatrix[(byte) Din.CA.ordinal()][i] + dimatrix[(byte)Din.CC.ordinal()][i] + dimatrix[(byte)Din.CG.ordinal()][i] + dimatrix[(byte)Din.CT.ordinal()][i];
      // G
      matrix[2][i] = dimatrix[(byte) Din.GA.ordinal()][i] + dimatrix[(byte)Din.GC.ordinal()][i] + dimatrix[(byte)Din.GG.ordinal()][i] + dimatrix[(byte)Din.GT.ordinal()][i];
      // T
      matrix[3][i] = dimatrix[(byte) Din.TA.ordinal()][i] + dimatrix[(byte)Din.TC.ordinal()][i] + dimatrix[(byte)Din.TG.ordinal()][i] + dimatrix[(byte)Din.TT.ordinal()][i];
    }

    // last position
    int i = dim.length()-1;
    // A
    matrix[0][i+1] = dimatrix[(byte) Din.AA.ordinal()][i] + dimatrix[(byte)Din.CA.ordinal()][i] + dimatrix[(byte)Din.GA.ordinal()][i] + dimatrix[(byte)Din.TA.ordinal()][i];
    // C
    matrix[1][i+1] = dimatrix[(byte) Din.AC.ordinal()][i] + dimatrix[(byte)Din.CC.ordinal()][i] + dimatrix[(byte)Din.GC.ordinal()][i] + dimatrix[(byte)Din.TC.ordinal()][i];
    // G
    matrix[2][i+1] = dimatrix[(byte) Din.AG.ordinal()][i] + dimatrix[(byte)Din.CG.ordinal()][i] + dimatrix[(byte)Din.GG.ordinal()][i] + dimatrix[(byte)Din.TG.ordinal()][i];
    // T
    matrix[3][i+1] = dimatrix[(byte) Din.AT.ordinal()][i] + dimatrix[(byte)Din.CT.ordinal()][i] + dimatrix[(byte)Din.GT.ordinal()][i] + dimatrix[(byte)Din.TT.ordinal()][i];

    N = dim.getN();
    processIUPAC();

  }

  public WPCM(double n, double[][] matrix) {
    this.length = matrix[0].length;
    N = n;
    this.matrix = matrix;
  }

  public WPCM(double[][] matrix, boolean processIUPAC) {
    this(matrix[0][0] + matrix[1][0] + matrix[2][0] + matrix[3][0], matrix);
    if (processIUPAC) processIUPAC();
  }

  public WPCM(WPCM base) {
    matrix = new double[base.matrix.length][];
    for (int i = 0; i < matrix.length; i++) {
      matrix[i] = base.matrix[i].clone();
    }
    this.length = base.length;
    this.N = base.N;
  }

  public static void iupacomprobs(double[] probs) {

    double pa = probs[0];
    double pc = probs[1];
    double pg = probs[2];
    double pt = probs[3];

    // N
    probs[14] = 0.25;
    // V-s ACG
    probs[13] = (pa + pc + pg) / 3;
    // H-s ACT
    probs[12] = (pa + pc + pt) / 3;
    // D-s AGT
    probs[11] = (pa + pg + pt) / 3;
    // B-s CGT
    probs[10] = (pc + pg + pt) / 3;
    // W-s AT
    probs[9] = (pa + pt) / 2;
    // S-s CG
    probs[8] = (pc + pg) / 2;
    // M-s AC
    probs[7] = (pa + pc) / 2;
    // K-s GT
    probs[6] = (pg + pt) / 2;
    // Y-s CT
    probs[5] = (pc + pt) / 2;
    // R-s AG
    probs[4] = (pa + pg) / 2;
  }

  public WPCM(int length, List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences, boolean discardWeights) {
    this.length = length;
    N = sequences.length;

    matrix = new double[Sequence.IUPACOUNT][];
    for (int i = 0; i < Sequence.IUPACOUNT; i++) {
      matrix[i] = new double[length];
    }

    for (int k = 0; k < sequences.length; k++) {

      Sequence sequence = sequences[k];
      double total_weight = ( discardWeights ? 1.0 : sequence.weight ) / (prihits[k].size() + revhits[k].size());

      for (Integer prihit : prihits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.direct()[prihit + j]][j] += total_weight;
        }
      }
      for (Integer revhit : revhits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.revcomp()[revhit + j]][j] += total_weight;
        }
      }
    }

    processIUPAC();
  }

  public WPCM(int length, List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences) {
    this.length = length;
    N = sequences.length;

    matrix = new double[Sequence.IUPACOUNT][];
    for (int i = 0; i < Sequence.IUPACOUNT; i++) {
      matrix[i] = new double[length];
    }

    for (int k = 0; k < sequences.length; k++) {

      Sequence sequence = sequences[k];
      double total_weight = sequence.weight / (prihits[k].size() + revhits[k].size());

      for (Integer prihit : prihits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.direct()[prihit + j]][j] += total_weight;
        }
      }
      for (Integer revhit : revhits[k]) {
        for (int j = 0; j < length; j++) {
          matrix[sequence.revcomp()[revhit + j]][j] += total_weight;
        }
      }
    }

    processIUPAC();
  }

  public WPCM(byte[][] alignment) {
    N = alignment.length;
    this.length = alignment[0].length;
    matrix = new double[Sequence.IUPACOUNT][];
    for (int i = 0; i < Sequence.IUPACOUNT; i++) {
      matrix[i] = new double[length];
    }
    // j - pos, i - letter
    for (byte[] word : alignment) {
      for (int j = 0; j < length; j++) {
        matrix[word[j]][j] += 1;
      }
    }

    processIUPAC();
  }

  public WPCM(double[][] matrix) {
    this(matrix, false);
  }

  public WPCM toPWM(double[] background) {
    double pseudocount = N == 1 ? Math.log(2) : Math.log(N);
    for (int j = 0; j < length; j++)
      for (int i = 0; i < Sequence.IUPACOUNT; i++)
        matrix[i][j] = Math.log((matrix[i][j] + (background[i] * pseudocount)) / ((N + pseudocount) * background[i]));
        //matrix[i][j] = Math.log((matrix[i][j] + (uniformBackground[i] * pseudocount)) / ((N + pseudocount) * uniformBackground[i]));    
    return this;
  }

  public void rebuild(List<Integer>[] prihits, List<Integer>[] revhits, Sequence[] sequences) {

    for (int i = 0; i < Sequence.IUPACOUNT; i++) {
      java.util.Arrays.fill(matrix[i], 0);
    }

    for (int k = 0; k < sequences.length; k++) {
      Sequence sequence = sequences[k];
      sequence.rebuild(this, prihits[k], revhits[k]);
    }

    processIUPAC();
  }

  private void processIUPAC() {

    // remake letter matrix counting IUPAC-s
    for (int j = 0; j < length; j++) {

      // N-s
      matrix[0][j] += matrix[14][j] / 4;
      matrix[1][j] += matrix[14][j] / 4;
      matrix[2][j] += matrix[14][j] / 4;
      matrix[3][j] += matrix[14][j] / 4;

      // V-s ACG
      matrix[0][j] += matrix[13][j] / 3;
      matrix[1][j] += matrix[13][j] / 3;
      matrix[2][j] += matrix[13][j] / 3;

      // H-s ACT
      matrix[0][j] += matrix[12][j] / 3;
      matrix[1][j] += matrix[12][j] / 3;
      matrix[3][j] += matrix[12][j] / 3;

      // D-s AGT
      matrix[0][j] += matrix[11][j] / 3;
      matrix[2][j] += matrix[11][j] / 3;
      matrix[3][j] += matrix[11][j] / 3;

      // B-s CGT
      matrix[1][j] += matrix[10][j] / 3;
      matrix[2][j] += matrix[10][j] / 3;
      matrix[3][j] += matrix[10][j] / 3;

      // W-s AT
      matrix[0][j] += matrix[9][j] / 2;
      matrix[3][j] += matrix[9][j] / 2;

      // S-s CG
      matrix[1][j] += matrix[8][j] / 2;
      matrix[2][j] += matrix[8][j] / 2;

      // M-s AC
      matrix[0][j] += matrix[7][j] / 2;
      matrix[1][j] += matrix[7][j] / 2;

      // K-s GT
      matrix[2][j] += matrix[6][j] / 2;
      matrix[3][j] += matrix[6][j] / 2;

      // Y-s CT
      matrix[1][j] += matrix[5][j] / 2;
      matrix[3][j] += matrix[5][j] / 2;

      // R-s AG
      matrix[0][j] += matrix[4][j] / 2;
      matrix[2][j] += matrix[4][j] / 2;

    }


    // remake IUPAC-s counting matrix letters
    for (int j = 0; j < length; j++) {

      // N-s
      matrix[14][j] = (matrix[0][j] + matrix[1][j] + matrix[2][j] + matrix[3][j]) / 4;

      // V-s ACG
      matrix[13][j] = (matrix[0][j] + matrix[1][j] + matrix[2][j])/ 3;

      // H-s ACT
      matrix[12][j] = (matrix[0][j] + matrix[1][j] + matrix[3][j]) / 3;

      // D-s AGT
      matrix[11][j] = (matrix[0][j] + matrix[2][j] + matrix[3][j]) / 3;

      // B-s CGT
      matrix[10][j] = (matrix[1][j] + matrix[2][j] + matrix[3][j]) / 3;

      // W-s AT
      matrix[9][j] = (matrix[0][j] + matrix[3][j]) / 2;

      // S-s CG
      matrix[8][j] = (matrix[1][j] + matrix[2][j]) / 2;

      // M-s AC
      matrix[7][j] = (matrix[0][j] + matrix[1][j]) / 2;

      // K-s GT
      matrix[6][j] = (matrix[2][j] + matrix[3][j]) / 2;

      // Y-s CT
      matrix[5][j] = (matrix[1][j] + matrix[3][j]) / 2;

      // R-s AG
      matrix[4][j] = (matrix[0][j] + matrix[2][j]) / 2;

    }

  }

  public static WPCM load(String arg) {
    List<Double> aw = new ArrayList<Double>();
    List<Double> cw = new ArrayList<Double>();
    List<Double> gw = new ArrayList<Double>();
    List<Double> tw = new ArrayList<Double>();

    try {
      BufferedReader bin = new BufferedReader(new InputStreamReader(new FileInputStream(arg)));
      String s;
      while ((s = bin.readLine()) != null) {
        if (s.length() != 0) {
          String[] tores = s.split(" ");
          if (tores.length != 4) continue;
          aw.add(Double.parseDouble(tores[0]));
          cw.add(Double.parseDouble(tores[1]));
          gw.add(Double.parseDouble(tores[2]));
          tw.add(Double.parseDouble(tores[3]));
        }
      }
    } catch (IOException e) {
      throw new RuntimeException("unable to load matrix from " + arg);
    }

    double[][] resmat = new double[Sequence.IUPACOUNT][aw.size()];

    for (int i = 0; i < aw.size(); i++) {
      resmat[0][i] = aw.get(i);
      resmat[1][i] = cw.get(i);
      resmat[2][i] = gw.get(i);
      resmat[3][i] = tw.get(i);
    }

    return new WPCM(resmat, true);
  }

  public double kdic(double[] background, int j) {

    double kdic = logFact(N);
    for (int i = 0; i < 4; i++) {
      kdic -= logFact(matrix[i][j]);
      kdic += matrix[i][j] * Math.log(background[i]);
    }

    kdic = -kdic / N;
    return kdic / KDIC_MAX;
  }

  public double kdic(double[] background) {

    double kdic = 0.0;
    for (int j = 0; j < length; j++) {
      kdic += kdic(background, j);
    }
    return kdic / length;
  }

  /*public double kdic(double[] background, int start, int end) {
    double kdic = 0.0;
    for (int j = start; j <= end; j++) {
      kdic += kdic(background, j);
    }
    return kdic;
  }*/

  public double thresholdLC() {
    double io = N / 6.0;
    double icdpart = (logFact(io * 2.0) * 2 + logFact(io) * 2 - logFact(N)) / N;
    double kpart = Math.log(0.25);
    return (icdpart - kpart) / KDIC_MAX;
  }

  public double thresholdHC() {
    double io = N / 3.0;
    double icdpart = (logFact(io) * 3 - logFact(N)) / N;
    double kpart = Math.log(0.25);
    return (icdpart - kpart) / KDIC_MAX;
  }

  public double kdic2of4() {
    double io = N / 2.0;
    double icdpart = (logFact(io) * 2 - logFact(N)) / N;
    double kpart = Math.log(0.25);
    return (icdpart - kpart) / KDIC_MAX;
  }

  public WPCM toPWM(double pseudocount, double[] background) {
    for (int j = 0; j < length; j++) {
      for (int i = 0; i < Sequence.IUPACOUNT; i++)
        matrix[i][j] = Math.log((matrix[i][j] + (background[i] * pseudocount)) / ((N + pseudocount) * background[i]));
    }
    return this;
  }

  public String consensus2(double[] background) {
    String strand1 = consensus1(background);
    String strand2 = revcomp(strand1);
    return strand1.compareTo(strand2) < 0 ? strand1 : strand2;
  }

  public String consensus1() {
    return consensus1(Sequence.uniformBackground);
  }

  public String consensus1(double[] background) {
    StringBuilder resc = new StringBuilder(length);
    for (int j = 0; j < length; j++) {
      double lk = this.kdic(background, j);
      int letn = 4;
      if (lk >= this.kdic2of4()) {
        // preferably 1, may be 2
        letn = 1;
      } else if (lk >= this.thresholdHC()) {
        // preferably 2, may be 3
        letn = 2;
      } else if (lk >= this.thresholdLC()) {
        // preferably 3, may be N
        letn = 3;
      }
      String toAppend = conslets(this.matrix[0][j], this.matrix[1][j], this.matrix[2][j], this.matrix[3][j], letn);
      if (letn > 2) toAppend = toAppend.toLowerCase(); // CHECK if this happens actually
      resc.append(toAppend);
    }
    return resc.toString();

  }

  private String conslets(double a, double c, double g, double t, int letn) {
    String[] letters = { "A", "C", "G", "T" }; 
    final HashMap<String, Double> letcon = new HashMap<String, Double>(4);
    letcon.put("A", a); letcon.put("C", c); letcon.put("G", g); letcon.put("T", t);
    
    java.util.Arrays.sort(letters, new Comparator<String>() {
      public int compare(String o1, String o2) {
        return letcon.get(o2).compareTo(letcon.get(o1));
      }
    });

    List<String> rcl = new ArrayList<String>(letn);
    int i = 0;
    while (i < letn || (i < 4 && letcon.get(letters[i]).equals(letcon.get(rcl.get(rcl.size()-1)))) ) {
      rcl.add(letters[i++]);
    }

    String[] rcls = rcl.toArray(new String[rcl.size()]);
    java.util.Arrays.sort(rcls);
    StringBuilder sbrcls = new StringBuilder(rcls.length);
    for (String s: rcls) { sbrcls.append(s); }

    return CONSENSUS.get(sbrcls.toString());
  }

  @Override
  public AMatrix makePWM(double[] background) {
    return new WPCM(this).toPWM(background);
  }

  public static final double KDIC_MAX = -Math.log(0.25); // 100% conservative column
}
