From 0e3949f26972c7d600acd2b0ba8dc27fc8d39060 Mon Sep 17 00:00:00 2001
From: "E. Joshua Rigler" <erigler@usgs.gov>
Date: Thu, 12 Sep 2024 09:58:23 -0600
Subject: [PATCH] Fix FilterAlgorithm.align_trace() method and unit tests

The method FilterAlgorithm.align_trace() pre-processes the `trace.data` input
array for subsequent processing by FilterAlgorithm.firfilter() such that the
first input array element corresponds to a time step on which output samples
must fall (as defined in the dictionary `step`) minus half the fir window
width. In short, it ensures that output samples fall on desired time steps.

Prior to this fix, it only worked as intended when input trace's starttime fell
on an even time step. A bug became obvious when attempting to filter data from
non-Geomag stations that did not have nice time stamps. This fix addresses that
issue, and also ensures that `align_trace()` does what was claimed in its own
original docstrings, which is to handle trailing misalignments as well.

Note: one thing `align_trace()` does NOT do is ensure that all needed input
data are available to generate desired outputs. The user is responsible for
providing this, but can use the FilterAlgorithm.get_input_interal() method
to calculate the actual required input starttime and endtime. The method
`align_trace()` will trim or pad with NaNs only enough to align time stamps,
and may actually result in `firfilter()` output that are NaNs if the input
trace was not adequate.
---
 geomagio/algorithm/FilterAlgorithm.py       | 44 ++++++++++++++++-----
 test/algorithm_test/FilterAlgorithm_test.py |  7 ++--
 2 files changed, 39 insertions(+), 12 deletions(-)

diff --git a/geomagio/algorithm/FilterAlgorithm.py b/geomagio/algorithm/FilterAlgorithm.py
index da0f2df18..6cd104487 100644
--- a/geomagio/algorithm/FilterAlgorithm.py
+++ b/geomagio/algorithm/FilterAlgorithm.py
@@ -307,7 +307,11 @@ class FilterAlgorithm(Algorithm):
         return out
 
     def align_trace(self, step, trace):
-        """Aligns trace to handle trailing or missing values.
+        """Aligns `trace.data` with `step`'s window to half an output_sample_period
+        by padding or trimming samples in preparation for processing by `firfilter`;
+        this ensures `firfilter` output always falls on desired time stamps as
+        defined by `step`.
+
         Parameters
         ----------
         step: dict
@@ -324,20 +328,42 @@ class FilterAlgorithm(Algorithm):
         data = trace.data
         start = trace.stats.starttime
         filter_start = get_nearest_time(step=step, output_time=start, left=False)
-        while filter_start["data_start"] < start:
-            # filter needs more data, shift one output right
+        # roughly align starttime with filter-start to half an output_sample_period
+        while (start - filter_start["data_start"]) > (
+            step["output_sample_period"] / 2.0
+        ):
+            # shift one output sample to the right from start
             filter_start = get_nearest_time(
                 step=step,
                 output_time=filter_start["time"] + step["output_sample_period"],
                 left=False,
             )
-
-        if start != filter_start["data_start"]:
-            offset = int(
-                1e-6
-                + (filter_start["data_start"] - start) / step["input_sample_period"]
+        # pad or trim trace.data to get sample-resolution alignment
+        start_offset = round(
+            (filter_start["data_start"] - start) / step["input_sample_period"]
+        )
+        if start_offset > 0:
+            data = data[start_offset:]
+        else:
+            data = np.concatenate((np.tile(np.nan, -start_offset), data))
+
+        end = trace.stats.endtime
+        filter_end = get_nearest_time(step=step, output_time=end, left=True)
+        # roughly align endtime with filter_end to half an output_sample_period
+        while (filter_end["data_end"] - end) > (step["output_sample_period"] / 2.0):
+            # shift one output sample to the left from end
+            filter_end = get_nearest_time(
+                step=step,
+                output_time=filter_end["time"] - step["output_sample_period"],
+                left=True,
             )
-            data = data[offset:]
+        # pad or trim trace.data to get sample-resolution alignment
+        end_offset = round((end - filter_end["data_end"]) / step["input_sample_period"])
+        if end_offset > 0:
+            data = data[:-end_offset]
+        else:
+            data = np.concatenate((data, np.tile(np.nan, -end_offset)))
+
         return filter_start["time"], data
 
     @staticmethod
diff --git a/test/algorithm_test/FilterAlgorithm_test.py b/test/algorithm_test/FilterAlgorithm_test.py
index b4c21d45a..2a0b9833b 100644
--- a/test/algorithm_test/FilterAlgorithm_test.py
+++ b/test/algorithm_test/FilterAlgorithm_test.py
@@ -270,10 +270,11 @@ def test_starttime_shift():
     filtered = f.process(precise)
     assert_equal(filtered[0].stats.starttime, UTCDateTime("2020-01-01T00:01:00Z"))
     assert_equal(filtered[0].stats.endtime, UTCDateTime("2020-01-01T00:14:00Z"))
-    # remove one extra sample (filter no longer has enough to generate first/last)
+    # remove slightly more than half an output_sample_period
+    # (filter no longer has enough to generate first/last)
     trimmed = bou.trim(
-        starttime=UTCDateTime("2020-01-01T00:00:16Z"),
-        endtime=UTCDateTime("2020-01-01T00:14:44Z"),
+        starttime=UTCDateTime("2020-01-01T00:00:15Z") + 31,
+        endtime=UTCDateTime("2020-01-01T00:14:45Z") - 31,
     )
     filtered = f.process(trimmed)
     assert_equal(filtered[0].stats.starttime, UTCDateTime("2020-01-01T00:02:00Z"))
-- 
GitLab