From 292cde235b6a4cb89fb0b24db8779029df24a97e Mon Sep 17 00:00:00 2001 From: Alexandra Hobbs <ahobbs@contractor.usgs.gov> Date: Tue, 26 Nov 2024 10:06:25 -0700 Subject: [PATCH] make same serialization changes for metadataquery --- geomagio/api/db/MetadataDatabaseFactory.py | 4 +-- geomagio/metadata/MetadataQuery.py | 33 +++++++++++++++++----- test/db/MetadataDatabaseFactory_test.py | 22 ++++++++++++++- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/geomagio/api/db/MetadataDatabaseFactory.py b/geomagio/api/db/MetadataDatabaseFactory.py index b2274e8c..7a92a91d 100644 --- a/geomagio/api/db/MetadataDatabaseFactory.py +++ b/geomagio/api/db/MetadataDatabaseFactory.py @@ -43,9 +43,9 @@ class MetadataDatabaseFactory(object): channel, location, data_valid, - status, metadata, - ) = params.datetime_dict().values() + status, + ) = params.model_dump().values() if id: query = query.where(table.c.id == id) if category: diff --git a/geomagio/metadata/MetadataQuery.py b/geomagio/metadata/MetadataQuery.py index 15e70544..fbf74a00 100644 --- a/geomagio/metadata/MetadataQuery.py +++ b/geomagio/metadata/MetadataQuery.py @@ -1,7 +1,7 @@ from datetime import timezone from obspy import UTCDateTime -from pydantic import BaseModel +from pydantic import field_serializer, BaseModel from typing import List, Optional, Dict, Any from .MetadataCategory import MetadataCategory @@ -23,9 +23,28 @@ class MetadataQuery(BaseModel): metadata: Optional[Dict[str, Any]] = None status: Optional[List[str]] = None - def datetime_dict(self, **kwargs): - values = self.model_dump(**kwargs) - for key in ["starttime", "endtime", "created_after", "created_before"]: - if key in values and values[key] is not None: - values[key] = values[key].datetime.replace(tzinfo=timezone.utc) - return values + # instructions for model_dump() to serialize pydantic CustomUTCDateTimeType into aware datetime.datetime type + # sqlalchemy is expecting aware datetime.datetime, not the string model_dump() creates by default + @field_serializer("created_after") + def serialize_created_after(self, created_after: UTCDateTime): + if created_after is not None: + created_after = created_after.datetime.replace(tzinfo=timezone.utc) + return created_after + + @field_serializer("created_before") + def serialize_created_before(self, created_before: UTCDateTime): + if created_before is not None: + created_before = created_before.datetime.replace(tzinfo=timezone.utc) + return created_before + + @field_serializer("starttime") + def serialize_starttime(self, starttime: UTCDateTime): + if starttime is not None: + starttime = starttime.datetime.replace(tzinfo=timezone.utc) + return starttime + + @field_serializer("endtime") + def serialize_endtime(self, endtime: UTCDateTime): + if endtime is not None: + endtime = endtime.datetime.replace(tzinfo=timezone.utc) + return endtime diff --git a/test/db/MetadataDatabaseFactory_test.py b/test/db/MetadataDatabaseFactory_test.py index f69c9a89..cc830080 100644 --- a/test/db/MetadataDatabaseFactory_test.py +++ b/test/db/MetadataDatabaseFactory_test.py @@ -7,7 +7,7 @@ from databases import Database from obspy import UTCDateTime from geomagio.api.db import MetadataDatabaseFactory -from geomagio.metadata import Metadata, MetadataCategory +from geomagio.metadata import Metadata, MetadataCategory, MetadataQuery class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase): @@ -419,3 +419,23 @@ class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase): assert second_called_params["updated_by"] == "test_user" assert second_called_params["updated_time"] is not None assert second_called_params["metadata"] == test_data.metadata + + @patch("databases.Database.fetch_all", new_callable=AsyncMock) + async def test_get_metadata(self, mock_fetch_all): + test_query = MetadataQuery( + category=MetadataCategory.INSTRUMENT, + station="BSL", + starttime=UTCDateTime(2020, 1, 20) + ) + + db = Database("sqlite:///:memory:") + + await MetadataDatabaseFactory(database=db).get_metadata(params=test_query) + + mock_fetch_all.assert_called_once() + + called_params = mock_fetch_all.call_args.args[0].compile().params + + assert called_params["category_1"] == "instrument" + assert called_params["station_1"] == "BSL" + assert called_params["endtime_1"] == datetime.datetime(2020, 1, 20, tzinfo=tz.tzutc()) -- GitLab