From 3339c20723865cf82a2f0040f72f7a64b0b74156 Mon Sep 17 00:00:00 2001
From: pcain <pcain@usgs.gov>
Date: Mon, 4 Oct 2021 10:26:19 -0600
Subject: [PATCH] type hinting for MiniSeedFactory

---
 geomagio/edge/MiniSeedFactory.py | 271 +++++++++++++++++--------------
 1 file changed, 150 insertions(+), 121 deletions(-)

diff --git a/geomagio/edge/MiniSeedFactory.py b/geomagio/edge/MiniSeedFactory.py
index 94103caee..83011d608 100644
--- a/geomagio/edge/MiniSeedFactory.py
+++ b/geomagio/edge/MiniSeedFactory.py
@@ -9,15 +9,16 @@ to take advantage of it's newer realtime abilities.
 Edge is the USGS earthquake hazard centers replacement for earthworm.
 """
 from __future__ import absolute_import
-
 import sys
+from typing import List, Optional
+
 import numpy
 import numpy.ma
-
-import obspy.core
 from obspy.clients.neic import client as miniseed
+from obspy.core import Stats, Stream, Trace, UTCDateTime
 
 from .. import ChannelConverter, TimeseriesUtility
+from ..geomag_types import DataInterval, DataType
 from ..Metadata import get_instrument
 from ..TimeseriesFactory import TimeseriesFactory
 from ..TimeseriesFactoryException import TimeseriesFactoryException
@@ -41,16 +42,18 @@ class MiniSeedFactory(TimeseriesFactory):
         an array of channels {H, D, E, F, Z, MGD, MSD, HGD}.
         Known since channel names are mapped based on interval and type,
         others are passed through, see #_get_edge_channel().
-    type: str
-        the data type {variation, quasi-definitive, definitive}
-    interval: str
-        the data interval {'day', 'hour', 'minute', 'second', 'tenhertz'}
+    type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+        data type
+    interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
+        data interval
     observatoryMetadata: ObservatoryMetadata object
         an ObservatoryMetadata object used to replace the default
         ObservatoryMetadata.
     locationCode: str
         the location code for the given edge server, overrides type
         in get_timeseries/put_timeseries
+    convert_channels: array
+        list of channels to convert from volt/bin to nT
 
     See Also
     --------
@@ -66,16 +69,16 @@ class MiniSeedFactory(TimeseriesFactory):
 
     def __init__(
         self,
-        host="cwbpub.cr.usgs.gov",
-        port=2061,
-        write_port=7974,
-        observatory=None,
-        channels=None,
-        type=None,
-        interval=None,
-        observatoryMetadata=None,
-        locationCode=None,
-        convert_channels=None,
+        host: str = "cwbpub.cr.usgs.gov",
+        port: int = 2061,
+        write_port: int = 7974,
+        observatory: Optional[str] = None,
+        channels: Optional[List[str]] = None,
+        type: Optional[DataType] = None,
+        interval: Optional[DataInterval] = None,
+        observatoryMetadata: Optional[ObservatoryMetadata] = None,
+        locationCode: Optional[str] = None,
+        convert_channels: Optional[List[str]] = None,
     ):
         TimeseriesFactory.__init__(self, observatory, channels, type, interval)
 
@@ -91,36 +94,36 @@ class MiniSeedFactory(TimeseriesFactory):
 
     def get_timeseries(
         self,
-        starttime,
-        endtime,
-        observatory=None,
-        channels=None,
-        type=None,
-        interval=None,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
+        observatory: Optional[str] = None,
+        channels: Optional[List[str]] = None,
+        type: Optional[DataType] = None,
+        interval: Optional[DataInterval] = None,
         add_empty_channels: bool = True,
-    ):
+    ) -> Stream:
         """Get timeseries data
 
         Parameters
         ----------
-        starttime: obspy.core.UTCDateTime
-            time of first sample.
-        endtime: obspy.core.UTCDateTime
-            time of last sample.
+        starttime: UTCDateTime
+            time of first sample
+        endtime: UTCDateTime
+            time of last sample
+        add_empty_channels: bool
+            if True, returns channels without data as empty traces
         observatory: str
-            observatory code.
-        channels: array_like
+            observatory code
+        channels: array
             list of channels to load
-        type: {'variation', 'quasi-definitive', 'definitive'}
-            data type.
-        interval: {'day', 'hour', 'minute', 'second', 'tenhertz'}
-            data interval.
-        add_empty_channels
-            if True, returns channels without data as empty traces
+        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+            data type
+        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
+            data interval
 
         Returns
         -------
-        obspy.core.Stream
+        timeseries: Stream
             timeseries object with requested data.
 
         Raises
@@ -145,7 +148,7 @@ class MiniSeedFactory(TimeseriesFactory):
             # send stdout to stderr
             sys.stdout = sys.stderr
             # get the timeseries
