diff --git a/geomagio/API/data_api.py b/geomagio/API/data_api.py index a6aa289f91a56abf005302986d055b5664816962..e320e39846b5dfff3021566dfbd86f01f724a374 100644 --- a/geomagio/API/data_api.py +++ b/geomagio/API/data_api.py @@ -1,18 +1,25 @@ from datetime import datetime import enum -from fastapi import FastAPI, Query, Request from json import dumps -from obspy import UTCDateTime import os -from pydantic import BaseModel -from starlette.responses import Response from typing import List, Any, Union +from fastapi import FastAPI, Query, Request, Depends, HTTPException +from obspy import UTCDateTime +from pydantic import BaseModel, validator, ValidationError, root_validator +from starlette.responses import PlainTextResponse, Response +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from ..edge import EdgeFactory from ..iaga2002 import IAGA2002Writer from ..imfjson import IMFJSONWriter from ..TimeseriesUtility import get_interval_from_delta +from fastapi.exception_handlers import ( + http_exception_handler, + request_validation_exception_handler, +) DEFAULT_DATA_TYPE = "variation" @@ -29,6 +36,7 @@ ERROR_CODE_MESSAGES = { 503: "Service Unavailable", } REQUEST_LIMIT = 345600 +VALID_DATA_TYPES = ["variation", "adjusted", "quasi-definitive", "definitive"] VALID_ELEMENTS = [ "D", "DIST", @@ -78,114 +86,79 @@ VERSION = "version" app = FastAPI(docs_url="/data") -class SamplingPeriod(float, enum.Enum): - TEN_HERTZ = 0.1 - SECOND = 1.0 - MINUTE = 60 - HOUR = 3600 - DAY = 86400 - - class DataType(str, enum.Enum): - Variation = "variation" - Adjusted = "adjusted" - Quasi_Definitive = "quasi-definitive" - Definitive = "definitive" + VARIATION = "variation" + ADJUSTED = "adjusted" + QUASI_DEFINITIVE = "quasi-definitive" + DEFINITIVE = "definitive" class OutputFormat(str, enum.Enum): - Iaga2002 = "iaga2002" + IAGA2002 = "iaga2002" JSON = "json" -class WebServiceException(Exception): - """Base class for exceptions thrown by web services.""" - - pass +class SamplingPeriod(float, enum.Enum): + TEN_HERTZ = 0.1 + SECOND = 1.0 + MINUTE = 60 + HOUR = 3600 + DAY = 86400 class WebServiceQuery(BaseModel): - observatory_id: str + id: str starttime: Any endtime: Any elements: List[str] sampling_period: SamplingPeriod data_type: Union[DataType, str] - output_format: OutputFormat - - -@app.get("/data/") -def get_data( - request: Request, - id: str, - starttime: datetime = Query(None), - endtime: datetime = Query(None), - elements: List[str] = Query(DEFAULT_ELEMENTS), - sampling_period: SamplingPeriod = Query(DEFAULT_SAMPLING_PERIOD), - data_type: Union[DataType, str] = Query(DEFAULT_DATA_TYPE), - format: OutputFormat = Query(DEFAULT_OUTPUT_FORMAT), -): - if len(elements) == 1 and "," in elements[0]: - elements = [e.strip() for e in elements[0].split(",")] - - query = { - "observatory_id": id, - "starttime": starttime, - "endtime": endtime, - "elements": elements, - "sampling_period": sampling_period, - "data_type": data_type, - "output_format": format, - } - - if query["starttime"] == None: - now = datetime.now() - query["starttime"] = UTCDateTime(year=now.year, month=now.month, day=now.day) - - else: - try: - query["starttime"] = UTCDateTime(query["starttime"]) - - except Exception as e: - raise WebServiceException( - f"Bad starttime value '{query['starttime']}'." - " Valid values are ISO-8601 timestamps." - ) from e - - if query["endtime"] == None: - endtime = query["starttime"] + (24 * 60 * 60 - 1) - query["endtime"] = endtime - - else: - try: - query["endtime"] = UTCDateTime(query["endtime"]) - except Exception as e: - raise WebServiceException( - f"Bad endtime value '{query['endtime']}'." - " Valid values are ISO-8601 timestamps." - ) from e - - try: - params = WebServiceQuery(**query) - validate_query(params) - except Exception as e: - return format_error(400, e, format, request) - - try: - timeseries = get_timeseries(params) - return format_timeseries(timeseries, params) - except Exception as e: - return format_error(500, e, format, request) + format: OutputFormat + + @validator("data_type") + def validate_data_type(cls, data_type): + if len(data_type) != 2 and data_type not in DataType: + raise ValueError( + f"Bad data type value '{data_type}'. Valid values are: {', '.join(VALID_DATA_TYPES)}" + ) + + # @root_validator + # def validate_times(cls, values): + # print("**************************************") + # # print(values) + # # print("**************************************") + # # starttime, endtime = values["starttime"], values["endtime"] + # # print(starttime) + # # print(endtime) + # # if starttime > endtime: + # # raise ValueError("Starttime must be before endtime.") + + @validator("elements") + def validate_elements(cls, elements): + for element in elements: + if element not in VALID_ELEMENTS and len(element) != 3: + raise ValueError( + f"Bad element '{element}'." + f"Valid values are: {', '.join(VALID_ELEMENTS)}." + ) + + @validator("id") + def validate_id(cls, id): + if id not in VALID_OBSERVATORIES: + raise ValueError( + f"Bad observatory id '{id}'." + f" Valid values are: {', '.join(VALID_OBSERVATORIES)}." + ) def format_error(status_code, exception, format, request): if format == "json": - error = Response( - json_error(status_code, exception, request), media_type="application/json" - ) + + error = JSONResponse(json_error(status_code, exception, request.url)) + else: error = Response( - iaga2002_error(status_code, exception, request), media_type="text/plain" + iaga2002_error(status_code, exception, request.url), media_type="text/plain" ) return error @@ -207,7 +180,7 @@ def format_timeseries(timeseries, query): unicode IAGA2002 or JSON formatted string. """ - if query.output_format == "json": + if query.format == "json": return Response( IMFJSONWriter.format(timeseries, query.elements), media_type="application/json", @@ -253,26 +226,26 @@ def get_timeseries(query): data_factory = get_data_factory() timeseries = data_factory.get_timeseries( - query.starttime, - query.endtime, - query.observatory_id, - query.elements, - query.data_type, - get_interval_from_delta(query.sampling_period), + starttime=query.starttime, + endtime=query.endtime, + observatory=query.id, + channels=query.elements, + type=query.data_type, + interval=get_interval_from_delta(query.sampling_period), ) return timeseries -def iaga2002_error(status_code, error, request): - status_message = ERROR_CODE_MESSAGES[status_code] - error_body = f"""Error {status_code}: {status_message} +def iaga2002_error(code, error, url): + status_message = ERROR_CODE_MESSAGES[code] + error_body = f"""Error {code}: {status_message} {error} Usage details are available from Request: -{request.url} +{url} Request Submitted: {UTCDateTime().isoformat()}Z @@ -283,7 +256,7 @@ Service Version: return error_body -def json_error(code: int, error: Exception, request): +def json_error(code: int, error: Exception, url): """Format json error message. Returns @@ -292,7 +265,6 @@ def json_error(code: int, error: Exception, request): body of json error message. """ status_message = ERROR_CODE_MESSAGES[code] - url = request.url.__dict__ error_dict = { "type": "Error", "metadata": { @@ -300,33 +272,96 @@ def json_error(code: int, error: Exception, request): "status": code, "error": str(error), "generated": UTCDateTime().isoformat() + "Z", - "url": url, + "url": str(url), }, } - return dumps(error_dict).encode("utf8") + return error_dict + + +def parse_query( + id: str = Query(None), + starttime: Any = Query(None), + endtime: Any = Query(None), + elements: List[str] = Query(DEFAULT_ELEMENTS), + sampling_period: SamplingPeriod = Query(DEFAULT_SAMPLING_PERIOD), + data_type: Union[DataType, str] = Query(DEFAULT_DATA_TYPE), + format: OutputFormat = Query(DEFAULT_OUTPUT_FORMAT), +) -> WebServiceQuery: + + if len(elements) == 0: + elements = DEFAULT_ELEMENTS + if len(elements) == 1 and "," in elements[0]: + elements = [e.strip() for e in elements[0].split(",")] + + if starttime == None: + now = datetime.now() + starttime = UTCDateTime(year=now.year, month=now.month, day=now.day) + + else: + try: + starttime = UTCDateTime(starttime) + + except Exception as e: + raise ValueError( + f"Bad starttime value '{starttime}'." + " Valid values are ISO-8601 timestamps." + ) + + if endtime == None: + endtime = starttime + (24 * 60 * 60 - 1) + + else: + try: + endtime = UTCDateTime(endtime) + except Exception as e: + raise ValueError( + f"Bad endtime value '{endtime}'." + " Valid values are ISO-8601 timestamps." + ) from e + print(type(id)) + params = WebServiceQuery( + id=id, + starttime=starttime, + endtime=endtime, + elements=elements, + sampling_period=sampling_period, + data_type=data_type, + format=format, + ) + print(params) + params.id = id + params.elements = elements + params.data_type = data_type + print("***********") + print(params) + + return params def validate_query(query): - if len(query.data_type) > 2 and query.data_type not in DataType: - raise WebServiceException( - f"Bad data type value '{query.data_type}'." - f" Valid values are: {DataType.Adjusted}, {DataType.Variation}, {DataType.Definitive} and {DataType.Quasi_Definitive}." - ) - if query.observatory_id not in VALID_OBSERVATORIES: - raise WebServiceException( - f"Bad observatory id '{query.observatory_id}'." - f" Valid values are: {', '.join(VALID_OBSERVATORIES)}." - ) + # validate combinations - if len(query.elements) > 4 and query.output_format == "iaga2002": - raise WebServiceException( - "No more than four elements allowed for iaga2002 format." - ) - if query.starttime > query.endtime: - raise WebServiceException("Starttime must be before endtime.") + if len(query.elements) > 4 and query.format == "iaga2002": + raise ValueError("No more than four elements allowed for iaga2002 format.") + # check data volume samples = int( len(query.elements) * (query.endtime - query.starttime) / query.sampling_period ) if samples > REQUEST_LIMIT: - raise WebServiceException(f"Query exceeds request limit ({samples} > 345600)") + raise ValueError(f"Query exceeds request limit ({samples} > 345600)") + + +@app.exception_handler(ValueError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + data_format = str(request.query_params["format"]) + return format_error(400, str(exc), data_format, request) + + +@app.get("/data/") +def get_data(request: Request, query: WebServiceQuery = Depends(parse_query)): + try: + timeseries = get_timeseries(query) + return format_timeseries(timeseries, query) + except Exception as e: + return format_error(500, e, query.format, request)