//
//  regularization_path.hpp
//  pense
//
//  Created by David Kepplinger on 2019-01-30.
//  Copyright © 2016 David Kepplinger. All rights reserved.
//

#ifndef REGULARIZATION_PATH_HPP_
#define REGULARIZATION_PATH_HPP_

#include <memory>
#include <tuple>
#include <type_traits>
#include <cstdint>

#include "nsoptim.hpp"

#include "alias.hpp"
#include "m_loss.hpp"
#include "omp_utils.hpp"

namespace pense {
namespace regpath {

//! Test two coefficients for approximate equivalence.
//!
//! The function assumes that the coefficient vectors are of the same dimension.
//!
//! @param a coefficients to be compared
//! @param b coefficients to be compared
//! @param eps numerical tolerance for comparison
//! @return true if the two coefficient vectors are approximately equivalent, false otherwise.
template<class Coefficients>
bool CoefficientsEquivalent(const Coefficients& a, const Coefficients& b,
                            const double eps) noexcept {
  const double int_diff = a.intercept - b.intercept;
  if (int_diff * int_diff < a.beta.n_elem * eps) {
    // The intercept is similar. Check if the slope is also similar.
    const double beta_diff = arma::norm(a.beta - b.beta, 2);
    if (int_diff * int_diff + beta_diff * beta_diff < eps) {
      // The slope is also similar.
      return true;
    }
  }
  return false;
}

enum class InsertResult { kGood, kBad, kDuplicate };

enum class TupleComparison : std::int8_t {
  kLowerObjf = -2,
    kSlightlyLowerObjf = -1,
    kEqualObjf = 0,
    kSlightlyHigherObjf = 1,
    kHigherObjf = 2
};

//! A list of starting points with associated optimizer.
template<class Ordering, typename... Ts>
class OrderedTuples {
public:
  using Element = std::tuple<Ts...>;

  enum class InsertResult { kGood, kBad, kDuplicate };

  //! Create a new list of unique elements.
  explicit OrderedTuples() noexcept
    : max_size_(0), order_(), size_(0), elements_() {}

  //! Create a new list of unique items of unlimited size using the given ordering.
  //!
  //! @param order ordering of the elements.
  explicit OrderedTuples(const Ordering& order) noexcept
    : max_size_(0), order_(order), size_(0), elements_() {}

  //! Create a new list of unique items of limited size.
  //!
  //! @param max_size maximum number of coefficients retained.
  explicit OrderedTuples(const size_t max_size) noexcept
    : max_size_(max_size), order_(), size_(0), elements_() {}

  //! Create a new list of unique items of limited size.
  //!
  //! @param max_size maximum number of coefficients retained.
  //! @param order instance of the ordering class.
  OrderedTuples(const size_t max_size, const Ordering& order) noexcept
    : max_size_(max_size), order_(order), size_(0), elements_() {}

  //! Default copy constructor. Copy assignment is not possible.
  OrderedTuples(const OrderedTuples&) = default;
  OrderedTuples& operator=(const OrderedTuples&) = delete;

  //! Move constructor.
  OrderedTuples(OrderedTuples&& other) noexcept :
    max_size_(other.max_size_), order_(std::move(other.order_)),
    size_(other.size_), elements_(std::move(other.elements_)) {
    other.size_ = 0;
  }
  OrderedTuples& operator=(OrderedTuples&&) = delete;

  //! Get the number of elements in the list.
  size_t Size() const noexcept {
    return size_;
  }

  //! Clear all elements from the list.
  void Clear() noexcept {
    elements_.clear();
    size_ = 0;
  }

