diff --git a/geomagio/api/data/DataApiQuery.py b/geomagio/api/data/DataApiQuery.py index 053492932749fa3efb2ba0326d85e434ed1bfc42..a253490fadd35323a5c1e8659651a90f309cc823 100644 --- a/geomagio/api/data/DataApiQuery.py +++ b/geomagio/api/data/DataApiQuery.py @@ -84,31 +84,24 @@ class DataApiQuery(BaseModel): format: OutputFormat @validator("data_type", pre=True, always=True) - def set_data_type(cls, data_type): - return data_type or DataType.VARIATION + def set_and_validate_data_type(cls, data_type): + if not data_type: + return DataType.VARIATION - @validator("elements", pre=True, always=True) - def set_elements(cls, elements): - return elements or DEFAULT_ELEMENTS - - @validator("sampling_period", pre=True, always=True) - def set_sampling_period(cls, sampling_period): - return sampling_period or SamplingPeriod.HOUR - - @validator("format", pre=True, always=True) - def set_format(cls, format): - return format or OutputFormat.IAGA2002 - - @validator("data_type") - def validate_data_type(cls, data_type): if len(data_type) != 2 and data_type not in DataType: raise ValueError( f"Bad data type value '{data_type}'. Valid values are: {', '.join(VALID_DATA_TYPES)}" ) return data_type - @validator("elements") - def validate_elements(cls, elements): + @validator("elements", pre=True, always=True) + def set_and_validate_elements(cls, elements): + if not elements: + return DEFAULT_ELEMENTS + + if len(elements) == 1 and "," in elements[0]: + elements = [e.strip() for e in elements[0].split(",")] + for element in elements: if element not in VALID_ELEMENTS and len(element) != 3: raise ValueError( @@ -117,6 +110,14 @@ class DataApiQuery(BaseModel): ) return elements + @validator("sampling_period", pre=True, always=True) + def set_sampling_period(cls, sampling_period): + return sampling_period or SamplingPeriod.HOUR + + @validator("format", pre=True, always=True) + def set_format(cls, format): + return format or OutputFormat.IAGA2002 + @validator("id") def validate_id(cls, id): if id not in VALID_OBSERVATORIES: diff --git a/geomagio/api/data/data_api.py b/geomagio/api/data/data_api.py index 40ce0a9a55bdd6d599578180eddadbdaf7f9de25..c2be73f4b4f7e08918ae22c2a65af81830023770 100644 --- a/geomagio/api/data/data_api.py +++ b/geomagio/api/data/data_api.py @@ -1,20 +1,20 @@ from datetime import datetime import enum import os -from typing import Any, List, Union +from typing import Any, Dict, List, Union from fastapi import Depends, FastAPI, Query, Request from fastapi.exceptions import RequestValidationError from fastapi.exception_handlers import request_validation_exception_handler from fastapi.responses import JSONResponse -from obspy import UTCDateTime +from obspy import UTCDateTime, Stream from starlette.responses import Response -from .DataApiQuery import DataApiQuery, DataType, OutputFormat, SamplingPeriod from ...edge import EdgeFactory from ...iaga2002 import IAGA2002Writer from ...imfjson import IMFJSONWriter from ...TimeseriesUtility import get_interval_from_delta +from .DataApiQuery import DataApiQuery, DataType, OutputFormat, SamplingPeriod ERROR_CODE_MESSAGES = { @@ -32,7 +32,7 @@ VERSION = "version" def format_error( status_code: int, exception: str, format: str, request: Request -) -> str or dict: +) -> Response: """Assign error_body value based on error format.""" if format == "json": @@ -46,7 +46,7 @@ def format_error( return error -def format_timeseries(timeseries, query: DataApiQuery) -> str: +def format_timeseries(timeseries, query: DataApiQuery) -> Stream: """Formats timeseries into JSON or IAGA data Parameters @@ -145,7 +145,7 @@ Service Version: return error_body -def json_error(code: int, error: Exception, url: str) -> dict: +def json_error(code: int, error: Exception, url: str) -> Dict: """Format json error message. Returns @@ -172,15 +172,11 @@ def parse_query( starttime: datetime = Query(None), endtime: datetime = Query(None), elements: List[str] = Query(None), - sampling_period: SamplingPeriod = Query(None), - data_type: Union[DataType, str] = Query(None), - format: OutputFormat = Query(None), + sampling_period: SamplingPeriod = Query(SamplingPeriod.HOUR), + data_type: Union[DataType, str] = Query(DataType.VARIATION), + format: OutputFormat = Query(OutputFormat.IAGA2002), ) -> DataApiQuery: - if elements != None: - if len(elements) == 1 and "," in elements[0]: - elements = [e.strip() for e in elements[0].split(",")] - if starttime == None: now = datetime.now() starttime = UTCDateTime(year=now.year, month=now.month, day=now.day)