// Copyright (c) 2003 Nick Mathewson.  See LICENSE for licensing information.
// $Id: sim.h 1223 2004-04-11 01:17:01Z nickm $
// sim.h -- Simulation classes
#ifndef _SIM_H
#define _SIM_H

#include "vec.h"
#include "rng.h"

//  Message flow:
//
//   Background -->
//                 --> Mixnet ---> FwdAttacker
//        Alice -->


// A target sender whose recipients the attacker is trying to identify.
class Alice {
public:
  Alice() {}
  // Add messages for a single round to the vector v_add.  Set nAdd
  // to the number of messages sent in this round.
  virtual void addTraffic(vec<int> &v_add, int &nAdd) = 0;
  virtual void reset() {};
  virtual ~Alice();
};

// Source of background traffic masking the target sender.
class Background {
public:
  Background() {}
  // Add nMessages to the traffic vector v_out.
  virtual void addNTraffic(vec<int> &v_out, int nMessages) = 0;
  // Add messages for a single round to the vector v_out according to
  // the default background-size distribution.  Set nOut to the number
  // of messages added.
  virtual void addTraffic(vec<int> &v_out, int &nOut) = 0;
  virtual ~Background();
};

// An attack algorithm: attempts learn Alice's recipients.
class FwdAttacker {
 public:
  FwdAttacker() {}

  // Clear the attacker's memory.
  virtual void reset() = 0;
  // Add a round to the attacker's observations.  The attacker sees that
  // Alice has sent nSentAlice messages, that other users have send nSentOther
  // messages, and that each recipient i has received nReceived[i] messages.
  virtual void addRound(int nSentAlice, int nSentOther,
                        const vec<int> &nReceived) = 0;
  // Guess Alice's recipients, based on past observations.  Set recipients[i]
  // to our estimate of the probability that a message from Alice is to
  // recipient i.
  virtual bool guessAlice(vec<double> &recipients) = 0;
  void getRoundCounts(int &nObs, int &nAlice) {
    nObs = nAlice = 0;
  }

  virtual ~FwdAttacker() {}
};

// A mixnet: distributes messages to their recipients (and to the attacker)
class Mixnet {
public:
  Mixnet() {}
  // clear the mixnet's memory.
  virtual void reset() = 0;
  // Send a round through the mixnet.  The round countains nAlice messages
  // from alice, and nBackground messages from other users.  It contains
  // input[i] messages to each recipient i.
  virtual void addRound(const vec<int> &input,
                        int nAlice, int nBackground,
                        FwdAttacker *r) = 0;
  virtual ~Mixnet() {}
};


// ======================================================================
// Original statistical disclosure attack

// Sender that chooses recipients, messages, and dummies through random
// distributions.  Alternatively, if padding is sent, then we always
// send at least padding messages and ignore the dummies dist.  The
// sender considers whether to participate in a round with probability pSend.
class DistAlice : public Alice {
 protected:
  Dist<int> *nMessageDist;
  Dist<int> *nDummyDist;
  Dist<int> *recipientDist;
  int nPending;
  // if padding is >0, and we send any messages, we instead send >=padding
  int padding;
  // if smoothpadding, we never send > padding.
  double pSend;
  bool smoothPadding;
  DistAlice() : nMessageDist(0), nDummyDist(0), recipientDist(0), padding(0) {}
 public:
  DistAlice(Dist<int> *recips, Dist<int> *msgs, Dist<int> *dummies,
            int p=0, double pSendAny=1.0, bool sm=false)
    : nMessageDist(msgs->copy()), nDummyDist(dummies->copy()),
      recipientDist(recips->copy()), padding(p), pSend(pSendAny),
      smoothPadding(sm)
  {}
  void addTraffic(vec<int> &v_out, int &nOut);
  void reset();
  ~DistAlice() { delete nMessageDist; delete nDummyDist; delete recipientDist;}
};


// Sender that chooses recipients uniformly from a list
class UniformAlice : public DistAlice {
public:
  UniformAlice(const std::vector<int> &r,
               Dist<int> *msgDist, Dist<int> *dummyDist, int padding=0,
	       double pOnline=1.0);
  ~UniformAlice() { }
};

// Background that sends messages to recipients with uniform probability
class UniformBackground : public Background {
private:
  int nRecipients;
  int nPerRound;
public:
  UniformBackground(int nR, int nPR=-1);
  ~UniformBackground() {}
  void addNTraffic(vec<int> &v_out, int nMessages);
  void addTraffic(vec<int> &v_out, int &nOut);
};