  //! Insert coefficients into the container.
  // template<typename... Args>
  InsertResult Emplace(Ts&&... args) {
    // Check if the optimum's objective function value is good enough.
    if (max_size_ > 0 && size_ >= max_size_) {
      auto first_compared_to_new = order_.CompareObjf(elements_.front(), std::forward<Ts>(args)...);
      if (first_compared_to_new < TupleComparison::kEqualObjf) {
        return InsertResult::kBad;
      }
    }

    // Determine insert position.
    const auto elements_end = elements_.end();
    auto current_it = elements_.begin();
    auto insert_it = elements_.before_begin();
    // auto before_insert_it = elements_.before_begin();

    while (current_it != elements_end) {
      auto objf_comparison = order_.CompareObjf(*current_it, std::forward<Ts>(args)...);
      // const bool better = order_.before(*current_it, std::forward<Ts>(args)...);

      if (objf_comparison < TupleComparison::kSlightlyLowerObjf) {
        // current element has much better objective function value. Insert here.
        break;
      } else if (objf_comparison <= TupleComparison::kSlightlyHigherObjf) {
        // current element has similar objective function as the new element
        const bool equivalent = order_.Equivalent(*current_it, std::forward<Ts>(args)...);
        if (equivalent) {
          if (objf_comparison > TupleComparison::kEqualObjf) {
            // current element is equivalent, but the the new element is slightly better.
            // Replace.
            // *current_it = Element(std::forward<Ts>(args)...);
            auto updated_it = elements_.emplace_after(insert_it, std::forward<Ts>(args)...);
            elements_.erase_after(updated_it);
            return InsertResult::kGood;
          } else {
            return InsertResult::kDuplicate;
          }
        } else if (objf_comparison < TupleComparison::kEqualObjf) {
          // current element is not equivalent and has slightly better objective function.
          // Insert here.
          break;
        }
        // else current element is not equivalent and slightly worse than the new one.
      }
      // else current element is worse. Continue search.

      ++current_it;
      ++insert_it;
    }

    elements_.emplace_after(insert_it, std::forward<Ts>(args)...);
    // Ensure that the size stays within the limits.
    if (++size_ > max_size_ && max_size_ > 0) {
      elements_.erase_after(elements_.before_begin());
      --size_;
    }
    return InsertResult::kGood;
  }

  const alias::FwdList<Element>& Elements() const noexcept {
    return elements_;
  }

  alias::FwdList<Element>& Elements() noexcept {
    return elements_;
  }

private:
  const size_t max_size_;
  Ordering order_;
  size_t size_;
  alias::FwdList<Element> elements_;
};

template<class Coefficients>
class DuplicateCoefficients {
public:
  explicit DuplicateCoefficients(const double eps) noexcept : eps_(eps) {}

  //! Always returns `kLowerObjf` since coefficients do not have an associated objective function value
  template<typename Element, typename... Args>
  TupleComparison CompareObjf(const Element& el, const Coefficients& coefs, Args&&...) const noexcept {
    return TupleComparison::kSlightlyLowerObjf;
  }

  //! Is the existing element equivalent to new element?
  template<typename Element, typename... Args>
  bool Equivalent(const Element& el, const Coefficients& coefs, Args&&...) const noexcept {
    return CoefficientsEquivalent(std::get<0>(el), coefs, eps_);
  }

private:
  const double eps_;
};

template<class Optimizer>
class OptimaOrder {
  using Coefficients = typename Optimizer::Coefficients;
  using Optimum = typename Optimizer::Optimum;
public:
  explicit OptimaOrder(const double eps) noexcept : eps_(eps) {}

  //! Compare the value of the objective function of `el` against `opt.objf_value` in the form
  //! `el.objf_value < opt.objf_value`.
  //! Returns `kLowerObjf` if `el` has a lower value of the objective function than
  //! `opt.objf_value` and `kHigherObjf` if `el` has a higher value of the objective function
  //! than `opt.objf_value`.
  template<typename Element, typename... Args>
  TupleComparison CompareObjf(const Element& el, const Optimum& opt, Args&&...) const noexcept {
    const double objf_value = std::get<0>(el).objf_value;
    if (objf_value < opt.objf_value * (1 - eps_)) {
      return TupleComparison::kLowerObjf;
    } else if (objf_value < opt.objf_value) {
      return TupleComparison::kSlightlyLowerObjf;
    } else if (objf_value > opt.objf_value * (1 + eps_)) {
      return TupleComparison::kHigherObjf;
    } else if (objf_value > opt.objf_value) {
      return TupleComparison::kSlightlyHigherObjf;
    }
    return TupleComparison::kEqualObjf;
  }

