Skip to content
Snippets Groups Projects
Commit c25b5b3e authored by Wilbur, Spencer Franklin's avatar Wilbur, Spencer Franklin
Browse files

Merge branch 'CLEAN_Recursive_Filter_Branch' into 'master'

Created a fresh branch to commit changes to files for adding recursive logic,...

See merge request !366
parents fbfb312c e639ddd2
No related branches found
No related tags found
1 merge request!366Created a fresh branch to commit changes to files for adding recursive logic,...
Pipeline #544645 passed
import datetime
import enum
from enum import Enum
import os
from typing import List, Optional, Union
......@@ -9,8 +10,12 @@ from pydantic import ConfigDict, field_validator, model_validator, Field, BaseMo
from .Element import ELEMENTS
from .Observatory import OBSERVATORY_INDEX, ASL_OBSERVATORY_INDEX
from ...pydantic_utcdatetime import CustomUTCDateTimeType
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
DEFAULT_ELEMENTS = ["X", "Y", "Z", "F"]
REQUEST_LIMIT = 3456000 # Increased the request limit by 10x what was decided by Jeremy
VALID_ELEMENTS = [e.id for e in ELEMENTS]
......@@ -68,7 +73,7 @@ class DataApiQuery(BaseModel):
# endtime default is dependent on start time, so it's handled after validation in the model_validator
endtime: Optional[CustomUTCDateTimeType] = None
elements: List[str] = DEFAULT_ELEMENTS
sampling_period: SamplingPeriod = SamplingPeriod.MINUTE
sampling_period: Optional[SamplingPeriod] = None
data_type: Union[DataType, str] = DataType.VARIATION
format: Union[OutputFormat, str] = OutputFormat.IAGA2002
data_host: Union[DataHost, str] = DataHost.DEFAULT
......@@ -123,11 +128,20 @@ class DataApiQuery(BaseModel):
self.endtime = self.starttime + (86400 - 0.001)
if self.starttime > self.endtime:
raise ValueError("Starttime must be before endtime.")
# check data volume
samples = int(
len(self.elements) * (self.endtime - self.starttime) / self.sampling_period
)
if samples > REQUEST_LIMIT:
raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})")
# otherwise okay
# check data volume and if SamplingPeriod is assigned as None
if self.sampling_period is None:
logging.warning(
"Sampling period is None. Default value or further processing needed."
)
else:
samples = int(
len(self.elements)
* (self.endtime - self.starttime)
/ self.sampling_period
)
if samples > REQUEST_LIMIT:
raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})")
# otherwise okay
return self
from .DataApiQuery import DataApiQuery, SamplingPeriod, REQUEST_LIMIT
from pydantic import ConfigDict, model_validator
from .DataApiQuery import (
DataApiQuery,
SamplingPeriod,
REQUEST_LIMIT,
)
from pydantic import ConfigDict, model_validator, field_validator, ValidationError
from typing import Optional
import logging
from fastapi import HTTPException
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
"""This class inherits all the fields and validation on DataApiQuery and adds
the fields input_sampling_period and output_sampling_period."""
......@@ -8,20 +20,26 @@ the fields input_sampling_period and output_sampling_period."""
class FilterApiQuery(DataApiQuery):
model_config = ConfigDict(extra="forbid")
input_sampling_period: SamplingPeriod = SamplingPeriod.SECOND
output_sampling_period: SamplingPeriod = SamplingPeriod.MINUTE
input_sampling_period: Optional[SamplingPeriod] = None
@model_validator(mode="after")
def validate_sample_size(self):
# Calculate the number of samples based on the input sampling period
samples = int(
len(self.elements)
* (self.endtime - self.starttime)
/ self.input_sampling_period
)
# Validate the request size
if samples > REQUEST_LIMIT:
raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})")
if self.sampling_period is None:
# Log a warning indicating that the sampling period is missing
logging.warning(
"Sampling period is None. Please provide a valid Sampling Period."
)
else:
# Calculate the number of samples based on the input sampling period
samples = int(
len(self.elements)
* (self.endtime - self.starttime)
/ self.sampling_period
)
# Validate the request size
if samples > REQUEST_LIMIT:
raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})")
return self
......@@ -2,20 +2,21 @@ import json
from fastapi import APIRouter, Depends, HTTPException, Query
from starlette.responses import Response
from obspy.core import Stream, Stats
from typing import List, Union
from ...algorithm import DbDtAlgorithm, FilterAlgorithm
from ...algorithm import DbDtAlgorithm
from ...residual import (
calculate,
Reading,
)
from .DataApiQuery import DataApiQuery, SamplingPeriod
from .DataApiQuery import DataApiQuery
from .FilterApiQuery import FilterApiQuery
from .data import format_timeseries, get_data_factory, get_data_query, get_timeseries
from .filter import get_filter_data_query
from . import filter
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
......@@ -43,42 +44,16 @@ def get_dbdt(
####################################### The router .get filter isnt visible on the docs page
# Look for register routers in the backend
@router.get(
"/algorithms/filter/",
description="Filtered data dependent on requested interval",
name="Filtered Algorithm",
)
# New query parameter defined below. I am using a new query defined in DataApitQuery.
# This relies on the new filter.py module and the get_filter_data function
# to define input and output sampling period.
def get_filter(
query: FilterApiQuery = Depends(get_filter_data_query),
) -> Response:
filt = FilterAlgorithm(
input_sample_period=query.input_sampling_period,
output_sample_period=query.output_sampling_period,
)
# Grab the correct starttime and endtime for the timeseries from get_input_interval
starttime, endtime = filt.get_input_interval(query.starttime, query.endtime)
def get_filter(query: FilterApiQuery = Depends(get_filter_data_query)) -> Response:
# Reassign the actual start/endtime to the query parameters
query.starttime = starttime
query.endtime = endtime
data_factory = get_data_factory(query=query)
# read data
raw = filter.get_timeseries(data_factory, query)
filtered_timeseries = filt.process(raw)
filtered_timeseries = filter.get_timeseries(query)
elements = [f"{element}" for element in query.elements]
# output response
return format_timeseries(
timeseries=filtered_timeseries, format=query.format, elements=elements
)
......
......@@ -74,7 +74,7 @@ def get_data_query(
" NOTE: when using 'iaga2002' output format, a maximum of 4 elements is allowed",
),
sampling_period: Union[SamplingPeriod, float] = Query(
SamplingPeriod.MINUTE,
None,
title="data rate",
description="Interval in seconds between values.",
),
......
from typing import List, Union, Optional
from fastapi import Query
from obspy import UTCDateTime, Stream
from ... import TimeseriesFactory, TimeseriesUtility
from obspy import Stream
from ... import TimeseriesUtility
import numpy as np
from .DataApiQuery import (
DEFAULT_ELEMENTS,
DataHost,
......@@ -10,6 +11,9 @@ from .DataApiQuery import (
SamplingPeriod,
)
from .FilterApiQuery import FilterApiQuery
from ...algorithm import FilterAlgorithm
import logging as logger
from .data import get_data_factory
from ...pydantic_utcdatetime import CustomUTCDateTimeType
......@@ -26,11 +30,15 @@ def get_filter_data_query(
format: Union[OutputFormat, str] = Query(
OutputFormat.IAGA2002, title="Output Format"
),
input_sampling_period: Union[SamplingPeriod, float] = Query(
SamplingPeriod.SECOND, title="Initial sampling period"
input_sampling_period: Optional[SamplingPeriod] = Query(
None,
title="Input Sampling Period",
description="`--` dynamically determines a necessary sampling period.",
),
output_sampling_period: Union[SamplingPeriod, float] = Query(
SamplingPeriod.MINUTE, title="Output sampling period"
sampling_period: Optional[SamplingPeriod] = Query(
None,
alias="output_sampling_period",
title="Output sampling period",
),
data_host: Union[DataHost, str] = Query(
DataHost.DEFAULT, title="Data Host", description="Edge host to pull data from."
......@@ -43,28 +51,107 @@ def get_filter_data_query(
endtime=endtime,
elements=elements,
input_sampling_period=input_sampling_period,
output_sampling_period=output_sampling_period,
sampling_period=sampling_period,
data_type=data_type,
data_host=data_host,
format=format,
)
def get_timeseries(data_factory: TimeseriesFactory, query: FilterApiQuery) -> Stream:
"""Get timeseries data for variometers
# Main filter function
def get_timeseries(query: FilterApiQuery) -> Stream:
data_factory = get_data_factory(query=query)
Parameters
----------
data_factory: where to read data
query: parameters for the data to read
"""
# get data
timeseries = data_factory.get_timeseries(
starttime=query.starttime,
endtime=query.endtime,
# Determine input sampling period if not provided
if query.input_sampling_period is None:
# Dynamically determine the input sampling period
input_sampling_period, data = determine_available_period(
query.sampling_period, query, data_factory
)
else:
input_sampling_period = query.input_sampling_period
filt = FilterAlgorithm(
input_sample_period=input_sampling_period,
output_sample_period=query.sampling_period,
)
# Fetch filtered data
starttime, endtime = filt.get_input_interval(query.starttime, query.endtime)
data = data_factory.get_timeseries(
starttime=starttime,
endtime=endtime,
observatory=query.id,
channels=query.elements,
type=query.data_type,
interval=TimeseriesUtility.get_interval_from_delta(query.input_sampling_period),
interval=TimeseriesUtility.get_interval_from_delta(filt.input_sample_period),
)
return timeseries
# Apply filtering if needed
filtered_timeseries = filt.process(data)
return filtered_timeseries
def determine_available_period(output_sampling_period: float, query, data_factory):
"""
Finds the lowest resolution (longest sampling period) <= output_sampling_period
that has valid data available.
"""
# Sort and filter periods starting from output_sampling_period
sorted_periods: List[SamplingPeriod] = sorted(
SamplingPeriod, key=lambda p: p.value, reverse=True
)
if output_sampling_period is None:
raise ValueError("Output sampling period cannot be None.")
else:
valid_sampling_periods = [
p for p in sorted_periods if p.value <= output_sampling_period
]
for period in valid_sampling_periods:
if period <= output_sampling_period:
data = data_factory.get_timeseries(
starttime=query.starttime,
endtime=query.endtime,
observatory=query.id,
channels=query.elements,
type=query.data_type,
interval=TimeseriesUtility.get_interval_from_delta(period.value),
)
# Check if the fetched data is valid
if is_valid_data(data):
logger.info(f"Valid data found for sampling period: {period.name}")
return period.value, data # Return the sampling period and the data
else:
logger.error(
f"No valid data found for requested sampling period: {period.name}"
)
continue
raise ValueError("No valid data found for the requested output sampling period.")
def is_valid_data(data: Stream) -> bool:
"""
Checks if the fetched data contains actual values and not just filler values (e.g., NaN).
A Stream is invalid if any trace contains only NaN values.
"""
if not data or len(data) == 0:
return False # No data in the stream
for trace in data:
# Check if trace.data exists and has data
if trace.data is None or len(trace.data) == 0:
return False # Trace has no data
# Check if all values in trace.data are NaN
if np.all(np.isnan(trace.data)):
return False # Invalid if all values are NaN
return True # All traces are valid
This diff is collapsed.
......@@ -3,3 +3,9 @@ norecursedirs = */site-packages
testpaths = test
asyncio_mode=auto
asyncio_default_fixture_loop_scope="function"
# Suppress warnings of level WARNING and below
log_level = WARNING
# Optionally, you can filter out UserWarnings generated by logging
filterwarnings =
ignore::UserWarning
\ No newline at end of file
......@@ -23,7 +23,7 @@ def test_DataApiQuery_defaults():
assert_equal(query.starttime, expected_start_time)
assert_equal(query.endtime, expected_endtime)
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.sampling_period, None)
assert_equal(query.data_type, DataType.VARIATION)
assert_equal(query.format, OutputFormat.IAGA2002)
# assumes the env var DATA_HOST is not set
......@@ -41,7 +41,7 @@ def test_DataApiQuery_starttime_is_none():
assert_equal(query.starttime, expected_start_time)
assert_equal(query.endtime, expected_endtime)
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.sampling_period, None)
assert_equal(query.data_type, DataType.VARIATION)
assert_equal(query.format, OutputFormat.IAGA2002)
# assumes the env var DATA_HOST is not set
......@@ -100,7 +100,7 @@ def test_DataApiQuery_default_endtime():
# endtime is 1 day after start time
assert_equal(query.endtime, UTCDateTime("2024-11-02T00:00:00.999"))
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.sampling_period, None)
assert_equal(query.data_type, DataType.VARIATION)
assert_equal(query.format, OutputFormat.IAGA2002)
assert_equal(query.data_host, DataHost.DEFAULT)
......@@ -122,7 +122,7 @@ def test_DataApiQuery_default_only_endtime():
assert_equal(query.endtime, hour_later)
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.sampling_period, None)
assert_equal(query.data_type, DataType.VARIATION)
assert_equal(query.format, OutputFormat.IAGA2002)
assert_equal(query.data_host, DataHost.DEFAULT)
......
......@@ -22,8 +22,8 @@ def test_FilterApiQuery_defaults():
assert_equal(query.starttime, expected_start_time)
assert_equal(query.endtime, expected_endtime)
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.input_sampling_period, SamplingPeriod.SECOND)
assert_equal(query.output_sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.input_sampling_period, None)
assert_equal(query.sampling_period, None)
assert_equal(query.data_type, DataType.VARIATION)
assert_equal(query.format, OutputFormat.IAGA2002)
assert_equal(query.data_host, DataHost.DEFAULT)
......@@ -36,7 +36,7 @@ def test_FilterApiQuery_valid():
endtime="2024-09-01T01:00:01",
elements=["Z"],
input_sampling_period=60,
output_sampling_period=3600,
sampling_period=3600,
data_type="adjusted",
format="json",
data_host="cwbpub.cr.usgs.gov",
......@@ -47,7 +47,7 @@ def test_FilterApiQuery_valid():
assert_equal(query.endtime, UTCDateTime("2024-09-01T01:00:01"))
assert_equal(query.elements, ["Z"])
assert_equal(query.input_sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.output_sampling_period, SamplingPeriod.HOUR)
assert_equal(query.sampling_period, SamplingPeriod.HOUR)
assert_equal(query.data_type, "adjusted")
assert_equal(query.format, "json")
assert_equal(query.data_host, "cwbpub.cr.usgs.gov")
......@@ -83,7 +83,7 @@ def test_FilterApiQuery_default_endtime():
# endtime is 1 day after start time
assert_equal(query.endtime, UTCDateTime("2024-11-02T00:00:00.999"))
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.sampling_period, None)
assert_equal(query.data_type, DataType.VARIATION)
assert_equal(query.format, OutputFormat.IAGA2002)
assert_equal(query.data_host, DataHost.DEFAULT)
......@@ -190,3 +190,13 @@ def test_FilterApiQuery_extra_fields():
assert "Extra inputs are not permitted" == err[0]["msg"]
assert_equal(query, None)
def test_FilterApiQuery_no_output_sampling_period():
query = None
try:
query = FilterApiQuery(id="ANMO", sampling_period=None)
except Exception as e:
err = e.errors()
assert "Output sampling period cannot be None." == err[0]["msg"]
assert_equal(query, None)
......@@ -28,7 +28,7 @@ def test_client():
def test_get_data_query(test_client):
"""test.api_test.ws_test.data_test.test_get_data_query()"""
response = test_client.get(
"/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=R1&sampling_period=60&format=iaga2002"
"/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002"
)
query = DataApiQuery(**response.json())
assert_equal(query.id, "BOU")
......@@ -37,7 +37,7 @@ def test_get_data_query(test_client):
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.format, "iaga2002")
assert_equal(query.data_type, "R1")
assert_equal(query.data_type, "variation")
def test_get_data_query_no_starttime(test_client):
......@@ -57,38 +57,18 @@ def test_get_data_query_no_starttime(test_client):
async def test_get_data_query_extra_params(test_client):
with pytest.raises(ValueError) as error:
response = await test_client.get(
response = test_client.get(
"/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT"
)
DataApiQuery(**response.json())
assert error.match("Invalid query parameter(s): location, network")
# def test_get_data_query_extra_params(test_client):
# """test.api_test.ws_test.data_test.test_get_data_query_extra_params()"""
# with pytest.raises(ValueError) as error:
# test_client.get(
# "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT"
# )
# assert error.message == "Invalid query parameter(s): location, network"
async def test_get_data_query_bad_params(test_client):
def test_get_data_query_bad_params(test_client):
"""test.api_test.ws_test.data_test.test_get_data_query_bad_params()"""
with pytest.raises(ValueError) as error:
response = await test_client.get(
response = test_client.get(
"/query/?id=BOU&startime=2020-09-01T00:00:01&elements=X,Y,Z,F&data_type=variation&sampling_period=60&format=iaga2002"
)
DataApiQuery(**response.json())
assert error.match == "Invalid query parameter(s): startime, data_type"
# def test_filter_data_query(test_client):
# """test.api_test.ws_test.data_test.test_filter_data_query()"""
# response = test_client.get(
# "/algorithms/filter/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=R1&sampling_period=60&format=iaga2002&input_sampling_period=60&output_sampling_period=30"
# )
# filter_query = FilterDataApiQuery(**response.json())
# assert_equal(filter_query.id, "BOU")
# assert_equal(filter_query.starttime, UTCDateTime("2020-09-01T00:00:01"))
# assert_equal(filter_query.elements, ["X", "Y", "Z", "F"])
# assert_equal(filter_query.input_sampling_period, 60)
# assert_equal(filter_query.output_sampling_period, 30)
......@@ -31,11 +31,10 @@ def test_get_filter_data_query(test_client):
assert_equal(query.starttime, UTCDateTime("2020-09-01T00:00:01"))
assert_equal(query.endtime, UTCDateTime("2020-09-02T00:00:00.999"))
assert_equal(query.elements, ["X", "Y", "Z", "F"])
assert_equal(query.sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.format, "iaga2002")
assert_equal(query.data_type, "variation")
assert_equal(query.input_sampling_period, SamplingPeriod.MINUTE)
assert_equal(query.output_sampling_period, SamplingPeriod.HOUR)
assert_equal(query.sampling_period, SamplingPeriod.HOUR)
def test_get_filter_data_query_no_starttime(test_client):
......
import io
from typing import List
import numpy
from numpy.testing import assert_equal, assert_array_equal
import numpy as np
from obspy.core import Stream, Trace, UTCDateTime
from obspy.core.inventory import Inventory, Network, Station, Channel, Site
import pytest
from geomagio.edge import FDSNFactory
from geomagio.metadata.instrument.InstrumentCalibrations import (
get_instrument_calibrations,
)
from .mseed_FDSN_test_clients import MockFDSNSeedClient
@pytest.fixture(scope="class")
def FDSN_factory() -> FDSNFactory:
"""instance of FDSNFactory with MockFDSNClient"""
factory = FDSNFactory()
factory.client = MockFDSNSeedClient()
yield factory
@pytest.fixture()
def anmo_u_metadata():
metadata = get_instrument_calibrations(observatory="ANMO")
instrument = metadata[0]["instrument"]
channels = instrument["channels"]
yield channels["X"]
def test__get_timeseries_add_empty_channels(FDSN_factory: FDSNFactory):
"""test.edge_test.FDSNFactory_test.test__get_timeseries_add_empty_channels()"""
FDSN_factory.client.return_empty = True
starttime = UTCDateTime("2024-09-07T00:00:00Z")
endtime = UTCDateTime("2024-09-07T00:10:00Z")
trace = FDSN_factory._get_timeseries(
starttime=starttime,
endtime=endtime,
observatory="ANMO",
channel="X",
type="variation",
interval="second",
add_empty_channels=True,
)[0]
assert_array_equal(trace.data, numpy.ones(trace.stats.npts) * numpy.nan)
assert trace.stats.starttime == starttime
assert trace.stats.endtime == endtime
with pytest.raises(IndexError):
trace = FDSN_factory._get_timeseries(
starttime=starttime,
endtime=endtime,
observatory="ANMO",
channel="X",
type="variation",
interval="second",
add_empty_channels=False,
)[0]
def test__set_metadata():
"""edge_test.FDSNFactory_test.test__set_metadata()"""
# Call _set_metadata with 2 traces, and make certain the stats get
# set for both traces.
trace1 = Trace()
trace2 = Trace()
stream = Stream(traces=[trace1, trace2])
FDSNFactory()._set_metadata(stream, "ANMO", "X", "variation", "second")
assert_equal(stream[0].stats["channel"], "X")
assert_equal(stream[1].stats["channel"], "X")
def test_get_timeseries(FDSN_factory):
"""edge_test.FDSNFactory_test.test_get_timeseries()"""
# Call get_timeseries, and test stats for comfirmation that it came back.
# TODO, need to pass in host and port from a config file, or manually
# change for a single test.
timeseries = FDSN_factory.get_timeseries(
starttime=UTCDateTime(2024, 3, 1, 0, 0, 0),
endtime=UTCDateTime(2024, 3, 1, 1, 0, 0),
observatory="ANMO",
channels=("X"),
type="variation",
interval="second",
)
assert_equal(
timeseries.select(channel="X")[0].stats.station,
"ANMO",
"Expect timeseries to have stats",
)
assert_equal(
timeseries.select(channel="X")[0].stats.channel,
"X",
"Expect timeseries stats channel to be equal to X",
)
assert_equal(
timeseries.select(channel="X")[0].stats.data_type,
"variation",
"Expect timeseries stats data_type to be equal to variation",
)
def test_get_timeseries_by_location(FDSN_factory):
"""test.edge_test.FDSNFactory_test.test_get_timeseries_by_location()"""
timeseries = FDSN_factory.get_timeseries(
UTCDateTime(2024, 3, 1, 0, 0, 0),
UTCDateTime(2024, 3, 1, 1, 0, 0),
"ANMO",
("X"),
"R0",
"second",
)
assert_equal(
timeseries.select(channel="X")[0].stats.data_type,
"R0",
"Expect timeseries stats data_type to be equal to R0",
)
def test_rotate_trace():
# Initialize the factory
factory = FDSNFactory(
observatory="ANMO",
channels=["X", "Y", "Z"],
type="variation",
interval="second",
)
# Simulate input traces for X, Y, Z channels
starttime = UTCDateTime("2024-01-01T00:00:00")
endtime = UTCDateTime(2024, 1, 1, 0, 10)
data_x = Trace(
data=np.array([1, 2, 3, 4, 5]),
header={"channel": "X", "starttime": starttime, "delta": 60},
)
data_y = Trace(
data=np.array([6, 7, 8, 9, 10]),
header={"channel": "Y", "starttime": starttime, "delta": 60},
)
data_z = Trace(
data=np.array([11, 12, 13, 14, 15]),
header={"channel": "Z", "starttime": starttime, "delta": 60},
)
input_stream = Stream(traces=[data_x, data_y, data_z])
# Mock the Client.get_waveforms method to return the simulated stream
factory.Client.get_waveforms = lambda *args, **kwargs: input_stream
# Create a mock inventory object for get stations
mock_inventory = create_mock_inventory()
# Mock the Client.get_stations method to return dummy inventory (if required for rotation)
factory.Client.get_stations = lambda *args, **kwargs: mock_inventory
# Call get_timeseries with channel "X" to trigger rotation
rotated_stream = factory.get_timeseries(
starttime=starttime,
endtime=endtime,
observatory="ANMO",
channels=["X"], # Requesting any channel in [X, Y, Z] should trigger rotation
)
# Assertions
assert (
len(rotated_stream) == 1
), "Expected only the requested channel (X, Y, Z) after rotation"
assert rotated_stream[0].stats.channel in [
"X",
], "Unexpected channel names after rotation"
assert (
rotated_stream[0].stats.starttime == starttime
), "Start time mismatch in rotated data"
def create_mock_inventory():
"""Creates a mock inventory for testing purposes."""
# Create a dummy channel
channel = Channel(
code="X",
location_code="",
latitude=0.0,
longitude=0.0,
elevation=0.0,
depth=0.0,
azimuth=0.0,
dip=0.0,
sample_rate=1.0,
)
# Create a dummy station
station = Station(
code="ANMO",
latitude=0.0,
longitude=0.0,
elevation=0.0,
site=Site(name="TestSite"),
channels=[channel],
)
# Create a dummy network
network = Network(code="XX", stations=[station])
# Create an inventory
inventory = Inventory(networks=[network], source="MockInventory")
return inventory
import numpy
from obspy import Stream, UTCDateTime
from obspy.clients.neic.client import Client
from geomagio import TimeseriesUtility
from geomagio.edge import FDSNSNCL
class MockFDSNSeedClient(Client):
"""replaces default obspy miniseed client's get_waveforms method to return trace of ones
Note: includes 'return_empty' parameter to simulate situations where no data is received
"""
def __init__(self, return_empty: bool = False):
self.return_empty = return_empty
def get_waveforms(
self,
network: str,
station: str,
location: str,
channel: str,
starttime: UTCDateTime,
endtime: UTCDateTime,
):
if self.return_empty:
return Stream()
sncl = FDSNSNCL(
station=station,
network=network,
channel=channel,
location=location,
)
trace = TimeseriesUtility.create_empty_trace(
starttime=starttime,
endtime=endtime,
observatory=station,
channel=channel,
type=sncl.data_type,
interval=sncl.interval,
network=network,
station=station,
location=location,
)
trace.data = numpy.ones(trace.stats.npts)
return Stream([trace])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment