
#include <iostream>
#include "vec.h"
#include "comb.h"
#include "rng.h"
#include "netparams.h"

using namespace std;

#define test(x) if (!(x)) { cerr << __LINE__ << ": !" << (x) << endl; return; }
#define test_eq(x,y) if ((x)!=(y)) { \
     cerr << __LINE__ << ": " << (x) << "!=" << (y) << endl; return; }
#define test_neq(x,y) if ((x)==(y)) { \
     cerr << __LINE__ << ": " << (x) << "==" << (y) << endl; return; }

vector<int> vec<double>::topN(int n) const;

void
test_vec() {
    vec<int> vi1(5);
    vec<int> vi2(5, 7);
    vec<double> vd(6, 2.0);
    vec<int> vi3(vd);
    vec<double> vd2(vi2);

    test_eq(vi3[0], 2);

    vi1[0] = 3;
    test_eq(3, vi1[0]);
    vi1[1] = 4;
    vi1[2] = 5;
    vi1[3] = 6;
    vi1[4] = 7;
    vi3 = vi1+vi2;
    test_eq(vi3[0], 10);
    test_eq(vi3[4], 14);
    vi1 += vi2;
    test_eq(vi1[0], 10);
    test_eq(vi1[4], 14);

    vi1 *= 2;
    test_eq(vi1[0], 20);
    test_eq(vi1[4], 28);

    vd /= 3;
    test_eq(vd[0], 2.0/3);
    test_eq(vd[5], 2.0/3);

    test_eq(vi1.minidx(), 0);
    test_eq(vi1.maxidx(), 4);


    vec<double> vdx(10, 0.0);
    vdx[7] = 3;
    vdx[3] = 10;
    vdx[9] = 60;
    vdx[1] = 1;
    std::vector<int> top3 = vdx.topN(3);
    test_eq(top3[0], 3);
    test_eq(top3[1], 7);
    test_eq(top3[2], 9);
    std::vector<int> top4 = vdx.topN(4);
    test_eq(top4[0], 8);
    test_eq(top4[1], 3);
    test_eq(top4[2], 7);
    test_eq(top4[3], 9);

    cerr << "Vectors OK" << endl;
}

void 
test_comb() {
    test_eq(1, fact(0));
    test_eq(1, fact(1));
    test_eq(2, fact(2));
    test_eq(120, fact(5));

    test_eq(1, comb(10, 0));
    test_eq(1, comb(1, 1));
    test_eq(5, comb(5, 1));
    test_eq(10, comb(5, 2));
    test_eq(10, comb(5, 3));
    test_eq(5, comb(5, 4));
    test_eq(1, comb(5, 5));

    test_eq(1, comb(6, 0));
    test_eq(6, comb(6, 1));
    test_eq(15, comb(6, 2));
    test_eq(20, comb(6, 3));
    test_eq(15, comb(6, 4));
    test_eq(6, comb(6, 5));
    test_eq(1, comb(6, 6));

    cerr << "Combinatorics OK" << endl;
}

int 
main(int c, char **v) {
 
    test_vec();
    test_comb();

    return 0;
}  


