import numpy as np
from obspy import Stream, UTCDateTime
from pydantic import BaseModel
from typing import Any, List, Optional

from .. import ChannelConverter
from .. import pydantic_utcdatetime
from ..residual.Reading import Reading, get_absolutes_xyz, get_ordinates
from .Metric import Metric


class AdjustedMatrix(BaseModel):
    """Attributes pertaining to adjusted(affine) matrices, applied by the AdjustedAlgorithm

    Attributes
    ----------
    matrix: affine matrix generated by Affine's calculate method
    pier_correction: pier correction generated by Affine's calculate method
    starttime: beginning of interval that matrix is valid for
    endtime: end of interval that matrix is valid for
    NOTE: valid intervals are only generated when bad data is encountered.
    Matrix is non-constrained otherwise
    """

    matrix: Optional[Any] = None
    pier_correction: Optional[float] = None
    metrics: Optional[List[Metric]] = None
    starttime: Optional[UTCDateTime] = None
    endtime: Optional[UTCDateTime] = None
    time: Optional[UTCDateTime] = None

    def process(
        self,
        stream: Stream,
        inchannels=["H", "E", "Z", "F"],
        outchannels=["X", "Y", "Z", "F"],
    ):
        """ Apply matrix to raw data. Apply pier correction to F when necessary """
        raws = np.vstack(
            [
                stream.select(channel=channel)[0].data
                for channel in inchannels
                if channel != "F"
            ]
            + [np.ones_like(stream[0].data)]
        )
        adjusted = self.matrix @ raws
        if "F" in inchannels and "F" in outchannels:
            f = stream.select(channel="F")[0].data + self.pier_correction
            adjusted[-1] = f
        return adjusted

    def get_metrics(self, readings: List[Reading]) -> List[Metric]:
        """Computes mean absolute error and standard deviation for X, Y, Z, and dF between expected and predicted values.

        Attributes
        ----------
        readings: list of Readings
        matrix: composed matrix

        Outputs
        -------
        metrics: list of Metric objects
        """
        absolutes = get_absolutes_xyz(readings=readings)
        ordinates = get_ordinates(readings=readings)
        stacked_ordinates = np.vstack((ordinates, np.ones_like(ordinates[0])))
        predicted = self.matrix @ stacked_ordinates
        metrics = []
        elements = ["X", "Y", "Z", "dF"]
        expected = np.vstack(
            (
                absolutes,
                ChannelConverter.get_computed_f_using_squares(*absolutes),
            )
        )
        predicted = np.vstack(
            (
                predicted[0:3],
                ChannelConverter.get_computed_f_using_squares(*predicted[0:3]),
            )
        )
        for i in range(len(elements)):
            diff = expected[i] - predicted[i]
            metrics.append(
                Metric(
                    element=elements[i],
                    absmean=abs(np.nanmean(diff)),
                    stddev=np.std(diff),
                )
            )
        return metrics