From af8ea2e699fee3fee8c17707cd92095b6b6d765c Mon Sep 17 00:00:00 2001
From: pcain-usgs <pcain@usgs.gov>
Date: Mon, 15 Mar 2021 14:33:01 -0600
Subject: [PATCH] Make algorithm an optional parameter for run methods

---
 geomagio/Controller.py                | 21 ++++++++---
 geomagio/algorithm/FilterAlgorithm.py |  2 +-
 geomagio/processing/obsrio.py         | 53 ++++++++++++++++-----------
 3 files changed, 47 insertions(+), 29 deletions(-)

diff --git a/geomagio/Controller.py b/geomagio/Controller.py
index 199ab64e4..e91ef4d25 100644
--- a/geomagio/Controller.py
+++ b/geomagio/Controller.py
@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple, Union
 
 from obspy.core import Stream, UTCDateTime
 
-from .algorithm import algorithms, AlgorithmException
+from .algorithm import Algorithm, algorithms, AlgorithmException
 from .PlotTimeseriesFactory import PlotTimeseriesFactory
 from .StreamTimeseriesFactory import StreamTimeseriesFactory
 from . import TimeseriesUtility, Util
@@ -51,7 +51,7 @@ class Controller(object):
         self,
         inputFactory,
         outputFactory,
-        algorithm,
+        algorithm: Optional[Algorithm] = None,
         inputInterval: Optional[str] = None,
         outputInterval: Optional[str] = None,
     ):
@@ -61,7 +61,9 @@ class Controller(object):
         self._outputFactory = outputFactory
         self._outputInterval = outputInterval
 
-    def _get_input_timeseries(self, observatory, channels, starttime, endtime):
+    def _get_input_timeseries(
+        self, observatory, channels, starttime, endtime, algorithm=None
+    ):
         """Get timeseries from the input factory for requested options.
 
         Parameters
@@ -84,12 +86,13 @@ class Controller(object):
         -------
         timeseries : obspy.core.Stream
         """
+        algorithm = algorithm or self._algorithm
         timeseries = Stream()
         for obs in observatory:
             # get input interval for observatory
             # do this per observatory in case an
             # algorithm needs different amounts of data
