/*Copyright 2014 Francisco Alvaro

This file is part of SESHAT.

  SESHAT is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  (at your option) any later version.

  SESHAT is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with SESHAT.  If not, see <http://www.gnu.org/licenses/>.


This file is a modification of the RNNLIB original software covered by
the following copyright and permission notice:

*/
/*Copyright 2009,2010 Alex Graves

This file is part of RNNLIB.

RNNLIB is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

RNNLIB is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with RNNLIB.  If not, see <http://www.gnu.org/licenses/>.*/

#ifndef _INCLUDED_Rprop_h
#define _INCLUDED_Rprop_h

#include <algorithm>
#include <string>
#include <vector>

#include "DataExporter.hpp"
#include "Helpers.hpp"
#include "Optimiser.hpp"

using namespace std;

struct Rprop: public DataExporter, public Optimiser {
  // data
  ostream& out;
  vector<real_t> deltas;
  vector<real_t> prevDerivs;
  real_t etaChange;
  real_t etaMin;
  real_t etaPlus;
  real_t minDelta;
  real_t maxDelta;
  real_t initDelta;
  real_t prevAvgDelta;
  bool online;
  WeightContainer *wc;

  // functions
  Rprop(
      const string& name, ostream& o, vector<real_t>& weights,
      vector<real_t>& derivatives, WeightContainer *weight, DataExportHandler *deh, bool on = false):
    DataExporter(name, deh), Optimiser(weights, derivatives), out(o),
      etaChange(0.01), etaMin(0.5), etaPlus(1.2), minDelta(1e-9), maxDelta(0.2),
    initDelta(0.01), prevAvgDelta(0), online(on), wc(weight) {
    if (online) {
      SAVE(prevAvgDelta);
      SAVE(etaPlus);
    }
    build();
  }

  void update_weights() {
    assert(wts.size() == derivs.size());
    assert(wts.size() == deltas.size());
    assert(wts.size() == prevDerivs.size());
    LOOP(int i, indices(wts)) {
      real_t deriv = derivs[i];
      real_t delta = deltas[i];
      real_t derivTimesPrev =  deriv * prevDerivs[i];
      if (derivTimesPrev > 0) {
        deltas[i] = bound(delta * etaPlus, minDelta, maxDelta);
        wts[i] -= sign(deriv) * delta;
        prevDerivs[i] = deriv;
      } else if (derivTimesPrev < 0) {
        deltas[i] = bound(delta * etaMin, minDelta, maxDelta);
        prevDerivs[i] = 0;
      } else {
        wts[i] -= sign(deriv) * delta;
        prevDerivs[i] = deriv;
      }
    }
    // use eta adaptations for online training (from Mike Schuster's thesis)
    if (online) {
      real_t avgDelta = mean(deltas);
      if (avgDelta > prevAvgDelta) {
        etaPlus = max((real_t)1.0, etaPlus - etaChange);
      } else {
        etaPlus += etaChange;
      }
      prevAvgDelta = avgDelta;
    }
    if (verbose) {
      PRINT(minmax(wts), out);
      PRINT(minmax(derivs), out);
      PRINT(minmax(deltas), out);
      PRINT(minmax(prevDerivs), out);
    }
  }

  // NOTE must be called after any change to weightContainer
  void build() {
    if (deltas.size() != wts.size()) {
      deltas.resize(wts.size());
      prevDerivs.resize(wts.size());
      fill(deltas, initDelta);
      fill(prevDerivs, 0);
      wc->save_by_conns(
          deltas, ((name == "optimiser") ? "" : name + "_") + "deltas");
      wc->save_by_conns(
          prevDerivs, ((name == "optimiser") ? "" : name + "_") + "prevDerivs");
    }
  }

  void print(ostream& out = cout) const {
    out << "RPROP" << endl;
    PRINT(online, out);
    if (online) {
      PRINT(prevAvgDelta, out);
      PRINT(etaChange, out);
    }
    PRINT(etaMin, out);
    PRINT(etaPlus, out);
    PRINT(minDelta, out);
    PRINT(maxDelta, out);
    PRINT(initDelta, out);
  }
};

#endif
