From b94c90b4f2beb977d27c1cc67d4d28fdcf4ae051 Mon Sep 17 00:00:00 2001
From: Alex Wernle <awernle@usgs.gov>
Date: Wed, 11 Dec 2024 16:32:18 -0700
Subject: [PATCH] Added mode for managing existing spikes. Added logic to
 create_spike_metadata to allow empty arrays.

---
 geomagio/algorithm/SpikesAlgorithm.py | 49 +++++++++++++++++++++++----
 1 file changed, 42 insertions(+), 7 deletions(-)

diff --git a/geomagio/algorithm/SpikesAlgorithm.py b/geomagio/algorithm/SpikesAlgorithm.py
index 185b5023..b73a9013 100644
--- a/geomagio/algorithm/SpikesAlgorithm.py
+++ b/geomagio/algorithm/SpikesAlgorithm.py
@@ -4,6 +4,7 @@ import numpy as np
 from obspy.core.utcdatetime import UTCDateTime
 from obspy import Trace
 from typing import List, Optional
+from enum import Enum
 
 from .MetadataAlgorithm import MetadataAlgorithm
 from ..metadata import Metadata
@@ -11,6 +12,12 @@ from ..metadata.flag.Flag import ArtificialDisturbance, ArtificialDisturbanceTyp
 from ..metadata import MetadataCategory
 
 
+class SpikesMode(str, Enum):
+    MERGE = "merge"
+    OVERWRITE = "overwrite"
+
+
+# TODO: Add option to choose _despike method(s) and add more methods
 class SpikesAlgorithm(MetadataAlgorithm):
     """
     Algorithm that identifies spikes in data and generates metadata flags.
@@ -21,10 +28,14 @@ class SpikesAlgorithm(MetadataAlgorithm):
         Size of the rolling window for dynamic alignment (no shrinking or padding). Defaults to 1000.
     threshold: int
         Threshold for spike detection (threshold * sigma). Defaults to 6.
+    mode:SpikesMode
+        Merge will automatically merge new spikes with preexisting spikes from requested period (if they exist), overwrite will automatically overwrite preexisting spikes with
+        new spikes from requested period (if they exist). Default is overwrite, meaning only new spikes are saved.
     """
 
     window_size: Optional[int] = 1000
     threshold: Optional[int] = 6
+    mode: Optional[SpikesMode] = SpikesMode.OVERWRITE
 
     def _despike(
         self, signal: Trace, timestamps: List[UTCDateTime]
@@ -114,14 +125,30 @@ class SpikesAlgorithm(MetadataAlgorithm):
         if len(spike_timestamps) > 0:
             spike_metadata = self.create_spike_metadata(spike_timestamps, trace)
             if existing_metadata:
-                # update ID to preexisting one
-                spike_metadata.id = existing_metadata[0].id
+                if self.mode == SpikesMode.OVERWRITE:
+                    # only use new metadata by updating ID to preexisting one
+                    spike_metadata.id = existing_metadata[0].id
+                if self.mode == SpikesMode.MERGE:
+                    # concatenate existing timestamps with new timestamps
+                    existing_timestamps = existing_metadata[0].metadata["spikes"]
+                    merged_timestamps = np.unique(
+                        np.concatenate((spike_timestamps, existing_timestamps))
+                    )
+                    spike_metadata = self.create_spike_metadata(
+                        merged_timestamps, trace
+                    )
+                    spike_metadata.id = existing_metadata[0].id
             return spike_metadata
 
         if len(spike_timestamps) == 0 and existing_metadata is not None:
-            # TODO:change existing spike metadata to rejected or data_valid to false? Update array?
-            # convert existing_metadata back into Metadata
-            pass
+            if self.mode == SpikesMode.OVERWRITE:
+                # update ID to preexisting one with empty spike array
+                spike_metadata = self.create_spike_metadata(spike_timestamps, trace)
+                spike_metadata.id = existing_metadata[0].id
+            if self.mode == SpikesMode.MERGE:
+                # simply return preexisting metadata because you are adding zero spikes
+                spike_metadata = existing_metadata
+            return spike_metadata
 
         return None
 
@@ -164,6 +191,14 @@ class SpikesAlgorithm(MetadataAlgorithm):
         ArtificialDisturbance
             Metadata object for the detected spikes.
         """
+        # use trace starttime and endtime if spike_timestamps is empty
+        if len(spike_timestamps) >= 1:
+            starttime = spike_timestamps[0]
+            endtime = spike_timestamps[-1]
+        else:
+            starttime = trace.stats.starttime
+            endtime = trace.stats.endtime
+
         metadata_dict = {
             "description": f"Spikes detected using a window of {self.window_size} and a threshold of {self.threshold}*sigma",
             "artificial_disturbance_type": ArtificialDisturbanceType.SPIKES,
@@ -171,8 +206,8 @@ class SpikesAlgorithm(MetadataAlgorithm):
         }
 
         return self.create_metadata(
-            starttime=spike_timestamps[0],
-            endtime=spike_timestamps[-1],
+            starttime=starttime,
+            endtime=endtime,
             metadata_class=ArtificialDisturbance,
             metadata_dict=metadata_dict,
             category=MetadataCategory.FLAG,
-- 
GitLab