// Copyright (c) 2003 Nick Mathewson.  See LICENSE for licensing information.
// $Id: trials.cpp 1339 2004-09-12 02:39:32Z nickm $
#include <iostream>
#include <vector>
#include <cmath>
#include "trials.h"
#include "netparams.h"

using namespace std;

void
TrialResult::write(std::ostream &out) const
{
  out << "Result(succeeded="<<!failed
      << ",rounds="<< nRounds<<",roundsWithAlice="<< nRoundsAlice
      << ",msgs="<< nMsgs << ",msgsFromAlice= "<< nMsgsAlice
      << ",nMsgsMaybeReal="<<  nMsgsAliceReal;
  if (nRoundsMaybeAlice) {
    out << ",nObservedRounds="<<nRoundsObserved<<",nObsRoundsMaybeAlice="
        << nRoundsMaybeAlice;
  }
  out << ",nRoundsToGuessNth=";
  pvec(out, roundsToGuessN) << ")";
}

std::ostream &
operator<<(std::ostream &out, const TrialResult &r)
{
  r.write(out); return out;
}

TrialResult
BatchTrial::attempt()
{
  TrialResult res;
  unsigned int nGuessed = 0;
  int gran = granularity;
  res.roundsToGuessN.reserve(truth.size()+1);
  res.roundsToGuessN.push_back(0);

  mixnet->reset();
  attacker->reset();
  alice->reset();

  vec<int> trafficIn(nRecips, 0);
  vec<double> recips(nRecips, 0.0);

  int n = 0;
  while (1) {
    trafficIn.reset(0);
    int nAlice;
    alice->addTraffic(trafficIn, nAlice);
    if (nAlice) {
      res.nMsgsAlice += nAlice;
      res.nMsgsAliceReal += trafficIn.total(0);
      res.nRoundsAlice += 1;
    }
    int bgvol = nBatch-nAlice; if (bgvol<0) bgvol=0;
    background->addNTraffic(trafficIn, bgvol);
    res.nMsgs += nBatch;
    //std::cout << n << " "<< aTraffic << std::endl;
    mixnet->addRound(trafficIn, nAlice, bgvol, attacker);
    ++n;
    if (!(n % gran)) {
      if (n>=(gran*20) && gran <= granularity*128)
	gran *= 2;
      if (n>cutoff) {
	res.failed = true;
	res.nRounds = n;
	return res;
      }
      if (!attacker->guessAlice(recips)) continue;
#ifndef QUIET
      std::cout << "!!!!!!!" << n << " " << recips << std::endl;
      std::cout << n << " ------ " << std::endl;
      pvec(std::cout, truth);
      pvec(std::cout, recips.topN(truth.size()));
      std::cout << n << " ------ " << std::endl;
#endif
      unsigned matches = nSortedVecMatch(recips.topN(truth.size()), truth);
      while (matches > nGuessed) {
	res.roundsToGuessN.push_back(n);
	++nGuessed;
      }
      //if (truth == recips.topN(truth.size())) {
      if (matches == truth.size()) {
	res.nRounds = n;
        return res;
      }
    }
  }
}

BatchTrial::~BatchTrial()
{
  delete mixnet;
  delete attacker;
  delete background;
  delete alice;
}

void
SDTrialSpec::write(std::ostream &o) const
{
  o << "SDTrial(b=" << batchSize << ",m="<< nAliceRecipients
    << ",N=" << nRecipients
    << ",granularity=" << granularity
    << ")" << std::endl;
}

Trial *
SDTrialSpec::create() const
{
  assertFilled();
  return new SDTrial(*this);
}

void
SDTrial::init(int nR, int nAR, int b, int g, int c)
{
  ConstDist<int> aMsgs(1);
  ConstDist<int> aDummies(0);
  truth = std::vector<int>(nAR, 0);
  for (int i=0; i<nAR; ++i) { truth[i] = i; }

  alice = new UniformAlice(truth, &aMsgs, &aDummies);
  background = new UniformBackground(nR);
  vec<double> uniform(nR, 1.0/b);
  attacker = new SDAttacker(uniform);
  mixnet = new BatchMix(b);
  nRecips = nR;
  nBatch = b;  granularity = g;
  cutoff = c;
}

