// Copyright (c) 2003 Nick Mathewson.  See LICENSE for licensing information.
// $Id: trials.h 1339 2004-09-12 02:39:32Z nickm $
#ifndef _TRIALS_H
#define _TRIALS_H

#include <iostream>
#include <assert.h>
#include "sim.h"
#include "rng.h"

class Trial;

class TrialSpec
{
 public:
  TrialSpec() {}
  virtual void write(std::ostream &o) const = 0;
  virtual Trial *create() const = 0;
  virtual ~TrialSpec() {}
};

class TrialResult
{
 public:
  TrialResult() {
    nRounds = nRoundsAlice = nRoundsObserved = nRoundsMaybeAlice = 0;
    nMsgs = nMsgsAlice = nMsgsAliceReal = 0L;
    failed = false;
  }
  int nRounds;
  int nRoundsAlice;
  long nMsgs;
  long nMsgsAlice;
  long nMsgsAliceReal;
  std::vector<int> roundsToGuessN;

  bool failed;

  // only interesting for DelayMix attacks
  int nRoundsObserved;
  int nRoundsMaybeAlice;

  virtual void write(std::ostream &o) const;
  virtual ~TrialResult() {}
};


std::ostream &operator<<(std::ostream &out,
                         const TrialResult &r);

class Trial {
public:
  Trial() {}
  virtual TrialResult attempt() = 0;
  virtual ~Trial() {}
};

// Trial with a batch mix.  use calss.
class BatchTrial : public Trial {
 protected:
  Alice *alice;           // What does alice do?
  Background *background; // What do the other senders do?
  BatchMix *mixnet;       // How does the mix behave?
  FwdAttacker *attacker;  // How does the attacker analyze the network?
  std::vector<int> truth; // What does the attacker guess to win?
  int nRecips;            // How many recipients are there in total?
  int nBatch;             // How large is the batch?
  int granularity;     // How often to we check whether the attacker has won?
  int cutoff;

  BatchTrial() :  alice(0), background(0), mixnet(0), attacker(0) {}
 public:
  TrialResult attempt();
  virtual ~BatchTrial();
};

class SDTrialSpec : public TrialSpec
{
  friend class SDTrial;
 protected:
  int nRecipients;
  int nAliceRecipients;
  int batchSize;
  int granularity;
  int cutoff;
 public:
  SDTrialSpec() {
    granularity = 5; nRecipients = nAliceRecipients = batchSize = 0;
    cutoff = 1000000000;
  }
  SDTrialSpec &setNRecipients(int n) { nRecipients = n; return *this; }
  SDTrialSpec &setAliceRecipients(int n) { nAliceRecipients = n;return *this; }
  SDTrialSpec &setBatchSize(int n) { batchSize = n; return *this; }
  SDTrialSpec &setGranularity(int n) { granularity = n; return *this; }
  SDTrialSpec &setCutoff(int n) { cutoff = n; return *this; }
  void assertFilled() const {
    assert(nRecipients > 0);
    assert(nAliceRecipients > 0);
    assert(batchSize > 0);
  }
  void write(std::ostream &o) const;
  Trial *create() const;
};

// ========================================
// Simple statistical disclosure attack
//    Batch mix, alice sends 1 message per batch
//    Background distribution is known
//    Alice has fixed list of recipients, chooses with equal probability.
class SDTrial : public BatchTrial {
 protected:
  void init(int nRecipients, int nAliceRecipients, int batchSize,
	    int granularity, int cutoff);
 public:
  SDTrial(const SDTrialSpec &s) {
    s.assertFilled();
    init(s.nRecipients, s.nAliceRecipients, s.batchSize, s.granularity,
	 s.cutoff);
  }
  ~SDTrial();
};

// ========================================
// Generalized statistical diclosure attack against a batch mix.
//    Batch mix, alice either:
//         - sends 1 or 0 messages per batch.
//         - follows smallworld distribution.
//    Background distribution is unknown smallworld instance.

class UnkBGBatchTrialSpec : public TrialSpec
{
  friend class UnkBGBatchTrial;
 protected:
  int nRecipients, nAliceRecipients, batchSize;
  bool aliceIsSmallworld, expMsgDist, weightAlice, ais_set, emd_set;
  double pMsgAlice, pDummyAlice, pOnline;
  int paddingLevel, granularity;
  int cutoff;
 public:
  UnkBGBatchTrialSpec() {
    paddingLevel = 0; granularity = 5; cutoff = 1000000000; pOnline=1.0;
    nRecipients = nAliceRecipients = batchSize = 0;
    pMsgAlice = pDummyAlice = -1;
    ais_set = emd_set = false;
    weightAlice = false;
  }
  UnkBGBatchTrialSpec &setNRecipients(int n) { nRecipients = n; return *this; }
  UnkBGBatchTrialSpec &setNAliceRecipients(int n) { nAliceRecipients = n; return *this; }
  UnkBGBatchTrialSpec &setBatchSize(int n) { batchSize = n; return *this; }
  UnkBGBatchTrialSpec &setPaddingLevel(int n) { paddingLevel = n; return *this; }
  UnkBGBatchTrialSpec &setGranularity(int n) { granularity = n; return *this; }
  UnkBGBatchTrialSpec &setCutoff(int n) { cutoff = n; return *this; }

