// Modular integer operations.

#ifndef _CL_MODINTEGER_H
#define _CL_MODINTEGER_H

#include "cl_object.h"
#include "cl_integer.h"
#include "cl_random.h"
#include "cl_malloc.h"
#include "cl_io.h"
#undef random // Linux defines random() as a macro!


// Representation of an element of a ring Z/mZ.

// To protect against mixing elements of different modular rings, such as
// (3 mod 4) + (2 mod 5), every modular integer carries its ring in itself.


// Representation of a ring Z/mZ.

class cl_heap_modint_ring;

class cl_modint_ring : public cl_rcpointer {
public:
	// Default constructor.
	cl_modint_ring ();
	// Constructor. Takes a cl_heap_modint_ring*, increments its refcount.
	cl_modint_ring (cl_heap_modint_ring* r);
	// Copy constructor.
	cl_modint_ring (const cl_modint_ring&);
	// Assignment operator.
	cl_modint_ring& operator= (const cl_modint_ring&);
	// Automatic dereferencing.
	cl_heap_modint_ring* operator-> () const
	{ return (cl_heap_modint_ring*)heappointer; }
};

// Z/0Z
extern cl_modint_ring cl_modint0_ring;
// Default constructor. This avoids dealing with NULL pointers.
inline cl_modint_ring::cl_modint_ring ()
{ pointer = as_cl_private_thing(cl_modint0_ring); }
// Copy constructor and assignment operator.
CL_DEFINE_COPY_CONSTRUCTOR(cl_modint_ring)
CL_DEFINE_ASSIGNMENT_OPERATOR(cl_modint_ring,cl_modint_ring)

// Operations on modular integer rings.

inline bool operator== (const cl_modint_ring& R1, const cl_modint_ring& R2)
{ return (R1.pointer == R2.pointer); }
inline bool operator!= (const cl_modint_ring& R1, const cl_modint_ring& R2)
{ return (R1.pointer != R2.pointer); }
inline bool operator== (const cl_modint_ring& R1, cl_heap_modint_ring* R2)
{ return (R1.pointer == R2); }
inline bool operator!= (const cl_modint_ring& R1, cl_heap_modint_ring* R2)
{ return (R1.pointer != R2); }


// Representation of an element of a ring Z/mZ.

class cl_MI {
public:
	cl_modint_ring ring;	// ring Z/mZ
	cl_I rep;		// representative, integer >=0, <m
				// (maybe the Montgomery representative!)
	// Default constructor.
	cl_MI (); // : ring (), rep () {}
public: /* ugh */
	// Constructor.
	cl_MI (const cl_modint_ring& R, const cl_I& r) : ring (R), rep (r) {}
public:	// Ability to place an object at a given address.
	void* operator new (size_t size) { return cl_malloc_hook(size); }
	void* operator new (size_t size, cl_MI* ptr) { (void)size; return ptr; }
};


// Vectors of function pointers are more efficient than virtual member
// functions. But it constrains us not to use multiple or virtual inheritance.
//
// Note! We are passing raw `cl_heap_modint_ring*' pointers here for efficiency
// (compared to passing `const cl_modint_ring&', we save a memory access, and
// it is easier to cast to a `cl_heap_modint_ring_specialized*').
// These raw pointers are meant to be used downward (in the dynamic extent
// of the call) only. If you need to save them in a data structure, cast
// to `cl_modint_ring'; this will correctly increment the reference count.
// (This technique is safe because the inline wrapper functions make sure
// that we have a `cl_modint_ring' somewhere containing the pointer, so there
// is no danger of dangling pointers.)
//
struct _cl_modint_setops {
	// equality
	cl_boolean (* equal) (cl_heap_modint_ring* R, const cl_MI& x, const cl_MI& y);
	// random number
	cl_MI (* random) (cl_heap_modint_ring* R, cl_random_state& randomstate);
};
struct _cl_modint_addops {
	// 0
	cl_MI (* zero) (cl_heap_modint_ring* R);
	// x+y
	cl_MI (* plus) (cl_heap_modint_ring* R, const cl_MI& x, const cl_MI& y);
	// x-y
	cl_MI (* minus) (cl_heap_modint_ring* R, const cl_MI& x, const cl_MI& y);
	// -x
	cl_MI (* uminus) (cl_heap_modint_ring* R, const cl_MI& x);
};
struct _cl_modint_mulops {
	// 1
	cl_MI (* one) (cl_heap_modint_ring* R);
	// x*y
	cl_MI (* mul) (cl_heap_modint_ring* R, const cl_MI& x, const cl_MI& y);
	// x^2
	cl_MI (* square) (cl_heap_modint_ring* R, const cl_MI& x);
	// x^-1
	cl_MI (* recip) (cl_heap_modint_ring* R, const cl_MI& x);
	// x*y^-1
	cl_MI (* div) (cl_heap_modint_ring* R, const cl_MI& x, const cl_MI& y);
	// x^y, y Integer >0
	cl_MI (* expt_pos) (cl_heap_modint_ring* R, const cl_MI& x, const cl_I& y);
	// x^y, y Integer
	cl_MI (* expt) (cl_heap_modint_ring* R, const cl_MI& x, const cl_I& y);
	// x -> x mod m for x>=0
	cl_I (* reduce_modulo) (cl_heap_modint_ring* R, const cl_I& x);
	// canonical homomorphism
	cl_MI (* canonhom) (cl_heap_modint_ring* R, const cl_I& x);
	// some inverse of canonical homomorphism
	cl_I (* retract) (cl_heap_modint_ring* R, const cl_MI& x);
};
#ifdef __GNUC__ // workaround two g++-2.7.0 bugs
  #define cl_modint_setops  _cl_modint_setops
  #define cl_modint_addops  _cl_modint_addops
  #define cl_modint_mulops  _cl_modint_mulops
