diff --git a/geomagio/api/ws/DataApiQuery.py b/geomagio/api/ws/DataApiQuery.py index ac88e6092a94967b8d2a3907936b88fd776c6545..9f5217e09a7707913c26cdbb37230b770c83f6d3 100644 --- a/geomagio/api/ws/DataApiQuery.py +++ b/geomagio/api/ws/DataApiQuery.py @@ -1,5 +1,6 @@ import datetime import enum +from enum import Enum import os from typing import List, Optional, Union @@ -9,8 +10,12 @@ from pydantic import ConfigDict, field_validator, model_validator, Field, BaseMo from .Element import ELEMENTS from .Observatory import OBSERVATORY_INDEX, ASL_OBSERVATORY_INDEX from ...pydantic_utcdatetime import CustomUTCDateTimeType +import logging - +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) DEFAULT_ELEMENTS = ["X", "Y", "Z", "F"] REQUEST_LIMIT = 3456000 # Increased the request limit by 10x what was decided by Jeremy VALID_ELEMENTS = [e.id for e in ELEMENTS] @@ -68,7 +73,7 @@ class DataApiQuery(BaseModel): # endtime default is dependent on start time, so it's handled after validation in the model_validator endtime: Optional[CustomUTCDateTimeType] = None elements: List[str] = DEFAULT_ELEMENTS - sampling_period: SamplingPeriod = SamplingPeriod.MINUTE + sampling_period: Optional[SamplingPeriod] = None data_type: Union[DataType, str] = DataType.VARIATION format: Union[OutputFormat, str] = OutputFormat.IAGA2002 data_host: Union[DataHost, str] = DataHost.DEFAULT @@ -123,11 +128,20 @@ class DataApiQuery(BaseModel): self.endtime = self.starttime + (86400 - 0.001) if self.starttime > self.endtime: raise ValueError("Starttime must be before endtime.") - # check data volume - samples = int( - len(self.elements) * (self.endtime - self.starttime) / self.sampling_period - ) - if samples > REQUEST_LIMIT: - raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})") - # otherwise okay + + # check data volume and if SamplingPeriod is assigned as None + if self.sampling_period is None: + logging.warning( + "Sampling period is None. Default value or further processing needed." + ) + + else: + samples = int( + len(self.elements) + * (self.endtime - self.starttime) + / self.sampling_period + ) + if samples > REQUEST_LIMIT: + raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})") + # otherwise okay return self diff --git a/geomagio/api/ws/FilterApiQuery.py b/geomagio/api/ws/FilterApiQuery.py index cb2e4e027207a41410037a41901d63206b630cfb..cee262d9004d219e47dceccac31971f266b626a8 100644 --- a/geomagio/api/ws/FilterApiQuery.py +++ b/geomagio/api/ws/FilterApiQuery.py @@ -1,5 +1,17 @@ -from .DataApiQuery import DataApiQuery, SamplingPeriod, REQUEST_LIMIT -from pydantic import ConfigDict, model_validator +from .DataApiQuery import ( + DataApiQuery, + SamplingPeriod, + REQUEST_LIMIT, +) +from pydantic import ConfigDict, model_validator, field_validator, ValidationError +from typing import Optional +import logging +from fastapi import HTTPException + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) """This class inherits all the fields and validation on DataApiQuery and adds the fields input_sampling_period and output_sampling_period.""" @@ -8,20 +20,26 @@ the fields input_sampling_period and output_sampling_period.""" class FilterApiQuery(DataApiQuery): model_config = ConfigDict(extra="forbid") - input_sampling_period: SamplingPeriod = SamplingPeriod.SECOND - output_sampling_period: SamplingPeriod = SamplingPeriod.MINUTE + input_sampling_period: Optional[SamplingPeriod] = None @model_validator(mode="after") def validate_sample_size(self): - # Calculate the number of samples based on the input sampling period - samples = int( - len(self.elements) - * (self.endtime - self.starttime) - / self.input_sampling_period - ) - - # Validate the request size - if samples > REQUEST_LIMIT: - raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})") + if self.sampling_period is None: + # Log a warning indicating that the sampling period is missing + logging.warning( + "Sampling period is None. Please provide a valid Sampling Period." + ) + + else: + # Calculate the number of samples based on the input sampling period + samples = int( + len(self.elements) + * (self.endtime - self.starttime) + / self.sampling_period + ) + + # Validate the request size + if samples > REQUEST_LIMIT: + raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})") return self diff --git a/geomagio/api/ws/algorithms.py b/geomagio/api/ws/algorithms.py index ec18ae9fd2c2ed01e646a6d0f514a3248550021e..c4696a4c5323115c7f715eb325006ae1394cccf8 100644 --- a/geomagio/api/ws/algorithms.py +++ b/geomagio/api/ws/algorithms.py @@ -2,20 +2,21 @@ import json from fastapi import APIRouter, Depends, HTTPException, Query from starlette.responses import Response -from obspy.core import Stream, Stats -from typing import List, Union -from ...algorithm import DbDtAlgorithm, FilterAlgorithm + +from ...algorithm import DbDtAlgorithm from ...residual import ( calculate, Reading, ) -from .DataApiQuery import DataApiQuery, SamplingPeriod +from .DataApiQuery import DataApiQuery from .FilterApiQuery import FilterApiQuery from .data import format_timeseries, get_data_factory, get_data_query, get_timeseries from .filter import get_filter_data_query from . import filter +import logging +logger = logging.getLogger(__name__) router = APIRouter() @@ -43,42 +44,16 @@ def get_dbdt( ####################################### The router .get filter isnt visible on the docs page # Look for register routers in the backend - - @router.get( "/algorithms/filter/", description="Filtered data dependent on requested interval", name="Filtered Algorithm", ) -# New query parameter defined below. I am using a new query defined in DataApitQuery. -# This relies on the new filter.py module and the get_filter_data function -# to define input and output sampling period. - - -def get_filter( - query: FilterApiQuery = Depends(get_filter_data_query), -) -> Response: - - filt = FilterAlgorithm( - input_sample_period=query.input_sampling_period, - output_sample_period=query.output_sampling_period, - ) - - # Grab the correct starttime and endtime for the timeseries from get_input_interval - starttime, endtime = filt.get_input_interval(query.starttime, query.endtime) +def get_filter(query: FilterApiQuery = Depends(get_filter_data_query)) -> Response: - # Reassign the actual start/endtime to the query parameters - query.starttime = starttime - query.endtime = endtime - - data_factory = get_data_factory(query=query) - # read data - raw = filter.get_timeseries(data_factory, query) - - filtered_timeseries = filt.process(raw) + filtered_timeseries = filter.get_timeseries(query) elements = [f"{element}" for element in query.elements] - # output response return format_timeseries( timeseries=filtered_timeseries, format=query.format, elements=elements ) diff --git a/geomagio/api/ws/data.py b/geomagio/api/ws/data.py index 3a77dbb8a51b4cfe4641b728789c2c7e77f9f514..1eb735fd54c39d89f7ec75a38dc8725b5cff09b4 100644 --- a/geomagio/api/ws/data.py +++ b/geomagio/api/ws/data.py @@ -74,7 +74,7 @@ def get_data_query( " NOTE: when using 'iaga2002' output format, a maximum of 4 elements is allowed", ), sampling_period: Union[SamplingPeriod, float] = Query( - SamplingPeriod.MINUTE, + None, title="data rate", description="Interval in seconds between values.", ), diff --git a/geomagio/api/ws/filter.py b/geomagio/api/ws/filter.py index 54c5f6825a9922f74dfa55c4ba7776909a9d76f8..69516903bc7085621a48985709c5edc8068ae9d0 100644 --- a/geomagio/api/ws/filter.py +++ b/geomagio/api/ws/filter.py @@ -1,7 +1,8 @@ from typing import List, Union, Optional from fastapi import Query -from obspy import UTCDateTime, Stream -from ... import TimeseriesFactory, TimeseriesUtility +from obspy import Stream +from ... import TimeseriesUtility +import numpy as np from .DataApiQuery import ( DEFAULT_ELEMENTS, DataHost, @@ -10,6 +11,9 @@ from .DataApiQuery import ( SamplingPeriod, ) from .FilterApiQuery import FilterApiQuery +from ...algorithm import FilterAlgorithm +import logging as logger +from .data import get_data_factory from ...pydantic_utcdatetime import CustomUTCDateTimeType @@ -26,11 +30,15 @@ def get_filter_data_query( format: Union[OutputFormat, str] = Query( OutputFormat.IAGA2002, title="Output Format" ), - input_sampling_period: Union[SamplingPeriod, float] = Query( - SamplingPeriod.SECOND, title="Initial sampling period" + input_sampling_period: Optional[SamplingPeriod] = Query( + None, + title="Input Sampling Period", + description="`--` dynamically determines a necessary sampling period.", ), - output_sampling_period: Union[SamplingPeriod, float] = Query( - SamplingPeriod.MINUTE, title="Output sampling period" + sampling_period: Optional[SamplingPeriod] = Query( + None, + alias="output_sampling_period", + title="Output sampling period", ), data_host: Union[DataHost, str] = Query( DataHost.DEFAULT, title="Data Host", description="Edge host to pull data from." @@ -43,28 +51,107 @@ def get_filter_data_query( endtime=endtime, elements=elements, input_sampling_period=input_sampling_period, - output_sampling_period=output_sampling_period, + sampling_period=sampling_period, data_type=data_type, data_host=data_host, format=format, ) -def get_timeseries(data_factory: TimeseriesFactory, query: FilterApiQuery) -> Stream: - """Get timeseries data for variometers +# Main filter function +def get_timeseries(query: FilterApiQuery) -> Stream: + data_factory = get_data_factory(query=query) - Parameters - ---------- - data_factory: where to read data - query: parameters for the data to read - """ - # get data - timeseries = data_factory.get_timeseries( - starttime=query.starttime, - endtime=query.endtime, + # Determine input sampling period if not provided + if query.input_sampling_period is None: + # Dynamically determine the input sampling period + input_sampling_period, data = determine_available_period( + query.sampling_period, query, data_factory + ) + else: + input_sampling_period = query.input_sampling_period + + filt = FilterAlgorithm( + input_sample_period=input_sampling_period, + output_sample_period=query.sampling_period, + ) + # Fetch filtered data + starttime, endtime = filt.get_input_interval(query.starttime, query.endtime) + + data = data_factory.get_timeseries( + starttime=starttime, + endtime=endtime, observatory=query.id, channels=query.elements, type=query.data_type, - interval=TimeseriesUtility.get_interval_from_delta(query.input_sampling_period), + interval=TimeseriesUtility.get_interval_from_delta(filt.input_sample_period), ) - return timeseries + + # Apply filtering if needed + + filtered_timeseries = filt.process(data) + return filtered_timeseries + + +def determine_available_period(output_sampling_period: float, query, data_factory): + """ + Finds the lowest resolution (longest sampling period) <= output_sampling_period + that has valid data available. + """ + + # Sort and filter periods starting from output_sampling_period + sorted_periods: List[SamplingPeriod] = sorted( + SamplingPeriod, key=lambda p: p.value, reverse=True + ) + if output_sampling_period is None: + raise ValueError("Output sampling period cannot be None.") + else: + + valid_sampling_periods = [ + p for p in sorted_periods if p.value <= output_sampling_period + ] + + for period in valid_sampling_periods: + if period <= output_sampling_period: + + data = data_factory.get_timeseries( + starttime=query.starttime, + endtime=query.endtime, + observatory=query.id, + channels=query.elements, + type=query.data_type, + interval=TimeseriesUtility.get_interval_from_delta(period.value), + ) + # Check if the fetched data is valid + if is_valid_data(data): + + logger.info(f"Valid data found for sampling period: {period.name}") + return period.value, data # Return the sampling period and the data + + else: + logger.error( + f"No valid data found for requested sampling period: {period.name}" + ) + continue + + raise ValueError("No valid data found for the requested output sampling period.") + + +def is_valid_data(data: Stream) -> bool: + """ + Checks if the fetched data contains actual values and not just filler values (e.g., NaN). + A Stream is invalid if any trace contains only NaN values. + """ + if not data or len(data) == 0: + return False # No data in the stream + + for trace in data: + # Check if trace.data exists and has data + if trace.data is None or len(trace.data) == 0: + return False # Trace has no data + + # Check if all values in trace.data are NaN + if np.all(np.isnan(trace.data)): + return False # Invalid if all values are NaN + + return True # All traces are valid diff --git a/poetry.lock b/poetry.lock index ca13b8bfa7cfececd73c93ac1c96aff690ed4049..5946b52da5e31c6b97862e02ff98f9b2cc62304e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2247,13 +2247,13 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruamel-yaml" -version = "0.18.8" +version = "0.18.10" description = "ruamel.yaml is a YAML parser/emitter that supports roundtrip preservation of comments, seq/map flow style, and map key order" optional = false python-versions = ">=3.7" files = [ - {file = "ruamel.yaml-0.18.8-py3-none-any.whl", hash = "sha256:a7c02af6ec9789495b4d19335addabc4d04ab1e0dad3e491c0c9457bbc881100"}, - {file = "ruamel.yaml-0.18.8.tar.gz", hash = "sha256:1b7e14f28a4b8d09f8cd40dca158852db9b22ac84f22da5bb711def35cb5c548"}, + {file = "ruamel.yaml-0.18.10-py3-none-any.whl", hash = "sha256:30f22513ab2301b3d2b577adc121c6471f28734d3d9728581245f1e76468b4f1"}, + {file = "ruamel.yaml-0.18.10.tar.gz", hash = "sha256:20c86ab29ac2153f80a428e1254a8adf686d3383df04490514ca3b79a362db58"}, ] [package.dependencies] diff --git a/pytest.ini b/pytest.ini index 3d8cf70409816af31181c2160f79f4d3b8b4b17e..92ad01e1f519b02369ca11e0dd347de146d86b91 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,3 +3,9 @@ norecursedirs = */site-packages testpaths = test asyncio_mode=auto asyncio_default_fixture_loop_scope="function" +# Suppress warnings of level WARNING and below +log_level = WARNING + +# Optionally, you can filter out UserWarnings generated by logging +filterwarnings = + ignore::UserWarning \ No newline at end of file diff --git a/test/DataApiQuery_test.py b/test/DataApiQuery_test.py index 117f1758c1be5f4d9668dc25208ec8fcee5a3596..737286b3b77e47d77292bac556db72bc9fda8f52 100644 --- a/test/DataApiQuery_test.py +++ b/test/DataApiQuery_test.py @@ -23,7 +23,7 @@ def test_DataApiQuery_defaults(): assert_equal(query.starttime, expected_start_time) assert_equal(query.endtime, expected_endtime) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.sampling_period, SamplingPeriod.MINUTE) + assert_equal(query.sampling_period, None) assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.format, OutputFormat.IAGA2002) # assumes the env var DATA_HOST is not set @@ -41,7 +41,7 @@ def test_DataApiQuery_starttime_is_none(): assert_equal(query.starttime, expected_start_time) assert_equal(query.endtime, expected_endtime) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.sampling_period, SamplingPeriod.MINUTE) + assert_equal(query.sampling_period, None) assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.format, OutputFormat.IAGA2002) # assumes the env var DATA_HOST is not set @@ -100,7 +100,7 @@ def test_DataApiQuery_default_endtime(): # endtime is 1 day after start time assert_equal(query.endtime, UTCDateTime("2024-11-02T00:00:00.999")) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.sampling_period, SamplingPeriod.MINUTE) + assert_equal(query.sampling_period, None) assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.format, OutputFormat.IAGA2002) assert_equal(query.data_host, DataHost.DEFAULT) @@ -122,7 +122,7 @@ def test_DataApiQuery_default_only_endtime(): assert_equal(query.endtime, hour_later) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.sampling_period, SamplingPeriod.MINUTE) + assert_equal(query.sampling_period, None) assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.format, OutputFormat.IAGA2002) assert_equal(query.data_host, DataHost.DEFAULT) diff --git a/test/FilterApiQuery_test.py b/test/FilterApiQuery_test.py index f4f49a563aa0d82ffb0103aea6f1b832a699d190..6f734f133f8c6f66bad32a2840bac469bbdcc92f 100644 --- a/test/FilterApiQuery_test.py +++ b/test/FilterApiQuery_test.py @@ -22,8 +22,8 @@ def test_FilterApiQuery_defaults(): assert_equal(query.starttime, expected_start_time) assert_equal(query.endtime, expected_endtime) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.input_sampling_period, SamplingPeriod.SECOND) - assert_equal(query.output_sampling_period, SamplingPeriod.MINUTE) + assert_equal(query.input_sampling_period, None) + assert_equal(query.sampling_period, None) assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.format, OutputFormat.IAGA2002) assert_equal(query.data_host, DataHost.DEFAULT) @@ -36,7 +36,7 @@ def test_FilterApiQuery_valid(): endtime="2024-09-01T01:00:01", elements=["Z"], input_sampling_period=60, - output_sampling_period=3600, + sampling_period=3600, data_type="adjusted", format="json", data_host="cwbpub.cr.usgs.gov", @@ -47,7 +47,7 @@ def test_FilterApiQuery_valid(): assert_equal(query.endtime, UTCDateTime("2024-09-01T01:00:01")) assert_equal(query.elements, ["Z"]) assert_equal(query.input_sampling_period, SamplingPeriod.MINUTE) - assert_equal(query.output_sampling_period, SamplingPeriod.HOUR) + assert_equal(query.sampling_period, SamplingPeriod.HOUR) assert_equal(query.data_type, "adjusted") assert_equal(query.format, "json") assert_equal(query.data_host, "cwbpub.cr.usgs.gov") @@ -83,7 +83,7 @@ def test_FilterApiQuery_default_endtime(): # endtime is 1 day after start time assert_equal(query.endtime, UTCDateTime("2024-11-02T00:00:00.999")) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.sampling_period, SamplingPeriod.MINUTE) + assert_equal(query.sampling_period, None) assert_equal(query.data_type, DataType.VARIATION) assert_equal(query.format, OutputFormat.IAGA2002) assert_equal(query.data_host, DataHost.DEFAULT) @@ -190,3 +190,13 @@ def test_FilterApiQuery_extra_fields(): assert "Extra inputs are not permitted" == err[0]["msg"] assert_equal(query, None) + + +def test_FilterApiQuery_no_output_sampling_period(): + query = None + try: + query = FilterApiQuery(id="ANMO", sampling_period=None) + except Exception as e: + err = e.errors() + assert "Output sampling period cannot be None." == err[0]["msg"] + assert_equal(query, None) diff --git a/test/api_test/ws_test/data_test.py b/test/api_test/ws_test/data_test.py index 3376e31100056e7186ce47543fad8f81d352e404..765fa407b013db3f9e87b87d04c56a3b576d6d99 100644 --- a/test/api_test/ws_test/data_test.py +++ b/test/api_test/ws_test/data_test.py @@ -28,7 +28,7 @@ def test_client(): def test_get_data_query(test_client): """test.api_test.ws_test.data_test.test_get_data_query()""" response = test_client.get( - "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=R1&sampling_period=60&format=iaga2002" + "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002" ) query = DataApiQuery(**response.json()) assert_equal(query.id, "BOU") @@ -37,7 +37,7 @@ def test_get_data_query(test_client): assert_equal(query.elements, ["X", "Y", "Z", "F"]) assert_equal(query.sampling_period, SamplingPeriod.MINUTE) assert_equal(query.format, "iaga2002") - assert_equal(query.data_type, "R1") + assert_equal(query.data_type, "variation") def test_get_data_query_no_starttime(test_client): @@ -57,38 +57,18 @@ def test_get_data_query_no_starttime(test_client): async def test_get_data_query_extra_params(test_client): with pytest.raises(ValueError) as error: - response = await test_client.get( + response = test_client.get( "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT" ) + DataApiQuery(**response.json()) assert error.match("Invalid query parameter(s): location, network") -# def test_get_data_query_extra_params(test_client): -# """test.api_test.ws_test.data_test.test_get_data_query_extra_params()""" -# with pytest.raises(ValueError) as error: -# test_client.get( -# "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT" -# ) -# assert error.message == "Invalid query parameter(s): location, network" - - -async def test_get_data_query_bad_params(test_client): +def test_get_data_query_bad_params(test_client): """test.api_test.ws_test.data_test.test_get_data_query_bad_params()""" with pytest.raises(ValueError) as error: - response = await test_client.get( + response = test_client.get( "/query/?id=BOU&startime=2020-09-01T00:00:01&elements=X,Y,Z,F&data_type=variation&sampling_period=60&format=iaga2002" ) + DataApiQuery(**response.json()) assert error.match == "Invalid query parameter(s): startime, data_type" - - -# def test_filter_data_query(test_client): -# """test.api_test.ws_test.data_test.test_filter_data_query()""" -# response = test_client.get( -# "/algorithms/filter/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=R1&sampling_period=60&format=iaga2002&input_sampling_period=60&output_sampling_period=30" -# ) -# filter_query = FilterDataApiQuery(**response.json()) -# assert_equal(filter_query.id, "BOU") -# assert_equal(filter_query.starttime, UTCDateTime("2020-09-01T00:00:01")) -# assert_equal(filter_query.elements, ["X", "Y", "Z", "F"]) -# assert_equal(filter_query.input_sampling_period, 60) -# assert_equal(filter_query.output_sampling_period, 30) diff --git a/test/api_test/ws_test/filter_test.py b/test/api_test/ws_test/filter_test.py index dbd2964ea2072d8dd678ae745fe60a5f37c54f64..67a09a4217317ddba87fce88a3e63e4df30fc5df 100644 --- a/test/api_test/ws_test/filter_test.py +++ b/test/api_test/ws_test/filter_test.py @@ -31,11 +31,10 @@ def test_get_filter_data_query(test_client): assert_equal(query.starttime, UTCDateTime("2020-09-01T00:00:01")) assert_equal(query.endtime, UTCDateTime("2020-09-02T00:00:00.999")) assert_equal(query.elements, ["X", "Y", "Z", "F"]) - assert_equal(query.sampling_period, SamplingPeriod.MINUTE) assert_equal(query.format, "iaga2002") assert_equal(query.data_type, "variation") assert_equal(query.input_sampling_period, SamplingPeriod.MINUTE) - assert_equal(query.output_sampling_period, SamplingPeriod.HOUR) + assert_equal(query.sampling_period, SamplingPeriod.HOUR) def test_get_filter_data_query_no_starttime(test_client): diff --git a/test/edge_test/FDSNFactory_test.py b/test/edge_test/FDSNFactory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f409674a5a0be0a5540e3b52b075e9708fa782fa --- /dev/null +++ b/test/edge_test/FDSNFactory_test.py @@ -0,0 +1,206 @@ +import io +from typing import List + +import numpy +from numpy.testing import assert_equal, assert_array_equal +import numpy as np +from obspy.core import Stream, Trace, UTCDateTime +from obspy.core.inventory import Inventory, Network, Station, Channel, Site +import pytest + +from geomagio.edge import FDSNFactory +from geomagio.metadata.instrument.InstrumentCalibrations import ( + get_instrument_calibrations, +) +from .mseed_FDSN_test_clients import MockFDSNSeedClient + + +@pytest.fixture(scope="class") +def FDSN_factory() -> FDSNFactory: + """instance of FDSNFactory with MockFDSNClient""" + factory = FDSNFactory() + factory.client = MockFDSNSeedClient() + yield factory + + +@pytest.fixture() +def anmo_u_metadata(): + metadata = get_instrument_calibrations(observatory="ANMO") + instrument = metadata[0]["instrument"] + channels = instrument["channels"] + yield channels["X"] + + +def test__get_timeseries_add_empty_channels(FDSN_factory: FDSNFactory): + """test.edge_test.FDSNFactory_test.test__get_timeseries_add_empty_channels()""" + FDSN_factory.client.return_empty = True + starttime = UTCDateTime("2024-09-07T00:00:00Z") + endtime = UTCDateTime("2024-09-07T00:10:00Z") + trace = FDSN_factory._get_timeseries( + starttime=starttime, + endtime=endtime, + observatory="ANMO", + channel="X", + type="variation", + interval="second", + add_empty_channels=True, + )[0] + + assert_array_equal(trace.data, numpy.ones(trace.stats.npts) * numpy.nan) + assert trace.stats.starttime == starttime + assert trace.stats.endtime == endtime + + with pytest.raises(IndexError): + trace = FDSN_factory._get_timeseries( + starttime=starttime, + endtime=endtime, + observatory="ANMO", + channel="X", + type="variation", + interval="second", + add_empty_channels=False, + )[0] + + +def test__set_metadata(): + """edge_test.FDSNFactory_test.test__set_metadata()""" + # Call _set_metadata with 2 traces, and make certain the stats get + # set for both traces. + trace1 = Trace() + trace2 = Trace() + stream = Stream(traces=[trace1, trace2]) + FDSNFactory()._set_metadata(stream, "ANMO", "X", "variation", "second") + assert_equal(stream[0].stats["channel"], "X") + assert_equal(stream[1].stats["channel"], "X") + + +def test_get_timeseries(FDSN_factory): + """edge_test.FDSNFactory_test.test_get_timeseries()""" + # Call get_timeseries, and test stats for comfirmation that it came back. + # TODO, need to pass in host and port from a config file, or manually + # change for a single test. + timeseries = FDSN_factory.get_timeseries( + starttime=UTCDateTime(2024, 3, 1, 0, 0, 0), + endtime=UTCDateTime(2024, 3, 1, 1, 0, 0), + observatory="ANMO", + channels=("X"), + type="variation", + interval="second", + ) + + assert_equal( + timeseries.select(channel="X")[0].stats.station, + "ANMO", + "Expect timeseries to have stats", + ) + assert_equal( + timeseries.select(channel="X")[0].stats.channel, + "X", + "Expect timeseries stats channel to be equal to X", + ) + assert_equal( + timeseries.select(channel="X")[0].stats.data_type, + "variation", + "Expect timeseries stats data_type to be equal to variation", + ) + + +def test_get_timeseries_by_location(FDSN_factory): + """test.edge_test.FDSNFactory_test.test_get_timeseries_by_location()""" + timeseries = FDSN_factory.get_timeseries( + UTCDateTime(2024, 3, 1, 0, 0, 0), + UTCDateTime(2024, 3, 1, 1, 0, 0), + "ANMO", + ("X"), + "R0", + "second", + ) + assert_equal( + timeseries.select(channel="X")[0].stats.data_type, + "R0", + "Expect timeseries stats data_type to be equal to R0", + ) + + +def test_rotate_trace(): + # Initialize the factory + factory = FDSNFactory( + observatory="ANMO", + channels=["X", "Y", "Z"], + type="variation", + interval="second", + ) + + # Simulate input traces for X, Y, Z channels + starttime = UTCDateTime("2024-01-01T00:00:00") + endtime = UTCDateTime(2024, 1, 1, 0, 10) + data_x = Trace( + data=np.array([1, 2, 3, 4, 5]), + header={"channel": "X", "starttime": starttime, "delta": 60}, + ) + data_y = Trace( + data=np.array([6, 7, 8, 9, 10]), + header={"channel": "Y", "starttime": starttime, "delta": 60}, + ) + data_z = Trace( + data=np.array([11, 12, 13, 14, 15]), + header={"channel": "Z", "starttime": starttime, "delta": 60}, + ) + input_stream = Stream(traces=[data_x, data_y, data_z]) + + # Mock the Client.get_waveforms method to return the simulated stream + factory.Client.get_waveforms = lambda *args, **kwargs: input_stream + + # Create a mock inventory object for get stations + mock_inventory = create_mock_inventory() + # Mock the Client.get_stations method to return dummy inventory (if required for rotation) + factory.Client.get_stations = lambda *args, **kwargs: mock_inventory + + # Call get_timeseries with channel "X" to trigger rotation + rotated_stream = factory.get_timeseries( + starttime=starttime, + endtime=endtime, + observatory="ANMO", + channels=["X"], # Requesting any channel in [X, Y, Z] should trigger rotation + ) + + # Assertions + assert ( + len(rotated_stream) == 1 + ), "Expected only the requested channel (X, Y, Z) after rotation" + assert rotated_stream[0].stats.channel in [ + "X", + ], "Unexpected channel names after rotation" + assert ( + rotated_stream[0].stats.starttime == starttime + ), "Start time mismatch in rotated data" + + +def create_mock_inventory(): + """Creates a mock inventory for testing purposes.""" + # Create a dummy channel + channel = Channel( + code="X", + location_code="", + latitude=0.0, + longitude=0.0, + elevation=0.0, + depth=0.0, + azimuth=0.0, + dip=0.0, + sample_rate=1.0, + ) + # Create a dummy station + station = Station( + code="ANMO", + latitude=0.0, + longitude=0.0, + elevation=0.0, + site=Site(name="TestSite"), + channels=[channel], + ) + # Create a dummy network + network = Network(code="XX", stations=[station]) + # Create an inventory + inventory = Inventory(networks=[network], source="MockInventory") + return inventory diff --git a/test/edge_test/mseed_FDSN_test_clients.py b/test/edge_test/mseed_FDSN_test_clients.py new file mode 100644 index 0000000000000000000000000000000000000000..dcad34d2a32b1a258f61935a67aba5ee385d5841 --- /dev/null +++ b/test/edge_test/mseed_FDSN_test_clients.py @@ -0,0 +1,47 @@ +import numpy +from obspy import Stream, UTCDateTime +from obspy.clients.neic.client import Client + +from geomagio import TimeseriesUtility +from geomagio.edge import FDSNSNCL + + +class MockFDSNSeedClient(Client): + """replaces default obspy miniseed client's get_waveforms method to return trace of ones + + Note: includes 'return_empty' parameter to simulate situations where no data is received + """ + + def __init__(self, return_empty: bool = False): + self.return_empty = return_empty + + def get_waveforms( + self, + network: str, + station: str, + location: str, + channel: str, + starttime: UTCDateTime, + endtime: UTCDateTime, + ): + if self.return_empty: + return Stream() + sncl = FDSNSNCL( + station=station, + network=network, + channel=channel, + location=location, + ) + trace = TimeseriesUtility.create_empty_trace( + starttime=starttime, + endtime=endtime, + observatory=station, + channel=channel, + type=sncl.data_type, + interval=sncl.interval, + network=network, + station=station, + location=location, + ) + trace.data = numpy.ones(trace.stats.npts) + return Stream([trace])