  UnkBGBatchTrialSpec &setAliceIsSmallworld(bool b) {
    aliceIsSmallworld = b; ais_set = true; return *this; }
  UnkBGBatchTrialSpec &setExpMsgDist(bool b) {
    expMsgDist = b; emd_set = true; return *this; }
  UnkBGBatchTrialSpec &setWeightAlice(bool b) { weightAlice = b; return *this;}

  UnkBGBatchTrialSpec &setPMsgAlice(double p) { pMsgAlice = p; return *this; }
  UnkBGBatchTrialSpec &setPDummyAlice(double p) { pDummyAlice = p; return *this; }

  void assertFilled() const {
    assert(nRecipients && nAliceRecipients && batchSize);
    assert(pMsgAlice >= 0.0 && pDummyAlice >= 0.0);
    assert(ais_set && emd_set);
  }
  void write(std::ostream &o) const;
  Trial *create() const;
};

class UnkBGBatchTrial : public BatchTrial {
 protected:
  void init(int nRecipients, int nAliceRecipients,
	    int batchSize, bool aliceIsSmallworld,
	    bool expMsgDist, // if true, alice sends N msgs on exp. dist
	    double pMsgAlice, double pDummyAlice,
	    int paddingLevel, // if true, and if alice sends dummies in a round, she sends fillDummies messages total.
	    double pOnline,
	    int granularity, int cutoff, bool weightAlice);
 public:
  UnkBGBatchTrial(const UnkBGBatchTrialSpec &s) {
    s.assertFilled();
    init(s.nRecipients, s.nAliceRecipients, s.batchSize,
	 s.aliceIsSmallworld, s.expMsgDist, s.pMsgAlice, s.pDummyAlice,
	 s.paddingLevel, s.pOnline, s.granularity, s.cutoff,
	 s.weightAlice);
  }

  ~UnkBGBatchTrial();
};

// Shared code for case when background decides how many messages to send
// per round.
class NonbatchTrial : public Trial {
 private:
  NonbatchTrial();
 protected:
  Alice *alice;
  Background *background;
  Mixnet *mixnet;
  FwdAttacker *attacker;
  std::vector<int> truth;
  int nRecips;
  int granularity;
  int cutoff;

  NonbatchTrial(int nR, int g) : alice(0), background(0), mixnet(0),
    attacker(0), truth(nR), nRecips(nR), granularity(g) {}
 public:
  TrialResult attempt();
  ~NonbatchTrial();
};

class MixTrialSpec : public TrialSpec {
  friend class MixTrial;
 protected:
  int nRecipients, nAliceRecipients, pathLen, padding, granularity, cutoff;
  bool expAlice, partial, ea_set, p_set, smoothPadding, knownBackground;
  bool pseudonyms;
  double pOnline, pDelay, pMessage, pDummy, bgVolMean, bgVolDev, pObserve;
 public:
  MixTrialSpec() {
    padding = 0; granularity = 5; partial=false; pObserve=1.0; pOnline=1.0;
    cutoff = 1000000000;

    nRecipients = nAliceRecipients = pathLen = 0;
    ea_set = p_set = smoothPadding = false;
    pDelay = pMessage = pDummy = bgVolMean = bgVolDev = -1;
    knownBackground = pseudonyms = false;
  }
  MixTrialSpec &setNRecipients(int n) { nRecipients = n; return *this; }
  MixTrialSpec &setNAliceRecipients(int n) { nAliceRecipients = n; return *this; }
  MixTrialSpec &setPathLen(int n) { pathLen = n; return *this; }
  MixTrialSpec &setPadding(int n) { padding = n; return *this; }
  MixTrialSpec &setGranularity(int n) { granularity = n; return *this; }
  MixTrialSpec &setCutoff(int n) { cutoff = n; return *this; }
  MixTrialSpec &setSmoothPadding(bool b) { smoothPadding = b; return *this; }

  // rename to geometric.
  MixTrialSpec &setExpAlice(bool b) { expAlice = b; ea_set = true; return *this; }
  MixTrialSpec &setPartial(bool b) { partial = b; p_set = true; return *this; }

