#ifndef LLOYDPROBMEDIAN_HPP
#define	LLOYDPROBMEDIAN_HPP

#include <algorithm>
#include <utility>
#include <memory>

#include "Metric.hpp"
#include "WeightedPoint.hpp"
#include "AdaptiveSampling.hpp"
#include "Weiszfeld.hpp"
#include "CenterOfGravity.hpp"
#include "KumarMedian.hpp"
#include "ProbabilisticPoint.hpp"

/**
 * @brief Adaption of Lloyd's algorithm for probabilistic k-median
 * 
 * Uses k-means++-like sampling
 */
class LloydProbMedian
{
private:
    AdaptiveSampling sampling;
    Weiszfeld weiszfeld;
    CenterOfGravity centerOfGravity;
    KumarMedian kumar;
    Metric<Point>* metric;

    int weiszfeldMaxIt;
    int kumarMedianIterations;
public:
    LloydProbMedian(std::function<Metric<Point>*() > createMetric, std::function<Norm<Point>*() > createNorm);

    /**
     * Computes a center set
     * @param begin Input point set: begin
     * @param end Input point set: end
     * @param output Output iterator
     * @param k Number of centers
     * @param maxIterations Maximum number of iterations
     * @param n Size of input set (optional)
     */
    template<typename ForwardIterator, typename OutputIterator>
    void computeCenterSet(ForwardIterator begin, ForwardIterator end, OutputIterator output, size_t k, size_t maxIterations, size_t n = 0);

    virtual ~LloydProbMedian();
};

template<typename ForwardIterator, typename OutputIterator>
void LloydProbMedian::computeCenterSet(ForwardIterator begin, ForwardIterator end, OutputIterator output, size_t k, size_t maxIterations, size_t n)
{
    //std::cout << "LloydProbMedian" << std::endl;
    if (n == 0)
        for (ForwardIterator it = begin; it != end; ++it)
            ++n;
    int dimension = (*begin)[0].getDimension();

    //std::cout << "Bastele Realisationsvektor (" << n << ")" << std::endl;
    std::vector<WeightedPoint> allRealizations;
    int z = 0;
    for (auto it = begin; it != end; ++it)
    {
        ++z;
        //std::cout << z << ", ";
        ProbabilisticPoint const & pp = *it;
        for (auto wpit = pp.cbegin(); wpit != pp.cend(); ++wpit)
        {
            WeightedPoint wp = *wpit;
            wp.setWeight(wp.getWeight() * pp.getWeight());
            allRealizations.push_back(wp);
        }
    }
    int m = allRealizations.size();
    std::unique_ptr < std::vector < Point >> initialCenters = sampling.computeCenterSet(allRealizations.begin(), allRealizations.end(), k, m);

    std::vector<Point> centers(*initialCenters);
    std::vector<size_t> centerAssignmentIndices(n, 0);
    bool assignmentChanged = true;
    for (size_t i = 0; i < maxIterations && assignmentChanged; ++i)
    {
        //std::cout << "Iteration " << i << std::endl;
        assignmentChanged = false;
        std::vector < std::vector < WeightedPoint >> centerAssignments(k, std::vector<WeightedPoint>());
        size_t p = 0;
        for (ForwardIterator it = begin; it != end; ++it)
        {
            ProbabilisticPoint const & pp = *it;
            double minDist = std::numeric_limits<double>::infinity();
            size_t minCenter = 0;
            for (size_t c = 0; c < centers.size(); ++c)
            {
                double tmpDist = 0;
                for (auto wpit = pp.cbegin(); wpit != pp.cend(); ++wpit)
                {
                    WeightedPoint const & wp = *wpit;
                    tmpDist += wp.getWeight() * metric->distance(centers[c], *wpit);
                }
                if (tmpDist < minDist)
                {
                    minDist = tmpDist;
                    minCenter = c;
                }
            }
            if (centerAssignmentIndices[p] != minCenter)
            {
                centerAssignmentIndices[p] = minCenter;
                assignmentChanged = true;
            }
            for (auto wpit = pp.cbegin(); wpit != pp.cend(); ++wpit)
                centerAssignments[minCenter].push_back(*wpit);
            ++p;
        }

        for (size_t c = 0; c < centers.size(); ++c)
        {
            if (centerAssignments[c].size() > 0)
#ifdef KMEANS
                centers[c] = centerOfGravity.cog(centerAssignments[c].begin(), centerAssignments[c].end());
#else
                try
                {
                    centers[c] = weiszfeld.approximateOneMedian(centerAssignments[c].begin(), centerAssignments[c].end(), weiszfeldMaxIt);
                }
                catch (Weiszfeld::IterationFailed err)
                {
                    centers[c] = kumar.approximateOneMedianRounds(centerAssignments[c].begin(), centerAssignments[c].end(), 0.9999999999, weiszfeldMaxIt);
                }
#endif
else
                centers[c] = Point(dimension);
        }
    }

    for (size_t i = 0; i < centers.size(); ++i)
    {
        *output = centers[i];
        ++output;
    }
}

#endif	/* LLOYDMEDIAN_HPP */

