//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Sim/Fitting/SimDataPair.cpp
//! @brief     Defines class SimDataPair.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Sim/Fitting/SimDataPair.h"
#include "Base/Axis/Frame.h"
#include "Base/Axis/Scale.h"
#include "Base/Math/Numeric.h"
#include "Base/Util/Assert.h"
#include "Device/Coord/ICoordSystem.h"
#include "Device/Data/Datafield.h"
#include "Device/Detector/IDetector.h"
#include "Device/Detector/SimulationAreaIterator.h" // roiIndex
#include "Device/Histo/SimulationResult.h"
#include "Sim/Simulation/ScatteringSimulation.h"
#include <utility>

namespace {

std::unique_ptr<Datafield> initUserWeights(const Datafield& shape, double value)
{
    auto result = std::make_unique<Datafield>(shape.frame().clone());
    result->setAllTo(value);
    return result;
}

bool haveSameSizes(const IDetector& detector, const Datafield& data)
{
    if (data.rank() != 2)
        return false;

    for (size_t i = 0; i < 2; ++i)
        if (data.axis(i).size() != detector.axis(i).size())
            return false;

    return true;
}

//! Convert user data to SimulationResult object for later drawing in various axes units.
//! User data will be cropped to the ROI defined in the simulation, amplitudes in areas
//! corresponding to the masked areas of the detector will be set to zero.

SimulationResult convertData(const ScatteringSimulation& simulation, const Datafield& data)
{
    const ICoordSystem* coordSystem = simulation.simCoordSystem();
    auto roi_data = std::make_unique<Datafield>(coordSystem->defaultAxes());

    if (roi_data->frame().hasSameSizes(data.frame())) {
        // data is already cropped to ROI
        simulation.detector().iterateOverNonMaskedPoints([&](IDetector::const_iterator it) {
            (*roi_data)[it.roiIndex()] = data[it.roiIndex()];
        });
    } else if (haveSameSizes(simulation.detector(), data)) {
        // experimental data has same shape as the detector, we have to copy the original
        // data to a smaller roi map
        simulation.detector().iterateOverNonMaskedPoints([&](IDetector::const_iterator it) {
            (*roi_data)[it.roiIndex()] = data[it.detectorIndex()];
        });
    } else
        throw std::runtime_error(
            "FitObject::init_dataset: Detector and experimental data have different shape");

    return SimulationResult(*roi_data, coordSystem);
}

} // namespace

//  ************************************************************************************************
//  class implementation
//  ************************************************************************************************

SimDataPair::SimDataPair(simulation_builder_t builder, const Datafield& raw_data,
                         std::unique_ptr<Datafield>&& raw_stdv, double user_weight)
    : m_simulation_builder(std::move(builder))
    , m_raw_data(raw_data.clone())
    , m_raw_uncertainties(std::move(raw_stdv))
{
    m_raw_user_weights = initUserWeights(*m_raw_data, user_weight);
    validate();
}

SimDataPair::SimDataPair(simulation_builder_t builder, const Datafield& raw_data,
                         std::unique_ptr<Datafield>&& raw_stdv,
                         std::unique_ptr<Datafield>&& user_weights)
    : m_simulation_builder(std::move(builder))
    , m_raw_data(raw_data.clone())
    , m_raw_uncertainties(std::move(raw_stdv))
    , m_raw_user_weights(std::move(user_weights))
{
    if (!m_raw_user_weights)
        m_raw_user_weights = initUserWeights(*m_raw_data, 1.0);
    validate();
}

SimDataPair::SimDataPair(SimDataPair&& other) = default;

SimDataPair::~SimDataPair() = default;

void SimDataPair::execSimulation(const mumufit::Parameters& params)
{
    std::unique_ptr<ISimulation> simulation = m_simulation_builder(params);
    ASSERT(simulation);
    m_sim_data = std::make_unique<SimulationResult>(simulation->simulate());
    ASSERT(!m_sim_data->empty());

    if (m_exp_data && !m_exp_data->empty() && m_uncertainties && !m_uncertainties->empty()
        && m_user_weights && !m_user_weights->empty())
        return;

    auto* const sim2d = dynamic_cast<ScatteringSimulation*>(simulation.get());
    if (sim2d) {
        m_exp_data = std::make_unique<SimulationResult>(convertData(*sim2d, *m_raw_data));
        m_user_weights =
            std::make_unique<SimulationResult>(convertData(*sim2d, *m_raw_user_weights));
    } else {
        const ICoordSystem& converter = m_sim_data->converter();
        m_exp_data = std::make_unique<SimulationResult>(*m_raw_data, converter.clone());
        m_user_weights = std::make_unique<SimulationResult>(*m_raw_user_weights, converter.clone());
    }

    if (sim2d && containsUncertainties())
        m_uncertainties =
            std::make_unique<SimulationResult>(convertData(*sim2d, *m_raw_uncertainties));
    else {
        const ICoordSystem& converter = m_sim_data->converter();
        auto dummy_array = std::make_unique<Datafield>(converter.defaultAxes());
        m_uncertainties = std::make_unique<SimulationResult>(*dummy_array, converter.clone());
    }
}

bool SimDataPair::containsUncertainties() const
{
    return static_cast<bool>(m_raw_uncertainties);
}

SimulationResult SimDataPair::simulationResult() const
{
    ASSERT(m_sim_data);
    ASSERT(!m_sim_data->empty());
    return *m_sim_data;
}

SimulationResult SimDataPair::experimentalData() const
{
    ASSERT(m_exp_data);
    ASSERT(!m_exp_data->empty());
    return *m_exp_data;
}

SimulationResult SimDataPair::uncertainties() const
{
    ASSERT(m_uncertainties);
    ASSERT(!m_uncertainties->empty());
    return *m_uncertainties;
}

//! Returns the user uncertainties cut to the ROI area.
SimulationResult SimDataPair::userWeights() const
{
    ASSERT(m_user_weights);
    ASSERT(!m_user_weights->empty());
    return *m_user_weights;
}

std::vector<double> SimDataPair::simulation_array() const
{
    return simulationResult().flatVector();
}

std::vector<double> SimDataPair::experimental_array() const
{
    return experimentalData().flatVector();
}

std::vector<double> SimDataPair::uncertainties_array() const
{
    return uncertainties().flatVector();
}

std::vector<double> SimDataPair::user_weights_array() const
{
    return userWeights().flatVector();
}

//! Returns relative difference between simulation and experimental data.

SimulationResult SimDataPair::relativeDifference() const
{
    size_t N = m_sim_data->size();
    if (!N)
        throw std::runtime_error("Empty simulation data => won't compute relative difference");
    if (!m_exp_data || m_exp_data->size() != N)
        throw std::runtime_error("Different data shapes => won't compute relative difference");

    std::vector<double> data(N, 0.);
    for (size_t i = 0; i < N; ++i)
        data[i] = Numeric::relativeDifference((*m_sim_data)[i], (*m_exp_data)[i]);

    const Frame* f = m_sim_data->frame().clone();
    Datafield df(f, data);
    return {df, m_sim_data->converter().clone()};
}

SimulationResult SimDataPair::absoluteDifference() const
{
    size_t N = m_sim_data->size();
    if (!N)
        throw std::runtime_error("Empty simulation data => won't compute absolute difference");
    if (!m_exp_data || m_exp_data->size() != N)
        throw std::runtime_error("Different data shapes => won't compute absolute difference");

    std::vector<double> data(N, 0.);
    for (size_t i = 0; i < N; ++i)
        data[i] = std::abs((*m_sim_data)[i] - (*m_exp_data)[i]);

    const Frame* f = m_sim_data->frame().clone();
    Datafield df(f, data);
    return {df, m_sim_data->converter().clone()};
}

void SimDataPair::validate() const
{
    if (!m_simulation_builder)
        throw std::runtime_error("Error in SimDataPair: simulation builder is empty");

    if (!m_raw_data)
        throw std::runtime_error("Error in SimDataPair: passed experimental data array is empty");

    if (m_raw_uncertainties && m_raw_uncertainties->frame() != m_raw_data->frame())
        throw std::runtime_error(
            "Error in SimDataPair: experimental data and uncertainties have different shape.");

    if (!m_raw_user_weights || m_raw_user_weights->frame() != m_raw_data->frame())
        throw std::runtime_error(
            "Error in SimDataPair: user weights are not initialized or have invalid shape");
}
