package ru.autosome.engine;

import ru.autosome.assist.ASequence;
import ru.autosome.assist.Conductor;
import ru.autosome.assist.WordRecord;
import ru.autosome.ytilib.*;

import java.lang.reflect.Array;
import java.util.*;

@SuppressWarnings("unchecked")
public class SoptiStep {

  private Sequence[] sequences;
  private double totalWeight;
  private double[] background;

  private List<Integer>[] primaryHits;
  private List<Integer>[] revcompHits;
  private double[] customShape;

  public List<Integer>[] getRevcompHits() {
    return revcompHits;
  }
  public List<Integer>[] getPrimaryHits() {
    return primaryHits;
  }

  protected WPCM pm;

  public SoptiStep(Sequence[] joinedSets, double weight, double[] background, int iterationLimit) {
    sequences = joinedSets;
    totalWeight = sequences.length;
    primaryHits = (List<Integer>[]) Array.newInstance(List.class, sequences.length);
    revcompHits = (List<Integer>[]) Array.newInstance(List.class, sequences.length);

    for (int i = 0; i < sequences.length; i++) {
      primaryHits[i] = new ArrayList<Integer>(17); // in memory of Sosinskiy the Great
      revcompHits[i] = new ArrayList<Integer>(17);
    }

    this.background = background;
    if (iterationLimit > GLOBAL_ITERATION_LIMIT) throw new RuntimeException("too big iteration limit");
    this.iterationLimit = iterationLimit;
    this.totalWeight = weight;
  }

  public SoptiStep(Sequence[] joinedSets, double weight, double[] background, int iterationLimit, double[] customShape) {
    this(joinedSets, weight, background, iterationLimit);
    this.customShape = customShape;
  }

  public static final int GLOBAL_ITERATION_LIMIT = 1000;

  private int iterationLimit = GLOBAL_ITERATION_LIMIT;
  private HashSet<Double> prevKDIC = new HashSet<Double>(iterationLimit, 1.0f);

  public MunkResult zoops(double zoopsFactor, Conductor conductor) {

    WPCM zpm = new WPCM(pm).toPWM(zoopsFactor*sequences.length, background);

    List<Integer> sequenceOrder = new ArrayList<Integer>();
    for (int i = 0; i < sequences.length; i++) {
      sequenceOrder.add(i);
    }

    final double[] bestScores = new double[sequences.length];
    for (int i = 0; i < sequences.length; i++) {
      bestScores[i] = primaryHits[i].size() > 0 ?
        zpm.score(sequences[i].getDirect(), primaryHits[i].get(0)) : zpm.score(sequences[i].getRevcomp(), revcompHits[i].get(0));
    }

    java.util.Collections.sort(sequenceOrder, new Comparator<Integer>() {
      public int compare(Integer o1, Integer o2) {
        return ((Double) bestScores[o2]).compareTo(bestScores[o1]);
      }
    });

    List<byte[]> wordsal = new ArrayList<byte[]>();
    HashMap<Integer, Integer> sequenceNumbers = new HashMap<Integer, Integer>();
    for (int i = 0; i < sequences.length; i++) {
      int index = sequenceOrder.get(i);
      for (Integer hit : primaryHits[index]) {
        wordsal.add(java.util.Arrays.copyOfRange(sequences[index].getDirect(), hit, hit + zpm.length()));
        sequenceNumbers.put(wordsal.size() - 1, i + 1);
      }
      for (Integer hit : revcompHits[index]) {
        wordsal.add(java.util.Arrays.copyOfRange(sequences[index].getRevcomp(), hit, hit + zpm.length()));
        sequenceNumbers.put(wordsal.size() - 1, i + 1);
      }
    }
    byte[][] words = wordsal.toArray(new byte[wordsal.size()][]);
    double[] deltaSI = new double[words.length];

    for (int i = 0; i < words.length; i++) {
      byte[][] partialWords = java.util.Arrays.copyOfRange(words, 0, i + 1);
      byte[] lastWord = partialWords[partialWords.length - 1];

      WPCM partialMatrix = new WPCM(partialWords).toPWM(zoopsFactor*partialWords.length, background);
      double deltasi = partialMatrix.score(lastWord) - zpm.score(lastWord);

      deltaSI[i] = deltasi;
      if (i > 0) deltaSI[i] += deltaSI[i - 1];
    }

    // need2find the max for deltaSI
    double maxDeltaSI = deltaSI[0];
    int sequenceCount = sequenceNumbers.get(0);
    for (int i = 1; i < deltaSI.length; i++) {
      if (deltaSI[i] > maxDeltaSI) {
        maxDeltaSI = deltaSI[i];
        sequenceCount = Math.max(sequenceCount, sequenceNumbers.get(i));
      }
    }

    conductor.message("ZOOPS detected motif in " + sequenceCount + " sequences of " + sequences.length);

    WPCM pm = new WPCM(zpm.length(), primaryHits, revcompHits, sequences, sequenceOrder, sequenceCount);
    WPCM pwm = new WPCM(pm).toPWM(background);

    List<WordRecord> wordList = extractWordList(pwm, sequenceOrder, sequenceCount);
    return new MunkResult(wordList, pm, background, sequenceCount, sequences.length, ASequence.medianLength(sequences));
  }

