package de.christofreichardt.scala.krypto.algorithms

import de.christofreichardt.scala.krypto.DLProblem
import de.christofreichardt.scala.krypto.Constants
import de.christofreichardt.diagnosis.AbstractTracer
import de.christofreichardt.diagnosis.TracerFactory
import de.christofreichardt.diagnosis.LogLevel
import scala.annotation.tailrec
import scala.collection.immutable.StreamIterator
import scala.util.Random

abstract class DLAlgorithm(dlProblem: DLProblem) extends Algorithm[DLProblem, Option[BigInt]](dlProblem){
  val a = dlProblem.a
  val b = dlProblem.b
  val p = dlProblem.p
  val order = dlProblem.order

  require(a > 0  &&  b > 0  &&  p > 0  &&  order > 0, dlProblem)
  require(p.isProbablePrime(Constants.certainty), "Provide a prime modulus.")
  require(a.modPow(order, p) == BigInt(1), "Incorrect order")
  require(a < p  &&  b < p  &&  order < p, dlProblem)
  
  override def toString: java.lang.String = a + "^x = " + b + " (mod " + p + "), ord(" + a + ") = " + order
  
  override def crossCheck(): Boolean = {
    withTracer("Boolean", this, "crossCheck()") {
      if (outcome.isEmpty) throw new IllegalArgumentException("Algorithm hasn't found any solution for: " + this)
      else a.modPow(outcome.get, p) == b
    }
  }
  
  override def getCurrentTracer(): AbstractTracer = {
    try {
      TracerFactory.getInstance().getTracer("TestTracer")
    }
    catch {
      case ex: TracerFactory.Exception => TracerFactory.getInstance().getDefaultTracer
    }
  }
}

class EnumerationAlgo(dlProblem: DLProblem) extends DLAlgorithm(dlProblem) {
  require(order < Short.MaxValue)
  
  def calculate(): Option[BigInt] = {
    withTracer("Option[BigInt]", this, "calculate()") {
      val range = Range(0, order.toInt)
      val solution = range.view.indexWhere(x => a.modPow(BigInt(x),p) == b)
      if (solution == -1) None
      else Some(solution)
    }
  }
}

class BabyStepGiantStep(dlProblem: DLProblem) extends DLAlgorithm(dlProblem) {
  val m = BigInt(scala.math.ceil(scala.math.sqrt(order.toDouble)).toLong)
  def giantStep(i: BigInt) = (a.modPow(-m*i, p)*b).mod(p)
  
  def calculate(): Option[BigInt] = {
    withTracer("Option[BigInt]", this, "calculate()") {
      val tracer = getCurrentTracer
      val range = BigInt(0).until(m)
      val babySteps = range.map(j => (a.modPow(j, p),j)).toMap
      tracer.out().printfIndentln("babySteps = %s", babySteps);
      val i = range.find(i => babySteps.contains(giantStep(i)))
      if (i.isDefined) {
        val key = giantStep(i.get)
	      val j = babySteps(key)
	      Some(i.get*m + j)
      }
      else
        None
    }
  }
  
  override def getCurrentTracer(): AbstractTracer = TracerFactory.getInstance().getDefaultTracer()
}

class PollardRho(dlProblem: DLProblem) extends DLAlgorithm(dlProblem) {
  val d1 = p/3
  val d2 = p*2/3
  
  def quadrupleSequence(x0: BigInt): Stream[Tuple4[BigInt,BigInt,BigInt,BigInt]] = {
    withTracer("Stream[Tuple3[BigInt,BigInt,BigInt]]", this, "quadrupleSequence(x0: BigInt)") {
      val tracer = getCurrentTracer
      tracer.out().printfIndentln("x0 = %s", x0);
      val quadruple = (BigInt(0), a.modPow(x0, p), x0, BigInt(0))
      tracer.out().printfIndentln("quadruple = %s", quadruple);
      Stream.cons(quadruple, tripleSequence(quadruple))
    }
  }
  