SDTrial::~SDTrial()
{
}


void
UnkBGBatchTrialSpec::write(std::ostream &o) const
{
  o << "UnkBGBatchTrial(b="<<batchSize<<",m="<<nAliceRecipients
    << ",N="<<nRecipients<<",aliceStrategy=";
  if (aliceIsSmallworld)
    o << "\"smallworld";
  else
    o << "\"uniform";
  if (expMsgDist)
    o << "-geometric";
  else
    o << "-uniform";
  if (weightAlice)
    o << "-weighted\"";
  else
    o << "\"";
  o << ",pMessage=" << pMsgAlice << ",pDummy=" << pDummyAlice
    << ",paddingLevel="<< paddingLevel
    << ",granularity="<<granularity
    << ")" << std::endl;
  // XXX paddingLevel is wrong, so include it when it's right.
}

Trial *
UnkBGBatchTrialSpec::create() const
{
  assertFilled();
  return new UnkBGBatchTrial(*this);
}

void
UnkBGBatchTrial::init(int nR, int nAR, int b,
		      bool smallworldAlice,
		      bool expMsgDist,
		      double pMsgA, double pDA,
		      int padding, double pOnline,
		      int g, int c, bool weightAlice)
{
  GeometricDist aMsgsE(1.0-pMsgA);
  GeometricDist aDummiesE(1.0-pDA);
  BinaryDist<int> aMsgsB(pMsgA, 1, 0);
  BinaryDist<int> aDummiesB(pDA, 1, 0);

  InvDist<int> *aliceRecipDist = 0;
  InvDist<int> *backgroundTrafficDist = 0;
  std::vector<int> *aliceRecipients = 0;
  getCommunicationLinks(aliceRecipDist, backgroundTrafficDist,
                        aliceRecipients,
                        nAR, nR, weightAlice, smallworldAlice);

  truth = *aliceRecipients;
  if (expMsgDist)
    alice = new DistAlice(aliceRecipDist, &aMsgsE, &aDummiesE, padding, pOnline);
  else
    alice = new DistAlice(aliceRecipDist, &aMsgsB, &aDummiesB, padding, pOnline);
  cout << "bgAlice={";
  for (int i = 0; i < nAR; ++i)
    cout << truth[i] << ":" << backgroundTrafficDist->getP(i) << ",";
  cout << "}" << endl;
  ConstDist<int> oneDist(1); // ignored.
  background = new DistBackground(*backgroundTrafficDist, oneDist);

  delete aliceRecipDist;
  delete backgroundTrafficDist;
  delete aliceRecipients;

  attacker = new UnkBGBatchAttacker(nR);
  mixnet = new BatchMix(b);

  nRecips = nR;
  nBatch = b;  granularity = g;
  cutoff = c;
}

UnkBGBatchTrial::~UnkBGBatchTrial()
{
}

// ======================================================================

TrialResult
NonbatchTrial::attempt()
{
  TrialResult res;
  int gran = granularity;
  res.roundsToGuessN.reserve(truth.size());
  res.roundsToGuessN.push_back(0);
  unsigned int nGuessed = 0;

  attacker->reset();
  mixnet->reset();
  alice->reset();

  vec<int> trafficIn(nRecips, 0);
  vec<int> bTraffic(nRecips, 0);
  vec<double> recips(nRecips, 0.0);

  int n = 0;
  while (1) {
    trafficIn.reset(0);
    int nAlice;
    alice->addTraffic(trafficIn, nAlice);
    if (nAlice) {
      res.nMsgsAlice += nAlice;
      res.nMsgsAliceReal += trafficIn.total(0);
      res.nRoundsAlice += 1;
    }
    int nBackground;
    background->addTraffic(trafficIn, nBackground);
    res.nMsgs += nAlice + nBackground;
    mixnet->addRound(trafficIn, nAlice, nBackground, attacker);
    ++n;
    if (!(n % gran)) {
      if (n>=(gran*20) && gran <= granularity*128)
	gran *= 2;
      if (n > cutoff) {
	attacker->guessAlice(recips);
	res.nRounds = n;
	res.failed = true;
	std::cout << "Got only "<<nSortedVecMatch(recips.topN(truth.size()),truth)
		  << " / " << truth.size() << std::endl;
	return res;
      }
      if (!attacker->guessAlice(recips)) continue;
#ifndef QUIET
      std::cout << "########" << n << " ------ " << std::endl;
      pvec(std::cout, truth);
      pvec(std::cout, recips.topN(truth.size()));
      std::cout << n << " ------ " << std::endl;
#endif
      unsigned matches = nSortedVecMatch(recips.topN(truth.size()), truth);
      while (matches > nGuessed) {
	res.roundsToGuessN.push_back(n);
	++nGuessed;
      }
      //if (truth == recips.topN(truth.size())) {
      if (matches == truth.size()) {
	res.nRounds = n;
        return res;
      }
    }
  }
}

