diff --git a/geomagio/algorithm/SpikesAlgorithm.py b/geomagio/algorithm/SpikesAlgorithm.py index 185b50239b998bc563f8494557561e76241dbb8e..b73a90132ee48ea56e224a92b2878a4290235550 100644 --- a/geomagio/algorithm/SpikesAlgorithm.py +++ b/geomagio/algorithm/SpikesAlgorithm.py @@ -4,6 +4,7 @@ import numpy as np from obspy.core.utcdatetime import UTCDateTime from obspy import Trace from typing import List, Optional +from enum import Enum from .MetadataAlgorithm import MetadataAlgorithm from ..metadata import Metadata @@ -11,6 +12,12 @@ from ..metadata.flag.Flag import ArtificialDisturbance, ArtificialDisturbanceTyp from ..metadata import MetadataCategory +class SpikesMode(str, Enum): + MERGE = "merge" + OVERWRITE = "overwrite" + + +# TODO: Add option to choose _despike method(s) and add more methods class SpikesAlgorithm(MetadataAlgorithm): """ Algorithm that identifies spikes in data and generates metadata flags. @@ -21,10 +28,14 @@ class SpikesAlgorithm(MetadataAlgorithm): Size of the rolling window for dynamic alignment (no shrinking or padding). Defaults to 1000. threshold: int Threshold for spike detection (threshold * sigma). Defaults to 6. + mode:SpikesMode + Merge will automatically merge new spikes with preexisting spikes from requested period (if they exist), overwrite will automatically overwrite preexisting spikes with + new spikes from requested period (if they exist). Default is overwrite, meaning only new spikes are saved. """ window_size: Optional[int] = 1000 threshold: Optional[int] = 6 + mode: Optional[SpikesMode] = SpikesMode.OVERWRITE def _despike( self, signal: Trace, timestamps: List[UTCDateTime] @@ -114,14 +125,30 @@ class SpikesAlgorithm(MetadataAlgorithm): if len(spike_timestamps) > 0: spike_metadata = self.create_spike_metadata(spike_timestamps, trace) if existing_metadata: - # update ID to preexisting one - spike_metadata.id = existing_metadata[0].id + if self.mode == SpikesMode.OVERWRITE: + # only use new metadata by updating ID to preexisting one + spike_metadata.id = existing_metadata[0].id + if self.mode == SpikesMode.MERGE: + # concatenate existing timestamps with new timestamps + existing_timestamps = existing_metadata[0].metadata["spikes"] + merged_timestamps = np.unique( + np.concatenate((spike_timestamps, existing_timestamps)) + ) + spike_metadata = self.create_spike_metadata( + merged_timestamps, trace + ) + spike_metadata.id = existing_metadata[0].id return spike_metadata if len(spike_timestamps) == 0 and existing_metadata is not None: - # TODO:change existing spike metadata to rejected or data_valid to false? Update array? - # convert existing_metadata back into Metadata - pass + if self.mode == SpikesMode.OVERWRITE: + # update ID to preexisting one with empty spike array + spike_metadata = self.create_spike_metadata(spike_timestamps, trace) + spike_metadata.id = existing_metadata[0].id + if self.mode == SpikesMode.MERGE: + # simply return preexisting metadata because you are adding zero spikes + spike_metadata = existing_metadata + return spike_metadata return None @@ -164,6 +191,14 @@ class SpikesAlgorithm(MetadataAlgorithm): ArtificialDisturbance Metadata object for the detected spikes. """ + # use trace starttime and endtime if spike_timestamps is empty + if len(spike_timestamps) >= 1: + starttime = spike_timestamps[0] + endtime = spike_timestamps[-1] + else: + starttime = trace.stats.starttime + endtime = trace.stats.endtime + metadata_dict = { "description": f"Spikes detected using a window of {self.window_size} and a threshold of {self.threshold}*sigma", "artificial_disturbance_type": ArtificialDisturbanceType.SPIKES, @@ -171,8 +206,8 @@ class SpikesAlgorithm(MetadataAlgorithm): } return self.create_metadata( - starttime=spike_timestamps[0], - endtime=spike_timestamps[-1], + starttime=starttime, + endtime=endtime, metadata_class=ArtificialDisturbance, metadata_dict=metadata_dict, category=MetadataCategory.FLAG,