import {e, log, mean, pow, sqrt} from "mathjs";
import {REGRESSION_TYPES} from "src/components/ChartBuilder/constants.jsx";
import {isNumber} from "src/utils/validators.js";


export function calculateR2(actual, predicted) {
    if (actual.length !== predicted.length) {
        throw new Error("The lengths of the actual and predicted arrays must be the same.");
    }

    const meanActual = mean(actual);

    const totalSumOfSquares = actual.reduce((sum, value) => sum + pow(value - meanActual, 2), 0);
    const residualSumOfSquares = actual.reduce((sum, value, index) => sum + pow(value - predicted[index], 2), 0);

    return 1 - (residualSumOfSquares / totalSumOfSquares);
}

export function calculateRMSE(actual, predicted) {
    /**
     * RootMeanSquaredError
     */
    if (actual.length !== predicted.length) {
        throw new Error("The lengths of the actual and predicted arrays must be the same.");
    }

    const residualSumOfSquares = actual.reduce((sum, value, index) => sum + pow(value - predicted[index], 2), 0);
    const meanSquareError = residualSumOfSquares / actual.length;
    return sqrt(meanSquareError);
}


function createLinearModel(parameters) {
    const {gradient, intercept} = parameters;

    function model(value) {
        return gradient * value + intercept;
    }

    return model;
}

function createExponentialModel(parameters) {
    const {coefficient, index} = parameters;

    function model(v) {
        return coefficient * pow(e, v * index);
    }

    return model;
}

function createLogarithmicModel(parameters) {
    const {gradient, intercept} = parameters;

    function model(value) {
        return gradient * log(value) + intercept;
    }

    return model;
}

function createPolynomialModel(parameters) {
    function model(value) {
        let sum = 0;
        parameters.forEach((parameter, index) => {
            sum += parameter * pow(value, index);
        });
        return sum;
    }

    return model;
}


export function getRegressionFunction(modelType, parameters) {
    switch (modelType) {
    case REGRESSION_TYPES.LINEAR:
        return createLinearModel(parameters);
    case REGRESSION_TYPES.EXPONENTIAL:
        return createExponentialModel(parameters);
    case REGRESSION_TYPES.LOGARITHMIC:
        return createLogarithmicModel(parameters);
    case REGRESSION_TYPES.POLYNOMIAL:
        return createPolynomialModel(parameters);
    default:
        throw new Error("Invalid model");
    }
}


export function calculateRegressionResult(data, regressionType, parameters) {
    const actual = [];
    const predicted = [];
    const modelFunc = getRegressionFunction(regressionType, parameters);

    data.forEach(([x, y]) => {
        if (isNumber(x) && isNumber(y)) {
            actual.push(y);
            predicted.push(modelFunc(x));
        }
    });

    return {
        r2: calculateR2(actual, predicted),
        rmse: calculateRMSE(actual, predicted)
    };
}
