Skip to content
Snippets Groups Projects
Commit 292cde23 authored by Hobbs, Alexandra (Contractor)'s avatar Hobbs, Alexandra (Contractor)
Browse files

make same serialization changes for metadataquery

parent 85e59ba8
No related branches found
No related tags found
1 merge request!361make same serialization changes for metadataquery
......@@ -43,9 +43,9 @@ class MetadataDatabaseFactory(object):
channel,
location,
data_valid,
status,
metadata,
) = params.datetime_dict().values()
status,
) = params.model_dump().values()
if id:
query = query.where(table.c.id == id)
if category:
......
from datetime import timezone
from obspy import UTCDateTime
from pydantic import BaseModel
from pydantic import field_serializer, BaseModel
from typing import List, Optional, Dict, Any
from .MetadataCategory import MetadataCategory
......@@ -23,9 +23,28 @@ class MetadataQuery(BaseModel):
metadata: Optional[Dict[str, Any]] = None
status: Optional[List[str]] = None
def datetime_dict(self, **kwargs):
values = self.model_dump(**kwargs)
for key in ["starttime", "endtime", "created_after", "created_before"]:
if key in values and values[key] is not None:
values[key] = values[key].datetime.replace(tzinfo=timezone.utc)
return values
# instructions for model_dump() to serialize pydantic CustomUTCDateTimeType into aware datetime.datetime type
# sqlalchemy is expecting aware datetime.datetime, not the string model_dump() creates by default
@field_serializer("created_after")
def serialize_created_after(self, created_after: UTCDateTime):
if created_after is not None:
created_after = created_after.datetime.replace(tzinfo=timezone.utc)
return created_after
@field_serializer("created_before")
def serialize_created_before(self, created_before: UTCDateTime):
if created_before is not None:
created_before = created_before.datetime.replace(tzinfo=timezone.utc)
return created_before
@field_serializer("starttime")
def serialize_starttime(self, starttime: UTCDateTime):
if starttime is not None:
starttime = starttime.datetime.replace(tzinfo=timezone.utc)
return starttime
@field_serializer("endtime")
def serialize_endtime(self, endtime: UTCDateTime):
if endtime is not None:
endtime = endtime.datetime.replace(tzinfo=timezone.utc)
return endtime
......@@ -7,7 +7,7 @@ from databases import Database
from obspy import UTCDateTime
from geomagio.api.db import MetadataDatabaseFactory
from geomagio.metadata import Metadata, MetadataCategory
from geomagio.metadata import Metadata, MetadataCategory, MetadataQuery
class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase):
......@@ -419,3 +419,23 @@ class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase):
assert second_called_params["updated_by"] == "test_user"
assert second_called_params["updated_time"] is not None
assert second_called_params["metadata"] == test_data.metadata
@patch("databases.Database.fetch_all", new_callable=AsyncMock)
async def test_get_metadata(self, mock_fetch_all):
test_query = MetadataQuery(
category=MetadataCategory.INSTRUMENT,
station="BSL",
starttime=UTCDateTime(2020, 1, 20)
)
db = Database("sqlite:///:memory:")
await MetadataDatabaseFactory(database=db).get_metadata(params=test_query)
mock_fetch_all.assert_called_once()
called_params = mock_fetch_all.call_args.args[0].compile().params
assert called_params["category_1"] == "instrument"
assert called_params["station_1"] == "BSL"
assert called_params["endtime_1"] == datetime.datetime(2020, 1, 20, tzinfo=tz.tzutc())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment