diff --git a/geomagio/api/ws/DataApiQuery.py b/geomagio/api/ws/DataApiQuery.py index ac88e6092a94967b8d2a3907936b88fd776c6545..5e90a33c01eea2739777e5312833be57c01b7c57 100644 --- a/geomagio/api/ws/DataApiQuery.py +++ b/geomagio/api/ws/DataApiQuery.py @@ -1,5 +1,6 @@ import datetime import enum +from enum import Enum import os from typing import List, Optional, Union @@ -40,6 +41,15 @@ class SamplingPeriod(float, enum.Enum): DAY = 86400.0 +class SamplingPeriodWithAuto(str, enum.Enum): + AUTO = "auto" + TEN_HERTZ = SamplingPeriod.TEN_HERTZ.value + SECOND = SamplingPeriod.SECOND.value + MINUTE = SamplingPeriod.MINUTE.value + HOUR = SamplingPeriod.HOUR.value + DAY = SamplingPeriod.DAY.value + + class DataHost(str, enum.Enum): # recognized public Edge data hosts, plus one user-specified DEFAULT = os.getenv("DATA_HOST", "edgecwb.usgs.gov") @@ -69,9 +79,9 @@ class DataApiQuery(BaseModel): endtime: Optional[CustomUTCDateTimeType] = None elements: List[str] = DEFAULT_ELEMENTS sampling_period: SamplingPeriod = SamplingPeriod.MINUTE - data_type: Union[DataType, str] = DataType.VARIATION - format: Union[OutputFormat, str] = OutputFormat.IAGA2002 - data_host: Union[DataHost, str] = DataHost.DEFAULT + data_type: DataType = DataType.VARIATION + format: OutputFormat = OutputFormat.IAGA2002 + data_host: DataHost = DataHost.DEFAULT @field_validator("starttime", mode="before") def validate_starttime( diff --git a/geomagio/api/ws/FilterApiQuery.py b/geomagio/api/ws/FilterApiQuery.py index cb2e4e027207a41410037a41901d63206b630cfb..f602fc44d1a3a7d8573c2efb88cc85e3e20b34d6 100644 --- a/geomagio/api/ws/FilterApiQuery.py +++ b/geomagio/api/ws/FilterApiQuery.py @@ -1,5 +1,11 @@ -from .DataApiQuery import DataApiQuery, SamplingPeriod, REQUEST_LIMIT -from pydantic import ConfigDict, model_validator +from .DataApiQuery import ( + DataApiQuery, + SamplingPeriod, + REQUEST_LIMIT, + SamplingPeriodWithAuto, +) +from pydantic import ConfigDict, model_validator, field_validator +from typing import Optional, Union """This class inherits all the fields and validation on DataApiQuery and adds the fields input_sampling_period and output_sampling_period.""" @@ -8,16 +14,42 @@ the fields input_sampling_period and output_sampling_period.""" class FilterApiQuery(DataApiQuery): model_config = ConfigDict(extra="forbid") - input_sampling_period: SamplingPeriod = SamplingPeriod.SECOND - output_sampling_period: SamplingPeriod = SamplingPeriod.MINUTE + input_sampling_period: Union[SamplingPeriodWithAuto, float] = ( + SamplingPeriodWithAuto.AUTO + ) + + @field_validator("input_sampling_period", mode="before") + def normalize_sampling_period(cls, value): + if isinstance(value, str) and value == SamplingPeriodWithAuto.AUTO: + return float("nan") # Map 'auto' to NaN internally + if isinstance(value, str): + return float(value) # Map string values to float + try: + value = float(value) # Coerce numeric-like strings to float + except ValueError: + raise ValueError( + f"Invalid sampling period. Must be one of " + f"{[item.value for item in SamplingPeriodWithAuto]}" + ) + # If the value is a float, find its matching enum based on the string value + if isinstance(value, float): + print("The sampling period is touching this logic downstream", type(value)) + # Find matching enum value by checking if the float matches the string equivalent but skip the string "auto" + for period in list(SamplingPeriodWithAuto)[1:]: + if float(period.value) == value: + + return period + + # If no match is found, return the float as-is + return value + + return value # Return the value if it's not a recognized string or float @model_validator(mode="after") def validate_sample_size(self): # Calculate the number of samples based on the input sampling period samples = int( - len(self.elements) - * (self.endtime - self.starttime) - / self.input_sampling_period + len(self.elements) * (self.endtime - self.starttime) / self.sampling_period ) # Validate the request size diff --git a/geomagio/api/ws/algorithms.py b/geomagio/api/ws/algorithms.py index ec18ae9fd2c2ed01e646a6d0f514a3248550021e..879ac70cf4a8f8ef447059642906621848b3acdc 100644 --- a/geomagio/api/ws/algorithms.py +++ b/geomagio/api/ws/algorithms.py @@ -15,7 +15,10 @@ from .FilterApiQuery import FilterApiQuery from .data import format_timeseries, get_data_factory, get_data_query, get_timeseries from .filter import get_filter_data_query from . import filter +from ...algorithm.FilterAlgorithm import STEPS +import logging +logger = logging.getLogger(__name__) router = APIRouter() @@ -43,42 +46,16 @@ def get_dbdt( ####################################### The router .get filter isnt visible on the docs page # Look for register routers in the backend - - @router.get( "/algorithms/filter/", description="Filtered data dependent on requested interval", name="Filtered Algorithm", ) -# New query parameter defined below. I am using a new query defined in DataApitQuery. -# This relies on the new filter.py module and the get_filter_data function -# to define input and output sampling period. - - -def get_filter( - query: FilterApiQuery = Depends(get_filter_data_query), -) -> Response: - - filt = FilterAlgorithm( - input_sample_period=query.input_sampling_period, - output_sample_period=query.output_sampling_period, - ) +def get_filter(query: FilterApiQuery = Depends(get_filter_data_query)) -> Response: - # Grab the correct starttime and endtime for the timeseries from get_input_interval - starttime, endtime = filt.get_input_interval(query.starttime, query.endtime) - - # Reassign the actual start/endtime to the query parameters - query.starttime = starttime - query.endtime = endtime - - data_factory = get_data_factory(query=query) - # read data - raw = filter.get_timeseries(data_factory, query) - - filtered_timeseries = filt.process(raw) + filtered_timeseries = filter.get_timeseries(query) elements = [f"{element}" for element in query.elements] - # output response return format_timeseries( timeseries=filtered_timeseries, format=query.format, elements=elements ) diff --git a/geomagio/api/ws/filter.py b/geomagio/api/ws/filter.py index 54c5f6825a9922f74dfa55c4ba7776909a9d76f8..234aaa7cd2f659776589be4d098d5c71c1146f70 100644 --- a/geomagio/api/ws/filter.py +++ b/geomagio/api/ws/filter.py @@ -1,15 +1,23 @@ from typing import List, Union, Optional from fastapi import Query -from obspy import UTCDateTime, Stream +from obspy.clients.fdsn.header import FDSNNoDataException +from obspy import UTCDateTime, Stream, Trace from ... import TimeseriesFactory, TimeseriesUtility +import math +import numpy as np from .DataApiQuery import ( DEFAULT_ELEMENTS, DataHost, DataType, OutputFormat, SamplingPeriod, + SamplingPeriodWithAuto, ) from .FilterApiQuery import FilterApiQuery +from ...algorithm.FilterAlgorithm import STEPS +from ...algorithm import FilterAlgorithm +import logging as logger +from .data import get_data_factory from ...pydantic_utcdatetime import CustomUTCDateTimeType @@ -26,11 +34,13 @@ def get_filter_data_query( format: Union[OutputFormat, str] = Query( OutputFormat.IAGA2002, title="Output Format" ), - input_sampling_period: Union[SamplingPeriod, float] = Query( - SamplingPeriod.SECOND, title="Initial sampling period" + input_sampling_period: Union[SamplingPeriodWithAuto, float, int] = Query( + SamplingPeriodWithAuto.AUTO, # Assign Default value as auto + title="Initial Sampling Period", + description="Auto will dynamically determine necessary input_sampling_period.", ), - output_sampling_period: Union[SamplingPeriod, float] = Query( - SamplingPeriod.MINUTE, title="Output sampling period" + sampling_period: Union[SamplingPeriod, float] = Query( + SamplingPeriod, alias="output_sampling_period", title="Output sampling period" ), data_host: Union[DataHost, str] = Query( DataHost.DEFAULT, title="Data Host", description="Edge host to pull data from." @@ -43,28 +53,103 @@ def get_filter_data_query( endtime=endtime, elements=elements, input_sampling_period=input_sampling_period, - output_sampling_period=output_sampling_period, + sampling_period=sampling_period, data_type=data_type, data_host=data_host, format=format, ) -def get_timeseries(data_factory: TimeseriesFactory, query: FilterApiQuery) -> Stream: - """Get timeseries data for variometers +# Main filter function +def get_timeseries(query: FilterApiQuery) -> Stream: + data_factory = get_data_factory(query=query) - Parameters - ---------- - data_factory: where to read data - query: parameters for the data to read - """ - # get data - timeseries = data_factory.get_timeseries( - starttime=query.starttime, - endtime=query.endtime, + # Determine input sampling period if not provided + # if query.input_sampling_period == SamplingPeriodWithAuto.AUTO or query.input_sampling_period is None: + if math.isnan(query.input_sampling_period) or query.input_sampling_period is None: + # Dynamically determine the input sampling period + input_sampling_period, data = determine_available_period( + query.sampling_period, query, data_factory + ) + else: + input_sampling_period = query.input_sampling_period + + filt = FilterAlgorithm( + input_sample_period=input_sampling_period, + output_sample_period=query.sampling_period, + ) + # Fetch data + starttime, endtime = filt.get_input_interval(query.starttime, query.endtime) + + data = data_factory.get_timeseries( + starttime=starttime, + endtime=endtime, observatory=query.id, channels=query.elements, type=query.data_type, - interval=TimeseriesUtility.get_interval_from_delta(query.input_sampling_period), + interval=TimeseriesUtility.get_interval_from_delta(filt.input_sample_period), + ) + + # Apply filtering + filtered_timeseries = filt.process(data) + return filtered_timeseries + + +def determine_available_period(output_sampling_period: float, query, data_factory): + """ + Finds the lowest resolution (longest sampling period) <= output_sampling_period + that has valid data available. + """ + + # Sort and filter periods starting from output_sampling_period + sorted_periods: List[SamplingPeriod] = sorted( + SamplingPeriod, key=lambda p: p.value, reverse=True ) - return timeseries + valid_sampling_periods = [ + p for p in sorted_periods if p.value <= output_sampling_period + ] + + for period in valid_sampling_periods: + if period <= output_sampling_period: + + data = data_factory.get_timeseries( + starttime=query.starttime, + endtime=query.endtime, + observatory=query.id, + channels=query.elements, + type=query.data_type, + interval=TimeseriesUtility.get_interval_from_delta(period.value), + ) + # Check if the fetched data is valid + if is_valid_data(data): + + logger.info(f"Valid data found for sampling period: {period.name}") + return period.value, data # Return the sampling period and the data + + else: + logger.error( + f"No valid data found for requested sampling period: {period.name}" + ) + continue + + raise ValueError("No valid data found for the requested output sampling period.") + + +def is_valid_data(data: Stream) -> bool: + """ + Checks if the fetched data contains actual values and not just filler values (e.g., NaN). + A Stream is invalid if any trace contains only NaN values. + """ + if not data or len(data) == 0: + return False # No data in the stream + + for trace in data: + # Check if trace.data exists and has data + if trace.data is None or len(trace.data) == 0: + return False # Trace has no data + + # Check if all values in trace.data are NaN + if np.all(np.isnan(trace.data)): + return False # Invalid if all values are NaN + + return True # All traces are valid diff --git a/test/DataApiQuery_test.py b/test/DataApiQuery_test.py index 117f1758c1be5f4d9668dc25208ec8fcee5a3596..4434af7d15411e38502c56acfcd4d9fc0e4de318 100644 --- a/test/DataApiQuery_test.py +++ b/test/DataApiQuery_test.py @@ -65,9 +65,9 @@ def test_DataApiQuery_valid(): assert_equal(query.endtime, UTCDateTime("2024-09-01T01:00:01")) assert_equal(query.elements, ["F"]) assert_equal(query.sampling_period, SamplingPeriod.SECOND) - assert_equal(query.data_type, "adjusted") - assert_equal(query.format, "json") - assert_equal(query.data_host, "cwbpub.cr.usgs.gov") + assert_equal(query.data_type, DataType.ADJUSTED) + assert_equal(query.format, OutputFormat.JSON) + assert_equal(query.data_host, DataHost.CWBPUB) def test_DataApiQuery_no_id(): diff --git a/test/FilterApiQuery_test.py b/test/FilterApiQuery_test.py index f4f49a563aa0d82ffb0103aea6f1b832a699d190..1c7c384547b71c413ae340a31e6e2860a4ad2804 100644 --- a/test/FilterApiQuery_test.py +++ b/test/FilterApiQuery_test.py @@ -5,6 +5,7 @@ from obspy import UTCDateTime from geomagio.api.ws.FilterApiQuery import FilterApiQuery from geomagio.api.ws.DataApiQuery import ( SamplingPeriod, + SamplingPeriodWithAuto, DataType, OutputFormat, DataHost, @@ -22,8 +23,8 @@ def test_FilterApiQuery_defaults(): assert_equal(query.starttime, expected_start_time) assert_equal(query.endtime, expected_endtime) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.input_sampling_period, SamplingPeriod.SECOND) - assert_equal(query.output_sampling_period, SamplingPeriod.MINUTE) + assert_equal(query.input_sampling_period, SamplingPeriodWithAuto.AUTO) + assert_equal(query.sampling_period, SamplingPeriod.MINUTE) assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.format, OutputFormat.IAGA2002) assert_equal(query.data_host, DataHost.DEFAULT) @@ -36,7 +37,7 @@ def test_FilterApiQuery_valid(): endtime="2024-09-01T01:00:01", elements=["Z"], input_sampling_period=60, - output_sampling_period=3600, + sampling_period=3600, data_type="adjusted", format="json", data_host="cwbpub.cr.usgs.gov", @@ -46,11 +47,11 @@ def test_FilterApiQuery_valid(): assert_equal(query.starttime, UTCDateTime("2024-09-01T00:00:01")) assert_equal(query.endtime, UTCDateTime("2024-09-01T01:00:01")) assert_equal(query.elements, ["Z"]) - assert_equal(query.input_sampling_period, SamplingPeriod.MINUTE) - assert_equal(query.output_sampling_period, SamplingPeriod.HOUR) - assert_equal(query.data_type, "adjusted") - assert_equal(query.format, "json") - assert_equal(query.data_host, "cwbpub.cr.usgs.gov") + assert_equal(query.input_sampling_period, SamplingPeriodWithAuto.MINUTE) + assert_equal(query.sampling_period, SamplingPeriod.HOUR) + assert_equal(query.data_type, DataType.ADJUSTED) + assert_equal(query.format, OutputFormat.JSON) + assert_equal(query.data_host, DataHost.CWBPUB) def test_FilterApiQuery_no_id(): diff --git a/test/api_test/ws_test/data_test.py b/test/api_test/ws_test/data_test.py index 3376e31100056e7186ce47543fad8f81d352e404..96fba2f0d3bc02b2eb4428ee2295969286593db0 100644 --- a/test/api_test/ws_test/data_test.py +++ b/test/api_test/ws_test/data_test.py @@ -28,7 +28,7 @@ def test_client(): def test_get_data_query(test_client): """test.api_test.ws_test.data_test.test_get_data_query()""" response = test_client.get( - "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=R1&sampling_period=60&format=iaga2002" + "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002" ) query = DataApiQuery(**response.json()) assert_equal(query.id, "BOU") @@ -36,8 +36,8 @@ def test_get_data_query(test_client): assert_equal(query.endtime, UTCDateTime("2020-09-02T00:00:00.999")) assert_equal(query.elements, ["X", "Y", "Z", "F"]) assert_equal(query.sampling_period, SamplingPeriod.MINUTE) - assert_equal(query.format, "iaga2002") - assert_equal(query.data_type, "R1") + assert_equal(query.format, OutputFormat.IAGA2002) + assert_equal(query.data_type, DataType.VARIATION) def test_get_data_query_no_starttime(test_client): diff --git a/test/api_test/ws_test/filter_test.py b/test/api_test/ws_test/filter_test.py index dbd2964ea2072d8dd678ae745fe60a5f37c54f64..470e8474364bdf88a3fde8313ae3cf5ac13e344c 100644 --- a/test/api_test/ws_test/filter_test.py +++ b/test/api_test/ws_test/filter_test.py @@ -31,9 +31,8 @@ def test_get_filter_data_query(test_client): assert_equal(query.starttime, UTCDateTime("2020-09-01T00:00:01")) assert_equal(query.endtime, UTCDateTime("2020-09-02T00:00:00.999")) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.sampling_period, SamplingPeriod.MINUTE) - assert_equal(query.format, "iaga2002") - assert_equal(query.data_type, "variation") + assert_equal(query.format, OutputFormat.IAGA2002) + assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.input_sampling_period, SamplingPeriod.MINUTE) assert_equal(query.output_sampling_period, SamplingPeriod.HOUR)