  MixTrialSpec &setPDelay(double d) { pDelay = d; return *this; }
  MixTrialSpec &setPMessage(double d) { pMessage = d; return *this; }
  MixTrialSpec &setPDummy(double d) { pDummy = d; return *this; }
  MixTrialSpec &setBGVolMean(double d) { bgVolMean = d; return *this; }
  MixTrialSpec &setBGVolDev(double d) { bgVolDev = d; return *this; }
  MixTrialSpec &setPObserve(double d) { pObserve = d; return *this; }
  MixTrialSpec &setPOnline(double d) { pOnline = d; return *this; }
  MixTrialSpec &setKnownBackground(bool b) { knownBackground = b; return *this; }
  MixTrialSpec &setPseudonyms(bool b) { pseudonyms = b;  return *this; }

  void assertFilled() const {
    assert(nRecipients && nAliceRecipients && pathLen);
    assert(ea_set && p_set);
    assert(pDelay >= 0.0);
    assert(pMessage >= 0.0);
    assert(pDummy >= 0.0);
    assert(bgVolMean >= 0.0);
    assert(bgVolDev >= 0.0);
  }

  void write(std::ostream &o) const;
  Trial *create() const;

};

// Trial with a single mix or a whole mixnet.
class MixTrial : public NonbatchTrial {
 private:
  void init(int nRecipients, int nAliceRecipients,
	    int pathLen,   // 1 if using a single mix.
	    double pDelay, // probability of being delayed in a randomly chosen round.
	    bool expAlice,  // alice uses exponential distribution.
	    double pMessage, // probability of Alice sending a single message.
	    double pDummy,   // probability of Alice sending a dummy message
	    double bgVolMean, // \ Taken together, these two decide how many
	    double bgVolDev,  // / messages the background sends in a round.
	    double pOnline,
	    int padding,
	    bool smoothPadding,
	    int granularity,
	    bool partial,
	    double pObserve,
	    int cutoff,
	    bool knownBackground,
	    bool pseudonyms
	    );
 public:
  MixTrial(const MixTrialSpec &s) : NonbatchTrial(s.nRecipients, s.granularity)
    { s.assertFilled();
      init(s.nRecipients, s.nAliceRecipients, s.pathLen, s.pDelay,
	   s.expAlice, s.pMessage, s.pDummy, s.bgVolMean, s.bgVolDev,
	   s.pOnline, s.padding, s.smoothPadding,
	   s.granularity, s.partial, s.pObserve, s.cutoff, s.knownBackground,
	   s.pseudonyms); }
  TrialResult attempt();
  ~MixTrial() {}
};


// ======================================================================
class NymTrialSpec : public TrialSpec {
  friend class NymTrial;
 protected:
  int nRecipients, pathLen;
  int granularity, cutoff;
  double pDelay, pMsgAlice, bgVolMean, bgVolDev, pObserve;
  bool partial;
  bool partialSet;
 public:
  NymTrialSpec() {
    nRecipients = pathLen = 0;
    pDelay = pMsgAlice = bgVolMean = bgVolDev = 0.0;
    granularity = 5;
    cutoff = 10000000;
    pObserve = 1.0;
    partialSet = false;
  }

  NymTrialSpec &setNRecipients(int i) { nRecipients = i; return *this; }
  NymTrialSpec &setPathLen(int i) { pathLen = i; return *this; }
  NymTrialSpec &setCutoff(int n) { cutoff = n; return *this; }
  NymTrialSpec &setGranularity(int n) { granularity = n; return *this; }

  NymTrialSpec &setPDelay(double d) { pDelay = d; return *this; }
  NymTrialSpec &setPMsgAlice(double d) { pMsgAlice = d; return *this; }
  NymTrialSpec &setBGVolMean(double d) { bgVolMean = d; return *this; }
  NymTrialSpec &setBGVolDev(double d) { bgVolDev = d; return *this; }
  NymTrialSpec &setPObserve(double d) { pObserve = d; return *this; }

  NymTrialSpec &setPartial(bool b) { partial = b; partialSet=true; return *this; }
  void assertFilled() const {
    assert(nRecipients > 0);
    assert(pathLen > 0);
    assert(pDelay > 0.0);
    assert(pMsgAlice > 0.0);
    assert(bgVolMean > 0.0);
    assert(bgVolDev > 0.0);
    assert(partialSet);
  }
  void write(std::ostream &o) const;
  Trial *create() const;
};

class NymTrial : public NonbatchTrial {
 private:
  void init(const NymTrialSpec &s);
 public:
  NymTrial(const NymTrialSpec &s) : NonbatchTrial(s.nRecipients,
                                                  s.granularity)
    { s.assertFilled();
      init(s); }
  TrialResult attempt();
  ~NymTrial() {}
};

#endif

