package de.christofreichardt.scala.krypto.algorithms

import de.christofreichardt.diagnosis.AbstractTracer
import de.christofreichardt.diagnosis.TracerFactory

/**
 * Computes the solution set for the congruence a*x = b (mod n)
 */
class LinearCongruence(equation: Tuple3[BigInt,BigInt,BigInt]) extends Algorithm[Tuple3[BigInt,BigInt,BigInt], Set[BigInt]](equation) {
  val a = equation._1
  val b = equation._2.mod(equation._3)
  val n = equation._3
  val g = a.gcd(n)

  def calculate(): Set[BigInt] = {
    withTracer("Set[BigInt]", this, "calculate()") {
      val tracer = getCurrentTracer()
      tracer.out.printfIndentln("equation = %s", this.toString)
      tracer.out.printfIndentln("ggt(%s,%s) = %s", a,n,g)
      
      if (b % g == 0) {
      	require(g < BigInt(Short.MaxValue), "To many solutions")
	      val n0 = (n/g)
	      val x0 = (a/g).modInverse(n0)
	      val x1 = ((b/g)*x0).mod(n0)
	      val solutions = for (i <- 0 to (g.toInt - 1)) yield x1 + i*(n0)
	      solutions.toSet
      }
      else
        Set.empty
    }
  }
  
  override def crossCheck(): Boolean = {
    withTracer("Boolean", this, "crossCheck()") {
      val tracer = getCurrentTracer()
      tracer.out().printfIndentln("%d solutions = %s", int2Integer(outcome.size), outcome)
      if (b % g == 0) {
      	(for (x <- outcome) yield (a*x).mod(n) == b.mod(n)).forall(check => check == true) && outcome.size == g.toInt
      }
      else
        outcome.isEmpty
    }
  }
  
  override def toString(): String = a + "*x = " + b + " (mod " + n + ")"
  
  override def getCurrentTracer(): AbstractTracer = {
    try {
      TracerFactory.getInstance().getTracer("TestTracer")
    }
    catch {
      case ex: TracerFactory.Exception => TracerFactory.getInstance().getDefaultTracer
    }
  }
}