-            timeseries = obspy.core.Stream()
+            timeseries = Stream()
             for channel in channels:
                 if channel in self.convert_channels:
                     data = self._convert_timeseries(
@@ -173,28 +176,28 @@ class MiniSeedFactory(TimeseriesFactory):
 
     def put_timeseries(
         self,
-        timeseries,
-        starttime=None,
-        endtime=None,
-        observatory=None,
-        channels=None,
-        type=None,
-        interval=None,
+        timeseries: Stream,
+        starttime: Optional[UTCDateTime] = None,
+        endtime: Optional[UTCDateTime] = None,
+        observatory: Optional[str] = None,
+        channels: Optional[List[str]] = None,
+        type: Optional[DataType] = None,
+        interval: Optional[DataInterval] = None,
     ):
         """Put timeseries data
 
         Parameters
         ----------
-        timeseries: obspy.core.Stream
+        timeseries: Stream
             timeseries object with data to be written
         observatory: str
-            observatory code.
-        channels: array_like
+            observatory code
+        channels: array
             list of channels to load
-        type: {'variation', 'quasi-definitive', 'definitive'}
-            data type.
-        interval: {'day', 'hour', 'minute', 'second', 'tenhertz'}
-            data interval.
+        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+            data type
+        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
+            data interval
 
         Notes
         -----
@@ -226,24 +229,31 @@ class MiniSeedFactory(TimeseriesFactory):
         self.write_client.close()
 
     def get_calculated_timeseries(
-        self, starttime, endtime, observatory, channel, type, interval, components
-    ):
+        self,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
+        observatory: str,
+        channel: str,
+        type: DataType,
+        interval: DataInterval,
+        components: List[dict],
+    ) -> Trace:
         """Calculate a single channel using multiple component channels.
 
         Parameters
         ----------
-        starttime: obspy.core.UTCDateTime
+        starttime: UTCDateTime
             the starttime of the requested data
-        endtime: obspy.core.UTCDateTime
+        endtime: UTCDateTime
             the endtime of the requested data
-        observatory : str
+        observatory: str
             observatory code
-        channel : str
+        channel: str
             single character channel {H, E, D, Z, F}
-        type : str
-            data type {definitive, quasi-definitive, variation}
-        interval : str
-            interval length {'day', 'hour', 'minute', 'second', 'tenhertz'}
+        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+            data type
+        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
+            data interval
         components: list
             each component is a dictionary with the following keys:
                 channel: str
@@ -252,7 +262,7 @@ class MiniSeedFactory(TimeseriesFactory):
 
         Returns
         -------
-        obspy.core.trace
+        out: Trace
             timeseries trace of the converted channel data
         """
         # sum channels
@@ -268,7 +278,7 @@ class MiniSeedFactory(TimeseriesFactory):
             # add to converted
             if converted is None:
                 converted = nt
-                stats = obspy.core.Stats(data.stats)
+                stats = Stats(data.stats)
             else:
                 converted += nt
         # set channel parameter to U, V, or W
@@ -288,19 +298,19 @@ class MiniSeedFactory(TimeseriesFactory):
         out.data = converted
         return out
 
-    def _convert_stream_to_masked(self, timeseries, channel):
+    def _convert_stream_to_masked(self, timeseries: Stream, channel: str) -> Stream:
         """convert geomag edge traces in a timeseries stream to a MaskedArray
             This allows for gaps and splitting.
         Parameters
         ----------
-        stream : obspy.core.stream
-            a stream retrieved from a geomag edge representing one channel.
-        channel: string
-            the channel to be masked.
+        stream: Stream
+            a stream retrieved from a geomag edge representing one channel
+        channel: str
+            the channel to be masked
         Returns
         -------
-        obspy.core.stream
-            a stream with all traces converted to masked arrays.
+        stream: Stream
+            a stream with all traces converted to masked arrays
         """
         stream = timeseries.copy()
         for trace in stream.select(channel=channel):
@@ -309,36 +319,36 @@ class MiniSeedFactory(TimeseriesFactory):
 
     def _get_timeseries(
         self,
-        starttime,
-        endtime,
-        observatory,
-        channel,
-        type,
-        interval,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
+        observatory: str,
+        channel: str,
+        type: DataType,
+        interval: DataInterval,
         add_empty_channels: bool = True,
-    ):
+    ) -> Trace:
         """get timeseries data for a single channel.
 
         Parameters
         ----------
-        starttime: obspy.core.UTCDateTime
+        starttime: UTCDateTime
             the starttime of the requested data
-        endtime: obspy.core.UTCDateTime
+        endtime: UTCDateTime
             the endtime of the requested data
-        observatory : str
+        observatory: str
             observatory code
-        channel : str
+        channel: str
             single character channel {H, E, D, Z, F}
-        type : str
-            data type {definitive, quasi-definitive, variation}
-        interval : str
-            interval length {'day', 'hour', 'minute', 'second', 'tenhertz'}
-        add_empty_channels
+        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+            data type
+        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
+            interval length
+        add_empty_channels: bool
             if True, returns channels without data as empty traces
 
         Returns
         -------
-        obspy.core.trace
+        data: Trace
             timeseries trace of the requested channel data
         """
         sncl = SNCL.get_sncl(
@@ -367,8 +377,14 @@ class MiniSeedFactory(TimeseriesFactory):
         return data
 
     def _convert_timeseries(
-        self, starttime, endtime, observatory, channel, type, interval
-    ):
+        self,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
+        observatory: str,
+        channel: str,
+        type: DataType,
+        interval: DataInterval,
+    ) -> Trace:
         """Generate a single channel using multiple components.
 
         Finds metadata, then calls _get_converted_timeseries for actual
@@ -376,25 +392,25 @@ class MiniSeedFactory(TimeseriesFactory):
 
         Parameters
         ----------
-        starttime: obspy.core.UTCDateTime
+        starttime: UTCDateTime
             the starttime of the requested data
-        endtime: obspy.core.UTCDateTime
+        endtime: UTCDateTime
             the endtime of the requested data
         observatory : str
             observatory code
         channel : str
             single character channel {H, E, D, Z, F}
-        type : str
-            data type {definitive, quasi-definitive, variation}
-        interval : str
-            interval length {'day', 'hour', 'minute', 'second', 'tenhertz'}
+        type : {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+            data type
+        interval : {'tenhertz', 'second', 'minute', 'hour', 'day'}
+            data interval
 
         Returns
         -------
-        obspy.core.trace
+        out: Trace
             timeseries trace of the requested channel data
         """
-        out = obspy.core.Stream()
+        out = Stream()
         metadata = get_instrument(observatory, starttime, endtime)
         # loop in case request spans different configurations
         for entry in metadata:
@@ -428,7 +444,13 @@ class MiniSeedFactory(TimeseriesFactory):
             )
         return out
 
-    def _post_process(self, timeseries, starttime, endtime, channels):
+    def _post_process(
+        self,
+        timeseries: Stream,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
+        channels: List[str],
+    ):
         """Post process a timeseries stream after the raw data is
                 is fetched from querymom. Specifically changes
                 any MaskedArray to a ndarray with nans representing gaps.
@@ -437,13 +459,13 @@ class MiniSeedFactory(TimeseriesFactory):
 
         Parameters
         ----------
-        timeseries: obspy.core.stream
+        timeseries: Stream
             The timeseries stream as returned by the call to get_waveforms
-        starttime: obspy.core.UTCDateTime
+        starttime: UTCDateTime
             the starttime of the requested data
-        endtime: obspy.core.UTCDateTime
+        endtime: UTCDateTime
             the endtime of the requested data
-        channels: array_like
+        channels: array
             list of channels to load
 
         Notes: the original timeseries object is changed.
@@ -460,24 +482,31 @@ class MiniSeedFactory(TimeseriesFactory):
         TimeseriesUtility.pad_timeseries(timeseries, starttime, endtime)
 
     def _put_channel(
-        self, timeseries, observatory, channel, type, interval, starttime, endtime
+        self,
+        timeseries: Stream,
+        observatory: str,
+        channel: str,
+        type: DataType,
+        interval: DataInterval,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
     ):
         """Put a channel worth of data
 
         Parameters
         ----------
-        timeseries: obspy.core.Stream
+        timeseries: Stream
             timeseries object with data to be written
         observatory: str
-            observatory code.
+            observatory code
         channel: str
             channel to load
-        type: {'variation', 'quasi-definitive', 'definitive'}
-            data type.
-        interval: {'day', 'hour', 'minute', 'second', 'tenhertz'}
-            data interval.
-        starttime: obspy.core.UTCDateTime
-        endtime: obspy.core.UTCDateTime
+        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+            data type
+        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
+            data interval
+        starttime: UTCDateTime
+        endtime: UTCDateTime
         """
         # use separate traces when there are gaps
         to_write = timeseries.select(channel=channel)
@@ -502,23 +531,23 @@ class MiniSeedFactory(TimeseriesFactory):
 
     def _set_metadata(
         self,
-        stream: obspy.core.Stream,
+        stream: Stream,
         observatory: str,
         channel: str,
-        type: str,
-        interval: str,
+        type: DataType,
+        interval: DataInterval,
     ):
         """set metadata for a given stream/channel
         Parameters
         ----------
-        observatory
+        observatory: str
             observatory code
-        channel
+        channel: str
             edge channel code {MVH, MVE, MVD, ...}
-        type
-            data type {definitive, quasi-definitive, variation}
-        interval
-            interval length {minute, second}
+        type: {'adjusted', 'definitive', 'quasi-definitive', 'variation'}
+            data type
+        interval: {'tenhertz', 'second', 'minute', 'hour', 'day'}
+            data interval
         """
         for trace in stream:
             self.observatoryMetadata.set_metadata(
-- 
GitLab