diff --git a/geomagio/api/ws/DataApiQuery.py b/geomagio/api/ws/DataApiQuery.py index dc5dcb8ec37612f4e52d28e877bf0d6443756e3e..cba5d25f166a0c63cf8527521fa2833ac7808490 100644 --- a/geomagio/api/ws/DataApiQuery.py +++ b/geomagio/api/ws/DataApiQuery.py @@ -65,8 +65,9 @@ class DataApiQuery(BaseModel): elements: List[str] = DEFAULT_ELEMENTS sampling_period: SamplingPeriod = SamplingPeriod.MINUTE data_type: Union[DataType, str] = DataType.VARIATION - dbdt: List[str] = [] format: OutputFormat = OutputFormat.IAGA2002 + # extensions + dbdt: List[str] = [] @validator("data_type") def validate_data_type( @@ -79,6 +80,15 @@ class DataApiQuery(BaseModel): ) return data_type + @validator("dbdt", pre=True, always=True) + def validate_dbdt(cls, dbdt: List[str]) -> List[str]: + if not dbdt: + return [] + if len(dbdt) == 1 and "," in dbdt[0]: + dbdt = [e.strip() for e in dbdt[0].split(",")] + # values are validated below in validate_combinations + return dbdt + @validator("elements", pre=True, always=True) def validate_elements(cls, elements: List[str]) -> List[str]: if not elements: @@ -124,24 +134,15 @@ class DataApiQuery(BaseModel): endtime = starttime + (86400 - 0.001) return endtime - @validator("dbdt", always=True) - def validate_dbdt(cls, dbdt: List[str],) -> List[str]: - """Default dbdt based on valid elements. - """ - for channel in dbdt: - if channel not in ELEMENTS: - raise ValueError("Specified channel not found in valid elements.") - - return dbdt - @root_validator def validate_combinations(cls, values): - starttime, endtime, elements, format, sampling_period = ( + starttime, endtime, elements, format, sampling_period, dbdt = ( values.get("starttime"), values.get("endtime"), values.get("elements"), values.get("format"), values.get("sampling_period"), + values.get("dbdt"), ) if len(elements) > 4 and format == "iaga2002": raise ValueError("No more than four elements allowed for iaga2002 format.") @@ -151,5 +152,9 @@ class DataApiQuery(BaseModel): samples = int(len(elements) * (endtime - starttime) / sampling_period) if samples > REQUEST_LIMIT: raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})") + # check dbdt + for element in dbdt: + if element not in elements: + raise ValueError(f"dBdT element {element} not in {elements}") # otherwise okay return values