package ru.autosome.ytilib;

import ru.autosome.assist.ASequence;

public class Sequence extends ASequence {
  public static final byte IUPACOUNT = 15;

  public static final byte N = 14;

  public Sequence(Sequence sour) {
    weight = sour.weight;
    direct = sour.direct;
    revcomp = sour.revcomp;

    hdirect = new double[this.direct.length];
    hrevcomp = new double[this.revcomp.length];
  }

  public Sequence(String sour) {
    weight = 1.0;
    sour = sour.toUpperCase();
    direct = new byte[sour.length()];
    for (int i = 0; i < sour.length(); i++) {
      direct[i] = char2byte(sour.charAt(i));
      if (direct[i] >= IUPACOUNT)
        throw new RuntimeException("found unknown symbol '" + sour.charAt(i) + "'; cannot convert sequence");
    }

    revcomp = revcomp();

    hdirect = new double[this.direct.length];
    hrevcomp = new double[this.revcomp.length];
  }

  public Sequence(String sour, double weight) {
    this(sour);
    this.weight = weight;
  }

  public Sequence(Sequence seqc, double weight) {
    this.weight = weight;
    this.direct = seqc.direct;
    this.revcomp = seqc.revcomp;

    hdirect = new double[this.direct.length];
    hrevcomp = new double[this.revcomp.length];
  }

  public Sequence(byte[] cdirect, byte[] crevcomp, double weight) {
    this.weight = weight;
    this.direct = cdirect;
    this.revcomp = crevcomp;

    hdirect = new double[this.direct.length];
    // TODO: beautify RNASequence usage - when no revcomp is present
    hrevcomp = new double[this.revcomp == null ? 0 : this.revcomp.length];
  }

  public byte[] revcomp() {
    byte[] strand2 = new byte[direct.length];
    for (int i = 0; i < direct.length; i++)
      strand2[strand2.length - i - 1] = REVCOMP[direct[i]];
    return strand2;
  }

  protected static final char[] BYTE2CHAR = new char[]{'A', 'C', 'G', 'T', 'R', 'Y', 'K', 'M', 'S', 'W', 'B', 'D', 'H', 'V', 'N'};

  protected byte char2byte(char c) {
    switch (c) {
      case 'A':
        return 0;
      case 'C':
        return 1;
      case 'G':
        return 2;
      case 'T':
        return 3;
      case 'U':
        return 3;

      case 'R':
        return 4; //AG
      case 'Y':
        return 5; //CT
      case 'K':
        return 6; //GT
      case 'M':
        return 7; //AC
      case 'S':
        return 8; //CG
      case 'W':
        return 9; //AT

      case 'B':
        return 10; // CGT
      case 'D':
        return 11; // AGT
      case 'H':
        return 12; // ACT
      case 'V':
        return 13; // ACG

      case 'N':
        return Sequence.N; // ACGT
    }
    throw new RuntimeException("found unknown symbol '" + c + "'; cannot convert sequence");
  }

  protected static final byte[] REVCOMP = new byte[]{3, 2, 1, 0, 5, 4, 7, 6, 8, 9, 13, 12, 11, 10, 14};

  public static byte[] str2seq(String s) {
    return new Sequence(s).direct;
  }

  public static String seq2str(byte[] seq) {
    StringBuilder str = new StringBuilder();
    for (byte b : seq) {
      str.append(BYTE2CHAR[b]);
    }
    return str.toString();
  }

  public String word2str(byte[] seq, int offset, int length) {
    StringBuilder str = new StringBuilder();
    for (int i = offset; i < offset + length; i++) {
      str.append(BYTE2CHAR[seq[i]]);
    }
    return str.toString();
  }

  @Override public Sequence copy() {
    byte[] cdirect = new byte[this.direct.length];
    byte[] crevcomp = new byte[this.revcomp.length];
    java.lang.System.arraycopy(direct, 0, cdirect, 0, direct.length);
    java.lang.System.arraycopy(revcomp, 0, crevcomp, 0, revcomp.length);

    return new Sequence(cdirect, crevcomp, weight);
  }

  public static double[] background(Sequence[] sequences) {

    double[] probs = new double[15];

    int total_length = 0;
    for (Sequence sequence : sequences) {
      total_length += sequence.direct.length;
      for (int i = 0; i < sequence.direct.length; i++) {
        if (sequence.revcomp != null) {
          probs[sequence.direct[i]] += 0.5;
          probs[sequence.revcomp[i]] += 0.5;
        } else { // for RNASequence
          probs[sequence.direct[i]] += 1.0;
        } 
      }
    }

    // remake letter matrix counting IUPAC-s
    probs[0] += probs[14] / 4;
    probs[1] += probs[14] / 4;
    probs[2] += probs[14] / 4;
    probs[3] += probs[14] / 4;

    // V-s ACG
    probs[0] += probs[13] / 3;
    probs[1] += probs[13] / 3;
    probs[2] += probs[13] / 3;

    // H-s ACT
    probs[0] += probs[12] / 3;
    probs[1] += probs[12] / 3;
    probs[3] += probs[12] / 3;

    // D-s AGT
    probs[0] += probs[11] / 3;
    probs[2] += probs[11] / 3;
    probs[3] += probs[11] / 3;

    // B-s CGT
    probs[1] += probs[10] / 3;
    probs[2] += probs[10] / 3;
    probs[3] += probs[10] / 3;

    // W-s AT
    probs[0] += probs[9] / 2;
    probs[3] += probs[9] / 2;

    // S-s CG
    probs[1] += probs[8] / 2;
    probs[2] += probs[8] / 2;

    // M-s AC
    probs[0] += probs[7] / 2;
    probs[1] += probs[7] / 2;

    // K-s GT
    probs[2] += probs[6] / 2;
    probs[3] += probs[6] / 2;

    // Y-s CT
    probs[1] += probs[5] / 2;
    probs[3] += probs[5] / 2;

    // R-s AG
    probs[0] += probs[4] / 2;
    probs[2] += probs[4] / 2;

    for (int i = 0; i < 4; i++) {
      probs[i] /= total_length;
    }

    WPCM.iupacomprobs(probs);

    return probs;
  }

  public static final double[] uniformBackground = {0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25};
  public static double[] background(double gcp) {
    double[] background = new double[IUPACOUNT];
    background[1] = background[2] = gcp / 2;
    background[0] = background[3] = (1 - gcp) / 2;

    // remake IUPAC-s counting matrix letters

    // N-s
    background[14] = (background[0] + background[1] + background[2] + background[3]) / 4;

    // V-s ACG
    background[13] = (background[0] + background[1] + background[2])/ 3;

    // H-s ACT
    background[12] = (background[0] + background[1] + background[3]) / 3;

    // D-s AGT
    background[11] = (background[0] + background[2] + background[3]) / 3;

    // B-s CGT
    background[10] = (background[1] + background[2] + background[3]) / 3;

    // W-s AT
    background[9] = (background[0] + background[3]) / 2;

    // S-s CG
    background[8] = (background[1] + background[2]) / 2;

    // M-s AC
    background[7] = (background[0] + background[1]) / 2;

    // K-s GT
    background[6] = (background[2] + background[3]) / 2;

    // Y-s CT
    background[5] = (background[1] + background[3]) / 2;

    // R-s AG
    background[4] = (background[0] + background[2]) / 2;

    return background;
  }

}