// Mix that accumulates B messages and then relays them all.
class BatchMix : public Mixnet {
private:
  int batchSize;
public:
  BatchMix(int b);
  void reset();
  void addRound(const vec<int> &input,
                int nAlice, int nBackground,
                FwdAttacker *f);
  ~BatchMix() {}
};

// Implements the original statistical disclosure attack
class SDAttacker : public FwdAttacker {
private:
  int nRounds;
  int nAlice;
  int nOther;
  vec<double> background;
  vec<int> observed;
public:
  SDAttacker(vec<double> &background);
  void reset();
  void addRound(int nSentAlice, int nSentOther,
                const vec<int> &nReceived);
  bool guessAlice(vec<double> &recipients);
  ~SDAttacker() {}
};

// ======================================================================
// Attack with unknown background, unknown sender behavior.

class DistBackground : public Background {
 private:
  Dist<int> *recipientDist;
  Dist<int> *nMessages;
 public:
  DistBackground(const Dist<int>&d, const Dist<int> &nMsgs)
    : recipientDist(d.copy()), nMessages(nMsgs.copy()) {}
  ~DistBackground() { delete recipientDist; delete nMessages; }
  void addNTraffic(vec<int> &v_out, int nMessages);
  void addTraffic(vec<int> &v_out, int &nOut);
};

class UnkBGBatchAttacker : public FwdAttacker {
 private:
  int nAlice;
  int nOther;
  vec<int> vObservations;
  int nBg;
  vec<int> uObservations;
 public:
  UnkBGBatchAttacker(int nRecips);
  void reset();
  void addRound(int nSentAlice, int nSentOther, const vec<int> &nReceived);
  bool guessAlice(vec<double> &nRecipients);
  ~UnkBGBatchAttacker() {}
  // protected:
  void addNTraffic(vec<int> &v_out, int nMessages);
  void addTraffic(vec<int> &v_out, int &nOut);
};

// ======================================================================
// Timed mixes and mix-nets.

class DelayMix : public Mixnet
{
private:
  int maxDelay;
  int poolIdx;
  std::vector< vec<int>* > pools;
  Dist<int> *delayDist;

protected:
  int getDelay() { int d = delayDist->get();
    return d >= maxDelay ? maxDelay-1 : d; }

public:
  DelayMix(int nRecips, int maxDelay, Dist<int> *delayDist);
  void reset();
  void addRound(const vec<int> &input,
                int nAlice, int background,
                FwdAttacker *a);
  ~DelayMix();
};

class DelayMixAttacker : public FwdAttacker
{
private:
  int nRecips;
  int maxDelay;

  int nAliceIdx;
  std::vector<int> nAliceHist;
  std::vector<int> nOtherHist;

  vec<double> *knownBackground;

  vec<double> background;
  vec<double> observed;
  double nObservedOther;
  double nObservedAlice;

  double exOtherInBackground;
  double exAliceInBackground;

  InvDist<int> *delayDist;

  // for TrialResult
  int nRoundsObserved;
  int nRoundsMaybeAlice;

  int aHist(int rds) { return rds>=maxDelay ? 0 :
      nAliceHist[(maxDelay+nAliceIdx-rds)%maxDelay]; }
  int oHist(int rds) { return rds>=maxDelay ? 0 :
      nOtherHist[(maxDelay+nAliceIdx-rds)%maxDelay]; }

  double expectedAliceMsgs();
  double expectedOtherMsgs();

public:
  DelayMixAttacker(int nRecips, int maxDelay, InvDist<int> *delayDist);
  void setKnownBackground(vec<double> &bg);
  void reset();
  void addRound(int nSentAlice, int nSentOther, const vec<int> &nReceived);
  bool guessAlice(vec<double> &nRecipients);
  void getRoundCounts(int &nObs, int &nAlice) {
    nObs = nRoundsObserved; nAlice = nRoundsMaybeAlice; }
  ~DelayMixAttacker() { delete delayDist; delete knownBackground; }
};


// Partially observant attacker.
class POAttacker : public FwdAttacker
{
 private:
  FwdAttacker *base;
  double pObserve;
 public:
  POAttacker(FwdAttacker *baseA, double p) : base(baseA), pObserve(p) {}
  void reset() { base->reset(); }
  void addRound(int nAlice, int nOther, const vec<int> &rcvd);
  bool guessAlice(vec<double> &guess) { return base->guessAlice(guess); }
  void getRoundCounts(int &r, int &ra) { base->getRoundCounts(r, ra); }
  ~POAttacker() {}
};


#endif /* _SIM_H */