  //! Compare the value of the objective function of `el` against `objf_value` in the form
  //! `el.objf_value < objf_value`.
  //! Returns `kLowerObjf` if `el` has a lower value of the objective function than `objf_value` and
  //! `kHigherObjf` if `el` has a higher value of the objective function than `objf_value`.
  template<typename Element, typename... Args>
  TupleComparison CompareObjf(const Element& el,
                              const Coefficients& coefs,
                              const double objf_value,
                              Args&&...) const noexcept {
    const double el_objf_value = std::get<1>(el);
    if (el_objf_value < objf_value * (1 - eps_)) {
      return TupleComparison::kLowerObjf;
    } else if (el_objf_value < objf_value) {
      return TupleComparison::kSlightlyLowerObjf;
    } else if (el_objf_value > objf_value * (1 + eps_)) {
      return TupleComparison::kHigherObjf;
    } else if (el_objf_value > objf_value) {
      return TupleComparison::kSlightlyHigherObjf;
    }
    return TupleComparison::kEqualObjf;
  }

  //! Is the existing element equivalent to new element?
  template<typename Element, typename... Args>
  bool Equivalent(const Element& el, const Coefficients& coefs, Args&&...) const noexcept {
    return CoefficientsEquivalent(std::get<0>(el), coefs, eps_);
  }

  //! Is the existing element equivalent to new element?
  template<typename Element, typename... Args>
  bool Equivalent(const Element& el, const Optimum& opt, Args&&...) const noexcept {
    return CoefficientsEquivalent(std::get<0>(el).coefs, opt.coefs, eps_);
  }

private:
  const double eps_;
};

template<class Coefficients, typename... Ts>
using UniqueCoefficients = OrderedTuples<DuplicateCoefficients<Coefficients>, Coefficients, Ts...>;

template<class Optimizer, typename... Ts>
using UniqueStartPoints = OrderedTuples<OptimaOrder<Optimizer>, typename Optimizer::Coefficients,
                                        double, Optimizer, Ts...>;

template<class Optimizer, typename... Ts>
using UniqueOptima = OrderedTuples<OptimaOrder<Optimizer>, typename Optimizer::Optimum, Optimizer,
                                   Ts...>;

} // namespace regpath

template<class Optimizer>
class RegularizationPath {
  using LossFunction = typename Optimizer::LossFunction;
  using PenaltyFunction = typename Optimizer::PenaltyFunction;
  using Coefficients = typename Optimizer::Coefficients;
  using PenaltyList = alias::FwdList<PenaltyFunction>;
  using IsIterativeAlgorithmTag = typename nsoptim::traits::is_iterative_algorithm<Optimizer>::type;
  using Optimum = typename Optimizer::Optimum;
  using IndividualStartingPoints = alias::FwdList<alias::FwdList<Coefficients>>;
  using UniqueCoefficients = regpath::UniqueCoefficients<Coefficients>;
  using UniqueCoefficientsOrder = regpath::DuplicateCoefficients<Coefficients>;
  using MetricsPtr = std::unique_ptr<nsoptim::Metrics>;
  using ExploredSolutions = regpath::UniqueStartPoints<Optimizer, MetricsPtr>;
  using ExploredSolutionsOrder = regpath::OptimaOrder<Optimizer>;
  using BestOptima = regpath::UniqueOptima<Optimizer>;
  using BestOptimaOrder = regpath::OptimaOrder<Optimizer>;

public:
  struct Solutions {
    const PenaltyFunction& penalty;
    alias::Optima<Optimizer> optima;
  };

