"""Factory that loads data from earthworm and writes to Edge.

EdgeFactory uses obspy earthworm class to read data from any
earthworm standard Waveserver using the obspy getWaveform call.

Writing will be implemented with Edge specific capabilities,
to take advantage of it's newer realtime abilities.

Edge is the USGS earthquake hazard centers replacement for earthworm.
"""

from __future__ import absolute_import
from datetime import datetime
import sys
from typing import List, Optional

import numpy
import numpy.ma
from obspy.core import Stats, Stream, Trace, UTCDateTime
from obspy.clients import earthworm

from .. import ChannelConverter, TimeseriesUtility
from ..geomag_types import DataInterval, DataType
from ..metadata.instrument.InstrumentCalibrations import get_instrument_calibrations
from ..TimeseriesFactory import TimeseriesFactory
from ..TimeseriesFactoryException import TimeseriesFactoryException
from ..ObservatoryMetadata import ObservatoryMetadata
from .RawInputClient import RawInputClient
from .SNCL import SNCL
from .LegacySNCL import LegacySNCL


class EdgeFactory(TimeseriesFactory):
    """TimeseriesFactory for Edge related data.

    Parameters
    ----------
    host: str
        a string representing the IP number of the host to connect to.
    port: integer
        port number on which EdgeCWB's CWB waveserver (CWBWS, obspy's
        "earthworm" client) listens for requests to retrieve timeseries data.
    write_port: integer
        the port number on which EdgeCWB's RawInputServer is listening for
        requests to write timeseries data.
    tag: str
        A tag used by edge to log and associate a socket with a given data
        source
    forceout: bool
        Tells edge to forceout a packet to miniseed.  Generally used when
        the user knows no more data is coming.
    observatory: str
        the observatory code for the desired observatory.
    channels: array
        an array of channels {H, D, E, F, Z, MGD, MSD, HGD}.
        Known since channel names are mapped based on interval and type,
        others are passed through, see #_get_edge_channel().
    type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
        data type
    interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
        data interval
    observatoryMetadata: ObservatoryMetadata object
        an ObservatoryMetadata object used to replace the default
        ObservatoryMetadata.
    locationCode: str
        the location code for the given edge server, overrides type
        in get_timeseries/put_timeseries
    convert_channels: array
        list of channels to convert from volt/bin to nT
    scale_factor: float
        override default scalings when reading/writing miniseed blocks;
        (reading integer blocks divides by 1000; reading floating point
         blocks divides by 1; writing all data multiplies by 1000)
        default = None
    sncl_mode: {'geomag','legacy'}
        force mode to convert common names to SEED SNCL codes (that is,
        station, network, channel, location codes); default = legacy
    timeout: float
        timeout for Earthworm client; default=10

    See Also
    --------
    TimeseriesFactory

    Notes
    -----
    The EdgeFactory reads so-called trace-bufs and writes integer data
    to an EdgeCWB RawInputServer, which places these data, sample-by-sample,
    into a RAM buffer where it can be immediately retrieved, even before
    full miniseed blocks can be constructed and written to disk. The
    EdgeFactory cannot handle non-integer data, and is inefficient for
    larger data transfers. In those cases, consider using MiniSeedFactory.
    """

    def __init__(
        self,
        host: Optional[str] = None,
        port: Optional[int] = None,
        write_port: Optional[int] = None,
        tag: str = "GeomagAlg",
        forceout: bool = False,
        observatory: Optional[str] = None,
        channels: Optional[List[str]] = None,
        type: Optional[DataType] = None,
        interval: Optional[DataInterval] = None,
        observatoryMetadata: Optional[ObservatoryMetadata] = None,
        locationCode: Optional[str] = None,
        convert_channels: Optional[List[str]] = None,
        scale_factor: Optional[int] = None,
        sncl_mode: Optional[str] = None,
        timeout: Optional[float] = None,
    ):
        TimeseriesFactory.__init__(self, observatory, channels, type, interval)
        self.host = host or "edgecwb.usgs.gov"
        self.port = port or [2060]
        self.write_port = write_port  # no default write port
        self.tag = tag
        self.forceout = forceout
        self.observatoryMetadata = observatoryMetadata or ObservatoryMetadata()
        self.locationCode = locationCode
        self.convert_channels = convert_channels or []
        self.scale_factor = scale_factor
        self.sncl_mode = sncl_mode
        self.timeout = timeout or 10
        if self.sncl_mode == "legacy" or self.sncl_mode is None:
            self.get_sncl = LegacySNCL.get_sncl
        elif self.sncl_mode == "geomag":
            self.get_sncl = SNCL.get_sncl
        else:
            raise TimeseriesFactoryException("Unrecognized SNCL mode")

    def get_timeseries(
        self,
        starttime: UTCDateTime,
        endtime: UTCDateTime,
        observatory: Optional[str] = None,
        channels: Optional[List[str]] = None,
        type: Optional[DataType] = None,
        interval: Optional[DataInterval] = None,
        add_empty_channels: bool = True,
    ) -> Stream:
        """Get timeseries data

        Parameters
        ----------
        starttime: UTCDateTime
            time of first sample.
        endtime: UTCDateTime
            time of last sample.
        observatory: str
            observatory code.
        channels: array
            list of channels to load
        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
            data type
        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
            data interval
        add_empty_channels: bool
            if True, returns channels without data as empty traces

        Returns
        -------
        timeseries: Stream
            timeseries object with requested data

        Raises
        ------
        TimeseriesFactoryException
            if invalid values are requested, or errors occur while
            retrieving timeseries.
        """
        observatory = observatory or self.observatory
        channels = channels or self.channels
        type = type or self.type
        interval = interval or self.interval

        if starttime > endtime:
            raise TimeseriesFactoryException(
                'Starttime before endtime "%s" "%s"' % (starttime, endtime)
            )

        # obspy factories sometimes write to stdout, instead of stderr
        original_stdout = sys.stdout
        try:
            # send stdout to stderr
            sys.stdout = sys.stderr
            # get the timeseries
            timeseries = Stream()
            for channel in channels:
                if channel in self.convert_channels:
                    data = self._convert_timeseries(
                        starttime, endtime, observatory, channel, type, interval
                    )
                else:
                    data = self._get_timeseries(
                        starttime,
                        endtime,
                        observatory,
                        channel,
                        type,
                        interval,
                        add_empty_channels,
                    )
                    if len(data) == 0:
                        continue
                timeseries += data
        finally:
            # restore stdout
            sys.stdout = original_stdout
        self._post_process(timeseries, starttime, endtime, channels)

        return timeseries

    def put_timeseries(
        self,
        timeseries: Stream,
        starttime: Optional[UTCDateTime] = None,
        endtime: Optional[UTCDateTime] = None,
        observatory: Optional[str] = None,
        channels: Optional[List[str]] = None,
        type: Optional[DataType] = None,
        interval: Optional[DataInterval] = None,
    ):
        """Put timeseries data

        Parameters
        ----------
        timeseries: Stream
            timeseries object with data to be written
        observatory: str
            observatory code
        channels: array
            list of channels to load
        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
            data type
        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
            data interval

        Notes
        -----
        Streams sent to timeseries are expected to have a single trace per
            channel and that trace should have an ndarray, with nan's
            representing gaps.
        """
        stats = timeseries[0].stats
        observatory = observatory or stats.station or self.observatory
        channels = channels or self.channels
        type = type or self.type or stats.data_type
        interval = interval or self.interval or stats.data_interval

        if starttime is None or endtime is None:
            starttime, endtime = TimeseriesUtility.get_stream_start_end_times(
                timeseries
            )
        for channel in channels:
            if timeseries.select(channel=channel).count() == 0:
                raise TimeseriesFactoryException(
                    'Missing channel "%s" for output, available channels %s'
                    % (channel, str(TimeseriesUtility.get_channels(timeseries)))
                )
        for channel in channels:
            self._put_channel(
                timeseries.select(channel=channel),
                observatory,
                channel,
                type,
                interval,
                starttime,
                endtime,
            )

    def get_calculated_timeseries(
        self,
        starttime: UTCDateTime,
        endtime: UTCDateTime,
        observatory: str,
        channel: str,
        type: DataType,
        interval: DataInterval,
        components: List[dict],
    ) -> Trace:
        """Calculate a single channel using multiple component channels.

        Parameters
        ----------
        starttime: UTCDateTime
            the starttime of the requested data
        endtime: UTCDateTime
            the endtime of the requested data
        observatory: str
            observatory code
        channel: str
            single character channel {H, E, D, Z, F}
        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
            data type
        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
            data interval
        components: list
            each component is a dictionary with the following keys:
                channel: str
                offset: float
                scale: float

        Returns
        -------
        out: Trace
            timeseries trace of the converted channel data
        """
        # sum channels
        stats = None
        converted = None
        for component in components:
            # load component
            data = self._get_timeseries(
                starttime, endtime, observatory, component["channel"], type, interval
            )[0]
            # convert to nT
            nt = data.data * component["scale"] + component["offset"]
            # add to converted
            if converted is None:
                converted = nt
                stats = Stats(data.stats)
            else:
                converted += nt
        # set channel parameter to U, V, or W
        stats.channel = channel
        # create empty trace with adapted stats
        out = TimeseriesUtility.create_empty_trace(
            stats.starttime,
            stats.endtime,
            stats.station,
            stats.channel,
            stats.data_type,
            stats.data_interval,
            stats.network,
            stats.station,
            stats.location,
        )
        out.data = converted
        return out

    def _convert_stream_to_masked(self, timeseries: Stream, channel: str) -> Stream:
        """convert geomag edge traces in a timeseries stream to a MaskedArray
            This allows for gaps and splitting.
        Parameters
        ----------
        stream: Stream
            a stream retrieved from a geomag edge representing one channel
        channel: str
            the channel to be masked
        Returns
        -------
        stream: Stream
            a stream with all traces converted to masked arrays
        """
        stream = timeseries.copy()
        for trace in stream.select(channel=channel):
            trace.data = numpy.ma.masked_invalid(trace.data)
        return stream

    def _get_timeseries(
        self,
        starttime: UTCDateTime,
        endtime: UTCDateTime,
        observatory: str,
        channel: str,
        type: DataType,
        interval: DataInterval,
        add_empty_channels: bool = True,
    ) -> Trace:
        """get timeseries data for a single channel.

        Parameters
        ----------
        starttime: UTCDateTime
            the starttime of the requested data
        endtime: UTCDateTime
            the endtime of the requested data
        observatory: str
            observatory code
        channel: str
            single character channel {H, E, D, Z, F}
        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
            data type
        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
            data interval
        add_empty_channels: bool
            if True, returns channels without data as empty traces

        Returns
        -------
        data: Trace
            timeseries trace of the requested channel data
        """
        sncl = self.get_sncl(
            station=observatory,
            data_type=type,
            interval=interval,
            element=channel,
            location=self.locationCode,
        )
        # geomag-algorithms *should* treat starttime/endtime as inclusive everywhere;
        # according to its author, EdgeCWB is inclusive of starttime, but exclusive of
        # endtime, to satisfy seismic standards/requirements, to precision delta/2;
        half_delta = TimeseriesUtility.get_delta_from_interval(interval) / 2

        # list-type ports variable needed for fail-back logic
        try:
            ports = list(self.port)
        except TypeError:
            ports = [self.port]

        data = Stream()
        for port in ports:
            try:
                client = earthworm.Client(self.host, port, timeout=self.timeout)

                data += client.get_waveforms(
                    sncl.network,
                    sncl.station,
                    sncl.location,
                    sncl.channel,
                    starttime,
                    endtime + half_delta,
                )

                if data:
                    # if data was returned, force this port in subsequent calls
                    # to get_timeseries() that use this instance of EdgeFactory
                    self.port = [port]
                    break
                print(
                    "No data returned from ",
                    self.host,
                    "on port ",
                    port,
                    " - SNCL:",
                    sncl,
                )
                # try alternate port(s) if provided
            except Exception as e:
                print("Failed to get data from ", self.host, " on port ", port)
                print("Ignoring error: ", e.__class__, e)
                # try alternate port(s) if provided
                continue

        for trace in data:
            if trace.data.dtype.kind == "i":
                # convert all integer traces to rescaled 64-bit floats;
                if sncl.channel[1] == "E":
                    # instrumental voltages are stored as 1/10 microvolts
                    trace.data = trace.data.astype(float) / (self.scale_factor or 1e7)
                else:
                    # everything else (mostly magnetics stored as picoteslas)
                    trace.data = trace.data.astype(float) / (self.scale_factor or 1e3)
            elif trace.data.dtype.kind == "f":
                # convert all float traces to 64-bit floats;
                trace.data = trace.data.astype(float) / (self.scale_factor or 1.0)

        # when Traces with identical NSCL codes overlap, prioritize samples
        # that come  "later" in Stream; this depends on Edge returning miniseed
        # packets in the order written
        # NOTE: this is not possible with single calls to Stream.merge()
        st_tmp = Stream()
        for tr in data:
            try:
                # add tr to temporary stream
                st_tmp += tr
                # replace time overlaps with gaps
                st_tmp.merge(0)
                # add tr to temporary stream again
                st_tmp += tr
                # replace gaps with tr's data
                st_tmp.merge(0)
            except Exception as e:
                tr = st_tmp.pop()  # remove bad trace
                print("Dropped trace: ", tr)
                print("Ignoring error: ", e.__class__, e)

        # point `data` to the new stream and continue processing
        data = st_tmp
        if data.count() == 0 and add_empty_channels:
            data += self._get_empty_trace(
                starttime=starttime,
                endtime=endtime,
                observatory=observatory,
                channel=channel,
                data_type=type,
                interval=interval,
                network=sncl.network,
                location=sncl.location,
            )
        self._set_metadata(data, observatory, channel, type, interval)
        return data

    def _convert_timeseries(
        self,
        starttime: UTCDateTime,
        endtime: UTCDateTime,
        observatory: str,
        channel: str,
        type: DataType,
        interval: DataInterval,
    ) -> Stream:
        """Generate a single channel using multiple components.

        Finds metadata, then calls _get_converted_timeseries for actual
        conversion.

        Parameters
        ----------
        starttime: UTCDateTime
            the starttime of the requested data
        endtime: UTCDateTime
            the endtime of the requested data
        observatory : str
            observatory code
        channel : str
            single character channel {H, E, D, Z, F}
        type : {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
            data type
        interval : {'tenhertz', 'second', 'minute', 'hour', 'day'}
            data interval

        Returns
        -------
        out: Stream
            timeseries with single trace of the requested channel data
            NOTE: this originally returned a Trace, but was modified
                  (probably due to misunderstanding) at some point to
                  return a Stream. This Stream is, however, expected
                  to contain just a single Trace.
        """
        out = Stream()
        metadata = get_instrument_calibrations(observatory, starttime, endtime)
        # loop in case request spans different configurations
        for entry in metadata:
            entry_endtime = entry["end_time"]
            entry_starttime = entry["start_time"]
            instrument = entry["instrument"]
            instrument_channels = instrument["channels"]
            if channel not in instrument_channels:
                # no idea how to convert
                continue
            # determine metadata overlap with request
            start = (
                starttime
                if entry_starttime is None or entry_starttime < starttime
                else entry_starttime
            )
            end = (
                endtime
                if entry_endtime is None or entry_endtime > endtime
                else entry_endtime
            )
            # now convert
            out += self.get_calculated_timeseries(
                start,
                end,
                observatory,
                channel,
                type,
                interval,
                instrument_channels[channel],
            )
        # force to match the first trace's datatype
        for trace in out:
            trace.data = trace.data.astype(out[0].data.dtype)
        # merge to force a single trace
        # (precedence given to later inteval when metadata overlap)
        out.merge(1)
        return out

    def _post_process(
        self,
        timeseries: Stream,
        starttime: UTCDateTime,
        endtime: UTCDateTime,
        channels: List[str],
    ):
        """Post process a timeseries stream after the raw data is
                is fetched from waveserver. Specifically changes
                any MaskedArray to a ndarray with nans representing gaps.
                Then calls pad_timeseries to deal with gaps at the
                beggining or end of the streams.

        Parameters
        ----------
        timeseries: Stream
            The timeseries stream as returned by the call to get_waveforms
        starttime: UTCDateTime
            the starttime of the requested data
        endtime: UTCDateTime
            the endtime of the requested data
        channels: array
            list of channels to load

        Notes: the original timeseries object is changed.
        """
        for trace in timeseries:
            if isinstance(trace.data, numpy.ma.MaskedArray):
                trace.data.set_fill_value(numpy.nan)
                trace.data = trace.data.filled()

        if "D" in channels:
            for trace in timeseries.select(channel="D"):
                trace.data = ChannelConverter.get_radians_from_minutes(trace.data)

        TimeseriesUtility.pad_timeseries(timeseries, starttime, endtime)

    def _put_channel(
        self,
        timeseries: Stream,
        observatory: str,
        channel: str,
        type: DataType,
        interval: DataInterval,
        starttime: UTCDateTime,
        endtime: UTCDateTime,
    ):
        """Put a channel worth of data

        Parameters
        ----------
        timeseries: Stream
            timeseries object with data to be written
        observatory: str
            observatory code
        channel: str
            channel to load
        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
            data type
        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
            data interval
        starttime: UTCDateTime
        endtime: UTCDateTime
        """
        sncl = self.get_sncl(
            station=observatory,
            data_type=type,
            interval=interval,
            element=channel,
            location=self.locationCode,
        )

        stream_masked = self._convert_stream_to_masked(
            timeseries=timeseries, channel=channel
        )
        stream_split = stream_masked.split()

        for trace in stream_split:
            trace_send = trace.copy()
            trace_send.trim(starttime, endtime)
            if channel == "D":
                trace_send.data = ChannelConverter.get_minutes_from_radians(
                    trace_send.data
                )
            if sncl.channel[1] == "E":
                # instrumental voltages are stored as 1/10 microvolts
                trace_send.data = trace_send.data * (self.scale_factor or 1e7)
            else:
                # everything else (mostly magnetics stored as picoteslas)
                trace_send.data = trace_send.data * (self.scale_factor or 1e3)

            ric = RawInputClient(
                self.tag,
                self.host,
                self.write_port,
                sncl.station,
                sncl.channel,
                sncl.location,
                sncl.network,
            )
            trace_send.data = trace_send.data.astype(int)  # ric requires ints
            try:
                ric.send_trace(interval, trace_send)
            except Exception as e:
                print("No data sent to", ric.host, ":", ric.port)
                print("Ignoring error: ", e.__class__, e)
            if self.forceout:
                ric.forceout()
            ric.close()

    def _set_metadata(
        self,
        stream: Stream,
        observatory: str,
        channel: str,
        type: str,
        interval: str,
    ):
        """set metadata for a given stream/channel
        Parameters
        ----------
        observatory: str
            observatory code
        channel: str
            edge channel code {MVH, MVE, MVD, ...}
        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
            data type
        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
            data interval
        """
        for trace in stream:
            self.observatoryMetadata.set_metadata(
                trace.stats, observatory, channel, type, interval
            )