From cd040a1ff7734c12d3195dfaa8b52014f3d9d1f3 Mon Sep 17 00:00:00 2001 From: pcain <pcain@usgs.gov> Date: Wed, 22 Sep 2021 13:35:34 -0600 Subject: [PATCH] add type hints to TimeseriesFactory --- geomagio/TimeseriesFactory.py | 141 ++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 65 deletions(-) diff --git a/geomagio/TimeseriesFactory.py b/geomagio/TimeseriesFactory.py index 3d92f9c5..06b32c0b 100644 --- a/geomagio/TimeseriesFactory.py +++ b/geomagio/TimeseriesFactory.py @@ -1,10 +1,14 @@ """Abstract Timeseries Factory Interface.""" from __future__ import absolute_import, print_function - -import numpy -import obspy.core +from io import BytesIO import os import sys +from typing import List, Optional + +import numpy +from obspy import Stream, Trace, UTCDateTime + +from .geomag_types import DataInterval, DataType from .TimeseriesFactoryException import TimeseriesFactoryException from . import TimeseriesUtility from . import Util @@ -28,10 +32,10 @@ class TimeseriesFactory(object): channels : array_like default list of channels to load, optional. default ('H', 'D', 'Z', 'F') - type : {'definitive', 'provisional', 'quasi-definitive', 'variation'} + type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'} default data type, optional. default 'variation'. - interval : {'day', 'hour', 'minute', 'month, 'second'} + interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'} data interval, optional. default 'minute'. urlTemplate : str @@ -45,12 +49,12 @@ class TimeseriesFactory(object): def __init__( self, - observatory=None, - channels=("H", "D", "Z", "F"), - type="variation", - interval="minute", - urlTemplate="", - urlInterval=-1, + observatory: Optional[str] = None, + channels: List[str] = ("H", "D", "Z", "F"), + type: DataType = "variation", + interval: DataInterval = "minute", + urlTemplate: str = "", + urlInterval: int = -1, ): self.observatory = observatory self.channels = channels @@ -61,14 +65,14 @@ class TimeseriesFactory(object): def get_timeseries( self, - starttime, - endtime, - observatory=None, - channels=None, - type=None, - interval=None, + starttime: UTCDateTime, + endtime: UTCDateTime, add_empty_channels: bool = True, - ): + observatory: Optional[str] = None, + channels: Optional[List[str]] = None, + type: Optional[DataType] = None, + interval: Optional[DataInterval] = None, + ) -> Stream: """Get timeseries data. Support for specific channels, types, and intervals varies @@ -88,18 +92,18 @@ class TimeseriesFactory(object): channels : array_like list of channels to load, optional. uses default if unspecified. - type : {'definitive', 'provisional', 'quasi-definitive', 'variation'} + type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'} data type, optional. uses default if unspecified. - interval : {'day', 'hour', 'minute', 'month', 'second'} + interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'} data interval, optional. uses default if unspecified. - add_empty_channels + add_empty_channels : bool if True, returns channels without data as empty traces Returns ------- - obspy.core.Stream + timeseries : Stream stream containing traces for requested timeseries. Raises @@ -112,7 +116,7 @@ class TimeseriesFactory(object): type = type or self.type interval = interval or self.interval - timeseries = obspy.core.Stream() + timeseries = Stream() urlIntervals = Util.get_intervals( starttime=starttime, endtime=endtime, size=self.urlInterval ) @@ -143,7 +147,7 @@ class TimeseriesFactory(object): print("Error parsing data: " + str(e), file=sys.stderr) print(data, file=sys.stderr) if channels is not None: - filtered = obspy.core.Stream() + filtered = Stream() for channel in channels: filtered += timeseries.select(channel=channel) timeseries = filtered @@ -157,7 +161,7 @@ class TimeseriesFactory(object): ) return timeseries - def parse_string(self, data, **kwargs): + def parse_string(self, data: str, **kwargs): """Creates error message that this functions is not implemented by TimeseriesFactory. @@ -174,18 +178,18 @@ class TimeseriesFactory(object): def put_timeseries( self, - timeseries, - starttime=None, - endtime=None, - channels=None, - type=None, - interval=None, + timeseries: Stream, + starttime: Optional[UTCDateTime] = None, + endtime: Optional[UTCDateTime] = None, + channels: Optional[List[str]] = None, + type: Optional[DataType] = None, + interval: Optional[DataInterval] = None, ): """Store timeseries data. Parameters ---------- - timeseries : obspy.core.Stream + timeseries : Stream stream containing traces to store. starttime : UTCDateTime time of first sample in timeseries to store. @@ -196,10 +200,10 @@ class TimeseriesFactory(object): channels : array_like list of channels to store, optional. uses default if unspecified. - type : {'definitive', 'provisional', 'quasi-definitive', 'variation'} + type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'} data type, optional. uses default if unspecified. - interval : {'day', 'hour', 'minute', 'month', 'second'} + interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'} data interval, optional. uses default if unspecified. Raises @@ -283,14 +287,14 @@ class TimeseriesFactory(object): except NotImplementedError: raise NotImplementedError('"put_timeseries" not implemented') - def write_file(self, fh, timeseries, channels): + def write_file(self, fh: BytesIO, timeseries: Stream, channels: List[str]): """Write timeseries data to the given file object. Parameters ---------- fh : writable file handle where data is written. - timeseries : obspy.core.Stream + timeseries : Stream stream containing traces to store. channels : list list of channels to store. @@ -299,15 +303,15 @@ class TimeseriesFactory(object): def _get_empty_trace( self, - starttime: obspy.core.UTCDateTime, - endtime: obspy.core.UTCDateTime, + starttime: UTCDateTime, + endtime: UTCDateTime, observatory: str, channel: str, data_type: str, interval: str, network: str = "NT", location: str = "", - ) -> obspy.core.Trace: + ) -> Trace: """creates empty trace""" trace = TimeseriesUtility.create_empty_trace( starttime=starttime, @@ -322,7 +326,7 @@ class TimeseriesFactory(object): ) return trace - def _get_file_from_url(self, url): + def _get_file_from_url(self, url: str) -> str: """Get a file for writing. Ensures parent directory exists. @@ -334,7 +338,7 @@ class TimeseriesFactory(object): Returns ------- - str + filename : str path to file without file:// prefix Raises @@ -351,8 +355,13 @@ class TimeseriesFactory(object): return filename def _get_url( - self, observatory, date, type="variation", interval="minute", channels=None - ): + self, + observatory: str, + date: UTCDateTime, + type: DataType = "variation", + interval: DataInterval = "minute", + channels: Optional[List[str]] = None, + ) -> str: """Get the url for a specified file. Replaces patterns (described in class docstring) with values based on @@ -362,12 +371,11 @@ class TimeseriesFactory(object): ---------- observatory : str observatory code. - date : obspy.core.UTCDateTime + date : UTCDateTime day to fetch (only year, month, day are used) - type : {'variation', 'reported', 'provisional', 'adjusted', - 'quasi-definitive', 'definitive'} + type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'} data type. - interval : {'minute', 'second', 'hour', 'day'} + interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'} data interval. channels : list list of data channels being requested @@ -405,14 +413,15 @@ class TimeseriesFactory(object): # use old style string interpolation return self.urlTemplate % params - def _get_interval_abbreviation(self, interval): + def _get_interval_abbreviation(self, interval: DataInterval) -> str: """Get abbreviation for a data interval. Used by ``_get_url`` to replace ``%(i)s`` in urlTemplate. Parameters ---------- - interval : {'day', 'hour', 'minute', 'month', 'second'} + interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'} + data interval Returns ------- @@ -439,14 +448,15 @@ class TimeseriesFactory(object): raise TimeseriesFactoryException('Unexpected interval "%s"' % interval) return interval_abbr - def _get_interval_name(self, interval): + def _get_interval_name(self, interval: DataInterval) -> str: """Get name for a data interval. Used by ``_get_url`` to replace ``%(interval)s`` in urlTemplate. Parameters ---------- - interval : {'minute', 'second'} + interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'} + data interval Returns ------- @@ -468,14 +478,15 @@ class TimeseriesFactory(object): raise TimeseriesFactoryException('Unsupported interval "%s"' % interval) return interval_name - def _get_type_abbreviation(self, type): + def _get_type_abbreviation(self, type: DataType) -> str: """Get abbreviation for a data type. Used by ``_get_url`` to replace ``%(t)s`` in urlTemplate. Parameters ---------- - type : {'definitive', 'provisional', 'quasi-definitive', 'variation'} + type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'} + data type Returns ------- @@ -499,15 +510,15 @@ class TimeseriesFactory(object): raise TimeseriesFactoryException('Unexpected type "%s"' % type) return type_abbr - def _get_type_name(self, type): + def _get_type_name(self, type: DataType) -> str: """Get name for a data type. Used by ``_get_url`` to replace ``%(type)s`` in urlTemplate. Parameters ---------- - type : {'variation', 'reported', 'provisional', 'adjusted', - 'quasi-definitive', 'quasidefinitive', 'definitive' } + type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'} + data type Returns ------- @@ -533,22 +544,22 @@ class TimeseriesFactory(object): def _set_metadata( self, - stream: obspy.core.Stream, + stream: Stream, observatory: str, channel: str, - type: str, - interval: str, + type: DataType, + interval: DataInterval, ): """set metadata for a given stream/channel Parameters ---------- - observatory + observatory : str observatory code - channel + channel : str edge channel code {MVH, MVE, MVD, ...} - type - data type {definitive, quasi-definitive, variation} - interval - interval length {minute, second} + type : {'adjusted', 'definitive', 'provisional', 'quasi-definitive', 'reported', 'variation'} + data type + interval : {'tenhertz', 'second', 'minute', 'hour', 'day', 'month'} + interval length """ pass -- GitLab