#!/usr/bin/env python
# Copyright (c) 2011-2012, Universite de Versailles St-Quentin-en-Yvelines
#
# This file is part of ASK.  ASK is free software: you can redistribute
# it and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation, version 2.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

"""
This module call the hierarchical sampling module to find interesting
points to label
"""

import os
import json
import math
import numpy as np
import pprint
from random import randint, choice, random

from common.configuration import Configuration
from common.util import fatal, get_cov_ub
from common.tree import node_iterator, leaf_iterator
from common.fit import residues
from common.constraints import get_tagged_models
from scipy import stats
import common.cart

confi = 0.9


def leaferror_hoeffding(N, threshold):
    freq_error = 0
    m = N.model[0]
    card = 0
    for d in N.data:
        card += 1
        if abs(d[-1] - m) >= threshold:
            freq_error += 1

    error_probability = (float(freq_error) / float(card))
    assert(0 <= error_probability <= 1)
    R = 1
    confidence = math.log(1.0 / (1-confi))
    hoeff = math.sqrt(R * R * confidence / (3 * card))
    return {"error": min(error_probability + hoeff, 1.0),
            "confidence": hoeff,
            "raw_error": error_probability}


def leaferror_variance(N, threshold):
    est_variance = 0
    m = N.model[0]
    card = 0
    for d in N.data:
        card += 1
        est_variance += (d[-1] - m) ** 2

    est_variance = est_variance / float(card-1)
    # Note: stats.chi.ppf(0.05)**2 computes chi^2(1-0.05) !
    variance_corr = (card - 1) / stats.chi.ppf((1-confi)/2, card - 1) ** 2
    variance_ub = est_variance * variance_corr
    return {"error": variance_ub,
            "confidence": variance_ub - est_variance,
            "raw_error": est_variance}

def leaferror_cov(N, threshold):
    est_variance = 0
    m = N.model[0]
    card = 0
    for d in N.data:
        card += 1
        est_variance += (d[-1] - m) ** 2

    est_variance = est_variance / float(card-1)
    est_sd = math.sqrt(est_variance)
    est_cov = est_sd/abs(m)
    cov_ub = get_cov_ub(abs(m), est_sd, card, confi)
    if cov_ub == float("inf"):
        cov_ub = 2*est_cov
    variance_ub = (cov_ub)**2

    return {"error": variance_ub,
            "confidence": variance_ub - est_cov**2,
            "raw_error": est_cov**2}

def compute_error_pernode(data, T, configuration):
    global confi
    confi = configuration("modules.samples.params.confidence", float, 0.9)

    # fill the tree nodes
    for d in data:
        T.fill(d)

    use_hoeffding = configuration("modules.sampler.params.use_hoeffding",
                              bool,
                              False)
    use_cov = configuration("modules.sampler.params.use_cov",
                              bool,
                              False)

    if use_hoeffding:
        print "Using Hoeffding correction"
        # compute threshold
        threshold = np.median(map(abs, residues(data, T)))
        method = leaferror_hoeffding
    elif use_cov:
        print "Using COV correction"
        threshold = None
        method = leaferror_cov
    else:
        print "Using Variance Chi^2 correction"
        threshold = None
        method = leaferror_variance

    # compute the error on each node
    for N in node_iterator(T):
        N.error = method(N, threshold)


def compute_size(constraint):
    size = 1

    for v in constraint.itervalues():
        if not isinstance(v, dict):
            size *= len(v)
        else:
            size *= v["max"] - v["min"]
    assert(size >= 0)
    return size


def posargmin(x):
    pos = [k for k in x if k > 0]
    return x.index(min(pos))


def find_sampling_distribution(configuration, T, samplesize):
    factors = configuration("factors")
    ponderate = configuration("modules.sampler.params.ponderate_by_size",
                              bool,
                              True)

    cons_raw = get_tagged_models(factors,
                                 T)

    if ponderate:
        print "Ponderating by size"
        prob = [c["error"]["error"] * compute_size(c["constraint"])
                for c in cons_raw]
    else:
        print "Not ponderating by size"
        prob = [c["error"]["error"] for c in cons_raw]

    totalp = sum(prob)
    proportions = [p * samplesize / totalp for p in prob]
    rounded = map(int, map(math.ceil, proportions))

    for k, p in zip(cons_raw, proportions):
        k["raw_proportion"] = p

    # remove rounding error
    while (sum(rounded) - samplesize):
        p = posargmin(proportions)
        rounded[p] -= 1
        proportions[p] -= 1

    proportions = rounded

    for k, p in zip(cons_raw, proportions):
        k["sampling_proportion"] = p


    # Print also in the logs
    pprint.pprint(cons_raw)

    # Get name of region file
    region_file = os.path.join(configuration["output_directory"],
            "hierarchical.regions")

    # Load previous region dump if it exists
    if os.path.exists(region_file):
        f = file(region_file)
        old_dump = json.load(f)
        f.close()
    else:
        old_dump = []

    # Update region dump with current iteration tree
    # and write to disk
    f = file(region_file, "w")
    cons_dump = []
    for c in cons_raw:
        d = c
        del(d['leaf'])
        cons_dump.append(d)
    old_dump.append(cons_dump)
    json.dump(old_dump, f, indent=4)
    f.close()

    # ensure we satisfy exactly the samplesize
    assert(sum(proportions) == samplesize)
    # ensure there are no negative proportions
    assert(not [p for p in proportions if p < 0])
    return zip(proportions, [c["constraint"] for c in cons_raw])


