// Copyright (c) 2003 Nick Mathewson.  See LICENSE for licensing information.
// $Id: rng.h 961 2003-12-22 04:37:34Z nickm $
// rng.h -- random number code
#ifndef _RNG_H
#define _RNG_H

#include <vector>
#include <cmath>
#include "comb.h"

// Return a pseudorandom number between 0.0 and 1.0.
double rng();
// Return a pseudorandom number according to a normal distribution with
// mean 0.0 and std deviation 1.0
double normal_rng();

void seed_rng();

// Return a pseudorandom number x such that 0 <= x < m
inline int rng(int m) { return (int) (rng()*m); }

// Abstract base class: random distribution of elements of type C.
template <class C>
class Dist
{
 public:
  // return a random element from this distribution
  virtual C get() const = 0;
  // return a new copy of this distribution
  virtual Dist<C> *copy() const = 0;
  virtual ~Dist() {}
};

// Abstract base class: an 'invertible' discrete distribution of elements of
// type C.
template <class C>
class InvDist : public Dist<C>
{
 public:
  // Return the probability of v in this distribution.
  virtual double getP(const C &v) const = 0;
};

// Distrubution with only a single element of probability 1.0.
template <class C>
class ConstDist : public InvDist<C>
{
 private:
  C val;
 public:
  ConstDist(const C &vv) : val(vv) {}
  Dist<C> *copy() const { return new ConstDist<C>(val); }
  C get() const { return val; }
  double getP(const C &v) const { return v == val ? 1.0 : 0.0; }
  ~ConstDist() {}
};

// Geometric distribution.
class GeometricDist : public InvDist<int>
{
 private:
  double p;
 public:
  GeometricDist(double param) : p(param) {}
  Dist<int> *copy() const { return new GeometricDist(p); }
  int get() const { int v = 0; while (rng() > p) ++v; return v; }
  double getP(const int &val) const {
    return p * std::pow(1-p, (int)val);
  }
  ~GeometricDist() {}
};

// Binomial distribution.
class BinomialDist : public InvDist<int>
{
 private:
  double p;
  int n;
 public:
  BinomialDist(double prob, int maximum) : p(prob), n(maximum) {}
  Dist<int> *copy() const { return new BinomialDist(p, n); }
  int get() const {
    int v = 0;
    for (int i = 0; i < n; ++i) {
      if (rng() < p) ++v;
    }
    return v;
  }
  double getP(const int &v) const {
    return comb(n, v)*std::pow(p,v)*std::pow(1-p,n-v);
  }
  ~BinomialDist() {}
};

// Distribution assigning an arbitrary fixed probability to values 0..M-1, for
// some M.
class CumulativeDist : public InvDist<int>
{
 private:
  // M-element vector, containing the probability of each item in range 0..M-1
  std::vector<double> prob;
  // M-element vector containing, for each element i, the probability
  // of that element or any lesser one.
  std::vector<double> cumDist;
 public:
  CumulativeDist(const std::vector<double> &dist);
  CumulativeDist(const CumulativeDist &d) : prob(d.prob), cumDist(d.cumDist) {}
  Dist<int> *copy() const { return new CumulativeDist(*this); }
  CumulativeDist & operator=(const CumulativeDist &d)
    { prob=d.prob; cumDist = d.cumDist; return *this; }
  int get() const;
  double getP(const int &v) const
    { return (0 <= v && v <= (int)prob.size()) ? prob[v] : 0.0; }
  ~CumulativeDist() {}
};

// Optimized version of OCumulativeDist
class OCumulativeDist : public InvDist<int>
{
 private:
  // M-element vector, containing the probability of each item in range 0..M-1
  std::vector<double> prob;
  // N-element vector of items in the range 0..M-1, such that the number of
  // elements with value x is proportional to the probability of x.
  std::vector<int> lookupTable;
 public:
  OCumulativeDist(const std::vector<int> &dist);
  //OCumulativeDist(const std::vector<double> &dist, double granularity);
  OCumulativeDist(const OCumulativeDist &d) :
    prob(d.prob), lookupTable(d.lookupTable) {}
  Dist<int> *copy() const { return new OCumulativeDist(*this); }
  OCumulativeDist & operator=(const OCumulativeDist &d)
    { prob=d.prob; lookupTable = d.lookupTable; return *this; }
  int get() const { return lookupTable[(int)(rng()*lookupTable.size())]; }
  double getP(const int &v) const
    { return (0 <= v && v <= (int)prob.size()) ? prob[v] : 0.0; }
  ~OCumulativeDist() {}
};

// Distribution choosing uniformly between a number of choices.
template<class C>
class UniformChoiceDist : public InvDist<C>
{
  std::vector<C> choices;
 public:
  UniformChoiceDist(const std::vector<C> &c) : choices(c) {}
  Dist<C> *copy() const { return new UniformChoiceDist<C>(choices); }
  C get() const { return choices[rng(choices.size())]; }
  double getP(const C &v) const {
    int n = choices.size();
    for (int i = 0; i < n; ++i) { if (choices[i] == v) return 1.0/n; }
    return 0.0; }
  ~UniformChoiceDist() {}
};

// Distribution choosing between two elements
template<class C>
class BinaryDist : public InvDist<C>
{
  C c1, c2;
  double p;
 public:
  BinaryDist(double prob, const C &choice1, const C &choice2)
    : c1(choice1), c2(choice2), p(prob) {
    assert(0.0 <= p && p <= 1.0);
  }
  Dist<C> *copy() const { return new BinaryDist<C>(p,c1,c2); }
  C get() const { if (rng()<p) return c1; else return c2; }
  double getP(const C& v) const {
    if (v == c1) return p;
    else if (v == c2) return 1-p;
    else return 0;
  }
  ~BinaryDist() {}
};

// Normal distribution, rounded to the nearest integer.
class IntNormalDist : public Dist<int>
{
  double m, s;
  bool clamp;
 public:
  IntNormalDist(double mean, double stddev, bool clampToZero)
    : m(mean), s(stddev), clamp(clampToZero) {}
  Dist<int> *copy() const { return new IntNormalDist(m,s,clamp); }
  int get() const {
    int i = static_cast<int>(m+normal_rng()*s+0.5);
    return (!clamp || i>0)?i:0;
  }
};

// Return a random element of the vector v.
template<class C>
const C &rng_pick(const std::vector<C> &v) { return v[rng(v.size())]; }
template<class C>
C &rng_pick(std::vector<C> &v) { return v[rng(v.size())]; }

// Re-order the elements of a vector v at random.  If 'n' is provided,
// then shuffle 'n' random elements to the front.
template<class C> void
rng_shuffle(std::vector<C> &v, int n=-1) {
  int sz = v.size();
  if (n == -1) { n = v.size()-1; }
  for (int i = n; i; ++i) {
    int swap = i+rng(sz-i);
    if (i != swap) {
      C tmp = v[swap];
      v[swap] = v[i];
      v[i] = tmp;
    }
  }
}

#endif // _RNG_H