-            input_start, input_end = self._algorithm.get_input_interval(
+            input_start, input_end = algorithm.get_input_interval(
                 start=starttime, end=endtime, observatory=obs, channels=channels
             )
             if input_start is None or input_end is None:
@@ -217,6 +220,7 @@ class Controller(object):
         observatory: List[str],
         starttime: UTCDateTime,
         endtime: UTCDateTime,
+        algorithm: Optional[Algorithm] = None,
         input_channels: Optional[List[str]] = None,
         input_timeseries: Optional[Stream] = None,
         output_channels: Optional[List[str]] = None,
@@ -243,13 +247,14 @@ class Controller(object):
         # ensure realtime is a valid value:
         if realtime <= 0:
             realtime = False
-        algorithm = self._algorithm
+        algorithm = algorithm or self._algorithm
         input_channels = input_channels or algorithm.get_input_channels()
         output_channels = output_channels or algorithm.get_output_channels()
         next_starttime = algorithm.get_next_starttime()
         starttime = next_starttime or starttime
         # input
         timeseries = input_timeseries or self._get_input_timeseries(
+            algorithm=algorithm,
             observatory=observatory,
             starttime=starttime,
             endtime=endtime,
@@ -298,6 +303,7 @@ class Controller(object):
         output_observatory: List[str],
         starttime: UTCDateTime,
         endtime: UTCDateTime,
+        algorithm: Optional[Algorithm] = None,
         input_channels: Optional[List[str]] = None,
         output_channels: Optional[List[str]] = None,
         no_trim: bool = False,
@@ -337,7 +343,7 @@ class Controller(object):
         # If an update_limit is set, make certain we don't step past it.
         if update_limit > 0 and update_count >= update_limit:
             return
-        algorithm = self._algorithm
+        algorithm = algorithm or self._algorithm
         if algorithm.get_next_starttime() is not None:
             raise AlgorithmException("Stateful algorithms cannot use run_as_update")
         input_channels = input_channels or algorithm.get_input_channels()
@@ -373,6 +379,7 @@ class Controller(object):
             ]
         for output_gap in output_gaps:
             input_timeseries = self._get_input_timeseries(
+                algorithm=algorithm,
                 observatory=observatory,
                 starttime=output_gap[0],
                 endtime=output_gap[1],
@@ -389,6 +396,7 @@ class Controller(object):
                 recurse_starttime = starttime - interval
                 recurse_endtime = starttime - 1
                 self.run_as_update(
+                    algorithm=algorithm,
                     observatory=observatory,
                     output_observatory=output_observatory,
                     starttime=recurse_starttime,
@@ -414,6 +422,7 @@ class Controller(object):
                 file=sys.stderr,
             )
             self.run(
+                algorithm=algorithm,
                 observatory=observatory,
                 starttime=gap_starttime,
                 endtime=gap_endtime,
diff --git a/geomagio/algorithm/FilterAlgorithm.py b/geomagio/algorithm/FilterAlgorithm.py
index 50aa2f729..da0f2df18 100644
--- a/geomagio/algorithm/FilterAlgorithm.py
+++ b/geomagio/algorithm/FilterAlgorithm.py
@@ -123,7 +123,7 @@ class FilterAlgorithm(Algorithm):
         outchannels=None,
     ):
 
-        Algorithm.__init__(self, inchannels=None, outchannels=None)
+        Algorithm.__init__(self, inchannels=inchannels, outchannels=outchannels)
         self.coeff_filename = coeff_filename
         self.filtertype = filtertype
         self.input_sample_period = input_sample_period
diff --git a/geomagio/processing/obsrio.py b/geomagio/processing/obsrio.py
index dc7c1b7f6..9a761e58c 100644
--- a/geomagio/processing/obsrio.py
+++ b/geomagio/processing/obsrio.py
@@ -140,9 +140,6 @@ def obsrio_day(
     starttime, endtime = get_realtime_interval(realtime_interval)
     # filter 10Hz U,V,W to H,E,Z
     controller = Controller(
-        algorithm=FilterAlgorithm(
-            input_sample_period=60.0, output_sample_period=86400.0
-        ),
         inputFactory=input_factory or get_edge_factory(data_type="variation"),
         inputInterval="minute",
         outputFactory=output_factory or get_miniseed_factory(data_type="variation"),
@@ -151,9 +148,13 @@ def obsrio_day(
     renames = {"H": "U", "E": "V", "Z": "W", "F": "F"}
     for input_channel in renames.keys():
         output_channel = renames[input_channel]
-        controller._algorithm._inchannels = (input_channel,)
-        controller._algorithm._outchannels = (output_channel,)
         controller.run_as_update(
+            algorithm=FilterAlgorithm(
+                input_sample_period=60.0,
+                output_sample_period=86400.0,
+                inchannels=(input_channel,),
+                outchannels=(output_channel,),
+            ),
             observatory=(observatory,),
             output_observatory=(observatory,),
             starttime=starttime,
@@ -177,9 +178,6 @@ def obsrio_hour(
     starttime, endtime = get_realtime_interval(realtime_interval)
     # filter 10Hz U,V,W to H,E,Z
     controller = Controller(
-        algorithm=FilterAlgorithm(
-            input_sample_period=60.0, output_sample_period=3600.0
-        ),
         inputFactory=input_factory or get_edge_factory(data_type="variation"),
         inputInterval="minute",
         outputFactory=output_factory or get_miniseed_factory(data_type="variation"),
@@ -188,9 +186,13 @@ def obsrio_hour(
     renames = {"H": "U", "E": "V", "Z": "W", "F": "F"}
     for input_channel in renames.keys():
         output_channel = renames[input_channel]
-        controller._algorithm._inchannels = (input_channel,)
-        controller._algorithm._outchannels = (output_channel,)
         controller.run_as_update(
+            algorithm=FilterAlgorithm(
+                input_sample_period=60.0,
+                output_sample_period=3600.0,
+                inchannels=(input_channel,),
+                outchannels=(output_channel,),
+            ),
             observatory=(observatory,),
             output_observatory=(observatory,),
             starttime=starttime,
@@ -217,16 +219,19 @@ def obsrio_minute(
     """
     starttime, endtime = get_realtime_interval(realtime_interval)
     controller = Controller(
-        algorithm=FilterAlgorithm(input_sample_period=1, output_sample_period=60),
         inputFactory=input_factory or get_edge_factory(data_type="variation"),
         inputInterval="second",
         outputFactory=output_factory or get_edge_factory(data_type="variation"),
         outputInterval="minute",
     )
     for channel in ["H", "E", "Z", "F"]:
-        controller._algorithm._inchannels = (channel,)
-        controller._algorithm._outchannels = (channel,)
         controller.run_as_update(
+            algorithm=FilterAlgorithm(
+                input_sample_period=1,
+                output_sample_period=60,
+                inchannels=(channel,),
+                outchannels=(channel,),
+            ),
             observatory=(observatory,),
             output_observatory=(observatory,),
             starttime=starttime,
@@ -248,14 +253,12 @@ def obsrio_second(
     """Copy 1Hz miniseed F to 1Hz legacy F."""
     starttime, endtime = get_realtime_interval(realtime_interval)
     controller = Controller(
-        algorithm=Algorithm(),
+        algorithm=Algorithm(inchannels=("F",), outchannels=("F",)),
         inputFactory=input_factory or get_miniseed_factory(data_type="variation"),
         inputInterval="second",
         outputFactory=output_factory or get_edge_factory(data_type="variation"),
         outputInterval="second",
     )
-    controller._algorithm._inchannels = ("F",)
-    controller._algorithm._outchannels = ("F",)
     controller.run_as_update(
         observatory=(observatory,),
         output_observatory=(observatory,),
@@ -278,7 +281,6 @@ def obsrio_temperatures(
     """Filter temperatures 1Hz miniseed (LK1-4) to 1 minute legacy (UK1-4)."""
     starttime, endtime = get_realtime_interval(realtime_interval)
     controller = Controller(
-        algorithm=FilterAlgorithm(input_sample_period=1, output_sample_period=60),
         inputFactory=input_factory or get_miniseed_factory(data_type="variation"),
         inputInterval="second",
         outputFactory=output_factory or get_edge_factory(data_type="variation"),
@@ -287,9 +289,13 @@ def obsrio_temperatures(
     renames = {"LK1": "UK1", "LK2": "UK2", "LK3": "UK3", "LK4": "UK4"}
     for input_channel in renames.keys():
         output_channel = renames[input_channel]
-        controller._algorithm._inchannels = (input_channel,)
-        controller._algorithm._outchannels = (output_channel,)
         controller.run_as_update(
+            algorithm=FilterAlgorithm(
+                input_sample_period=1,
+                output_sample_period=60,
+                inchannels=(input_channel,),
+                outchannels=(output_channel,),
+            ),
             observatory=(observatory,),
             output_observatory=(observatory,),
             starttime=starttime,
@@ -313,7 +319,6 @@ def obsrio_tenhertz(
     starttime, endtime = get_realtime_interval(realtime_interval)
     # filter 10Hz U,V,W to H,E,Z
     controller = Controller(
-        algorithm=FilterAlgorithm(input_sample_period=0.1, output_sample_period=1),
         inputFactory=input_factory or get_miniseed_factory(data_type="variation"),
         inputInterval="tenhertz",
         outputFactory=output_factory or get_edge_factory(data_type="variation"),
@@ -322,9 +327,13 @@ def obsrio_tenhertz(
     renames = {"U": "H", "V": "E", "W": "Z"}
     for input_channel in renames.keys():
         output_channel = renames[input_channel]
-        controller._algorithm._inchannels = (input_channel,)
-        controller._algorithm._outchannels = (output_channel,)
         controller.run_as_update(
+            algorithm=FilterAlgorithm(
+                input_sample_period=0.1,
+                output_sample_period=1,
+                inchannels=(input_channel,),
+                outchannels=(output_channel,),
+            ),
             observatory=(observatory,),
             output_observatory=(observatory,),
             starttime=starttime,
-- 
GitLab