  //! Create a regularization paths using the given optimizer, loss function, and list of penalties.
  //!
  //! @param optimizer the optimizer to use. The loss function has to be set already.
  //! @param penalties a list of penalty functions.
  //! @param max_optima the maximum number of optima per penalty level.
  //! @param comparison_tol numeric tolerance for comparing two optima.
  //! @param num_threads number of threads to use.
  RegularizationPath(const Optimizer& optimizer,
                     const PenaltyList& penalties, const int max_optima,
                     const double comparison_tol, const int num_threads) :
    optimizer_template_(optimizer), penalties_(penalties),
    max_optima_(max_optima), comparison_tol_(comparison_tol), num_threads_(num_threads),
    shared_starts_(UniqueCoefficientsOrder(comparison_tol_)),
    best_starts_(max_optima, BestOptimaOrder(comparison_tol)),
    penalties_it_(penalties_.begin()) {
    auto penalties_it = penalties_.before_begin();
    const auto penalties_end = penalties_.end();

    while (++penalties_it != penalties_end) {
      individual_starts_.emplace_front(
        UniqueCoefficients(UniqueCoefficientsOrder(comparison_tol_)));
    }
    individual_starts_it_ = individual_starts_.before_begin();
  }

  //! Set the exploration options.
  //!
  //! @param explore_it the number of iterations for exploration. If <= 0, no exploration will
  //!   be done and all starting points will be iterated to full convergence.
  //! @param explore_tol the numeric tolerance for exploring solutions.
  //! @param explored_keep how many explored solutions to keep for full concentration.
  void ExplorationOptions(const int explore_it, const double explore_tol,
                          const int explored_keep) noexcept {
    explore_it_ = explore_it;
    explore_tol_ = explore_tol;
    explored_keep_ = explored_keep;
  }

  //! Enable/disable carrying forward solutions from the previous penalty.
  //!
  //! @param enabled whether to enable warm starts or not.
  void EnableWarmStarts(const bool enabled) noexcept {
    use_warm_start_ = enabled;
  }

  //! Add a starting point to be used only at the specified penalty.
  //!
  //! @param penalty penalty at which the starting point should be used.
  //! @param coefs starting point.
  void EmplaceIndividualStartingPoints(IndividualStartingPoints&& coefs_lists) {
    auto emplace_it = individual_starts_.begin();
    for (auto&& coefs_list : coefs_lists) {
      for (auto&& coefs : coefs_list) {
        emplace_it->Emplace(std::move(coefs));
      }
      emplace_it++;
    }
  }

  //! Add a starting point to be used for all penalties.
  //!
  //! @param coefs starting point.
  void EmplaceSharedStartingPoint(Coefficients&& coefs) {
    shared_starts_.Emplace(std::move(coefs));
  }

  Solutions Next() {
    ++individual_starts_it_;
    const auto& current_penalty = *penalties_it_++;
    optimizer_template_.penalty(current_penalty);

    auto explored_solutions = explore_it_ > 0 ? Explore(IsIterativeAlgorithmTag{}) : SkipExploration();
    return Solutions { current_penalty, Concentrate(std::move(explored_solutions)) };
  }

  bool End() const noexcept {
    return penalties_it_ == penalties_.end();
  }

private:
  Optimizer optimizer_template_;
  const PenaltyList& penalties_;
  const int max_optima_;
  const double comparison_tol_;
  int num_threads_;  //< OpenMP requires it to be an lvalue!
  bool use_warm_start_ = true;
  int explore_it_ = 1;
  double explore_tol_ = 1e-3;
  int explored_keep_ = 0;

  alias::FwdList<UniqueCoefficients> individual_starts_;
  UniqueCoefficients shared_starts_;
  BestOptima best_starts_;

  typename alias::FwdList<UniqueCoefficients>::iterator individual_starts_it_;
  typename PenaltyList::const_iterator penalties_it_;

  ExploredSolutions Explore(std::false_type) {
    return SkipExploration();
  }

  ExploredSolutions Explore(std::true_type) {
    if (omp::Enabled(num_threads_)) {
      return MTExplore(std::true_type{});
    } else {
      return MTExplore(std::false_type{});
    }
  }

