diff --git a/geomagio/algorithm/MetadataAlgorithm.py b/geomagio/algorithm/MetadataAlgorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..43b96804c9bd8f94fc039e00016ff9dfd6a196b0 --- /dev/null +++ b/geomagio/algorithm/MetadataAlgorithm.py @@ -0,0 +1,134 @@ +from pydantic import BaseModel +from obspy import UTCDateTime, Stream +from datetime import timedelta +from enum import Enum +from typing import Optional + +from ..metadata import Metadata, MetadataFactory, MetadataCategory +from ..edge.MiniSeedFactory import MiniSeedFactory +from ..edge.EdgeFactory import EdgeFactory +from ..pydantic_utcdatetime import CustomUTCDateTimeType + + +class DataFactory(str, Enum): + MINISEED = "miniseed" + EDGE = "edge" + + +class MetadataAlgorithm(BaseModel): + factory: Optional[DataFactory] = DataFactory.MINISEED + observatory: Optional[str] = None + channels: Optional[str] = None + metadata_token: Optional[str] = None + metadata_url: Optional[str] = None + type: Optional[str] = None + interval: Optional[str] = None + starttime: Optional[CustomUTCDateTimeType] = None + endtime: Optional[CustomUTCDateTimeType] = None + + def get_data_factory(self): + """Helper method to return the correct data factory based on the factory type.""" + factory_class = ( + MiniSeedFactory if self.factory == DataFactory.MINISEED else EdgeFactory + ) + return factory_class( + port=22061, + observatory=self.observatory, + channels=self.channels, + type=self.type, + interval=self.interval, + ) + + def get_stream(self) -> Stream: + """Retrieve the data stream based on the factory type.""" + data_factory = self.get_data_factory() + + try: + return data_factory.get_timeseries( + starttime=self.starttime, + endtime=self.endtime, + add_empty_channels=False, + ) + except Exception as e: + raise ValueError(f"Failed to retrieve data stream from {self.factory}: {e}") + + def create_metadata( + self, + metadata_class: MetadataCategory, + metadata_dict: dict, + category: MetadataCategory, + network: str, + channel: str, + location: str, + created_by: str, + starttime: UTCDateTime, + endtime: UTCDateTime, + status: str, + ) -> Metadata: + """Create metadata using the provided dictionary.""" + return Metadata( + category=category, + created_by=created_by, + starttime=starttime, + endtime=endtime, + metadata=(metadata_class(**metadata_dict)).model_dump(), + network=network, + channel=channel, + location=location, + station=self.observatory, + status=status, + ) + + def check_existing_metadata(self, metadata_obj: Metadata) -> Optional[Metadata]: + """Check if similar metadata already exists and return existing metadata if it does.""" + query_metadata = Metadata( + category=metadata_obj.category, + station=metadata_obj.station, + starttime=metadata_obj.starttime, + endtime=metadata_obj.endtime, + channel=metadata_obj.channel, + ) + metadata_factory = self._get_metadata_factory() + prior_metadata = metadata_factory.get_metadata(query=query_metadata) + + return prior_metadata if prior_metadata else None + + def update_metadata(self, metadata_obj: Metadata) -> Metadata: + """Update existing metadata.""" + return self._get_metadata_factory().update_metadata(metadata=metadata_obj) + + def create_new_metadata(self, metadata_obj: Metadata) -> Metadata: + """Create new metadata.""" + return self._get_metadata_factory().create_metadata(metadata=metadata_obj) + + def split_stream_by_day(self, stream: Stream) -> list[Stream]: + """Split stream into daily streams to control size of spike arrays.""" + daily_streams = [] + # get min and max time + current_time = min(trace.stats.starttime for trace in stream) + end_time = max(trace.stats.endtime for trace in stream) + + # loop through each day and slice the stream accordingly + while current_time <= end_time: + day_endtime = min( + UTCDateTime( + current_time.year, current_time.month, current_time.day, 23, 59, 59 + ), + end_time, + ) + + # slice stream for the current day + daily_stream = stream.slice( + starttime=current_time, endtime=day_endtime, nearest_sample=True + ) + + if daily_stream: + daily_streams.append(daily_stream) + + current_time += timedelta(days=1) + + return daily_streams + + def _get_metadata_factory(self) -> MetadataFactory: + """Helper method to instantiate MetadataFactory.""" + return MetadataFactory(token=self.metadata_token, url=self.metadata_url) diff --git a/geomagio/algorithm/SpikesAlgorithm.py b/geomagio/algorithm/SpikesAlgorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..b73a90132ee48ea56e224a92b2878a4290235550 --- /dev/null +++ b/geomagio/algorithm/SpikesAlgorithm.py @@ -0,0 +1,219 @@ +from __future__ import absolute_import + +import numpy as np +from obspy.core.utcdatetime import UTCDateTime +from obspy import Trace +from typing import List, Optional +from enum import Enum + +from .MetadataAlgorithm import MetadataAlgorithm +from ..metadata import Metadata +from ..metadata.flag.Flag import ArtificialDisturbance, ArtificialDisturbanceType +from ..metadata import MetadataCategory + + +class SpikesMode(str, Enum): + MERGE = "merge" + OVERWRITE = "overwrite" + + +# TODO: Add option to choose _despike method(s) and add more methods +class SpikesAlgorithm(MetadataAlgorithm): + """ + Algorithm that identifies spikes in data and generates metadata flags. + + Attributes + ---------- + window_size: int + Size of the rolling window for dynamic alignment (no shrinking or padding). Defaults to 1000. + threshold: int + Threshold for spike detection (threshold * sigma). Defaults to 6. + mode:SpikesMode + Merge will automatically merge new spikes with preexisting spikes from requested period (if they exist), overwrite will automatically overwrite preexisting spikes with + new spikes from requested period (if they exist). Default is overwrite, meaning only new spikes are saved. + """ + + window_size: Optional[int] = 1000 + threshold: Optional[int] = 6 + mode: Optional[SpikesMode] = SpikesMode.OVERWRITE + + def _despike( + self, signal: Trace, timestamps: List[UTCDateTime] + ) -> List[UTCDateTime]: + """ + Internal method to compute spike timestamps using a rolling window with dynamic alignment. + + Parameters + ---------- + signal: Trace + Signal trace data to analyze. + timestamps: List[UTCDateTime] + Corresponding timestamps for the signal data. + + Returns + ------- + List[UTCDateTime] + Timestamps where spikes were detected. + """ + n = len(signal) + half_window = self.window_size // 2 + + rolling_mean = np.zeros(n) + rolling_std = np.zeros(n) + + for i in range(n): + start = max(0, i - half_window) + end = min(n, i + half_window + 1) + + window = signal[start:end] + rolling_mean[i] = np.mean(window) + rolling_std[i] = np.std(window) + + spikes = np.abs(signal - rolling_mean) > self.threshold * rolling_std + return timestamps[spikes] + + def run(self) -> List[ArtificialDisturbance]: + """ + Run the despiking algorithm and generate spike metadata. + + Returns + ------- + List[ArtificialDisturbance] + List of metadata objects for detected spikes. + """ + stream = self.get_stream() + daily_streams = self.split_stream_by_day(stream) + spike_metadata_list = [] + + for day_stream in daily_streams: + for trace in day_stream: + spike_metadata = self.process_trace(trace) + if spike_metadata: + spike_metadata_list.append(spike_metadata) + + return spike_metadata_list + + def process_trace(self, trace: Trace) -> Optional[ArtificialDisturbance]: + """ + Process a single trace, detecting spikes and creating or updating metadata as needed. + + Parameters + ---------- + trace : Trace + The trace to process for spikes. + + Returns + ------- + Optional[ArtificialDisturbance] + The metadata for the trace if spikes are detected, None otherwise. + """ + + timestamps = trace.times("UTCDateTime") + signal = trace.data + + # throw error if singal len < window + if len(signal) < self.window_size: + raise ValueError( + f"Signal length ({len(signal)}) must be at least as large as the window size ({self.window_size})." + ) + + spike_timestamps = self._despike(signal, timestamps) + + # check existing metadata + existing_metadata = self.check_existing_spike_metadata(trace) + + if len(spike_timestamps) > 0: + spike_metadata = self.create_spike_metadata(spike_timestamps, trace) + if existing_metadata: + if self.mode == SpikesMode.OVERWRITE: + # only use new metadata by updating ID to preexisting one + spike_metadata.id = existing_metadata[0].id + if self.mode == SpikesMode.MERGE: + # concatenate existing timestamps with new timestamps + existing_timestamps = existing_metadata[0].metadata["spikes"] + merged_timestamps = np.unique( + np.concatenate((spike_timestamps, existing_timestamps)) + ) + spike_metadata = self.create_spike_metadata( + merged_timestamps, trace + ) + spike_metadata.id = existing_metadata[0].id + return spike_metadata + + if len(spike_timestamps) == 0 and existing_metadata is not None: + if self.mode == SpikesMode.OVERWRITE: + # update ID to preexisting one with empty spike array + spike_metadata = self.create_spike_metadata(spike_timestamps, trace) + spike_metadata.id = existing_metadata[0].id + if self.mode == SpikesMode.MERGE: + # simply return preexisting metadata because you are adding zero spikes + spike_metadata = existing_metadata + return spike_metadata + + return None + + def check_existing_spike_metadata(self, trace: Trace) -> Optional[str] | None: + """ + Find existing spike metadata for a specified period. + + Parameters + ---------- + trace : Trace + The trace from which the metadata is searched. + Returns + ------- + The existing metadata if found, None otherwise. + """ + spike_metadata = Metadata( + category=MetadataCategory.FLAG, + station=trace.stats.station, + starttime=trace.stats.starttime, + endtime=trace.stats.endtime, + channel=trace.stats.channel, + ) + return self.check_existing_metadata(spike_metadata) + + def create_spike_metadata( + self, spike_timestamps: List[UTCDateTime], trace: Trace + ) -> ArtificialDisturbance: + """ + Create new metadata for detected spikes in a trace. + + Parameters + ---------- + spike_timestamps : List[UTCDateTime] + List of timestamps where spikes were detected. + trace : Trace + The trace from which the metadata is created. + + Returns + ------- + ArtificialDisturbance + Metadata object for the detected spikes. + """ + # use trace starttime and endtime if spike_timestamps is empty + if len(spike_timestamps) >= 1: + starttime = spike_timestamps[0] + endtime = spike_timestamps[-1] + else: + starttime = trace.stats.starttime + endtime = trace.stats.endtime + + metadata_dict = { + "description": f"Spikes detected using a window of {self.window_size} and a threshold of {self.threshold}*sigma", + "artificial_disturbance_type": ArtificialDisturbanceType.SPIKES, + "spikes": list(spike_timestamps), + } + + return self.create_metadata( + starttime=starttime, + endtime=endtime, + metadata_class=ArtificialDisturbance, + metadata_dict=metadata_dict, + category=MetadataCategory.FLAG, + network=trace.stats.network, + channel=trace.stats.channel, + location=trace.stats.location, + created_by="SpikesAlgorithm", + status="New", + ) diff --git a/geomagio/metadata/flag/Flag.py b/geomagio/metadata/flag/Flag.py index a16c63632eec66467fe437d7ce31bedc6a89c491..1fe5e19468aa8e7e3e623b29ff82809ce51c98fd 100644 --- a/geomagio/metadata/flag/Flag.py +++ b/geomagio/metadata/flag/Flag.py @@ -1,16 +1,18 @@ from typing import Dict, Union, List, Optional from datetime import timedelta -import numpy as np from obspy import UTCDateTime -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from enum import Enum +from ...pydantic_utcdatetime import CustomUTCDateTimeType + class FlagCategory(str, Enum): ARTIFICIAL_DISTURBANCE = "ARTIFICIAL_DISTURBANCE" GAP = "GAP" EVENT = "EVENT" + FIELD_WORK = "FIELD_WORK" OTHER = "OTHER" @@ -23,30 +25,26 @@ class Flag(BaseModel): automatic_flag = Metadata( created_by='ex_algorithm', start_time=UTCDateTime('2023-01-01T03:05:10'), - end_time=UTCDateTime('2023-01-01T03:05:11'), + end_time=UTCDateTime('2023-01-01T03:10:11'), network='NT', station='BOU', channel='BEH', category=MetadataCategory.FLAG, comment="spike detected", priority=1, - data_valid=False, + data_valid=True, metadata= ArtificialDisturbance{ - "description": "Spike in magnetic field strength", - "field_work": false, - "corrected": false, + "description": "Spikes in magnetic field strength", "flag_category": ARTIFICIAL_DISTURBANCE, "artificial_disturbance_type": ArtificialDisturbanceType.SPIKE, - "source": "Lightning", "deviation": None, + "spikes": ['2023-01-01T03:05:10','2023-01-01T03:07:20','2023-01-01T03:10:11'] } ) ``` """ description: str = Field(..., description="Description of the flag") - field_work: bool = Field(..., description="Flag signaling field work") - corrected: int = Field(..., description="Corrected ID for processing stage") flag_category: FlagCategory = "OTHER" @@ -70,23 +68,48 @@ class ArtificialDisturbance(Flag): ---------- artificial_disturbance_type:ArtificialDisturbanceType The type of artificial disturbance(s). - source: str - Source of the disturbance if known or suspected. deviation: float Deviation of an offset in nt. - spikes: np.ndarray - NumPy array of timestamps as UTCDateTime. Can be a single spike or many spikes. + spikes: List[CustomUTCDateTimeType] + Array of timestamps as UTCDateTime. Can be a single spike or many spikes. """ artificial_disturbance_type: ArtificialDisturbanceType deviation: Optional[float] = None - source: Optional[str] = None - spikes: Optional[np.ndarray] = None + spikes: Optional[List[CustomUTCDateTimeType]] = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.flag_category = "ARTIFICIAL_DISTURBANCE" + self.flag_category = FlagCategory.ARTIFICIAL_DISTURBANCE + + @field_validator("spikes") + def check_spikes_duration(cls, spikes): + if spikes is None or len(spikes) < 2: + return spikes + + duration = spikes[-1] - spikes[0] + if duration > timedelta(days=1).total_seconds(): + raise ValueError( + f"The duration between the first and last spike timestamp must not exceed 1 day. Duration: {duration} seconds" + ) + + return spikes + + @classmethod + def check_spikes_match_times(cls, spikes, values): + metadata_starttime = values.get("starttime") + metadata_endtime = values.get("endtime") + + if spikes[0] != metadata_starttime: + raise ValueError( + f"The first spike timestamp {spikes[0]} does not match the starttime {metadata_starttime}." + ) + + if spikes[-1] != metadata_endtime: + raise ValueError( + f"The last spike timestamp {spikes[-1]} does not match the endtime {metadata_endtime}." + ) class Gap(Flag): @@ -103,12 +126,12 @@ class Gap(Flag): How the gap is being handled, e.g., backfilled. """ - cause: str = None - handling: str = None + cause: Optional[str] = None + handling: Optional[str] = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.flag_category = "GAP" + self.flag_category = FlagCategory.GAP class EventType(str, Enum): @@ -135,55 +158,39 @@ class Event(Flag): """ event_type: EventType - index: int = None - scale: str = None - url: str = None + index: Optional[int] = None + scale: Optional[str] = None + url: Optional[str] = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.flag_category = "EVENT" + self.flag_category = FlagCategory.EVENT # More example usage: -timestamps_array = np.array( - [ - UTCDateTime("2023-11-16T12:00:0"), - UTCDateTime("2023-11-16T12:01:10"), - UTCDateTime("2023-11-16T12:02:30"), - ] -) +timestamps_array = [ + UTCDateTime("2023-11-16T12:00:0"), + UTCDateTime("2023-11-16T12:01:10"), + UTCDateTime("2023-11-16T12:02:30"), +] spikes_data = { "starttime": "2023-11-16 12:00:00", "endtime": "2023-11-16 12:02:30", "description": "Spikes description", - "field_work": False, - "corrected": 32265, "disturbance_type": ArtificialDisturbanceType.SPIKES, - "source": "processing", "spikes": timestamps_array, } offset_data = { "description": "Offset description", - "field_work": False, - "corrected": 47999, "disturbance_type": ArtificialDisturbanceType.OFFSET, - "source": "Bin change", "deviation": 10.0, } geomagnetic_storm_data = { "description": "Geomagnetic storm", - "field_work": False, - "corrected": 36999, "event_type": EventType.GEOMAGNETIC_STORM, "scale": "G3", "index": 7, "url": "https://www.swpc.noaa.gov/products/planetary-k-index", } - -spike_instance = ArtificialDisturbance(**spikes_data) -offset_instance = ArtificialDisturbance(**offset_data) - -print(spike_instance.model_dump()) -print(offset_instance.model_dump()) diff --git a/geomagio/processing/flag_spikes.py b/geomagio/processing/flag_spikes.py new file mode 100644 index 0000000000000000000000000000000000000000..7e84f042e6b2962dec7afc2acf862324c1da3efa --- /dev/null +++ b/geomagio/processing/flag_spikes.py @@ -0,0 +1,123 @@ +import os +from datetime import datetime + +import typer +from typing import Optional + +from ..metadata import Metadata +from ..algorithm.SpikesAlgorithm import SpikesAlgorithm, SpikesMode + + +def flag_spikes( + observatory: str = typer.Option(..., help="Observatory code"), + channels: str = typer.Option( + default="F", help="Channels to despike, default is F. Example input: HEZF" + ), + starttime: datetime = typer.Option( + default=None, + help="Start time for metadata, default is None. CLI example usage: --starttime '2024-01-01' ", + ), + endtime: datetime = typer.Option( + default=None, + help="End time for metadata, default is None. CLI example usage: --endtime '2024-01-01' ", + ), + window_size: Optional[int] = typer.Option( + default=1000, + help="Window size for despike algorithm", + ), + threshold: Optional[int] = typer.Option( + default=6, + help="Threshold for spike size", + ), + mode: Optional[SpikesMode] = typer.Option( + SpikesMode.OVERWRITE, help="Merge or overwrite spikes if they already exist." + ), + metadata_token: str = typer.Option( + default=os.getenv("GITLAB_API_TOKEN"), + help="Token for metadata web service.", + metavar="TOKEN", + show_default="environment variable GITLAB_API_TOKEN", + ), + metadata_url: str = typer.Option( + default="https://geomag.usgs.gov/ws/secure/metadata", + help="URL to metadata web service", + metavar="URL", + ), + force: bool = typer.Option( + default=False, + help="Force skip the checks", + ), +): + """Flag spikes and create spike metadata in the metadata webservice.""" + + # TODO: Consider adding below variables as options + # run spike detection algorithm on data and return spike metadata + spike_algorithm = SpikesAlgorithm( + starttime=starttime, + endtime=endtime, + observatory=observatory, + window_size=window_size, + threshold=threshold, + mode=mode, + channels=channels, + type="variation", + interval="second", + metadata_token=metadata_token, + metadata_url=metadata_url, + ) + + spikes = spike_algorithm.run() + + if len(spikes) == 0: + print("No spikes found") + raise typer.Abort() + + # confirm whether or not to create spike metadata + if not force: + typer.confirm(f"Are you sure you want to create flag metadata?", abort=True) + print("Creating flag metadata") + + # write flag metadata to metadata service + with typer.progressbar( + iterable=spikes, label="Uploading to metadata service" + ) as progressbar: + for spike in progressbar: + upload_spike_metadata(algorithm=spike_algorithm, spike=spike) + + +def main() -> None: + """Entrypoint for flag spikes method.""" + # for one command, can use typer.run + typer.run(flag_spikes) + + +def upload_spike_metadata(algorithm: SpikesAlgorithm, spike: Metadata) -> Metadata: + """Upload new or existing spike metadata to metadata service. + + Parameters + ---------- + algorithm: SpikesAlgorithm + algorithm to update/load spike metadata + spike: Metadata + spike metadata to be uploaded. + + Returns + ------- + Metadata + created metadata object. + """ + # check if metadata already exists for period before uploading + prior_metadata = SpikesAlgorithm.check_existing_metadata(algorithm, spike) + if prior_metadata: + # TODO: Confirm whether or not to add force or simply update automatically + typer.confirm( + f"Spikes already exist for this period, would you like to update this metadata?", + abort=True, + ) + return SpikesAlgorithm.update_metadata(algorithm, spike) + else: + return SpikesAlgorithm.create_new_metadata(algorithm, spike) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 4ec57dad4a03f3e463e3a8ffa00d468182729886..130a69e02edf034d71dbb0c6b33ff071a1238c94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ copy-absolutes = "geomagio.processing.copy_absolutes:main" copy-observatory = "geomagio.processing.copy_observatory:main" make-iaga = "geomagio.processing.make_iaga:main" copy-instrument = "geomagio.processing.copy_instrument:main" +flag-spikes = "geomagio.processing.flag_spikes:main" [tool.poe.tasks] # e.g. "poetry run poe lint"