// Copyright (c) 2003 Nick Mathewson.  See LICENSE for licensing information.
// $Id: sim.cpp 1339 2004-09-12 02:39:32Z nickm $
#include <iostream>
#include "sim.h"
#include "rng.h"

Alice::~Alice() {}
Background::~Background() {}

// ======================================================================

void
DistAlice::addTraffic(vec<int> &v_out, int &n_out)
{
  int nM;
  if (pSend < 1.0 && rng() > pSend) {
    n_out = 0;
    return;
  }
  n_out = nM = nMessageDist->get();
  if (padding && nPending && nM < padding) {
    if (nPending < (padding-nM)) {
      n_out += nPending;
      nM = n_out;
      nPending = 0;
    } else {
      nPending -= (padding-nM);
      n_out = nM = padding;
    }
  }
  for (int i = 0; i < n_out; ++i) {
    ++ v_out[recipientDist->get()];
  }
  if (padding) {
    if (nM>padding) {
      if (smoothPadding) {
	nPending += (nM-padding);
	n_out = padding;
      }
    } else {
      n_out = padding;
    }
  } else if (nDummyDist) {
    n_out += nDummyDist->get();
  }
  /*
#ifndef QUIET
  std::cout << "Alice sends " << n_out << "; " << nPending << " pending."
	    << std::endl;
#endif
  */
}

void
DistAlice::reset()
{
  nPending = 0;
}

UniformAlice::UniformAlice(const std::vector<int> &r,
                           Dist<int> *mD, Dist<int> *dD, int p, double pOn)
  : DistAlice(new UniformChoiceDist<int>(r), mD, dD, p, pOn)
{
}

UniformBackground::UniformBackground(int nR, int nPR)
  : nRecipients(nR), nPerRound(nPR)
{
}

void
UniformBackground::addNTraffic(vec<int> &v_out, int nMessages)
{
  while (nMessages-- > 0) {
    ++v_out[rng(nRecipients)];
  }
}

void
UniformBackground::addTraffic(vec<int> &v_out, int &nOut)
{
  assert(nPerRound >= 0);
  addNTraffic(v_out, nPerRound);
  nOut = nPerRound;
}

BatchMix::BatchMix(int b)
  : Mixnet(), batchSize(b)
{
}

void
BatchMix::reset()
{
}

void
BatchMix::addRound(const vec<int> &input, int nAlice, int nOther,
                   FwdAttacker *a)
{
  a->addRound(nAlice, nOther, input);
}

SDAttacker::SDAttacker(vec<double> &bg)
  : nRounds(0), nAlice(0), nOther(0), background(bg), observed(bg.size(), 0)
{
}

void
SDAttacker::reset()
{
  nRounds = 0;
  nAlice = 0;
  nOther = 0;
  observed.reset(0);
}

void
SDAttacker::addRound(int nA, int nO, const vec<int> &rcvd)
{
  ++nRounds;
  if (nA) {
    nAlice += nA;
    nOther += nO;
    observed += rcvd;
  }
}

bool
SDAttacker::guessAlice(vec<double> &r)
{
  if (!nAlice) return false;

  r = vec<double>(observed);
  //cerr << "1." << r << endl;
  //cerr << "1b." << background << endl;
  //  cerr << "1c. " << nAlice << ", " << nOther << endl;
  r -= background*nOther;
  //cerr << "2." << r << endl;
  r /= nAlice;
  //cerr << "3." << r << endl;
  return true;
}

// ======================================================================
// Unknown background

void
DistBackground::addNTraffic(vec<int> &v_out, int nMessages)
{
  while (nMessages-- > 0) {
    // XXX rng bottleneck
    ++ v_out[ recipientDist->get() ];
  }
}

void
DistBackground::addTraffic(vec<int> &v_out, int &nOut)
{
  nOut = nMessages->get();
  addNTraffic(v_out, nOut);
}

UnkBGBatchAttacker::UnkBGBatchAttacker(int nR)
  : vObservations(nR, 0), uObservations(nR, 0)
{
  reset();
}

void
UnkBGBatchAttacker::reset()
{
  nAlice = nOther = nBg = 0;
  vObservations.reset(0);
  uObservations.reset(0);
}

void
UnkBGBatchAttacker::addRound(int nA, int nO, const vec<int> &nReceived)
{
  if (nA) {
    nAlice += nA;
    nOther += nO;
    vObservations += nReceived;
  } else {
    nBg += nO;
    uObservations += nReceived;
  }
}

bool
UnkBGBatchAttacker::guessAlice(vec<double> &r)
{
  if (!nBg || !nAlice)
    return false;
  vec<double> u(uObservations);
  assert(u.total(0.0) == nBg);
  u *= (nOther / nBg);
  r = vec<double>(vObservations);
  r -= u;
  r /= nAlice;
  return true;
}

// ======================================================================

DelayMix::DelayMix(int nR, int mD,
                   Dist<int> *d)
  : maxDelay(mD), poolIdx(0), pools(mD), delayDist(d->copy())
{
  for (int i = 0; i < mD; ++i) {
    pools[i] = new vec<int>(nR, 0);
  }
}

void
DelayMix::reset()
{
  poolIdx = 0;
  for (int i = 0; i < maxDelay; ++i) {
    pools[i]->reset(0);
  }
}

DelayMix::~DelayMix()
{
  for (int i = 0; i < maxDelay; ++i) {
    delete pools[i];
  }
  delete delayDist;
}

