From 70b9b650fc845268af3561bb616345fb2ddfcdcd Mon Sep 17 00:00:00 2001 From: Alex Wernle <awernle@usgs.gov> Date: Fri, 4 Oct 2024 10:13:08 -0600 Subject: [PATCH] New SpikesAlgorithm Class to create spike metadata. --- geomagio/algorithm/SpikesAlgorithm.py | 112 ++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 geomagio/algorithm/SpikesAlgorithm.py diff --git a/geomagio/algorithm/SpikesAlgorithm.py b/geomagio/algorithm/SpikesAlgorithm.py new file mode 100644 index 00000000..8ca3c490 --- /dev/null +++ b/geomagio/algorithm/SpikesAlgorithm.py @@ -0,0 +1,112 @@ +from __future__ import absolute_import + +import numpy as np +from obspy.core.utcdatetime import UTCDateTime, Trace +from typing import List + +from .MetadataAlgorithm import MetadataAlgorithm +from ..metadata.flag.Flag import ArtificialDisturbance, ArtificialDisturbanceType +from ..metadata import MetadataCategory + + +class SpikesAlgorithm(MetadataAlgorithm): + """ + Algorithm that identifies spikes in data and generates metadata flags. + + Attributes + ---------- + window_size: int + Size of the rolling window for dynamic alignment (no shrinking or padding). + threshold: int + Threshold for spike detection (threshold * sigma). + """ + + window_size: int + threshold: int + + def _despike( + self, signal: Trace, timestamps: List[UTCDateTime] + ) -> List[UTCDateTime]: + """ + Internal method to compute spike timestamps using a rolling window with dynamic alignment. + + Parameters + ---------- + signal: Trace + Signal trace data to analyze. + timestamps: List[UTCDateTime] + Corresponding timestamps for the signal data. + + Returns + ------- + List[UTCDateTime] + Timestamps where spikes were detected. + """ + n = len(signal) + half_window = self.window_size // 2 + + rolling_mean = np.zeros(n) + rolling_std = np.zeros(n) + + for i in range(n): + start = max(0, i - half_window) + end = min(n, i + half_window + 1) + + window = signal[start:end] + rolling_mean[i] = np.mean(window) + rolling_std[i] = np.std(window) + + spikes = np.abs(signal - rolling_mean) > self.threshold * rolling_std + return timestamps[spikes] + + def run(self) -> List[ArtificialDisturbance]: + """ + Run the despiking algorithm and generate spike metadata. + + Returns + ------- + List[ArtificialDisturbance] + List of metadata objects for detected spikes. + """ + stream = self.get_stream() + daily_streams = self.split_stream_by_day(stream) + spike_metadata_list = [] + + for day_stream in daily_streams: + for trace in day_stream: + timestamps = trace.times("UTCDateTime") + signal = trace.data + spike_timestamps = self._despike(signal, timestamps) + + if len(spike_timestamps) > 0: + metadata_dict = { + "description": f"Spikes detected using a window of {self.window_size} and a threshold of {self.threshold}*sigma", + "artificial_disturbance_type": ArtificialDisturbanceType.SPIKES, + "spikes": list(spike_timestamps), + } + + spike_metadata = self.create_metadata( + starttime=spike_timestamps[0], + endtime=spike_timestamps[-1], + metadata_class=ArtificialDisturbance, + metadata_dict=metadata_dict, + category=MetadataCategory.FLAG, + network=trace.stats.network, + channel=trace.stats.channel, + location=trace.stats.location, + created_by="SpikesAlgorithm", + status="New", + ) + + # Validate spikes match times + ArtificialDisturbance.check_spikes_match_times( + spikes=metadata_dict["spikes"], + values={ + "starttime": spike_metadata.starttime, + "endtime": spike_metadata.endtime, + }, + ) + + spike_metadata_list.append(spike_metadata) + + return spike_metadata_list -- GitLab