#include "ibex.h"

using namespace ibex;

// Technical stuff to automatically compute the negation of an Ibex function
Function operator-(Function& f) {
  Array <const ExprSymbol> args(f.nb_arg());
  varcopy (f.args(), args);
  const ExprConstant* cst = dynamic_cast<const ExprConstant*> (&f.expr());
  if (cst) {
    if (cst -> dim.is_vector ()) return Function (args, -cst->get_vector_value());
    else return Function(args, -cst->get_matrix_value());
  } else {
    const ExprVector* vec = dynamic_cast<const ExprVector*>(&f.expr());
    assert(vec); // TODO: manage matrix-valued function
    Array<const ExprNode> minus_vec(vec -> nb_args);
    for (int i = 0; i < vec -> nb_args; i++) {
      minus_vec.set_ref (i, -ExprCopy().copy(f.args(), args, vec->arg(i)));
    }

    return Function (args, ExprVector::new_(minus_vec,vec->row_vector()));
  }
}

// IVP Forward contractor
class CtcIVPFwd : public Ctc {
public:
  CtcIVPFwd(const Function& f, double t0, double t1) : Ctc(f.nb_var()), f(f), t0(t0), x0(f.nb_var()), t1(t1) {
  }

  void set_x0(const IntervalVector& x0) {
    this->x0 = x0;
  }

  void contract(IntervalVector& x1) {
    int n = f.nb_var();
    ivp_ode ode = ivp_ode(f, t0, x0);
    simulation simu = simulation(&ode, t1, RK4, 1e-9);
    simu.run_simulation();
    x1 &= simu.get_last();
  }

  Function f; // TODO: avoid copy?
  IntervalVector x0;
  double t0;
  double t1;

};

// IVP Forward-Backward contractor
// Backward contractor is an IVP Forward contractor with dynamics -f
class CtcIVP : public Ctc {
public:
  CtcIVP(Function& f, double t0, double t1) : Ctc(f.nb_var()*2), fwd(f,t0,t1), bwd(-f,t0,t1) {

  }

  void contract(IntervalVector& x0x1) {
    int n = x0x1.size() / 2;
    IntervalVector x0 = x0x1.subvector(0,n-1);
    IntervalVector x1(n);
    IntervalVector x1_save = x0x1.subvector(n,2*n-1);

    fwd.set_x0(x0);
    fwd.contract(x1);

    if (!x1_save.is_superset(x1)) {
      x1 &= x1_save;
      if (!x1.is_empty()) {
	bwd.set_x0(x1);
	bwd.contract(x0);
      }
    }

    if (x0.is_empty() || x1.is_empty()) {
      x0x1.set_empty();
      throw EmptyBoxException();
    } else {
      x0x1.put(0,x0);
      x0x1.put(n,x1);
    }
  }

  CtcIVPFwd fwd;
  CtcIVPFwd bwd;
};

int main(){
  // IVP dimension
  const int n = 2;

  // Some uncertainty
  Interval eps = 1000.0 * Interval(-1,1);

  // Initial condition at time 0.0
  IntervalVector yinit(n);
  yinit[0] = 10.0 + eps;
  yinit[1] = 0.0 + eps;

  // Final condition a time 2.0
  IntervalVector yfinal(n);
  yfinal[0] = -9.62 + eps;
  yfinal[1] = -19.62;


  // Numerical contractor
  NumConstraint dist ("x0[2]","x1[2]","x0(1)+x1(1)=0");
  CtcFwdBwd c2 (dist);

  // IVP contractor
  Variable y(n);
  Function ydot = Function(y,Return(y[1], -Interval(9.81)));
  CtcIVP c1 (ydot,0.0,2.0);
  IntervalVector box(4);
  box.put(0,yinit);
  box.put(n,yfinal);

  // Classical composition of contractors
  CtcCompo compo(c1, c2);
  CtcFixPoint fix(compo);

  // Bisection strategy
  LargestFirst lf(0.0001);
  CellStack _stack;

  // Default solver
  Solver solver(fix, lf, _stack);

  vector<IntervalVector> sols = solver.solve(box);

  if (sols.empty()) {
    cout << "no solution" << endl;
  }
  for (int i = 0; i < sols.size(); i++) {
    cout << "sols n°" << i << "= " << sols[i] << endl;
  }

  return 0;
}