From 70b9b650fc845268af3561bb616345fb2ddfcdcd Mon Sep 17 00:00:00 2001
From: Alex Wernle <awernle@usgs.gov>
Date: Fri, 4 Oct 2024 10:13:08 -0600
Subject: [PATCH] New SpikesAlgorithm Class to create spike metadata.

---
 geomagio/algorithm/SpikesAlgorithm.py | 112 ++++++++++++++++++++++++++
 1 file changed, 112 insertions(+)
 create mode 100644 geomagio/algorithm/SpikesAlgorithm.py

diff --git a/geomagio/algorithm/SpikesAlgorithm.py b/geomagio/algorithm/SpikesAlgorithm.py
new file mode 100644
index 00000000..8ca3c490
--- /dev/null
+++ b/geomagio/algorithm/SpikesAlgorithm.py
@@ -0,0 +1,112 @@
+from __future__ import absolute_import
+
+import numpy as np
+from obspy.core.utcdatetime import UTCDateTime, Trace
+from typing import List
+
+from .MetadataAlgorithm import MetadataAlgorithm
+from ..metadata.flag.Flag import ArtificialDisturbance, ArtificialDisturbanceType
+from ..metadata import MetadataCategory
+
+
+class SpikesAlgorithm(MetadataAlgorithm):
+    """
+    Algorithm that identifies spikes in data and generates metadata flags.
+
+    Attributes
+    ----------
+    window_size: int
+        Size of the rolling window for dynamic alignment (no shrinking or padding).
+    threshold: int
+        Threshold for spike detection (threshold * sigma).
+    """
+
+    window_size: int
+    threshold: int
+
+    def _despike(
+        self, signal: Trace, timestamps: List[UTCDateTime]
+    ) -> List[UTCDateTime]:
+        """
+        Internal method to compute spike timestamps using a rolling window with dynamic alignment.
+
+        Parameters
+        ----------
+        signal: Trace
+            Signal trace data to analyze.
+        timestamps: List[UTCDateTime]
+            Corresponding timestamps for the signal data.
+
+        Returns
+        -------
+        List[UTCDateTime]
+            Timestamps where spikes were detected.
+        """
+        n = len(signal)
+        half_window = self.window_size // 2
+
+        rolling_mean = np.zeros(n)
+        rolling_std = np.zeros(n)
+
+        for i in range(n):
+            start = max(0, i - half_window)
+            end = min(n, i + half_window + 1)
+
+            window = signal[start:end]
+            rolling_mean[i] = np.mean(window)
+            rolling_std[i] = np.std(window)
+
+        spikes = np.abs(signal - rolling_mean) > self.threshold * rolling_std
+        return timestamps[spikes]
+
+    def run(self) -> List[ArtificialDisturbance]:
+        """
+        Run the despiking algorithm and generate spike metadata.
+
+        Returns
+        -------
+        List[ArtificialDisturbance]
+            List of metadata objects for detected spikes.
+        """
+        stream = self.get_stream()
+        daily_streams = self.split_stream_by_day(stream)
+        spike_metadata_list = []
+
+        for day_stream in daily_streams:
+            for trace in day_stream:
+                timestamps = trace.times("UTCDateTime")
+                signal = trace.data
+                spike_timestamps = self._despike(signal, timestamps)
+
+                if len(spike_timestamps) > 0:
+                    metadata_dict = {
+                        "description": f"Spikes detected using a window of {self.window_size} and a threshold of {self.threshold}*sigma",
+                        "artificial_disturbance_type": ArtificialDisturbanceType.SPIKES,
+                        "spikes": list(spike_timestamps),
+                    }
+
+                    spike_metadata = self.create_metadata(
+                        starttime=spike_timestamps[0],
+                        endtime=spike_timestamps[-1],
+                        metadata_class=ArtificialDisturbance,
+                        metadata_dict=metadata_dict,
+                        category=MetadataCategory.FLAG,
+                        network=trace.stats.network,
+                        channel=trace.stats.channel,
+                        location=trace.stats.location,
+                        created_by="SpikesAlgorithm",
+                        status="New",
+                    )
+
+                    # Validate spikes match times
+                    ArtificialDisturbance.check_spikes_match_times(
+                        spikes=metadata_dict["spikes"],
+                        values={
+                            "starttime": spike_metadata.starttime,
+                            "endtime": spike_metadata.endtime,
+                        },
+                    )
+
+                    spike_metadata_list.append(spike_metadata)
+
+        return spike_metadata_list
-- 
GitLab