From 7319e1594ad4d80659af100d6beb5057a8c4f2f9 Mon Sep 17 00:00:00 2001
From: Alexandra Hobbs <ahobbs@contractor.usgs.gov>
Date: Mon, 25 Nov 2024 10:41:21 -0700
Subject: [PATCH] working tests and locally

---
 geomagio/api/secure/metadata.py         |   4 +-
 geomagio/metadata/Metadata.py           |  38 ++-
 geomagio/pydantic_utcdatetime.py        |   6 +
 test/db/MetadataDatabaseFactory_test.py | 409 +++++++++++++++---------
 test/pydantic_utcdatetime_test.py       |   8 +-
 5 files changed, 290 insertions(+), 175 deletions(-)

diff --git a/geomagio/api/secure/metadata.py b/geomagio/api/secure/metadata.py
index 235ca442..404b262b 100644
--- a/geomagio/api/secure/metadata.py
+++ b/geomagio/api/secure/metadata.py
@@ -72,7 +72,9 @@ async def create_metadata(
     metadata = await MetadataDatabaseFactory(database=database).create_metadata(
         meta=metadata
     )
-    return Response(metadata.json(), status_code=201, media_type="application/json")
+    return Response(
+        metadata.model_dump_json(), status_code=201, media_type="application/json"
+    )
 
 
 @router.get(
diff --git a/geomagio/metadata/Metadata.py b/geomagio/metadata/Metadata.py
index 53f605f0..e1a7bd18 100644
--- a/geomagio/metadata/Metadata.py
+++ b/geomagio/metadata/Metadata.py
@@ -2,7 +2,7 @@ from datetime import timezone
 from typing import Dict, Optional
 
 from obspy import UTCDateTime
-from pydantic import field_validator, field_serializer, BaseModel, Field
+from pydantic import field_serializer, BaseModel, Field
 
 from .MetadataCategory import MetadataCategory
 from ..pydantic_utcdatetime import CustomUTCDateTimeType
@@ -52,9 +52,7 @@ class Metadata(BaseModel):
     metadata_id: Optional[int] = None
     # author
     created_by: Optional[str] = None
-    created_time: CustomUTCDateTimeType = Field(
-        default_factory=lambda: UTCDateTime()
-    )
+    created_time: CustomUTCDateTimeType = Field(default_factory=lambda: UTCDateTime())
     # editor
     updated_by: Optional[str] = None
     updated_time: Optional[CustomUTCDateTimeType] = None
@@ -81,28 +79,28 @@ class Metadata(BaseModel):
     # metadata status indicator
     status: Optional[str] = None
 
-    # serialize pydantic CustomUTCDateTimeType type into UTCDateTime for model_dump() and make sure
-    # the timezone is in utc
+    # instructions for model_dump() to serialize pydantic CustomUTCDateTimeType into aware datetime.datetime type
+    # sqlalchemy is expecting aware datetime.datetime, not the string model_dump() creates by default
     @field_serializer("created_time")
     def serialize_created_time(self, created_time: UTCDateTime):
-        if self.created_time is not None:
-            self.created_time = self.created_time.datetime.replace(tzinfo=timezone.utc)
-        return self.created_time
+        if created_time is not None:
+            created_time = created_time.datetime.replace(tzinfo=timezone.utc)
+        return created_time
 
     @field_serializer("updated_time")
     def serialize_updated_time(self, updated_time: UTCDateTime):
-        if self.updated_time is not None:
-            self.updated_time = self.updated_time.datetime.replace(tzinfo=timezone.utc)
-        return self.updated_time
-    
+        if updated_time is not None:
+            updated_time = updated_time.datetime.replace(tzinfo=timezone.utc)
+        return updated_time
+
     @field_serializer("starttime")
     def serialize_starttime(self, starttime: UTCDateTime):
-        if self.starttime is not None:
-            self.starttime = self.starttime.datetime.replace(tzinfo=timezone.utc)
-        return self.starttime
-    
+        if starttime is not None:
+            starttime = starttime.datetime.replace(tzinfo=timezone.utc)
+        return starttime
+
     @field_serializer("endtime")
     def serialize_endtime(self, endtime: UTCDateTime):
-        if self.endtime is not None:
-            self.endtime = self.endtime.datetime.replace(tzinfo=timezone.utc)
-        return self.endtime
+        if endtime is not None:
+            endtime = endtime.datetime.replace(tzinfo=timezone.utc)
+        return endtime
diff --git a/geomagio/pydantic_utcdatetime.py b/geomagio/pydantic_utcdatetime.py
index d60950de..862cd138 100644
--- a/geomagio/pydantic_utcdatetime.py
+++ b/geomagio/pydantic_utcdatetime.py
@@ -2,6 +2,8 @@
 CustomUTCDateTimeType should be used in place of UTCDateTime on pydantic models.
 """
 
+import datetime
+from dateutil import tz
 from obspy import UTCDateTime
 from pydantic_core import CoreSchema, core_schema
 from typing import Annotated, Any
@@ -22,6 +24,10 @@ class CustomUTCDateTimeValidator:
         _handler: GetCoreSchemaHandler,
     ) -> CoreSchema:
         def UTCDateTime_validator(value: Any):
+            # if the user inputs an unaware datetime.datetime, make it aware
+            if isinstance(value, datetime.datetime):
+                if value.tzinfo is not tz.tzutc():
+                    value = value.replace(tzinfo=tz.tzutc())
             try:
                 time = UTCDateTime(value)
             except:
diff --git a/test/db/MetadataDatabaseFactory_test.py b/test/db/MetadataDatabaseFactory_test.py
index d7705001..045d297e 100644
--- a/test/db/MetadataDatabaseFactory_test.py
+++ b/test/db/MetadataDatabaseFactory_test.py
@@ -1,5 +1,6 @@
 import datetime
 import unittest
+from dateutil import tz
 from unittest.mock import AsyncMock, patch
 from databases import Database
 
@@ -8,11 +9,14 @@ from obspy import UTCDateTime
 from geomagio.api.db import MetadataDatabaseFactory
 from geomagio.metadata import Metadata, MetadataCategory
 
+
 class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase):
 
     @patch("databases.Database.execute", new_callable=AsyncMock)
     async def test_create_metadata_defaults(self, mock_execute):
+        now = UTCDateTime()
         test_data = Metadata(
+            created_time=now,
             category=MetadataCategory.INSTRUMENT,
             created_by="test_metadata.py",
             network="NT",
@@ -46,6 +50,101 @@ class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase):
 
         # assert data_valid, priority, and status are set to the correct defaults
         expected_values = {
+            "created_time": datetime.datetime(
+                year=now.year,
+                month=now.month,
+                day=now.day,
+                hour=now.hour,
+                minute=now.minute,
+                second=now.second,
+                microsecond=now.microsecond,
+                tzinfo=tz.tzutc(),
+            ),
+            "category": "instrument",
+            "created_by": "test_metadata.py",
+            "network": "NT",
+            "station": "BDT",
+            "metadata": {
+                "type": "FGE",
+                "channels": {
+                    "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}],
+                    "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}],
+                    "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}],
+                },
+                "electronics": {
+                    "serial": "E0542",
+                    "x-scale": 313.2,
+                    "y-scale": 312.3,
+                    "z-scale": 312.0,
+                    "temperature-scale": 0.01,
+                },
+                "sensor": {
+                    "serial": "S0419",
+                    "x-constant": 36958,
+                    "y-constant": 36849,
+                    "z-constant": 36811,
+                },
+            },
+            "data_valid": True,
+            "priority": 1,
+            "status": "new",
+        }
+
+        mock_execute.assert_called_once()
+        called_params = mock_execute.call_args.args[0].compile().params
+
+        assert called_params == expected_values
+
+    @patch("databases.Database.execute", new_callable=AsyncMock)
+    async def test_create_metadata_with_ids(self, mock_execute):
+        now = UTCDateTime()
+        test_data = Metadata(
+            id=1234,
+            created_time=now,
+            metadata_id=5678,
+            category=MetadataCategory.INSTRUMENT,
+            created_by="test_metadata.py",
+            network="NT",
+            station="BDT",
+            metadata={
+                "type": "FGE",
+                "channels": {
+                    "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}],
+                    "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}],
+                    "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}],
+                },
+                "electronics": {
+                    "serial": "E0542",
+                    "x-scale": 313.2,
+                    "y-scale": 312.3,
+                    "z-scale": 312.0,
+                    "temperature-scale": 0.01,
+                },
+                "sensor": {
+                    "serial": "S0419",
+                    "x-constant": 36958,
+                    "y-constant": 36849,
+                    "z-constant": 36811,
+                },
+            },
+        )
+
+        db = Database("sqlite:///:memory:")
+
+        await MetadataDatabaseFactory(database=db).create_metadata(test_data)
+
+        # assert id and metadata_id are removed
+        expected_values = {
+            "created_time": datetime.datetime(
+                year=now.year,
+                month=now.month,
+                day=now.day,
+                hour=now.hour,
+                minute=now.minute,
+                second=now.second,
+                microsecond=now.microsecond,
+                tzinfo=tz.tzutc(),
+            ),
             "category": "instrument",
             "created_by": "test_metadata.py",
             "network": "NT",
@@ -73,7 +172,7 @@ class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase):
             },
             "data_valid": True,
             "priority": 1,
-            "status": "new"
+            "status": "new",
         }
 
         mock_execute.assert_called_once()
@@ -81,155 +180,159 @@ class TestMetadataDatabaseFactory(unittest.IsolatedAsyncioTestCase):
 
         assert called_params == expected_values
 
-    # @patch("databases.Database.execute", new_callable=AsyncMock)
-    # async def test_create_metadata_with_ids(self, mock_execute):
-    #     now = datetime.datetime.now(tz=datetime.timezone.utc)
-    #     test_data = Metadata(
-    #         id=1234,
-    #         created_time=now,
-    #         metadata_id=5678,
-    #         category=MetadataCategory.INSTRUMENT,
-    #         created_by="test_metadata.py",
-    #         network="NT",
-    #         station="BDT",
-    #         metadata={
-    #             "type": "FGE",
-    #             "channels": {
-    #                 "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}],
-    #                 "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}],
-    #                 "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}],
-    #             },
-    #             "electronics": {
-    #                 "serial": "E0542",
-    #                 "x-scale": 313.2,
-    #                 "y-scale": 312.3,
-    #                 "z-scale": 312.0,
-    #                 "temperature-scale": 0.01,
-    #             },
-    #             "sensor": {
-    #                 "serial": "S0419",
-    #                 "x-constant": 36958,
-    #                 "y-constant": 36849,
-    #                 "z-constant": 36811,
-    #             },
-    #         },
-    #     )
-
-    #     db = Database("sqlite:///:memory:")
-
-    #     await MetadataDatabaseFactory(database=db).create_metadata(test_data)
-
-    #     # assert id is removed
-    #     expected_values = {
-    #         "created_time": now,
-    #         "category": "instrument",
-    #         "created_by": "test_metadata.py",
-    #         "network": "NT",
-    #         "station": "BDT",
-    #         "metadata": {
-    #             "type": "FGE",
-    #             "channels": {
-    #                 "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}],
-    #                 "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}],
-    #                 "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}],
-    #             },
-    #             "electronics": {
-    #                 "serial": "E0542",
-    #                 "x-scale": 313.2,
-    #                 "y-scale": 312.3,
-    #                 "z-scale": 312.0,
-    #                 "temperature-scale": 0.01,
-    #             },
-    #             "sensor": {
-    #                 "serial": "S0419",
-    #                 "x-constant": 36958,
-    #                 "y-constant": 36849,
-    #                 "z-constant": 36811,
-    #             },
-    #         },
-    #         "data_valid": True,
-    #         "priority": 1,
-    #         "status": "new"
-    #     }
-
-    #     mock_execute.assert_called_once()
-    #     called_params = mock_execute.call_args.args[0].compile().params
-
-    #     assert called_params == expected_values
-
-    # @patch("databases.Database.execute", new_callable=AsyncMock)
-    # async def test_create_metadata_with_metadata_id(self, mock_execute):
-    #     test_data = Metadata(
-    #         metadata_id=5678,
-    #         category=MetadataCategory.INSTRUMENT,
-    #         created_by="test_metadata.py",
-    #         network="NT",
-    #         station="BDT",
-    #         metadata={
-    #             "type": "FGE",
-    #             "channels": {
-    #                 "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}],
-    #                 "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}],
-    #                 "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}],
-    #             },
-    #             "electronics": {
-    #                 "serial": "E0542",
-    #                 "x-scale": 313.2,
-    #                 "y-scale": 312.3,
-    #                 "z-scale": 312.0,
-    #                 "temperature-scale": 0.01,
-    #             },
-    #             "sensor": {
-    #                 "serial": "S0419",
-    #                 "x-constant": 36958,
-    #                 "y-constant": 36849,
-    #                 "z-constant": 36811,
-    #             },
-    #         },
-    #     )
-
-    #     db = Database("sqlite:///:memory:")
-
-    #     await MetadataDatabaseFactory(database=db).create_metadata(test_data)
-
-    #     # assert metadata_id is removed on values
-    #     expected_values = {
-    #         "metadata_id": 5678,
-    #         "category": "instrument",
-    #         "created_by": "test_metadata.py",
-    #         "network": "NT",
-    #         "station": "BDT",
-    #         "metadata": {
-    #             "type": "FGE",
-    #             "channels": {
-    #                 "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}],
-    #                 "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}],
-    #                 "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}],
-    #             },
-    #             "electronics": {
-    #                 "serial": "E0542",
-    #                 "x-scale": 313.2,
-    #                 "y-scale": 312.3,
-    #                 "z-scale": 312.0,
-    #                 "temperature-scale": 0.01,
-    #             },
-    #             "sensor": {
-    #                 "serial": "S0419",
-    #                 "x-constant": 36958,
-    #                 "y-constant": 36849,
-    #                 "z-constant": 36811,
-    #             },
-    #         },
-    #         "data_valid": True,
-    #         "priority": 1,
-    #         "status": "new"
-    #     }
-
-    #     expected_insert = metadata_table.insert().values(**expected_values)
-    #     expected_params = expected_insert.compile().params
-
-
-    #     mock_execute.assert_called_once()
-    #     called_params = mock_execute.call_args.args[0].compile().params
-
-    #     assert called_params == expected_params
\ No newline at end of file
+    @patch("databases.Database.execute", new_callable=AsyncMock)
+    async def test_create_metadata_without_created_time(self, mock_execute):
+        test_data = Metadata(
+            metadata_id=5678,
+            category=MetadataCategory.INSTRUMENT,
+            created_by="test_metadata.py",
+            network="NT",
+            station="BDT",
+            metadata={
+                "type": "FGE",
+                "channels": {
+                    "U": [{"channel": "U_Volt", "offset": 0, "scale": 313.2}],
+                    "V": [{"channel": "V_Volt", "offset": 0, "scale": 312.3}],
+                    "W": [{"channel": "W_Volt", "offset": 0, "scale": 312.0}],
+                },
+                "electronics": {
+                    "serial": "E0542",
+                    "x-scale": 313.2,
+                    "y-scale": 312.3,
+                    "z-scale": 312.0,
+                    "temperature-scale": 0.01,
+                },
+                "sensor": {
+                    "serial": "S0419",
+                    "x-constant": 36958,
+                    "y-constant": 36849,
+                    "z-constant": 36811,
+                },
+            },
+        )
+
+        db = Database("sqlite:///:memory:")
+
+        await MetadataDatabaseFactory(database=db).create_metadata(test_data)
+
+        mock_execute.assert_called_once()
+        called_params = mock_execute.call_args.args[0].compile().params
+
+        assert called_params["created_time"] is not None
+
+    @patch("databases.Database.execute", new_callable=AsyncMock)
+    async def test_create_metadata_with_starttime_and_endtime(self, mock_execute):
+        now = UTCDateTime()
+        t = UTCDateTime(2020, 1, 3, 17, 24, 40)
+        test_data = Metadata(
+            created_by="test_metadata.py",
+            created_time=now,
+            starttime=t,
+            endtime=t,
+            network="NT",
+            station="BOU",
+            channel=None,
+            location=None,
+            category=MetadataCategory.READING,
+            priority=1,
+            data_valid=True,
+            metadata={},
+        )
+
+        db = Database("sqlite:///:memory:")
+
+        await MetadataDatabaseFactory(database=db).create_metadata(test_data)
+
+        # assert starttime and endtime are strings of expected UTCDateTime
+        expected_values = {
+            "category": "reading",
+            "created_time": datetime.datetime(
+                year=now.year,
+                month=now.month,
+                day=now.day,
+                hour=now.hour,
+                minute=now.minute,
+                second=now.second,
+                microsecond=now.microsecond,
+                tzinfo=tz.tzutc(),
+            ),
+            "created_by": "test_metadata.py",
+            "starttime": datetime.datetime(
+                year=t.year,
+                month=t.month,
+                day=t.day,
+                hour=t.hour,
+                minute=t.minute,
+                second=t.second,
+                microsecond=t.microsecond,
+                tzinfo=tz.tzutc(),
+            ),
+            "endtime": datetime.datetime(
+                year=t.year,
+                month=t.month,
+                day=t.day,
+                hour=t.hour,
+                minute=t.minute,
+                second=t.second,
+                microsecond=t.microsecond,
+                tzinfo=tz.tzutc(),
+            ),
+            "network": "NT",
+            "station": "BOU",
+            "metadata": {},
+            "data_valid": True,
+            "priority": 1,
+            "status": "new",
+        }
+
+        mock_execute.assert_called_once()
+        called_params = mock_execute.call_args.args[0].compile().params
+
+        assert called_params == expected_values
+
+    @patch("databases.Database.execute", new_callable=AsyncMock)
+    async def test_create_metadata_with_times_as_datetime(self, mock_execute):
+        # assert datetime is aware if not explicitly set by the user
+        s = datetime.datetime(2020, 1, 3, 17, 24, 40)
+        e = datetime.datetime(2020, 1, 3, 17, 24, 40, tzinfo=tz.tzutc())
+        test_data = Metadata(
+            created_by="test_metadata.py",
+            starttime=s,
+            endtime=e,
+            network="NT",
+            station="BOU",
+            channel=None,
+            location=None,
+            category=MetadataCategory.READING,
+            priority=1,
+            data_valid=True,
+            metadata={},
+        )
+
+        db = Database("sqlite:///:memory:")
+
+        await MetadataDatabaseFactory(database=db).create_metadata(test_data)
+
+        mock_execute.assert_called_once()
+        called_params = mock_execute.call_args.args[0].compile().params
+
+        assert called_params["starttime"] == datetime.datetime(
+            year=s.year,
+            month=s.month,
+            day=s.day,
+            hour=s.hour,
+            minute=s.minute,
+            second=s.second,
+            microsecond=s.microsecond,
+            tzinfo=tz.tzutc(),
+        )
+        assert called_params["endtime"] == datetime.datetime(
+            year=e.year,
+            month=e.month,
+            day=e.day,
+            hour=e.hour,
+            minute=e.minute,
+            second=e.second,
+            microsecond=e.microsecond,
+            tzinfo=tz.tzutc(),
+        )
diff --git a/test/pydantic_utcdatetime_test.py b/test/pydantic_utcdatetime_test.py
index 120bbabc..ef81615a 100644
--- a/test/pydantic_utcdatetime_test.py
+++ b/test/pydantic_utcdatetime_test.py
@@ -17,12 +17,18 @@ def test_UTCDateTime_string():
     assert_equal(t.starttime, UTCDateTime(2024, 11, 5, 0, 0))
 
 
-def test_UTCDateTime_timestamp():
+def test_UTCDateTime_datetime():
     t = TimeClass(starttime=datetime.datetime(2024, 11, 5, tzinfo=tz.tzutc()))
 
     assert_equal(t.starttime, UTCDateTime(2024, 11, 5, 0, 0))
 
 
+def test_UTCDateTime_datetime_unaware():
+    t = TimeClass(starttime=datetime.datetime(2024, 11, 5))
+
+    assert_equal(t.starttime, UTCDateTime(2024, 11, 5, 0, 0))
+
+
 def test_UTCDateTime_unix_timestamp():
     t = TimeClass(starttime=1730764800)
 
-- 
GitLab