From 4d3d8bd5aa8fd1f2568df93b60131f1f209aac66 Mon Sep 17 00:00:00 2001
From: "E. Joshua Rigler" <erigler@usgs.gov>
Date: Wed, 9 Nov 2022 16:49:39 -0700
Subject: [PATCH] Add --average-min-count option to AverageAlgorithm

- added a `min_count` keyword to the AvergeAlgorithm's __init__() method;
- added a --average-min-count option via the add_arguments() classmethod;
- set `self.min_count = arguments.average_min_count` in configure() method;
- refactored process() method to respect `min_count`, but now it defaults
  to a requirement that all inputs be valid (this modified a recent merge
  by @awernle that only required that any inputs be valid;
- modified AverageAlgorithm_test.py to properly assess things in light
  of the change to default behavior just mentioned.
---
 geomagio/algorithm/AverageAlgorithm.py       | 40 ++++++++++++++++----
 test/algorithm_test/AverageAlgorithm_test.py |  4 +-
 2 files changed, 35 insertions(+), 9 deletions(-)

diff --git a/geomagio/algorithm/AverageAlgorithm.py b/geomagio/algorithm/AverageAlgorithm.py
index 2b07bda4d..75ee9170d 100644
--- a/geomagio/algorithm/AverageAlgorithm.py
+++ b/geomagio/algorithm/AverageAlgorithm.py
@@ -18,7 +18,14 @@ class AverageAlgorithm(Algorithm):
 
     """
 
-    def __init__(self, observatories=None, channel=None, location=None, scales=None):
+    def __init__(
+        self,
+        observatories=None,
+        channel=None,
+        location=None,
+        scales=None,
+        min_count=None,
+    ):
         Algorithm.__init__(self)
         self._npts = -1
         self._stt = -1
@@ -27,6 +34,7 @@ class AverageAlgorithm(Algorithm):
         self.observatories = observatories
         self.outchannel = channel
         self.outlocation = location
+        self.min_count = min_count
         self.observatoryMetadata = ObservatoryMetadata()
 
     def check_stream(self, timeseries):
@@ -86,6 +94,8 @@ class AverageAlgorithm(Algorithm):
 
         self.outlocation = self.outlocation or timeseries[0].stats.location
 
+        self.min_count = self.min_count or len(timeseries)
+
         scale_values = self.scales or ([1] * len(timeseries))
         lat_corr = {}
         i = 0
@@ -111,8 +121,18 @@ class AverageAlgorithm(Algorithm):
             ts = timeseries.select(station=obsy)[0]
             combined.append(ts.data * latcorr)
 
+        # calculate counts
+        count_data = numpy.count_nonzero(
+            ~numpy.isnan(numpy.column_stack(timeseries)), axis=1
+        )
+
         # calculate averages
-        average_data = numpy.nanmean(combined, axis=0)
+        average_data = numpy.nansum(combined, axis=0) / count_data
+
+        # apply min_count
+        average_data[count_data < self.min_count] = numpy.nan
+
+        # create first output trace metadata
         average_stats = obspy.core.Stats()
         average_stats.station = "USGS"
         average_stats.channel = self.outchannel
@@ -121,10 +141,7 @@ class AverageAlgorithm(Algorithm):
         average_stats.starttime = timeseries[0].stats.starttime
         average_stats.npts = timeseries[0].stats.npts
         average_stats.delta = timeseries[0].stats.delta
-
-        # calculate counts
-        ts_counter = numpy.column_stack(timeseries)
-        count_data = numpy.count_nonzero(~numpy.isnan(ts_counter), axis=1)
+        # create second output trace metadata
         count_stats = average_stats.copy()
         count_stats.channel = average_stats.channel + "_count"
 
@@ -132,7 +149,7 @@ class AverageAlgorithm(Algorithm):
         stream = obspy.core.Stream(
             (
                 obspy.core.Trace(average_data, average_stats),
-                obspy.core.Trace(numpy.asarray(count_data), count_stats),
+                obspy.core.Trace(count_data, count_stats),
             )
         )
         return stream
@@ -154,6 +171,13 @@ class AverageAlgorithm(Algorithm):
             nargs="*",
             type=float,
         )
+        parser.add_argument(
+            "--average-min-count",
+            default=None,
+            help="Minimum number of inputs required to calculate average",
+            nargs="*",
+            type=int,
+        )
 
     def configure(self, arguments):
         """Configure algorithm using comand line arguments.
@@ -178,3 +202,5 @@ class AverageAlgorithm(Algorithm):
                 )
 
         self.outlocation = arguments.outlocationcode or arguments.locationcode
+
+        self.min_count = arguments.average_min_count
diff --git a/test/algorithm_test/AverageAlgorithm_test.py b/test/algorithm_test/AverageAlgorithm_test.py
index 2480f655b..1c3a88009 100644
--- a/test/algorithm_test/AverageAlgorithm_test.py
+++ b/test/algorithm_test/AverageAlgorithm_test.py
@@ -39,7 +39,7 @@ def test_process():
     timeseries[2].stats.station = "SJG"
 
     # initialize the algorithm factory with Observatories and Channel
-    a = AverageAlgorithm(("HON", "GUA", "SJG"), "H")
+    a = AverageAlgorithm(("HON", "GUA", "SJG"), "H", min_count=1)
     outstream = a.process(timeseries)
     # Ensure the average of two
     np.testing.assert_array_equal(outstream[0].data, expected_solution)
@@ -76,7 +76,7 @@ def test_gaps():
     timeseries += gap_trace
     timeseries += full_trace
     # Initialize the AverageAlgorithm factory with observatories and channel
-    alg = AverageAlgorithm(("HON", "SJG"), "H")
+    alg = AverageAlgorithm(("HON", "SJG"), "H", min_count=1)
     # Run timeseries through the average process
     outstream = alg.process(timeseries)
 
-- 
GitLab