#else
  typedef const _cl_modint_setops  cl_modint_setops;
  typedef const _cl_modint_addops  cl_modint_addops;
  typedef const _cl_modint_mulops  cl_modint_mulops;
#endif

// Representation of the ring Z/mZ.

// Currently rings are not garbage collected. (I am afraid of the performance
// penalty of always incrementing and decrementing a ring's reference count.
// A drawback is that one you have used a ring, it will stay in memory forever.)

// Modular integer rings are kept unique in memory. This way, ring equality
// can be check very efficiently by a simple pointer comparison.

class cl_heap_modint_ring : public cl_heap {
	// Allocation.
	void* operator new (size_t size) { return cl_malloc_hook(size); }
	// Deallocation.
	void operator delete (void* ptr) { cl_free_hook(ptr); }
public:
	cl_I modulus;	// m, normalized to be >= 0
protected:
	cl_modint_setops* setops;
	cl_modint_addops* addops;
	cl_modint_mulops* mulops;
public:
	// Set operations.
	cl_boolean equal (const cl_MI& x, const cl_MI& y)
		{ return setops->equal(this,x,y); }
	cl_MI random (cl_random_state& randomstate = cl_default_random_state)
		{ return setops->random(this,randomstate); }
	// Ring operations.
	cl_MI zero ()
		{ return addops->zero(this); }
	cl_MI plus (const cl_MI& x, const cl_MI& y)
		{ return addops->plus(this,x,y); }
	cl_MI minus (const cl_MI& x, const cl_MI& y)
		{ return addops->minus(this,x,y); }
	cl_MI uminus (const cl_MI& x)
		{ return addops->uminus(this,x); }
	cl_MI one ()
		{ return mulops->one(this); }
	cl_MI mul (const cl_MI& x, const cl_MI& y)
		{ return mulops->mul(this,x,y); }
	cl_MI square (const cl_MI& x)
		{ return mulops->square(this,x); }
	cl_MI recip (const cl_MI& x)
		{ return mulops->recip(this,x); }
	cl_MI div (const cl_MI& x, const cl_MI& y)
		{ return mulops->div(this,x,y); }
	cl_MI expt_pos (const cl_MI& x, const cl_I& y)
		{ return mulops->expt_pos(this,x,y); }
	cl_MI expt (const cl_MI& x, const cl_I& y)
		{ return mulops->expt(this,x,y); }
	cl_I reduce_modulo (const cl_I& x)
		{ return mulops->reduce_modulo(this,x); }
	cl_MI canonhom (const cl_I& x)
		{ return mulops->canonhom(this,x); }
	cl_I retract (const cl_MI& x)
		{ return mulops->retract(this,x); }
	// Miscellaneous.
	uintL bits; // number of bits needed to represent a representative, or 0
	int log2_bits; // log_2(bits), or -1
	// Function which is called when a nontrivial divisor of m is found.
	void (* notify_composite) (const cl_modint_ring& R, const cl_I& nonunit);
// Constructor.
	cl_heap_modint_ring (cl_I m, cl_modint_setops*, cl_modint_addops*, cl_modint_mulops*);
// This class is intented to be subclassable, hence needs a virtual destructor.
	virtual ~cl_heap_modint_ring () {}
private:
	virtual void dummy ();
};