  private def tripleSequence(quadruple: Tuple4[BigInt,BigInt,BigInt,BigInt]): Stream[Tuple4[BigInt,BigInt,BigInt,BigInt]] = {
    val tracer = getCurrentTracer
//    tracer.out().printfIndentln("quadruple = %s", quadruple);
    val index = quadruple._1 + 1
    val nextQuadruple =
      if (quadruple._2 < d1) {
        val betaNext = (a * quadruple._2).mod(p)
        val xNext = (quadruple._3 + 1).mod(p - 1)
        val yNext = quadruple._4
        (index, betaNext, xNext, yNext)
      } 
      else if (quadruple._2 < d2) {
        val betaNext = (quadruple._2 * quadruple._2).mod(p)
        val xNext = (2 * quadruple._3).mod(p - 1)
        val yNext = (2 * quadruple._4).mod(p - 1)
        (index, betaNext, xNext, yNext)
      } 
      else {
        val betaNext = (b * quadruple._2).mod(p)
        val xNext = quadruple._3
        val yNext = (quadruple._4 + 1).mod(p - 1)
        (index, betaNext, xNext, yNext)
      }
    Stream.cons(nextQuadruple, tripleSequence(nextQuadruple))
  }
  
  def searchForMatch(quadruples: StreamIterator[Tuple4[BigInt,BigInt,BigInt,BigInt]], iterations: BigInt): Tuple2[BigInt,BigInt] = {
//    withTracer("Tuple2[BigInt,BigInt]", this, "searchForMatch(triples: Stream[Tuple4[BigInt,BigInt,BigInt,BigInt]], iterations: BigInt)") {
      val tracer = getCurrentTracer
      val firstQuadruple = quadruples.next
      tracer.out().printfIndentln("firstQuadruple = %s", firstQuadruple);
      tracer.out().printfIndentln("iterations = %s", iterations);
      val hit = quadruples.takeWhile(quadruple => {
//        tracer.out().printfIndentln("quadruple = %s", quadruple);
        quadruple._1 < iterations - 1
      }).find(quadruple => quadruple._2 == firstQuadruple._2)
      if (hit.isDefined) (firstQuadruple._3 - hit.get._3, hit.get._4 - firstQuadruple._4)
//      else if (iterations > 200) (-1,-1)
      else searchForMatch(quadruples, iterations*2)
//    }
  }
  
  def calculate(): Option[BigInt] = {
    withTracer("Option[BigInt]", this, "calculate()") {
      val tracer = getCurrentTracer
      val x0 = (BigInt(p.bitLength, new Random) + 1).mod(p)
      val hit = searchForMatch(new StreamIterator(quadrupleSequence(x0)), 1)
      tracer.out().printfIndentln("hit = %s", hit);
      assert(a.modPow(hit._1, p) == b.modPow(hit._2, p), "Wrong exponents.")
      val linearCongruence = new LinearCongruence(hit._2, hit._1, order)
      assert(linearCongruence.crossCheck, "Computation of linear congruence failed.")
      assert(!linearCongruence.outcome.isEmpty, "Empty solution set for linear congruence.")
      linearCongruence.outcome.find(x => {
        tracer.out().printfIndentln("x = %s", x);
        a.modPow(x,p) == b
      })
    }
  }
}

abstract class PohligHellmanAlgo[T <: DLAlgorithm](dlProblem: DLProblem) extends DLAlgorithm(dlProblem) {
  lazy val primeFactors: List[Tuple2[BigInt,Int]] = factorizeOrder
  lazy val equations: List[Tuple2[List[BigInt],Tuple2[BigInt,BigInt]]] = buildEquations
  lazy val listOfcoefficients = equations.unzip._1
  
  def createAlgorithmInstance(dlProblem: DLProblem): T
  
  def factorizeOrder: List[Tuple2[BigInt, Int]] = {
    withTracer("List[Tuple2[BigInt,Int]]", this, "factorizeOrder()") {
      val tracer = getCurrentTracer()
      val primeFactorization = new PrimeFactorization(order)
      assert(primeFactorization.crossCheck || primeFactorization.productCheck, "Prime factorization of " + order + " failed.")
      tracer.out.printfIndentln("primeFactorization(%s).outcome = %s", order, primeFactorization.outcome)
      primeFactorization.outcome.reverse
    }
  }
  
  def buildEquations: List[Tuple2[List[BigInt],Tuple2[BigInt,BigInt]]] = {
    
    def computeCoefficients(prime: BigInt, index: Int, c: BigInt): Tuple2[List[BigInt],BigInt] = {
      withTracer("Tuple2[List[BigInt],BigInt]", this, "computeCoefficients(prime: BigInt, index: Int, c: BigInt)") {
        val tracer = getCurrentTracer()
        tracer.out.printfIndentln("prime = %s", prime)
        tracer.out.printfIndentln("index = %d", int2Integer(index))
        tracer.out.printfIndentln("c = %s", c)
        
        if (index == 0) {
          val B = b.modPow(order/prime, p)
          val subordinateDLAlgorithm = createAlgorithmInstance(new DLProblem(c, B, p, prime))
          tracer.out.printfIndentln("subordinateDLAlgorithm = %s", subordinateDLAlgorithm)
          if (subordinateDLAlgorithm.outcome.isDefined) (List(subordinateDLAlgorithm.outcome.get), subordinateDLAlgorithm.outcome.get)
          else throw new java.lang.ArithmeticException(c + "^x = " + B + " [mod " + p + "] has no solution.")
        }
        else {
          val previousCoefficients = computeCoefficients(prime, index - 1, c)
          
//          tracer.out.printfIndentln("previousCoefficients._2 = %s, %s", previousCoefficients._2, int2Integer(previousCoefficients._2.intValue()))
          
          val exp = order/prime.pow(index + 1)
          val inverseA = a.modPow(previousCoefficients._2, p).modInverse(p)
          val B = (b*inverseA).modPow(exp, p)
          val subordinateDLAlgorithm = createAlgorithmInstance(new DLProblem(c, B, p, prime))
          tracer.out.printfIndentln("subordinateDLAlgorithm = %s", subordinateDLAlgorithm)

          if (subordinateDLAlgorithm.outcome.isDefined)
            (subordinateDLAlgorithm.outcome.get :: previousCoefficients._1, prime.pow(index)*subordinateDLAlgorithm.outcome.get + previousCoefficients._2)
          else
            throw new java.lang.ArithmeticException(c + "^x = " + B + " [mod " + p + "] has no solution.")
        }
      }
    }
    
    withTracer("List[Tuple2[List[BigInt],Tuple2[BigInt,BigInt]]]", this, "buildEquations()") {
      for {
        factor <- primeFactors
        c = a.modPow(order/factor._1, p)
        coefficients = computeCoefficients(factor._1, factor._2 - 1, c)
        modulus = factor._1.pow(factor._2)
      } yield (coefficients._1, (coefficients._2, modulus))
    }
  }
  
  def calculate() : Option[BigInt] = {
    withTracer("Option[BigInt]", this, "calculate()") {
      val tracer = getCurrentTracer()
      try {
	      if (equations.length > 1) {
	        val chineseRemainderTheorem = new ChineseRemainderTheorem(equations.unzip._2)
	        Some(chineseRemainderTheorem.outcome)
	      }
	      else
	        Some(equations.unzip._2.head._1)
      }
      catch {
        case ex:java.lang.ArithmeticException => {
          tracer.logException(LogLevel.ERROR, ex, getClass(), "calculate()")
          None
        }
      }
    }
  }
}

class PohligHellmanWithEnumeration(dlProblem: DLProblem) extends PohligHellmanAlgo[EnumerationAlgo](dlProblem) {
  def createAlgorithmInstance(dlProblem: DLProblem) = new EnumerationAlgo(dlProblem)
}

class PohligHellmanWithBSGS(dlProblem: DLProblem) extends PohligHellmanAlgo[BabyStepGiantStep](dlProblem) {
  def createAlgorithmInstance(dlProblem: DLProblem) = new BabyStepGiantStep(dlProblem)
}

class PohligHellmanWithPollardRho(dlProblem: DLProblem) extends PohligHellmanAlgo[PollardRho](dlProblem) {
  def createAlgorithmInstance(dlProblem: DLProblem) = new PollardRho(dlProblem)
}