diff --git a/geomagio/api/ws/FilterApiQuery.py b/geomagio/api/ws/FilterApiQuery.py index 0aab1c78bfd38c5b2d481ed3e8998539d6e926a2..ee2a80bc6d80260e49449805fd9f053e08d95839 100644 --- a/geomagio/api/ws/FilterApiQuery.py +++ b/geomagio/api/ws/FilterApiQuery.py @@ -13,7 +13,7 @@ the fields input_sampling_period and output_sampling_period.""" class FilterApiQuery(DataApiQuery): model_config = ConfigDict(extra="forbid") - input_sampling_period: Optional[Union[SamplingPeriod, float]] = None + input_sampling_period: Optional[SamplingPeriod] = None @model_validator(mode="after") def validate_sample_size(self): diff --git a/geomagio/api/ws/filter.py b/geomagio/api/ws/filter.py index 11c24dd0614549cda6ed1c0d82f830f0a59ba55a..ef0b824a15091f7c623b427fcac346e97e51fbdb 100644 --- a/geomagio/api/ws/filter.py +++ b/geomagio/api/ws/filter.py @@ -33,12 +33,12 @@ def get_filter_data_query( format: Union[OutputFormat, str] = Query( OutputFormat.IAGA2002, title="Output Format" ), - input_sampling_period: Optional[Union[SamplingPeriod, float]] = Query( + input_sampling_period: Optional[SamplingPeriod] = Query( None, title="Initial Sampling Period", description="`--` dynamically determines a necessary sampling period.", ), - sampling_period: Union[SamplingPeriod, float] = Query( + sampling_period: SamplingPeriod = Query( SamplingPeriod.SECOND, alias="output_sampling_period", title="Output sampling period", diff --git a/poetry.lock b/poetry.lock index d2d17532d28de155a8a215cf90f1dbc421d0a6b5..95155b3c5c80324252313b55fe400525ae0fcd76 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiomysql" diff --git a/test/api_test/ws_test/data_test.py b/test/api_test/ws_test/data_test.py index 9a3da337112643e66030b3d0ac3406af19741419..765fa407b013db3f9e87b87d04c56a3b576d6d99 100644 --- a/test/api_test/ws_test/data_test.py +++ b/test/api_test/ws_test/data_test.py @@ -57,38 +57,18 @@ def test_get_data_query_no_starttime(test_client): async def test_get_data_query_extra_params(test_client): with pytest.raises(ValueError) as error: - response = await test_client.get( + response = test_client.get( "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT" ) + DataApiQuery(**response.json()) assert error.match("Invalid query parameter(s): location, network") -# def test_get_data_query_extra_params(test_client): -# """test.api_test.ws_test.data_test.test_get_data_query_extra_params()""" -# with pytest.raises(ValueError) as error: -# test_client.get( -# "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT" -# ) -# assert error.message == "Invalid query parameter(s): location, network" - - -async def test_get_data_query_bad_params(test_client): +def test_get_data_query_bad_params(test_client): """test.api_test.ws_test.data_test.test_get_data_query_bad_params()""" with pytest.raises(ValueError) as error: - response = await test_client.get( + response = test_client.get( "/query/?id=BOU&startime=2020-09-01T00:00:01&elements=X,Y,Z,F&data_type=variation&sampling_period=60&format=iaga2002" ) + DataApiQuery(**response.json()) assert error.match == "Invalid query parameter(s): startime, data_type" - - -# def test_filter_data_query(test_client): -# """test.api_test.ws_test.data_test.test_filter_data_query()""" -# response = test_client.get( -# "/algorithms/filter/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=R1&sampling_period=60&format=iaga2002&input_sampling_period=60&output_sampling_period=30" -# ) -# filter_query = FilterDataApiQuery(**response.json()) -# assert_equal(filter_query.id, "BOU") -# assert_equal(filter_query.starttime, UTCDateTime("2020-09-01T00:00:01")) -# assert_equal(filter_query.elements, ["X", "Y", "Z", "F"]) -# assert_equal(filter_query.input_sampling_period, 60) -# assert_equal(filter_query.output_sampling_period, 30) diff --git a/test/api_test/ws_test/filter_test.py b/test/api_test/ws_test/filter_test.py index aa0a4d8f05ce533b2bd787e353c07a32a7fea065..d56b63525e021527fc742b46a7302d5220bb39be 100644 --- a/test/api_test/ws_test/filter_test.py +++ b/test/api_test/ws_test/filter_test.py @@ -3,6 +3,7 @@ from fastapi import Depends from fastapi.testclient import TestClient from numpy.testing import assert_equal from obspy import UTCDateTime +from pydantic import ValidationError import pytest from geomagio.api.ws import app