  // zoops mode based on sequences as the primitive instead of single words
  protected WPCM zoops2(double zoopsFactor) {

    WPCM zpm = new WPCM(pm).toPWM(zoopsFactor*sequences.length, background);

    List<Integer> sequenceOrder = new ArrayList<Integer>();
    for (int i = 0; i < sequences.length; i++) {
      sequenceOrder.add(i);
    }

    final double[] bestScores = new double[sequences.length];
    for (int i = 0; i < sequences.length; i++) {
      bestScores[i] = primaryHits[i].size() > 0 ?
        zpm.score(sequences[i].getDirect(), primaryHits[i].get(0)) : zpm.score(sequences[i].getRevcomp(), revcompHits[i].get(0));
    }

    java.util.Collections.sort(sequenceOrder, new Comparator<Integer>() {
      public int compare(Integer o1, Integer o2) {
        return ((Double) bestScores[o2]).compareTo(bestScores[o1]);
      }
    });

    double[] deltaSI = new double[sequences.length];

    for (int i = 0; i < sequences.length; i++) {
      WPCM partialMatrix = new WPCM(zpm.length(), primaryHits, revcompHits, sequences, sequenceOrder, i + 1).toPWM(zoopsFactor*(i+1), background);
      int k = sequenceOrder.get(i);
      double deltasi = partialMatrix.bestScore(sequences[k], primaryHits[k], revcompHits[k]) - bestScores[k];
      deltaSI[i] = deltasi;
      if (i > 0) deltaSI[i] += deltaSI[i - 1];
    }

    // need2find a max for deltaSI
    double maxDeltaSI = deltaSI[0];
    int sequenceCount = 1;
    for (int i = 1; i < deltaSI.length; i++) {
      if (deltaSI[i] >= maxDeltaSI) {
        maxDeltaSI = deltaSI[i];
        sequenceCount = Math.max(sequenceCount, i+1);
      }
    }

    return new WPCM(zpm.length(), primaryHits, revcompHits, sequences, sequenceOrder, sequenceCount);
  }

  public double optimize(WPCM base) {

    prevKDIC.clear();
    pm = (customShape == null) ? new WPCM(base).toPWM(background) : new ShapedWPCM(base, customShape).toPWM(background);

    pm.setN(totalWeight);

    while (true) {

      for (int i = 0; i < sequences.length; i++) {
        Sequence s = sequences[i];
        s.bestHits(pm, primaryHits[i], revcompHits[i]);
      }
      pm.rebuild(primaryHits, revcompHits, sequences);

      double newInfocod = pm.kdic(background);
      if (prevKDIC.contains(newInfocod) || prevKDIC.size() >= iterationLimit-1) {
        pm = new WPCM(pm);
        return newInfocod;
      }

      prevKDIC.add(newInfocod);

      pm.toPWM(background);
    }

  }

  private List<WordRecord> extractWordList(WPCM pm, List<Integer> sequenceOrder, int sequenceCount) {

    List<WordRecord> result = new ArrayList<WordRecord>(sequenceCount);

    // conductor.output("LIST", "#\tword\tscore\tstrand\tweight");
    for (int i = 0; i < sequenceCount; i++) {
      int k = sequenceOrder.get(i);

      int hitc = primaryHits[k].size() + revcompHits[k].size();
      for (Integer hit : primaryHits[k]) {
        byte[] word = java.util.Arrays.copyOfRange(sequences[k].getDirect(), hit, hit + pm.length());
        result.add(new WordRecord(k, hit, word, WordRecord.DIRECT, pm.score(word), sequences[k].weight / hitc));
      }
      for (Integer hit : revcompHits[k]) {
        byte[] word = java.util.Arrays.copyOfRange(sequences[k].getRevcomp(), hit, hit + pm.length());
        result.add(new WordRecord(k, sequences[k].getRevcomp().length - pm.length() - hit, word, WordRecord.REVCOMP, pm.score(word), sequences[k].weight / hitc));
      }
    }

    return result;
  }

  public List<WordRecord> extractWordList(WPCM pm) {

    List<WordRecord> result = new ArrayList<WordRecord>(sequences.length*2);

    for (int k = 0; k < sequences.length; k++) {

      int hitc = primaryHits[k].size() + revcompHits[k].size();

      for (Integer hit : primaryHits[k]) {
        byte[] word = java.util.Arrays.copyOfRange(sequences[k].getDirect(), hit, hit + pm.length());
        result.add(new WordRecord(k, hit, word, WordRecord.DIRECT, pm.score(word), sequences[k].weight / hitc));
      }
      for (Integer hit : revcompHits[k]) {
        byte[] word = java.util.Arrays.copyOfRange(sequences[k].getRevcomp(), hit, hit + pm.length());
        result.add(new WordRecord(k, sequences[k].getRevcomp().length - pm.length() - hit, word, WordRecord.REVCOMP, pm.score(word), sequences[k].weight / hitc));
      }
    }

    return result;
  }

}