from typing import Dict, Union, List, Optional
from datetime import timedelta

from obspy import UTCDateTime
from pydantic import BaseModel, Field, field_validator
from enum import Enum

from ...pydantic_utcdatetime import CustomUTCDateTimeType


class FlagCategory(str, Enum):
    ARTIFICIAL_DISTURBANCE = "ARTIFICIAL_DISTURBANCE"
    GAP = "GAP"
    EVENT = "EVENT"
    FIELD_WORK = "FIELD_WORK"
    OTHER = "OTHER"


class Flag(BaseModel):
    """
        Base class for flagging features in magnetic timeseries data.

        Flag example:
        ```
        automatic_flag = Metadata(
            created_by='ex_algorithm',
            start_time=UTCDateTime('2023-01-01T03:05:10'),
            end_time=UTCDateTime('2023-01-01T03:10:11'),
            network='NT',
            station='BOU',
            channel='BEH',
            category=MetadataCategory.FLAG,
            comment="spike detected",
            priority=1,
            data_valid=True,
            metadata= ArtificialDisturbance{
                "description": "Spikes in magnetic field strength",
                "flag_category": ARTIFICIAL_DISTURBANCE,
                "artificial_disturbance_type": ArtificialDisturbanceType.SPIKE,
                "deviation": None,
                "spikes": ['2023-01-01T03:05:10','2023-01-01T03:07:20','2023-01-01T03:10:11']
        }
    )
        ```
    """

    description: str = Field(..., description="Description of the flag")
    flag_category: FlagCategory = "OTHER"


class ArtificialDisturbanceType(str, Enum):
    SPIKES = "SPIKES"
    OFFSET = "OFFSET"
    ARTIFICIAL_DISTURBANCES = "ARTIFICIAL_DISTURBANCES"


class ArtificialDisturbance(Flag):
    """
    This class is used to flag artificial disturbances.

    Artificial disturbances consist of the following types:

    SPIKES = Single data points that are outliers in the timeseries.
    OFFSET = A relatively constant shift or deviation in the baseline magnetic field.
    ARTIFICIAL_DISTURBANCES = A catch-all for a continuous period of unwanted variations, may include multiple spikes, offsets and/or gaps.

    Attributes
    ----------
    artificial_disturbance_type:ArtificialDisturbanceType
        The type of artificial disturbance(s).
    deviation: float
       Deviation of an offset in nt.
    spikes: List[CustomUTCDateTimeType]
        Array of timestamps as UTCDateTime. Can be a single spike or many spikes.

    """

    artificial_disturbance_type: ArtificialDisturbanceType
    deviation: Optional[float] = None
    spikes: Optional[List[CustomUTCDateTimeType]] = None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.flag_category = FlagCategory.ARTIFICIAL_DISTURBANCE

    @field_validator("spikes")
    def check_spikes_duration(cls, spikes):
        if spikes is None or len(spikes) < 2:
            return spikes

        duration = spikes[-1] - spikes[0]
        if duration > timedelta(days=1).total_seconds():
            raise ValueError(
                f"The duration between the first and last spike timestamp must not exceed 1 day. Duration: {duration} seconds"
            )

        return spikes

    @classmethod
    def check_spikes_match_times(cls, spikes, values):
        metadata_starttime = values.get("starttime")
        metadata_endtime = values.get("endtime")

        if spikes[0] != metadata_starttime:
            raise ValueError(
                f"The first spike timestamp {spikes[0]} does not match the starttime {metadata_starttime}."
            )

        if spikes[-1] != metadata_endtime:
            raise ValueError(
                f"The last spike timestamp {spikes[-1]} does not match the endtime {metadata_endtime}."
            )


class Gap(Flag):
    """
    This class is used to flag gaps in data.

    A gap is a period where data is missing or not recorded.

    Attributes
    ----------
    cause: str
        Cause of gap, e.g., network outage.
    handling: str
        How the gap is being handled, e.g., backfilled.
    """

    cause: Optional[str] = None
    handling: Optional[str] = None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.flag_category = FlagCategory.GAP


class EventType(str, Enum):
    GEOMAGNETIC_STORM = "GEOMAGNETIC_STORM"
    GEOMAGNETIC_SUBSTORM = "GEOMAGNETIC_SUBSTORM"
    EARTHQUAKE = "EARTHQUAKE"
    OTHER = "OTHER"


class Event(Flag):
    """
    This class is used to flag an event of interest such as a geomagnetic storm or earthquake.

    Attributes
    ----------
    event_type : EventType
        The type of event.
    scale : str
        Geomagnetic storm scale or Richter scale magnitude.
    index : int
        Planetary K-index, DST index or some other index.
    url : str
        A url related to the event. Could be NOAA SWPC, USGS Earthquakes page or another site.
    """

    event_type: EventType
    index: Optional[int] = None
    scale: Optional[str] = None
    url: Optional[str] = None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.flag_category = FlagCategory.EVENT


# More example usage:
timestamps_array = [
    UTCDateTime("2023-11-16T12:00:0"),
    UTCDateTime("2023-11-16T12:01:10"),
    UTCDateTime("2023-11-16T12:02:30"),
]

spikes_data = {
    "starttime": "2023-11-16 12:00:00",
    "endtime": "2023-11-16 12:02:30",
    "description": "Spikes description",
    "disturbance_type": ArtificialDisturbanceType.SPIKES,
    "spikes": timestamps_array,
}

offset_data = {
    "description": "Offset description",
    "disturbance_type": ArtificialDisturbanceType.OFFSET,
    "deviation": 10.0,
}
geomagnetic_storm_data = {
    "description": "Geomagnetic storm",
    "event_type": EventType.GEOMAGNETIC_STORM,
    "scale": "G3",
    "index": 7,
    "url": "https://www.swpc.noaa.gov/products/planetary-k-index",
}