import os
from typing import Any, Dict, List, Union

from fastapi import APIRouter, Depends, Query
from obspy import UTCDateTime, Stream
from starlette.responses import Response

from ...algorithm.DbDtAlgorithm import DbDtAlgorithm
from ... import TimeseriesFactory, TimeseriesUtility
from ...edge import EdgeFactory
from ...iaga2002 import IAGA2002Writer
from ...imfjson import IMFJSONWriter
from .DataApiQuery import (
    DEFAULT_ELEMENTS,
    DataApiQuery,
    DataType,
    OutputFormat,
    SamplingPeriod,
)


def get_data_factory() -> TimeseriesFactory:
    """Reads environment variable to determine the factory to be used

    Returns
    -------
    data_factory
        Edge or miniseed factory object
    """
    data_type = os.getenv("DATA_TYPE", "edge")
    data_host = os.getenv("DATA_HOST", "cwbpub.cr.usgs.gov")
    data_port = int(os.getenv("DATA_PORT", "2060"))
    if data_type == "edge":
        return EdgeFactory(host=data_host, port=data_port)
    else:
        return None


def format_timeseries(
    timeseries: Stream, format: OutputFormat, elements: List[str]
) -> Response:
    """Formats timeseries output

    Parameters
    ----------
    timeseries: data to format
    format: output format
    obspy.core.Stream
        timeseries object with requested data
    """
    if format == OutputFormat.JSON:
        data = IMFJSONWriter.format(timeseries, elements)
        media_type = "application/json"
    else:
        data = IAGA2002Writer.format(timeseries, elements)
        media_type = "text/plain"
    return Response(data, media_type=media_type)


def get_timeseries(data_factory: TimeseriesFactory, query: DataApiQuery) -> Stream:
    """Get timeseries data

    Parameters
    ----------
    data_factory: where to read data
    query: parameters for the data to read
    """

    # gather non-dbdt elements first
    base_elements = [element for element in query.elements if element[1::] != "_DDT"]

    # gather interval
    interval = TimeseriesUtility.get_interval_from_delta(query.sampling_period)
    # get data
    base_timeseries = data_factory.get_timeseries(
        starttime=query.starttime,
        endtime=query.endtime,
        observatory=query.id,
        channels=base_elements,
        type=query.data_type,
        interval=interval,
    )

    if "*_DDT" in query.elements:
        dbdt_elements = [
            element[0:1] for element in query.elements if element[1::] == "_DDT"
        ]

        timeseries = data_factory.get_timeseries(
            starttime=query.starttime,
            endtime=query.endtime,
            observatory=query.id,
            channels=dbdt_elements,
            type=query.data_type,
            interval=interval,
        )

        dbdt_timeseries = DbDtAlgorithm(
            inchannels=dbdt_elements, outchannels=dbdt_elements + "_DDT"
        ).process(timeseries)

        base_timeseries += (trace for trace in dbdt_timeseries)

    return base_timeseries


router = APIRouter()


@router.get("/data/")
def get_data(
    id: str,
    starttime: UTCDateTime = Query(None),
    endtime: UTCDateTime = Query(None),
    elements: List[str] = Query(DEFAULT_ELEMENTS),
    sampling_period: Union[SamplingPeriod, float] = Query(SamplingPeriod.MINUTE),
    data_type: Union[DataType, str] = Query(DataType.ADJUSTED, alias="type"),
    format: OutputFormat = Query(OutputFormat.IAGA2002),
    data_factory: TimeseriesFactory = Depends(get_data_factory),
) -> Response:
    # parse query
    query = DataApiQuery(
        id=id,
        starttime=starttime,
        endtime=endtime,
        elements=elements,
        sampling_period=sampling_period,
        data_type=data_type,
        format=format,
    )
    # read data
    timeseries = get_timeseries(data_factory, query)
    # output response
    return format_timeseries(
        timeseries=timeseries, format=format, elements=query.elements
    )