NonbatchTrial::~NonbatchTrial()
{
  delete alice;
  delete background;
  delete mixnet;
  delete attacker;
}


void
MixTrialSpec::write(std::ostream &o) const
{
  o << "MixTrial(N=" << nRecipients << ",m=" << nAliceRecipients
    << ",l=" << pathLen << ",pDelay=" << pDelay
    << ",bgVolume=" << bgVolMean << ",bgVolumeStdDev=" << bgVolDev
    << ",pObserve=" << pObserve << ",padding="<<padding
    << ",smoothPadding=" << smoothPadding
    << ",pOnline=" << pOnline << ",pMessage="<<pMessage
    << ",pDummy=" << pDummy
    << ",granularity=" << granularity
    << ")" << std::endl;
    // padding XXX all other vars.
}

Trial *
MixTrialSpec::create() const
{
  assertFilled();
  return new MixTrial(*this);
}

void
NymTrialSpec::write(std::ostream &o) const {
  o << "NymTrial(N=" << nRecipients
    << ",pA=" << pMsgAlice
    << ",pD=" << pDelay
    << ",l=" << pathLen
    << ",bgVol=" << bgVolMean
    << ",bgDev=" << bgVolDev
    << ",pObs=" << pObserve
    << ")" << std::endl;
}

Trial *
NymTrialSpec::create() const
{
  return new NymTrial(*this);
}

#define DELAY_SLOP 0.0001

void
MixTrial::init(int nR, int nAR, int pathLen, double pDelay,
	       bool expAlice, double pMessage, double pDummy,
	       double bgVolMean, double bgVolDev,
	       double pOnline, int padding, bool smoothPadding,
	       int g, bool partial, double pObserve, int c,
	       bool knownBackground, bool pseudonyms)
{
  assert(pathLen > 0);
  assert(!pseudonyms);

  ////
  // Set up mixnet and attacker.
  InvDist<int> *delayDist;
#if 0
  int maxDelay;
  if (pDelay < 0.00001)
    maxDelay = pathLen+1;
  else
    maxDelay = static_cast<int>(std::log(0.001)/std::log(pDelay))*2+pathLen;
#endif

  if (pathLen == 1) {
    delayDist = getSingleMixDelays(pDelay);
  } else {
    int md;
    if (pDelay < 0.00001)
      md = pathLen + 1;
    else
      md = (int)( (std::log(DELAY_SLOP) / std::log(pDelay))+1)*pathLen;
    delayDist = getMixNetDelays(ConstDist<int>(pathLen), pathLen+1,
                                pDelay, md);
  }
  double totalPDelay = 0.0;
  int maxDelay = 0;
  while (totalPDelay < 1.0-DELAY_SLOP) {
    totalPDelay += delayDist->getP(maxDelay++);
  }
#if 0
  std::cout << "MAXDELAY="<<maxDelay<<std::endl;
  std::cout << "Prob of delay (max-1) is = " << delayDist->getP(maxDelay-1)
	    << std::endl;
#endif

  //// Set up small world.
  InvDist<int> *aliceRecipDist = 0;
  InvDist<int> *backgroundTrafficDist = 0;
  std::vector<int> *aliceRecipients = 0;
  assert(!pseudonyms);
  getCommunicationLinks(aliceRecipDist, backgroundTrafficDist,
                        aliceRecipients,
                        nAR, nR, false, true);

  //// Set up attacker and mixnet
  mixnet = new DelayMix(nR, maxDelay, delayDist);
  attacker = new DelayMixAttacker(nR, maxDelay, delayDist);
  if (knownBackground) {
    vec<double> bg(nR);
    for (int i = 0; i < nR; ++i) {
      bg[i] = backgroundTrafficDist->getP(i);
    }
    ((DelayMixAttacker*)attacker)->setKnownBackground(bg);
  }

  if (partial)
    attacker = new POAttacker(attacker, pObserve);

  delete delayDist;

  //// set up alice.
  GeometricDist aMsgsE(1.0-pMessage);
  GeometricDist aDummiesE(1.0-pDummy);
  BinaryDist<int> aMsgsB(pMessage, 1, 0);
  BinaryDist<int> aDummiesB(pDummy, 1, 0);

  truth = *aliceRecipients;
  if (expAlice)
    alice = new DistAlice(aliceRecipDist,&aMsgsE, &aDummiesE, padding, pOnline,
			  smoothPadding);
  else
    alice = new DistAlice(aliceRecipDist,&aMsgsB, &aDummiesB, padding, pOnline,
			  smoothPadding);
  cout << "bgAlice={";
  for (int i = 0; i < nAR; ++i)
    cout << truth[i] << ":" << backgroundTrafficDist->getP(i) << ",";
  cout << "}" << endl;

  //// set up background.
  background = new DistBackground(*backgroundTrafficDist,
                                  IntNormalDist(bgVolMean, bgVolDev, true));

  cutoff = c;

  delete aliceRecipDist;
  delete backgroundTrafficDist;
  delete aliceRecipients;
}