// Default constructor for `cl_modint_ring'.
inline cl_modint_ring::cl_modint_ring (cl_heap_modint_ring* r)
{ cl_inc_pointer_refcount(r); pointer = r; }
// Default constructor for 'cl_MI'.
inline cl_MI::cl_MI () : ring (), rep () {}

// Lookup or create a modular integer ring  Z/mZ
extern cl_modint_ring cl_find_modint_ring (const cl_I& m);


// Operations on modular integers.

// Add.
inline cl_MI operator+ (const cl_MI& x, const cl_MI& y)
	{ return x.ring->plus(x,y); }
inline cl_MI operator+ (const cl_MI& x, const cl_I& y)
	{ return x.ring->plus(x,x.ring->canonhom(y)); }
inline cl_MI operator+ (const cl_I& x, const cl_MI& y)
	{ return y.ring->plus(y.ring->canonhom(x),y); }

// Negate.
inline cl_MI operator- (const cl_MI& x)
	{ return x.ring->uminus(x); }

// Subtract.
inline cl_MI operator- (const cl_MI& x, const cl_MI& y)
	{ return x.ring->minus(x,y); }
inline cl_MI operator- (const cl_MI& x, const cl_I& y)
	{ return x.ring->minus(x,x.ring->canonhom(y)); }
inline cl_MI operator- (const cl_I& x, const cl_MI& y)
	{ return y.ring->minus(y.ring->canonhom(x),y); }

// Shifts.
extern cl_MI operator<< (const cl_MI& x, sintL y); // assume 0 <= y < 2^31
extern cl_MI operator>> (const cl_MI& x, sintL y); // assume m odd, 0 <= y < 2^31

// Equality.
inline bool operator== (const cl_MI& x, const cl_MI& y)
	{ return x.ring->equal(x,y); }
inline bool operator!= (const cl_MI& x, const cl_MI& y)
	{ return !x.ring->equal(x,y); }
inline bool operator== (const cl_MI& x, const cl_I& y)
	{ return x.ring->equal(x,x.ring->canonhom(y)); }
inline bool operator!= (const cl_MI& x, const cl_I& y)
	{ return !x.ring->equal(x,x.ring->canonhom(y)); }
inline bool operator== (const cl_I& x, const cl_MI& y)
	{ return y.ring->equal(y.ring->canonhom(x),y); }
inline bool operator!= (const cl_I& x, const cl_MI& y)
	{ return !y.ring->equal(y.ring->canonhom(x),y); }

// Compare against 0.
inline cl_boolean zerop (const cl_MI& x)
	{ return zerop(x.rep); }

// Multiply.
inline cl_MI operator* (const cl_MI& x, const cl_MI& y)
	{ return x.ring->mul(x,y); }
inline cl_MI operator* (const cl_MI& x, const cl_I& y)
	{ return x.ring->mul(x,x.ring->canonhom(y)); }
inline cl_MI operator* (const cl_I& x, const cl_MI& y)
	{ return y.ring->mul(y.ring->canonhom(x),y); }

// Squaring.
inline cl_MI square (const cl_MI& x)
	{ return x.ring->square(x); }

// Reciprocal.
inline cl_MI recip (const cl_MI& x)
	{ return x.ring->recip(x); }

// Division.
inline cl_MI div (const cl_MI& x, const cl_MI& y)
	{ return x.ring->div(x,y); }
inline cl_MI div (const cl_MI& x, const cl_I& y)
	{ return x.ring->div(x,x.ring->canonhom(y)); }
inline cl_MI div (const cl_I& x, const cl_MI& y)
	{ return y.ring->div(y.ring->canonhom(x),y); }

// Exponentiation x^y, where y > 0.
inline cl_MI expt_pos (const cl_MI& x, const cl_I& y)
	{ return x.ring->expt_pos(x,y); }
// Exponentiation x^y.
inline cl_MI expt (const cl_MI& x, const cl_I& y)
	{ return x.ring->expt(x,y); }

// Output.
extern void fprint (cl_ostream stream, const cl_MI &x);
inline cl_ostream operator<< (cl_ostream stream, const cl_MI &x)
{
	fprint(stream,x);
	return stream;
}

// TODO: implement gcd, index (= gcd), unitp, sqrtp


#endif /* _CL_MODINTEGER_H */
