From cd040a1ff7734c12d3195dfaa8b52014f3d9d1f3 Mon Sep 17 00:00:00 2001
From: pcain <pcain@usgs.gov>
Date: Wed, 22 Sep 2021 13:35:34 -0600
Subject: [PATCH] add type hints to TimeseriesFactory

---
 geomagio/TimeseriesFactory.py | 141 ++++++++++++++++++----------------
 1 file changed, 76 insertions(+), 65 deletions(-)

diff --git a/geomagio/TimeseriesFactory.py b/geomagio/TimeseriesFactory.py
index 3d92f9c5..06b32c0b 100644
--- a/geomagio/TimeseriesFactory.py
+++ b/geomagio/TimeseriesFactory.py
@@ -1,10 +1,14 @@
 """Abstract Timeseries Factory Interface."""
 from __future__ import absolute_import, print_function
-
-import numpy
-import obspy.core
+from io import BytesIO
 import os
 import sys
+from typing import List, Optional
+
+import numpy
+from obspy import Stream, Trace, UTCDateTime
+
+from .geomag_types import DataInterval, DataType
 from .TimeseriesFactoryException import TimeseriesFactoryException
 from . import TimeseriesUtility
 from . import Util
@@ -28,10 +32,10 @@ class TimeseriesFactory(object):
     channels : array_like
         default list of channels to load, optional.
         default ('H', 'D', 'Z', 'F')
-    type : {'definitive', 'provisional', 'quasi-definitive', 'variation'}
+    type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'}
         default data type, optional.
         default 'variation'.
-    interval : {'day', 'hour', 'minute', 'month, 'second'}
+    interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'}
         data interval, optional.
         default 'minute'.
     urlTemplate : str
@@ -45,12 +49,12 @@ class TimeseriesFactory(object):
 
     def __init__(
         self,
-        observatory=None,
-        channels=("H", "D", "Z", "F"),
-        type="variation",
-        interval="minute",
-        urlTemplate="",
-        urlInterval=-1,
+        observatory: Optional[str] = None,
+        channels: List[str] = ("H", "D", "Z", "F"),
+        type: DataType = "variation",
+        interval: DataInterval = "minute",
+        urlTemplate: str = "",
+        urlInterval: int = -1,
     ):
         self.observatory = observatory
         self.channels = channels
@@ -61,14 +65,14 @@ class TimeseriesFactory(object):
 
     def get_timeseries(
         self,
-        starttime,
-        endtime,
-        observatory=None,
-        channels=None,
-        type=None,
-        interval=None,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
         add_empty_channels: bool = True,
-    ):
+        observatory: Optional[str] = None,
+        channels: Optional[List[str]] = None,
+        type: Optional[DataType] = None,
+        interval: Optional[DataInterval] = None,
+    ) -> Stream:
         """Get timeseries data.
 
         Support for specific channels, types, and intervals varies
@@ -88,18 +92,18 @@ class TimeseriesFactory(object):
         channels : array_like
             list of channels to load, optional.
             uses default if unspecified.
-        type : {'definitive', 'provisional', 'quasi-definitive', 'variation'}
+        type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'}
             data type, optional.
             uses default if unspecified.
-        interval : {'day', 'hour', 'minute', 'month', 'second'}
+        interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'}
             data interval, optional.
             uses default if unspecified.
-        add_empty_channels
+        add_empty_channels : bool
             if True, returns channels without data as empty traces
 
         Returns
         -------
-        obspy.core.Stream
+        timeseries : Stream
             stream containing traces for requested timeseries.
 
         Raises
@@ -112,7 +116,7 @@ class TimeseriesFactory(object):
         type = type or self.type
         interval = interval or self.interval
 
-        timeseries = obspy.core.Stream()
+        timeseries = Stream()
         urlIntervals = Util.get_intervals(
             starttime=starttime, endtime=endtime, size=self.urlInterval
         )
@@ -143,7 +147,7 @@ class TimeseriesFactory(object):
                 print("Error parsing data: " + str(e), file=sys.stderr)
                 print(data, file=sys.stderr)
         if channels is not None:
-            filtered = obspy.core.Stream()
+            filtered = Stream()
             for channel in channels:
                 filtered += timeseries.select(channel=channel)
             timeseries = filtered
@@ -157,7 +161,7 @@ class TimeseriesFactory(object):
         )
         return timeseries
 
-    def parse_string(self, data, **kwargs):
+    def parse_string(self, data: str, **kwargs):
         """Creates error message that this functions is not implemented by
         TimeseriesFactory.
 
@@ -174,18 +178,18 @@ class TimeseriesFactory(object):
 
     def put_timeseries(
         self,
-        timeseries,
-        starttime=None,
-        endtime=None,
-        channels=None,
-        type=None,
-        interval=None,
+        timeseries: Stream,
+        starttime: Optional[UTCDateTime] = None,
+        endtime: Optional[UTCDateTime] = None,
+        channels: Optional[List[str]] = None,
+        type: Optional[DataType] = None,
+        interval: Optional[DataInterval] = None,
     ):
         """Store timeseries data.
 
         Parameters
         ----------
-        timeseries : obspy.core.Stream
+        timeseries : Stream
             stream containing traces to store.
         starttime : UTCDateTime
             time of first sample in timeseries to store.
@@ -196,10 +200,10 @@ class TimeseriesFactory(object):
         channels : array_like
             list of channels to store, optional.
             uses default if unspecified.
-        type : {'definitive', 'provisional', 'quasi-definitive', 'variation'}
+        type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'}
             data type, optional.
             uses default if unspecified.
-        interval : {'day', 'hour', 'minute', 'month', 'second'}
+        interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'}
             data interval, optional.
             uses default if unspecified.
         Raises
@@ -283,14 +287,14 @@ class TimeseriesFactory(object):
                 except NotImplementedError:
                     raise NotImplementedError('"put_timeseries" not implemented')
 
-    def write_file(self, fh, timeseries, channels):
+    def write_file(self, fh: BytesIO, timeseries: Stream, channels: List[str]):
         """Write timeseries data to the given file object.
 
         Parameters
         ----------
         fh : writable
             file handle where data is written.
