From 81d90bd9e6ede40e89a5982b0cad7c23e033a878 Mon Sep 17 00:00:00 2001
From: Alex Wernle <awernle@usgs.gov>
Date: Tue, 26 Nov 2024 11:26:11 -0700
Subject: [PATCH] Pydantic changes

---
 geomagio/metadata/flag/Flag.py | 39 ++++++++++++----------------------
 1 file changed, 14 insertions(+), 25 deletions(-)

diff --git a/geomagio/metadata/flag/Flag.py b/geomagio/metadata/flag/Flag.py
index b9d4c98e..1fe5e194 100644
--- a/geomagio/metadata/flag/Flag.py
+++ b/geomagio/metadata/flag/Flag.py
@@ -2,9 +2,11 @@ from typing import Dict, Union, List, Optional
 from datetime import timedelta
 
 from obspy import UTCDateTime
-from pydantic import BaseModel, Field, validator
+from pydantic import BaseModel, Field, field_validator
 from enum import Enum
 
+from ...pydantic_utcdatetime import CustomUTCDateTimeType
+
 
 class FlagCategory(str, Enum):
     ARTIFICIAL_DISTURBANCE = "ARTIFICIAL_DISTURBANCE"
@@ -68,27 +70,20 @@ class ArtificialDisturbance(Flag):
         The type of artificial disturbance(s).
     deviation: float
        Deviation of an offset in nt.
-    spikes: List[UTCDateTime]
+    spikes: List[CustomUTCDateTimeType]
         Array of timestamps as UTCDateTime. Can be a single spike or many spikes.
 
     """
 
     artificial_disturbance_type: ArtificialDisturbanceType
     deviation: Optional[float] = None
-    spikes: Optional[List[UTCDateTime]] = None
+    spikes: Optional[List[CustomUTCDateTimeType]] = None
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.flag_category = "ARTIFICIAL_DISTURBANCE"
-
-    @validator("spikes", always=True)
-    def validate_spikes(cls, spikes):
-        if len(spikes) == 0:
-            raise ValueError("Spikes list cannot be empty.")
-
-        return spikes
+        self.flag_category = FlagCategory.ARTIFICIAL_DISTURBANCE
 
-    @validator("spikes", always=True)
+    @field_validator("spikes")
     def check_spikes_duration(cls, spikes):
         if spikes is None or len(spikes) < 2:
             return spikes
@@ -131,12 +126,12 @@ class Gap(Flag):
         How the gap is being handled, e.g., backfilled.
     """
 
-    cause: str = None
-    handling: str = None
+    cause: Optional[str] = None
+    handling: Optional[str] = None
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.flag_category = "GAP"
+        self.flag_category = FlagCategory.GAP
 
 
 class EventType(str, Enum):
@@ -163,13 +158,13 @@ class Event(Flag):
     """
 
     event_type: EventType
-    index: int = None
-    scale: str = None
-    url: str = None
+    index: Optional[int] = None
+    scale: Optional[str] = None
+    url: Optional[str] = None
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.flag_category = "EVENT"
+        self.flag_category = FlagCategory.EVENT
 
 
 # More example usage:
@@ -199,9 +194,3 @@ geomagnetic_storm_data = {
     "index": 7,
     "url": "https://www.swpc.noaa.gov/products/planetary-k-index",
 }
-
-spike_instance = ArtificialDisturbance(**spikes_data)
-offset_instance = ArtificialDisturbance(**offset_data)
-
-print(spike_instance.model_dump())
-print(offset_instance.model_dump())
-- 
GitLab