def draw_point(factors, constraints):
    p = []
    for f in factors:
        c = constraints[f["name"]]
        if f["type"] == "integer":
            g = randint(int(c["min"]),
                        int(c["max"]))
            p.append(g)
        elif f["type"] == "float":
            g = random() * (c["max"] - c["min"]) + c["min"]
            p.append(g)
        elif f["type"] == "categorical":
            g = f["values"].index(choice(c))
            p.append(g)
        else:
            fatal("Wrong factor type: <{0}>"
                  .format(f["type"]))
    return tuple(p)


def generate_possible_points(factors, constraints, labelled, n):
    # Prepare a list of generators that generate all possible
    # factors satisfying constraints
    count = 0
    tries = 0
    # We draw a random point until we find a point never selected
    # or we have done more than MAX_TRIES
    MAX_TRIES = 50
    while(count < n):
        p = draw_point(factors, constraints)
        if p in labelled and tries < MAX_TRIES:
            tries += 1
            continue
        else:
            tries = 0
            count += 1
            yield p


def predict(configuration, input_file, output_file, n, cp):
    print "using cut cp = {0}".format(cp)
    # build partition tree
    T = common.cart.build_tree(configuration, cp, input_file)

    # read the labelled points
    labelled = np.genfromtxt(input_file)

    # when there is only one observation we have to
    # reshape the data to say it is multidimensionnal
    if len(labelled.shape) == 1:
        labelled = labelled.reshape(1, labelled.shape[0])

    # compute the leaf error
    compute_error_pernode(labelled, T, configuration)

    # compute sampling distribution
    distribution = find_sampling_distribution(configuration, T, n)

    # open a file to write suggested points
    of = open(output_file, "w")

    # take all the labelled points and create a set
    already_labelled = set()
    all_labelled = list(labelled)
    for v in labelled[:,:-1]:
        already_labelled.add(tuple(v))

    for size, cons in distribution:
        for p in generate_possible_points(configuration["factors"],
                                          cons,
                                          already_labelled,
                                          size):
            already_labelled.add(p)
            all_labelled.append(p)
            of.write(" ".join(map(str, p)) + "\n")
    of.close()

    compute_weights = configuration("modules.sampler.params.use_weights",
                                    bool,
                                    True)


    if compute_weights:
        # all_labelled contains now all the labelled points, plus the new
        # points chosen in the current iteration. We shall compute the weights
        # for the model.

        # First count for each leaf the number of points
        # after the coming sample
        total_points = len(all_labelled)
        total_region_size = 0.0

        nodes = []
        region_sizes = []

        for l in leaf_iterator(T):
            total_region_size += compute_size(l.constraint.get_range())

        for p in all_labelled:
            n = T.whichnode(p)
            nodes.append(n)
            n.future_points += 1
            region_size = compute_size(n.constraint.get_range())
            region_sizes.append(region_size)


        # Now compute the weights
        weights = []
        for i,p in enumerate(all_labelled):
            region_size = region_sizes[i]
            future_points = nodes[i].future_points
            ratio_region_size = float(region_size) / float(total_region_size)
            ratio_region_points = float(future_points) / float(total_points)
            weights.append(ratio_region_size / ratio_region_points)

        # Get name of weight file
        region_file = os.path.join(configuration["output_directory"],
                "hierarchical.weights")

        # Dump weights
        print "weights:", weights
        f = file(region_file, "w")
        f.write(" ".join(map(str, weights)))
        f.write("\n")
        f.close()

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description="Using the currently labelled points,"
        " detects the more imprecise zones and draws points from them")
    parser.add_argument('configuration')
    parser.add_argument('input_file')
    parser.add_argument('output_file')
    args = parser.parse_args()
    conf = Configuration(args.configuration)
    predict(
        conf,
        args.input_file,
        args.output_file,
        conf("modules.sampler.params.n", int),
        conf("modules.sampler.params.cp", float, 0.0000001))
