diff --git a/geomagio/api/db/MetadataDatabaseFactory.py b/geomagio/api/db/MetadataDatabaseFactory.py index b2274e8c4e1e0618fa532333144f377e802418b1..7a92a91d52b072e7522f8e4db2aff577cc39d276 100644 --- a/geomagio/api/db/MetadataDatabaseFactory.py +++ b/geomagio/api/db/MetadataDatabaseFactory.py @@ -43,9 +43,9 @@ class MetadataDatabaseFactory(object): channel, location, data_valid, - status, metadata, - ) = params.datetime_dict().values() + status, + ) = params.model_dump().values() if id: query = query.where(table.c.id == id) if category: diff --git a/geomagio/metadata/MetadataQuery.py b/geomagio/metadata/MetadataQuery.py index 15e7054487aac492ecf1b7001a6b4e6249431c0f..fbf74a008d6d9573c73b850d83bb802010cddd44 100644 --- a/geomagio/metadata/MetadataQuery.py +++ b/geomagio/metadata/MetadataQuery.py @@ -1,7 +1,7 @@ from datetime import timezone from obspy import UTCDateTime -from pydantic import BaseModel +from pydantic import field_serializer, BaseModel from typing import List, Optional, Dict, Any from .MetadataCategory import MetadataCategory @@ -23,9 +23,28 @@ class MetadataQuery(BaseModel): metadata: Optional[Dict[str, Any]] = None status: Optional[List[str]] = None - def datetime_dict(self, **kwargs): - values = self.model_dump(**kwargs) - for key in ["starttime", "endtime", "created_after", "created_before"]: - if key in values and values[key] is not None: - values[key] = values[key].datetime.replace(tzinfo=timezone.utc) - return values + # instructions for model_dump() to serialize pydantic CustomUTCDateTimeType into aware datetime.datetime type + # sqlalchemy is expecting aware datetime.datetime, not the string model_dump() creates by default + @field_serializer("created_after") + def serialize_created_after(self, created_after: UTCDateTime): + if created_after is not None: + created_after = created_after.datetime.replace(tzinfo=timezone.utc) + return created_after + + @field_serializer("created_before") + def serialize_created_before(self, created_before: UTCDateTime): + if created_before is not None: + created_before = created_before.datetime.replace(tzinfo=timezone.utc) + return created_before + + @field_serializer("starttime") + def serialize_starttime(self, starttime: UTCDateTime): + if starttime is not None: + starttime = starttime.datetime.replace(tzinfo=timezone.utc) + return starttime + + @field_serializer("endtime") + def serialize_endtime(self, endtime: UTCDateTime): + if endtime is not None: + endtime = endtime.datetime.replace(tzinfo=timezone.utc) + return endtime diff --git a/test/db/MetadataDatabaseFactory_test.py b/test/db/MetadataDatabaseFactory_test.py index f69c9a890d136cfeef82205620628f3670f31654..61e60299d56e345432327275b20cf047c721efff 100644 --- a/test/db/MetadataDatabaseFactory_test.py +++ b/test/db/MetadataDatabaseFactory_test.py @@ -7,7 +7,7 @@ from databases import Database from obspy import UTCDateTime from geomagio.api.db import MetadataDatabaseFactory -from geomagio.metadata import Metadata, MetadataCategory +from geomagio.metadata import Metadata, MetadataCategory, MetadataQuery class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase): @@ -419,3 +419,25 @@ class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase): assert second_called_params["updated_by"] == "test_user" assert second_called_params["updated_time"] is not None assert second_called_params["metadata"] == test_data.metadata + + @patch("databases.Database.fetch_all", new_callable=AsyncMock) + async def test_get_metadata(self, mock_fetch_all): + test_query = MetadataQuery( + category=MetadataCategory.INSTRUMENT, + station="BSL", + starttime=UTCDateTime(2020, 1, 20), + ) + + db = Database("sqlite:///:memory:") + + await MetadataDatabaseFactory(database=db).get_metadata(params=test_query) + + mock_fetch_all.assert_called_once() + + called_params = mock_fetch_all.call_args.args[0].compile().params + + assert called_params["category_1"] == "instrument" + assert called_params["station_1"] == "BSL" + assert called_params["endtime_1"] == datetime.datetime( + 2020, 1, 20, tzinfo=tz.tzutc() + )