// Copyright (C) 2010, Guy Barrand. All rights reserved.
// See the file tools.license for terms.

#ifndef tools_randT
#define tools_randT

namespace tools {

template <class FLAT,class REAL>
class rgauss {
  typedef REAL(*math_func)(REAL); //for rootcint.
//public:
//  typedef REAL value_t;
public:
  rgauss(FLAT& a_flat,REAL a_mean = 0,REAL a_std_dev = 1)
  :m_flat(a_flat),m_mean(a_mean),m_std_dev(a_std_dev){}
  virtual ~rgauss(){}
public:
  rgauss(const rgauss& a_from):m_flat(a_from.m_flat),m_mean(a_from.m_mean),m_std_dev(a_from.m_std_dev){}
  rgauss& operator=(const rgauss& a_from) {
    m_mean = a_from.m_mean;
    m_std_dev = a_from.m_std_dev;
    return *this;
  }
public:
  REAL shoot(math_func a_sqrt,math_func a_log) const {
    REAL v1,v2,r,fac;
    do {
      v1 = REAL(2) * m_flat.shoot() - REAL(1);
      v2 = REAL(2) * m_flat.shoot() - REAL(1);
      r = v1*v1 + v2*v2;
    } while (r>REAL(1));
    fac = a_sqrt(-REAL(2)*a_log(r)/r);
    return (v2 * fac) * m_std_dev + m_mean;
  }
  FLAT& flat() {return m_flat;}
  void set_seed(unsigned int a_seed) {m_flat.set_seed(a_seed);}  
  void set(REAL a_mean = 0,REAL a_std_dev = 1) {m_mean = a_mean;m_std_dev = a_std_dev;}
protected:
  FLAT& m_flat;
  REAL m_mean;
  REAL m_std_dev;
};

template <class FLAT,class REAL>
class rbw {
public:
  rbw(FLAT& a_flat,REAL a_mean = 0,REAL a_gamma = 1):m_flat(a_flat),m_mean(a_mean),m_gamma(a_gamma){}
  virtual ~rbw(){}
public:
  rbw(const rbw& a_from):m_flat(a_from.m_flat),m_mean(a_from.m_mean),m_gamma(a_from.m_gamma){}
  rbw& operator=(const rbw& a_from) {
    m_mean = a_from.m_mean;
    m_gamma = a_from.m_gamma;
    return *this;
  }
public:
  REAL shoot(const REAL& a_half_pi,REAL(*a_tan)(REAL)) const {
    REAL rval = REAL(2) * m_flat.shoot() - REAL(1);
    REAL displ = (REAL(1)/REAL(2)) * m_gamma * a_tan(rval * a_half_pi);
    return m_mean + displ;
  }
  FLAT& flat() {return m_flat;}
protected:
  FLAT& m_flat;
  REAL m_mean;
  REAL m_gamma;
};

template <class FLAT,class REAL>
class rexp {
  typedef REAL(*math_func)(REAL); //for rootcint.
public:
  rexp(FLAT& a_flat,REAL a_rate = 1):m_flat(a_flat),m_rate(a_rate){}
  virtual ~rexp(){}
public:
  rexp(const rexp& a_from):m_flat(a_from.m_flat),m_rate(a_from.m_rate){}
  rexp& operator=(const rexp& a_from) {m_rate = a_from.m_rate;return *this;}
public:
  REAL shoot(math_func a_log) const {
    REAL v;
    do {
      v = m_flat.shoot();
    } while(v<=REAL(0));
    return -a_log(v)/m_rate;
  }
  FLAT& flat() {return m_flat;}
protected:
  FLAT& m_flat;
  REAL m_rate;
};

template <class FLAT,class REAL>
class rdir2 {
public:
  rdir2(FLAT& a_flat):m_flat(a_flat){}
  virtual ~rdir2(){}
public:
  rdir2(const rdir2& a_from):m_flat(a_from.m_flat){}
  rdir2& operator=(const rdir2&) {return *this;}
public:
  void shoot(REAL& a_x,REAL& a_y) const {
    // from gsl_ran_dir_2d.
    REAL u,v,s;
    do {
      u = REAL(2) * m_flat.shoot()-REAL(1);
      v = REAL(2) * m_flat.shoot()-REAL(1);
      s = u * u + v * v;
    }
    while ( (s > REAL(1)) || (s==REAL(0)) );
    a_x = (u * u - v * v) / s;
    a_y = REAL(2) * u * v / s;
  }
  FLAT& flat() {return m_flat;}
protected:
  FLAT& m_flat;
};

template <class FLAT,class REAL>
class rdir3 {
public:
  rdir3(FLAT& a_flat):m_flat(a_flat){}
  virtual ~rdir3(){}
public:
  rdir3(const rdir3& a_from):m_flat(a_from.m_flat){}
  rdir3& operator=(const rdir3&) {return *this;}
public:
  void shoot(REAL& a_x,REAL& a_y,REAL& a_z,REAL(*a_sqrt)(REAL)) const {
    // from gsl_ran_dir_3d.
    REAL s;
    do {
      a_x = REAL(2) * m_flat.shoot() - REAL(1);
      a_y = REAL(2) * m_flat.shoot() - REAL(1);
      s = (a_x) * (a_x) + (a_y) * (a_y);
    }
    while (s > REAL(1));
    a_z = REAL(2) * s - REAL(1);
    REAL a = REAL(2) * a_sqrt(REAL(1) - s);
    a_x *= a;
    a_y *= a;
  }
  FLAT& flat() {return m_flat;}
protected:
  FLAT& m_flat;
};

// from Geant4/G4Poisson.hh :
// NOTE : not used yet.
template <class FLAT,class REAL,class UINT>
class rpoiss {
public:
  rpoiss(FLAT& a_flat,REAL a_mean = 1):m_flat(a_flat),m_mean(a_mean){}
  virtual ~rpoiss(){}
public:
  rpoiss(const rpoiss& a_from):m_flat(a_from.m_flat),m_mean(a_from.m_mean){}
  rpoiss& operator=(const rpoiss& a_from) {
    m_mean = a_from.m_mean;
    return *this;
  }
public:
  UINT shoot(const REAL& a_two_pi,REAL(*a_sqrt)(REAL),REAL(*a_log)(REAL),REAL(*a_exp)(REAL),REAL(*a_cos)(REAL)) const {
    UINT number = 0;
    if(m_mean <= 16) {
      REAL position = m_flat.shoot();
      REAL poissonValue = a_exp(-m_mean);
      REAL poissonSum = poissonValue;
      while(poissonSum <= position) {
        number++;
        poissonValue *= m_mean/REAL(number);
        poissonSum += poissonValue;
      }
      return number;
    }
    // m_mean > 16 :
    REAL t = a_sqrt(-REAL(2)*a_log(m_flat.shoot()));
    REAL y = a_two_pi*m_flat.shoot();
    t *= a_cos(y);
    REAL value = m_mean + t*a_sqrt(m_mean) + REAL(0.5);
    if(value <= REAL(0)) return UINT(0);
    static const REAL limit = REAL(2e9);
    if(value >= limit) return UINT(limit);
    return UINT(value);
  }
  FLAT& flat() {return m_flat;}
  void set_seed(unsigned int a_seed) {m_flat.set_seed(a_seed);}  
protected:
  FLAT& m_flat;
  REAL m_mean;
};

}

#endif
