diff --git a/geomagio/api/db/MetadataDatabaseFactory.py b/geomagio/api/db/MetadataDatabaseFactory.py index fe538302135f122e47b3d002d1c2d670cf092b9a..7e7fb9ab47b8877b26d131e0a170597d9b0c7472 100644 --- a/geomagio/api/db/MetadataDatabaseFactory.py +++ b/geomagio/api/db/MetadataDatabaseFactory.py @@ -12,14 +12,21 @@ from .metadata_table import metadata as metadata_table class MetadataDatabaseFactory(object): def __init__(self, database: Database): + print("init") self.database = database async def create_metadata(self, meta: Metadata) -> Metadata: + print(1) query = metadata_table.insert() + print(2) meta.status = meta.status or "new" - values = meta.datetime_dict(exclude={"id", "metadata_id"}, exclude_none=True) + print(3) + values = meta.model_dump(exclude={"id", "metadata_id"}, exclude_none=True) + print(4) query = query.values(**values) + print(5) meta.id = await self.database.execute(query) + print(6) return meta async def get_metadata( @@ -118,7 +125,7 @@ class MetadataDatabaseFactory(object): # write current record to metadata history table original_metadata = await self.get_metadata_by_id(id=meta.id) original_metadata.metadata_id = original_metadata.id - values = original_metadata.datetime_dict(exclude={"id"}, exclude_none=True) + values = original_metadata.model_dump(exclude={"id"}, exclude_none=True) query = metadata_history.insert() query = query.values(**values) original_metadata.id = await self.database.execute(query) @@ -126,7 +133,7 @@ class MetadataDatabaseFactory(object): meta.updated_by = updated_by meta.updated_time = UTCDateTime() query = metadata_table.update().where(metadata_table.c.id == meta.id) - values = meta.datetime_dict(exclude={"id", "metadata_id"}) + values = meta.model_dump(exclude={"id", "metadata_id"}) query = query.values(**values) await self.database.execute(query) return await self.get_metadata_by_id(id=meta.id) diff --git a/geomagio/metadata/Metadata.py b/geomagio/metadata/Metadata.py index 770f0fd394d86ce6f3ee69542054747127c65ad7..7cae043007c9c4020d26a3b040c611fc6ec7a9d5 100644 --- a/geomagio/metadata/Metadata.py +++ b/geomagio/metadata/Metadata.py @@ -2,7 +2,7 @@ from datetime import timezone from typing import Dict, Optional from obspy import UTCDateTime -from pydantic import field_validator, BaseModel +from pydantic import field_validator, field_serializer, BaseModel from .MetadataCategory import MetadataCategory from ..pydantic_utcdatetime import CustomUTCDateTimeType @@ -79,12 +79,31 @@ class Metadata(BaseModel): # metadata status indicator status: Optional[str] = None - def datetime_dict(self, **kwargs): - values = self.model_dump(**kwargs) - for key in ["created_time", "updated_time", "starttime", "endtime"]: - if key in values and values[key] is not None: - values[key] = values[key].datetime.replace(tzinfo=timezone.utc) - return values + # serialize pydantic CustomUTCDateTimeType type into UTCDateTime for model_dump() and make sure + # the timezone is in utc + @field_serializer("created_time") + def serialize_created_time(self, created_time: UTCDateTime): + if self.created_time is not None: + self.created_time = self.created_time.datetime.replace(tzinfo=timezone.utc) + return self.created_time + + @field_serializer("updated_time") + def serialize_updated_time(self, updated_time: UTCDateTime): + if self.updated_time is not None: + self.updated_time = self.updated_time.datetime.replace(tzinfo=timezone.utc) + return self.updated_time + + @field_serializer("starttime") + def serialize_starttime(self, starttime: UTCDateTime): + if self.starttime is not None: + self.starttime = self.starttime.datetime.replace(tzinfo=timezone.utc) + return self.starttime + + @field_serializer("endtime") + def serialize_endtime(self, endtime: UTCDateTime): + if self.endtime is not None: + self.endtime = self.endtime.datetime.replace(tzinfo=timezone.utc) + return self.endtime @field_validator("created_time") @classmethod diff --git a/test/db/MetadataDatabaseFactory_test.py b/test/db/MetadataDatabaseFactory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..272304ea38e3f8471a6896640c33a6c760eb0346 --- /dev/null +++ b/test/db/MetadataDatabaseFactory_test.py @@ -0,0 +1,45 @@ +import unittest +from unittest.mock import patch + +from geomagio.api.db import MetadataDatabaseFactory +from geomagio.metadata import Metadata, MetadataCategory +from geomagio.api.db.metadata_table import metadata as metadata_table + +class TestMetadataDatabaseFactoryClass(unittest.IsolatedAsyncioTestCase): + + @patch('databases.Database.connect') + async def test_create_metadata(self, mock_connect_to_db): + mock_connection = mock_connect_to_db.return_value + mock_connection.execute.return_value = {"id": 1} + + test_data = Metadata( + category=MetadataCategory.INSTRUMENT, + created_by="test_metadata.py", + network="NT", + station="BDT", + metadata={ + "type": "FGE", + "channels": { + "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}], + "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}], + "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}], + }, + "electronics": { + "serial": "E0542", + "x-scale": 313.2, + "y-scale": 312.3, + "z-scale": 312.0, + "temperature-scale": 0.01, + }, + "sensor": { + "serial": "S0419", + "x-constant": 36958, + "y-constant": 36849, + "z-constant": 36811, + }, + }, + ) + + await MetadataDatabaseFactory(database=mock_connection).create_metadata(test_data) + + mock_connection.execute.assert_called_once()