  ExploredSolutions MTExplore(std::true_type) {
    const double orig_tol = optimizer_template_.convergence_tolerance();
    ExploredSolutions explored_solutions(explored_keep_, ExploredSolutionsOrder(explore_tol_));

    const auto is_end = individual_starts_it_->Elements().end();
    const auto sh_end = shared_starts_.Elements().end();

#pragma omp parallel          \
    num_threads(num_threads_) \
      default(shared)
      {
#pragma omp single nowait
        for (auto is_it = individual_starts_it_->Elements().begin(); is_it != is_end; ++is_it) {
#pragma omp task                            \
          default(none)                     \
          firstprivate(is_it)               \
          shared(explore_tol_, explore_it_) \
          shared(explored_solutions, optimizer_template_) const_local_shared(orig_tol)
          {
            Optimizer optimizer(optimizer_template_);
            optimizer.convergence_tolerance(explore_tol_);
            auto optimum = optimizer.Optimize(std::get<0>(*is_it), explore_it_);
            optimizer.convergence_tolerance(orig_tol);

#pragma omp critical(insert_explored)
            explored_solutions.Emplace(std::move(optimum.coefs), std::move(optimum.objf_value),
                                       std::move(optimizer), std::move(optimum.metrics));

          }
        }

#pragma omp single nowait
        for (auto sh_it = shared_starts_.Elements().begin(); sh_it != sh_end; ++sh_it) {
#pragma omp task                            \
          firstprivate(sh_it)               \
          default(none)                     \
          shared(explore_tol_, explore_it_) \
          shared(explored_solutions, optimizer_template_) const_local_shared(orig_tol)
          {
            Optimizer optimizer(optimizer_template_);
            optimizer.convergence_tolerance(explore_tol_);
            auto optimum = optimizer.Optimize(std::get<0>(*sh_it), explore_it_);
            optimizer.convergence_tolerance(orig_tol);

#pragma omp critical(insert_explored)
            explored_solutions.Emplace(std::move(optimum.coefs), std::move(optimum.objf_value),
                                       std::move(optimizer), std::move(optimum.metrics));

          }
        }

#pragma omp single nowait
        if (use_warm_start_ || explored_solutions.Size() == 0) {
          const auto bs_end = best_starts_.Elements().end();

          for (auto bs_it = best_starts_.Elements().begin(); bs_it != bs_end; ++bs_it) {
#pragma omp task                                                  \
            firstprivate(bs_it)                                   \
            default(none)                                         \
            shared(explore_tol_, explore_it_, explored_solutions) \
            shared(optimizer_template_) const_local_shared(orig_tol, bs_end)
            {
              auto&& optimizer = std::get<1>(*bs_it);
              optimizer.convergence_tolerance(explore_tol_);
              optimizer.penalty(optimizer_template_.penalty());
              auto optimum = optimizer.Optimize(explore_it_);
              optimizer.convergence_tolerance(orig_tol);

#pragma omp critical(insert_explored)
              explored_solutions.Emplace(std::move(optimum.coefs), std::move(optimum.objf_value),
                                         std::move(optimizer), std::move(optimum.metrics));

            }
          }
        }
      }

    Rcpp::checkUserInterrupt();
    return explored_solutions;
  }

  ExploredSolutions MTExplore(std::false_type) {
    const double orig_tol = optimizer_template_.convergence_tolerance();
    ExploredSolutions explored_solutions(explored_keep_, ExploredSolutionsOrder(explore_tol_));

    for (auto& start : individual_starts_it_->Elements()) {
      Optimizer optimizer(optimizer_template_);
      optimizer.convergence_tolerance(explore_tol_);
      auto optimum = optimizer.Optimize(std::get<0>(start), explore_it_);
      optimizer.convergence_tolerance(orig_tol);
      explored_solutions.Emplace(std::move(optimum.coefs), std::move(optimum.objf_value),
                                 std::move(optimizer), std::move(optimum.metrics));

      Rcpp::checkUserInterrupt();
    }

    for (auto& start : shared_starts_.Elements()) {
      Optimizer optimizer(optimizer_template_);
      optimizer.convergence_tolerance(explore_tol_);
      auto optimum = optimizer.Optimize(std::get<0>(start), explore_it_);
      optimizer.convergence_tolerance(orig_tol);
      explored_solutions.Emplace(std::move(optimum.coefs), std::move(optimum.objf_value),
                                 std::move(optimizer), std::move(optimum.metrics));

      Rcpp::checkUserInterrupt();
    }

    if (use_warm_start_ || explored_solutions.Size() == 0) {
      for (auto& start : best_starts_.Elements()) {
        auto&& optimizer = std::get<1>(start);
        optimizer.convergence_tolerance(explore_tol_);
        optimizer.penalty(optimizer_template_.penalty());
        auto optimum = optimizer.Optimize(explore_it_);
        optimizer.convergence_tolerance(orig_tol);
        explored_solutions.Emplace(std::move(optimum.coefs), std::move(optimum.objf_value),
                                   std::move(optimizer), std::move(optimum.metrics));

        Rcpp::checkUserInterrupt();
      }
    }
    return explored_solutions;
  }

