diff --git a/geomagio/api/data/DataApiQuery.py b/geomagio/api/data/DataApiQuery.py index a1fe0822f58f23f24c6977405ec818ad456b4854..053492932749fa3efb2ba0326d85e434ed1bfc42 100644 --- a/geomagio/api/data/DataApiQuery.py +++ b/geomagio/api/data/DataApiQuery.py @@ -1,8 +1,11 @@ +from datetime import datetime import enum from typing import Any, List, Union from pydantic import BaseModel, root_validator, validator + +DEFAULT_ELEMENTS = ["X", "Y", "Z", "F"] REQUEST_LIMIT = 345600 VALID_DATA_TYPES = ["variation", "adjusted", "quasi-definitive", "definitive"] VALID_ELEMENTS = [ @@ -73,13 +76,29 @@ class SamplingPeriod(float, enum.Enum): class DataApiQuery(BaseModel): id: str - starttime: Any - endtime: Any + starttime: datetime + endtime: datetime elements: List[str] sampling_period: SamplingPeriod data_type: Union[DataType, str] format: OutputFormat + @validator("data_type", pre=True, always=True) + def set_data_type(cls, data_type): + return data_type or DataType.VARIATION + + @validator("elements", pre=True, always=True) + def set_elements(cls, elements): + return elements or DEFAULT_ELEMENTS + + @validator("sampling_period", pre=True, always=True) + def set_sampling_period(cls, sampling_period): + return sampling_period or SamplingPeriod.HOUR + + @validator("format", pre=True, always=True) + def set_format(cls, format): + return format or OutputFormat.IAGA2002 + @validator("data_type") def validate_data_type(cls, data_type): if len(data_type) != 2 and data_type not in DataType: @@ -122,7 +141,9 @@ class DataApiQuery(BaseModel): if len(elements) > 4 and format == "iaga2002": raise ValueError("No more than four elements allowed for iaga2002 format.") - samples = int(len(elements) * (endtime - starttime) / sampling_period) + samples = int( + len(elements) * (endtime - starttime).total_seconds() / sampling_period + ) # check data volume if samples > REQUEST_LIMIT: raise ValueError(f"Query exceeds request limit ({samples} > 345600)") diff --git a/geomagio/api/data/app.py b/geomagio/api/data/app.py new file mode 100644 index 0000000000000000000000000000000000000000..3b69c7b8f0a6755b68b596482fbb78919b0142de --- /dev/null +++ b/geomagio/api/data/app.py @@ -0,0 +1 @@ +from . import data_api diff --git a/geomagio/api/data/data_api.py b/geomagio/api/data/data_api.py index b2226254ca560d46b4d5309fe0a4473e74910ca4..40ce0a9a55bdd6d599578180eddadbdaf7f9de25 100644 --- a/geomagio/api/data/data_api.py +++ b/geomagio/api/data/data_api.py @@ -4,20 +4,19 @@ import os from typing import Any, List, Union from fastapi import Depends, FastAPI, Query, Request -from obspy import UTCDateTime from fastapi.exceptions import RequestValidationError from fastapi.exception_handlers import request_validation_exception_handler from fastapi.responses import JSONResponse +from obspy import UTCDateTime from starlette.responses import Response +from .DataApiQuery import DataApiQuery, DataType, OutputFormat, SamplingPeriod from ...edge import EdgeFactory from ...iaga2002 import IAGA2002Writer from ...imfjson import IMFJSONWriter from ...TimeseriesUtility import get_interval_from_delta -from .DataApiQuery import * -DEFAULT_ELEMENTS = ["X", "Y", "Z", "F"] ERROR_CODE_MESSAGES = { 204: "No Data", 400: "Bad Request", @@ -30,10 +29,11 @@ ERROR_CODE_MESSAGES = { VERSION = "version" -app = FastAPI(docs_url="/data") - -def format_error(status_code, exception, format, request): +def format_error( + status_code: int, exception: str, format: str, request: Request +) -> str or dict: + """Assign error_body value based on error format.""" if format == "json": error = JSONResponse(json_error(status_code, exception, request.url)) @@ -46,7 +46,7 @@ def format_error(status_code, exception, format, request): return error -def format_timeseries(timeseries, query): +def format_timeseries(timeseries, query: DataApiQuery) -> str: """Formats timeseries into JSON or IAGA data Parameters @@ -54,7 +54,7 @@ def format_timeseries(timeseries, query): obspy.core.Stream timeseries object with requested data - WebServiceQuery + DataApiQuery parsed query object Returns @@ -92,12 +92,12 @@ def get_data_factory(): return None -def get_timeseries(query): +def get_timeseries(query: DataApiQuery): """Get timeseries data Parameters ---------- - WebServiceQuery + DataApiQuery parsed query object Returns @@ -108,8 +108,8 @@ def get_timeseries(query): data_factory = get_data_factory() timeseries = data_factory.get_timeseries( - starttime=query.starttime, - endtime=query.endtime, + starttime=UTCDateTime(query.starttime), + endtime=UTCDateTime(query.endtime), observatory=query.id, channels=query.elements, type=query.data_type, @@ -118,7 +118,14 @@ def get_timeseries(query): return timeseries -def iaga2002_error(code, error, url): +def iaga2002_error(code: int, error: Exception, url: str) -> str: + """Format iaga2002 error message. + + Returns + ------- + error_body : str + body of iaga2002 error message. + """ status_message = ERROR_CODE_MESSAGES[code] error_body = f"""Error {code}: {status_message} @@ -138,7 +145,7 @@ Service Version: return error_body -def json_error(code: int, error: Exception, url): +def json_error(code: int, error: Exception, url: str) -> dict: """Format json error message. Returns @@ -162,18 +169,17 @@ def json_error(code: int, error: Exception, url): def parse_query( id: str, - starttime: Any = Query(None), - endtime: Any = Query(None), - elements: List[str] = Query(DEFAULT_ELEMENTS), - sampling_period: SamplingPeriod = Query(SamplingPeriod.HOUR), - data_type: Union[DataType, str] = Query(DataType.VARIATION), - format: OutputFormat = Query(OutputFormat.IAGA2002), + starttime: datetime = Query(None), + endtime: datetime = Query(None), + elements: List[str] = Query(None), + sampling_period: SamplingPeriod = Query(None), + data_type: Union[DataType, str] = Query(None), + format: OutputFormat = Query(None), ) -> DataApiQuery: - if len(elements) == 0: - elements = DEFAULT_ELEMENTS - if len(elements) == 1 and "," in elements[0]: - elements = [e.strip() for e in elements[0].split(",")] + if elements != None: + if len(elements) == 1 and "," in elements[0]: + elements = [e.strip() for e in elements[0].split(",")] if starttime == None: now = datetime.now() @@ -213,9 +219,15 @@ def parse_query( return params +app = FastAPI(docs_url="/data") + + @app.exception_handler(ValueError) async def validation_exception_handler(request: Request, exc: RequestValidationError): - data_format = str(request.query_params["format"]) + if "format" in request.query_params: + data_format = str(request.query_params["format"]) + else: + data_format = "iaga2002" return format_error(400, str(exc), data_format, request)