From d0f94b4dff63364cb35de23b50c3edb398c01838 Mon Sep 17 00:00:00 2001
From: spencer <swilbur@usgs.gov>
Date: Wed, 11 Dec 2024 10:38:06 -0700
Subject: [PATCH] Removed union method from FilterApiQuery and Filter.py, I
 also added a negatvie test to filter_test to account check that an
 innapropriate float would cause and error. I also removed the asyn functions
 from data_test.py because they were being skipped in the testing pipeline. I
 made them regular function and they are now passing.

---
 geomagio/api/ws/FilterApiQuery.py    |  2 +-
 geomagio/api/ws/filter.py            |  4 ++--
 poetry.lock                          |  2 +-
 test/api_test/ws_test/data_test.py   | 30 +++++-----------------------
 test/api_test/ws_test/filter_test.py |  1 +
 5 files changed, 10 insertions(+), 29 deletions(-)

diff --git a/geomagio/api/ws/FilterApiQuery.py b/geomagio/api/ws/FilterApiQuery.py
index 0aab1c78..ee2a80bc 100644
--- a/geomagio/api/ws/FilterApiQuery.py
+++ b/geomagio/api/ws/FilterApiQuery.py
@@ -13,7 +13,7 @@ the fields input_sampling_period and output_sampling_period."""
 class FilterApiQuery(DataApiQuery):
     model_config = ConfigDict(extra="forbid")
 
-    input_sampling_period: Optional[Union[SamplingPeriod, float]] = None
+    input_sampling_period: Optional[SamplingPeriod] = None
 
     @model_validator(mode="after")
     def validate_sample_size(self):
diff --git a/geomagio/api/ws/filter.py b/geomagio/api/ws/filter.py
index 11c24dd0..ef0b824a 100644
--- a/geomagio/api/ws/filter.py
+++ b/geomagio/api/ws/filter.py
@@ -33,12 +33,12 @@ def get_filter_data_query(
     format: Union[OutputFormat, str] = Query(
         OutputFormat.IAGA2002, title="Output Format"
     ),
-    input_sampling_period: Optional[Union[SamplingPeriod, float]] = Query(
+    input_sampling_period: Optional[SamplingPeriod] = Query(
         None,
         title="Initial Sampling Period",
         description="`--` dynamically determines a necessary sampling period.",
     ),
-    sampling_period: Union[SamplingPeriod, float] = Query(
+    sampling_period: SamplingPeriod = Query(
         SamplingPeriod.SECOND,
         alias="output_sampling_period",
         title="Output sampling period",
diff --git a/poetry.lock b/poetry.lock
index d2d17532..95155b3c 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
 
 [[package]]
 name = "aiomysql"
diff --git a/test/api_test/ws_test/data_test.py b/test/api_test/ws_test/data_test.py
index 9a3da337..765fa407 100644
--- a/test/api_test/ws_test/data_test.py
+++ b/test/api_test/ws_test/data_test.py
@@ -57,38 +57,18 @@ def test_get_data_query_no_starttime(test_client):
 
 async def test_get_data_query_extra_params(test_client):
     with pytest.raises(ValueError) as error:
-        response = await test_client.get(
+        response = test_client.get(
             "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT"
         )
+        DataApiQuery(**response.json())
         assert error.match("Invalid query parameter(s): location, network")
 
 
-# def test_get_data_query_extra_params(test_client):
-#     """test.api_test.ws_test.data_test.test_get_data_query_extra_params()"""
-#     with pytest.raises(ValueError) as error:
-#         test_client.get(
-#             "/query/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=variation&sampling_period=60&format=iaga2002&location=R1&network=NT"
-#         )
-#         assert error.message == "Invalid query parameter(s): location, network"
-
-
-async def test_get_data_query_bad_params(test_client):
+def test_get_data_query_bad_params(test_client):
     """test.api_test.ws_test.data_test.test_get_data_query_bad_params()"""
     with pytest.raises(ValueError) as error:
-        response = await test_client.get(
+        response = test_client.get(
             "/query/?id=BOU&startime=2020-09-01T00:00:01&elements=X,Y,Z,F&data_type=variation&sampling_period=60&format=iaga2002"
         )
+        DataApiQuery(**response.json())
         assert error.match == "Invalid query parameter(s): startime, data_type"
-
-
-# def test_filter_data_query(test_client):
-#     """test.api_test.ws_test.data_test.test_filter_data_query()"""
-#     response = test_client.get(
-#         "/algorithms/filter/?id=BOU&starttime=2020-09-01T00:00:01&elements=X,Y,Z,F&type=R1&sampling_period=60&format=iaga2002&input_sampling_period=60&output_sampling_period=30"
-#     )
-#     filter_query = FilterDataApiQuery(**response.json())
-#     assert_equal(filter_query.id, "BOU")
-#     assert_equal(filter_query.starttime, UTCDateTime("2020-09-01T00:00:01"))
-#     assert_equal(filter_query.elements, ["X", "Y", "Z", "F"])
-#     assert_equal(filter_query.input_sampling_period, 60)
-#     assert_equal(filter_query.output_sampling_period, 30)
diff --git a/test/api_test/ws_test/filter_test.py b/test/api_test/ws_test/filter_test.py
index aa0a4d8f..d56b6352 100644
--- a/test/api_test/ws_test/filter_test.py
+++ b/test/api_test/ws_test/filter_test.py
@@ -3,6 +3,7 @@ from fastapi import Depends
 from fastapi.testclient import TestClient
 from numpy.testing import assert_equal
 from obspy import UTCDateTime
+from pydantic import ValidationError
 import pytest
 
 from geomagio.api.ws import app
-- 
GitLab