package de.christofreichardt.scala.krypto.algorithms

import de.christofreichardt.diagnosis.AbstractTracer
import de.christofreichardt.diagnosis.TracerFactory

class ChineseRemainderTheorem(val equations: List[Tuple2[BigInt, BigInt]]) extends Algorithm[List[Tuple2[BigInt, BigInt]], BigInt](equations) {
  require(equations.length >= 2, "More than one equation required.")
  val moduli = equations.map(equation => equation._2)
  require(checkPairwisePrime(moduli))
  val modulus = moduli.foldLeft(BigInt(1))((a,b) => a*b)

  private def checkPairwisePrime(moduli: List[BigInt]): Boolean = {
    withTracer("Boolean", this, "checkPairwisePrime(moduli: List[BigInt])") {
      moduli match {
        case List(m) => true
        case m1 :: ms => {
          if (m1 >= BigInt(1)) {
	          if (ms.forall(m2 => m1.gcd(m2) == BigInt(1)))
	            checkPairwisePrime(ms)
	          else
	            false
          }
          else
            false
        }
      }
    }
  }
  
  def calculate: BigInt = {
    val coPrimes = moduli.map(m => modulus/m).toIndexedSeq
    val inversionEquations = coPrimes.zip(moduli)
    val inverses = inversionEquations.map(invEq => invEq._1.modInverse(invEq._2))
    val values = equations.map(equation => equation._1).toIndexedSeq
    val products = for (i <- 0 until equations.length) yield (coPrimes(i)*inverses(i)*values(i))
    products.foldLeft(BigInt(0))((a,b) => a + b).mod(modulus)
  }
  
  override def crossCheck(): Boolean = {
    equations.forall(equation => outcome.mod(equation._2) == equation._1.mod(equation._2))
  }
  
  override def getCurrentTracer(): AbstractTracer = {
    try {
      TracerFactory.getInstance().getTracer("TestTracer")
    }
    catch {
      case ex: TracerFactory.Exception => TracerFactory.getInstance().getDefaultTracer
    }
  }
}