-        timeseries : obspy.core.Stream
+        timeseries : Stream
             stream containing traces to store.
         channels : list
             list of channels to store.
@@ -299,15 +303,15 @@ class TimeseriesFactory(object):
 
     def _get_empty_trace(
         self,
-        starttime: obspy.core.UTCDateTime,
-        endtime: obspy.core.UTCDateTime,
+        starttime: UTCDateTime,
+        endtime: UTCDateTime,
         observatory: str,
         channel: str,
         data_type: str,
         interval: str,
         network: str = "NT",
         location: str = "",
-    ) -> obspy.core.Trace:
+    ) -> Trace:
         """creates empty trace"""
         trace = TimeseriesUtility.create_empty_trace(
             starttime=starttime,
@@ -322,7 +326,7 @@ class TimeseriesFactory(object):
         )
         return trace
 
-    def _get_file_from_url(self, url):
+    def _get_file_from_url(self, url: str) -> str:
         """Get a file for writing.
 
         Ensures parent directory exists.
@@ -334,7 +338,7 @@ class TimeseriesFactory(object):
 
         Returns
         -------
-        str
+        filename : str
             path to file without file:// prefix
 
         Raises
@@ -351,8 +355,13 @@ class TimeseriesFactory(object):
         return filename
 
     def _get_url(
-        self, observatory, date, type="variation", interval="minute", channels=None
-    ):
+        self,
+        observatory: str,
+        date: UTCDateTime,
+        type: DataType = "variation",
+        interval: DataInterval = "minute",
+        channels: Optional[List[str]] = None,
+    ) -> str:
         """Get the url for a specified file.
 
         Replaces patterns (described in class docstring) with values based on
@@ -362,12 +371,11 @@ class TimeseriesFactory(object):
         ----------
         observatory : str
             observatory code.
-        date : obspy.core.UTCDateTime
+        date : UTCDateTime
             day to fetch (only year, month, day are used)
-        type : {'variation', 'reported', 'provisional', 'adjusted',
-                'quasi-definitive', 'definitive'}
+        type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'}
             data type.
-        interval : {'minute', 'second', 'hour', 'day'}
+        interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'}
             data interval.
         channels : list
             list of data channels being requested
@@ -405,14 +413,15 @@ class TimeseriesFactory(object):
         # use old style string interpolation
         return self.urlTemplate % params
 
-    def _get_interval_abbreviation(self, interval):
+    def _get_interval_abbreviation(self, interval: DataInterval) -> str:
         """Get abbreviation for a data interval.
 
         Used by ``_get_url`` to replace ``%(i)s`` in urlTemplate.
 
         Parameters
         ----------
-        interval : {'day', 'hour', 'minute', 'month', 'second'}
+        interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'}
+            data interval
 
         Returns
         -------
@@ -439,14 +448,15 @@ class TimeseriesFactory(object):
             raise TimeseriesFactoryException('Unexpected interval "%s"' % interval)
         return interval_abbr
 
-    def _get_interval_name(self, interval):
+    def _get_interval_name(self, interval: DataInterval) -> str:
         """Get name for a data interval.
 
         Used by ``_get_url`` to replace ``%(interval)s`` in urlTemplate.
 
         Parameters
         ----------
-        interval : {'minute', 'second'}
+        interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'}
+            data interval
 
         Returns
         -------
@@ -468,14 +478,15 @@ class TimeseriesFactory(object):
             raise TimeseriesFactoryException('Unsupported interval "%s"' % interval)
         return interval_name
 
-    def _get_type_abbreviation(self, type):
+    def _get_type_abbreviation(self, type: DataType) -> str:
         """Get abbreviation for a data type.
 
         Used by ``_get_url`` to replace ``%(t)s`` in urlTemplate.
 
         Parameters
         ----------
-        type : {'definitive', 'provisional', 'quasi-definitive', 'variation'}
+        type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'}
+            data type
 
         Returns
         -------
@@ -499,15 +510,15 @@ class TimeseriesFactory(object):
             raise TimeseriesFactoryException('Unexpected type "%s"' % type)
         return type_abbr
 
-    def _get_type_name(self, type):
+    def _get_type_name(self, type: DataType) -> str:
         """Get name for a data type.
 
         Used by ``_get_url`` to replace ``%(type)s`` in urlTemplate.
 
         Parameters
         ----------
-        type : {'variation', 'reported', 'provisional', 'adjusted',
-                'quasi-definitive', 'quasidefinitive', 'definitive' }
+        type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'}
+            data type
 
         Returns
         -------
@@ -533,22 +544,22 @@ class TimeseriesFactory(object):
 
     def _set_metadata(
         self,
-        stream: obspy.core.Stream,
+        stream: Stream,
         observatory: str,
         channel: str,
-        type: str,
-        interval: str,
+        type: DataType,
+        interval: DataInterval,
     ):
         """set metadata for a given stream/channel
         Parameters
         ----------
-        observatory
+        observatory : str
             observatory code
-        channel
+        channel : str
             edge channel code {MVH, MVE, MVD, ...}
-        type
-            data type {definitive, quasi-definitive, variation}
-        interval
-            interval length {minute, second}
+        type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'}
+            data type
+        interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'}
+            interval length
         """
         pass
-- 
GitLab