  //! Simply add all starts for the current penalty
  ExploredSolutions SkipExploration() {
    ExploredSolutions explored_solutions(0, ExploredSolutionsOrder(comparison_tol_));

    for (auto& start : individual_starts_it_->Elements()) {
      explored_solutions.Emplace(std::move(std::get<0>(start)), -1, Optimizer(optimizer_template_),
                                 MetricsPtr());
    }

    for (auto&& start : shared_starts_.Elements()) {
      // Shared starts must be copied, not moved!
      auto coefs = std::get<0>(start);
      explored_solutions.Emplace(std::move(coefs), -1, Optimizer(optimizer_template_),
                                 MetricsPtr());
    }

    if (use_warm_start_ || explored_solutions.Size() == 0) {
      for (auto& start : best_starts_.Elements()) {
        auto&& optimizer = std::get<1>(start);
        optimizer.penalty(optimizer_template_.penalty());
        explored_solutions.Emplace(std::move(std::get<0>(start).coefs), -1, std::move(optimizer),
                                   MetricsPtr());
      }
    }
    return explored_solutions;
  }

  alias::Optima<Optimizer> Concentrate(ExploredSolutions&& explored) {
    best_starts_.Clear();

    if (omp::Enabled(num_threads_)) {
      Concentrate(std::move(explored), std::true_type{});
    } else {
      Concentrate(std::move(explored), std::false_type{});
    }

    alias::Optima<Optimizer> optima;
    for (auto&& element : best_starts_.Elements()) {
      optima.emplace_front(std::get<0>(element));
    }
    return optima;
  }

  void Concentrate(ExploredSolutions&& explored, std::false_type) {
    for (auto&& start : explored.Elements()) {
      auto&& optimizer = std::get<2>(start);
      auto optim = (std::get<1>(start) > 0) ?
      optimizer.Optimize() :
        optimizer.Optimize(std::get<0>(start));

      if (optim.metrics && std::get<3>(start)) {
        auto&& exploration_metrics = optim.metrics->CreateSubMetrics("exploration");
        exploration_metrics.AddSubMetrics(std::move(*std::get<3>(start)));
        std::get<3>(start).reset();
      }
      best_starts_.Emplace(std::move(optim), std::move(optimizer));

      Rcpp::checkUserInterrupt();
    }
  }

  void Concentrate(ExploredSolutions&& explored, std::true_type) {
    const auto ex_end = explored.Elements().end();

#pragma omp parallel          \
    num_threads(num_threads_) \
      default(shared)
      {
#pragma omp single nowait
        for (auto ex_it = explored.Elements().begin(); ex_it != ex_end; ++ex_it) {
#pragma omp task              \
          default(none)       \
          firstprivate(ex_it) \
          shared(best_starts_)
          {
            auto&& optimizer = std::get<2>(*ex_it);
            auto optim = (std::get<1>(*ex_it) > 0) ?
            optimizer.Optimize() :
              optimizer.Optimize(std::get<0>(*ex_it));

            if (optim.metrics && std::get<3>(*ex_it)) {
              auto&& exploration_metrics = optim.metrics->CreateSubMetrics("exploration");
              exploration_metrics.AddSubMetrics(std::move(*std::get<3>(*ex_it)));
              std::get<3>(*ex_it).reset();
            }
#pragma omp critical(insert_concentrated)
            best_starts_.Emplace(std::move(optim), std::move(optimizer));
          }
        }
      }
    Rcpp::checkUserInterrupt();
  }
};
} // namespace pense

#endif // REGULARIZATION_PATH_HPP_