void
DelayMix::addRound(const vec<int> &inp, int nA, int nO,
                   FwdAttacker *a)
{
  int sz = inp.size();
  for (int i = 0; i < sz; ++i) {
    int totalForRecip = inp[i];
    for (int j = 0; j < totalForRecip; ++j) {
      // XXX rng bottleneck
      ++( (*pools[(poolIdx+getDelay())%maxDelay]) [i]);
    }
  }
  a->addRound(nA, nO, *pools[poolIdx]);
  pools[poolIdx]->reset(0);
  ++poolIdx;
  if (poolIdx >= maxDelay)
    poolIdx = 0;
}

DelayMixAttacker::DelayMixAttacker(int nR, int mD, InvDist<int> *dDist)
  : nRecips(nR), maxDelay(mD), nAliceIdx(0), nAliceHist(mD, 0),
    nOtherHist(mD, 0),
    knownBackground(0),
    background(nR, 0.0),
    observed(nR, 0.0),
    nObservedOther(0.0), nObservedAlice(0.0),
    delayDist(dynamic_cast<InvDist<int>*>(dDist->copy())),
    nRoundsObserved(0), nRoundsMaybeAlice(0)
{
  // std::cout << "DMA : maxDelay = " << maxDelay << std::endl;
}

void
DelayMixAttacker::setKnownBackground(vec<double> &bg)
{
  // std::cerr << "set known background" << std::endl; 
  knownBackground = new vec<double>(bg);
}

double
DelayMixAttacker::expectedAliceMsgs()
{
  double tot = 0.0;
  for (int i = 0; i < maxDelay; ++i) {
    //std::cout << i << "/" << maxDelay << ":" << delayDist->getP(i) << ","
    //   << aHist(i) << std::endl;
    tot += delayDist->getP(i)*aHist(i);
  }
  return tot;
}

double
DelayMixAttacker::expectedOtherMsgs()
{
  double tot = 0.0;
  for (int i = 0; i < maxDelay; ++i) {
    tot += delayDist->getP(i)*oHist(i);
  }
  return tot;
}

void
DelayMixAttacker::reset()
{
  nAliceIdx = 0;
  nAliceHist = std::vector<int>(nRecips, 0);
  nOtherHist = std::vector<int>(nRecips, 0);
  background.reset(0.0);
  observed.reset(0.0);
  nObservedOther = nObservedAlice = 0.0;
  nRoundsObserved = nRoundsMaybeAlice = 0;
  exOtherInBackground = exAliceInBackground = 0.0;
}

//XXXX make this configurable.  Low values really seem to hurt.
#define BG_THRESHOLD 1.0
void
DelayMixAttacker::addRound(int nAlice, int nOther,
                           const vec<int> &r)
{
  nAliceHist[nAliceIdx] = nAlice;
  nOtherHist[nAliceIdx] = nOther;

  double exAlice = expectedAliceMsgs();
  double exOther = expectedOtherMsgs();
  ++nRoundsObserved;
  // std::cerr << exAlice << std::endl;
  if (exAlice < BG_THRESHOLD && !knownBackground) {
    for (int i = 0; i < nRecips; ++i)
      background[i] += r[i];//*(exOther+exAlice);
    exOtherInBackground += exOther;
    exAliceInBackground += exAlice;
  } else {
    ++nRoundsMaybeAlice;
    for (int i = 0; i < nRecips; ++i)
      observed[i] += r[i]*(exAlice/(exOther+exAlice));
    nObservedOther += exOther;
    nObservedAlice += exAlice;
  }

  ++nAliceIdx;
  if (nAliceIdx >= maxDelay)
    nAliceIdx = 0;
}

bool
DelayMixAttacker::guessAlice(vec<double> &res)
{
  double uTotal = background.total(0.0);
  double oTotal = observed.total(0.0);
#ifndef QUIET // XXXX
  std::cout << exOtherInBackground << "++++" << exAliceInBackground
	    << "  " << ((double)exAliceInBackground)/exOtherInBackground
	    << "(" << uTotal << ")" <<std::endl;
  std::cout << nObservedOther << "++++" << nObservedAlice
	    << "  " << ((double)nObservedAlice)/nObservedOther
	    << "(" << observed.total(0.0) << ")" << std::endl;
  std::cout << observed << std::endl;
  std::cout << background << std::endl;
#endif
  if (knownBackground) {
    if (!nObservedAlice)
      return false;
    vec<double> u(*knownBackground);
    u *= (oTotal);
    res = observed - u;
    res /= nObservedAlice;
  } else {
    if (!uTotal || !nObservedAlice)
      return false;

    vec<double> u(background);
    u *= (oTotal / exOtherInBackground);
    res = observed - u;
    res /= nObservedAlice;
  }
#ifndef QUIET
  std::cout << res << std::endl << std::endl;
#endif
  return true;
}

// ======================================================================

void
POAttacker::addRound(int nAlice, int nOther, const vec<int> &rcvd)
{
  int nA = 0, nO = 0;
  int nR = rcvd.size();
  vec<int> r(nR, 0);
  for (int i = 0; i < nAlice; ++i) {
    if (rng() < pObserve) ++nA;
  }
  for (int i = 0; i < nOther; ++i) {
    if (rng() < pObserve) ++nO;
  }
  for (int i = 0; i < nR; ++i) {
    for (int j = 0; j < rcvd[i]; ++j) {
      if (rng() < pObserve) ++r[i];
    }
  }
  /*
  std::cout << nAlice << " -> " << nA << ", "
            << nOther << " -> " << nO << ", "
            << rcvd << " -> " << r << std::endl;
  */
  base->addRound(nA, nO, r);
}