TrialResult
MixTrial::attempt()
{
  TrialResult res = NonbatchTrial::attempt();
  attacker->getRoundCounts(res.nRoundsObserved, res.nRoundsMaybeAlice);
  return res;
}

// ======================================================================

void
NymTrial::init(const NymTrialSpec &s)
{
  // Set up mixnet and attacker
  InvDist<int> *delayDist;
  if (s.pathLen == 1) {
    delayDist = getSingleMixDelays(s.pDelay);
  } else {
    int md;
    if (s.pDelay < 0.00001)
      md = s.pathLen + 1;
    else
      md = (int)( (std::log(DELAY_SLOP) / std::log(s.pDelay))+1)*s.pathLen;
    delayDist = getMixNetDelays(ConstDist<int>(s.pathLen), s.pathLen+1,
                                s.pDelay, md);
  }
  double totalPDelay = 0.0;
  int maxDelay = 0;
  while (totalPDelay < 1.0-DELAY_SLOP) {
    totalPDelay += delayDist->getP(maxDelay++);
  }

  //// Set up alice's pseudonym and the other pseudonyms.
  InvDist<int> *alicePseudDist = 0;
  InvDist<int> *backgroundPseudDist = 0;
  getNymStats(alicePseudDist, backgroundPseudDist, s.nRecipients);

  //// Finish attacker and mixnet
  mixnet = new DelayMix(s.nRecipients, maxDelay, delayDist);
  attacker = new DelayMixAttacker(s.nRecipients, maxDelay, delayDist);

  if (s.partial)
    attacker = new POAttacker(attacker, s.pObserve);

  delete delayDist;

  //// Configure alize
  GeometricDist aMsgs(1.0-s.pMsgAlice);
  BinaryDist<int> aDummies(0.0, 0, 0);
  truth = std::vector<int>(1, 0);
  alice = new DistAlice(alicePseudDist, &aMsgs, &aDummies, 0, 1.0, false);

  //// Configure background
  background = new DistBackground(*backgroundPseudDist,
                                  IntNormalDist(s.bgVolMean, s.bgVolDev, true));
  cutoff = s.cutoff;

  delete alicePseudDist;
  delete backgroundPseudDist;
}

TrialResult
NymTrial::attempt()
{
  TrialResult res = NonbatchTrial::attempt();
  attacker->getRoundCounts(res.nRoundsObserved, res.nRoundsMaybeAlice);
  return res;
}

