Skip to content
Snippets Groups Projects
Commit 8b8d730e authored by Rivers, Travis (Contractor) Creighton's avatar Rivers, Travis (Contractor) Creighton Committed by Jeremy M Fee
Browse files

combine parameter setters and validators. clean up

parent b49ff06a
No related branches found
No related tags found
No related merge requests found
......@@ -84,31 +84,24 @@ class DataApiQuery(BaseModel):
format: OutputFormat
@validator("data_type", pre=True, always=True)
def set_data_type(cls, data_type):
return data_type or DataType.VARIATION
def set_and_validate_data_type(cls, data_type):
if not data_type:
return DataType.VARIATION
@validator("elements", pre=True, always=True)
def set_elements(cls, elements):
return elements or DEFAULT_ELEMENTS
@validator("sampling_period", pre=True, always=True)
def set_sampling_period(cls, sampling_period):
return sampling_period or SamplingPeriod.HOUR
@validator("format", pre=True, always=True)
def set_format(cls, format):
return format or OutputFormat.IAGA2002
@validator("data_type")
def validate_data_type(cls, data_type):
if len(data_type) != 2 and data_type not in DataType:
raise ValueError(
f"Bad data type value '{data_type}'. Valid values are: {', '.join(VALID_DATA_TYPES)}"
)
return data_type
@validator("elements")
def validate_elements(cls, elements):
@validator("elements", pre=True, always=True)
def set_and_validate_elements(cls, elements):
if not elements:
return DEFAULT_ELEMENTS
if len(elements) == 1 and "," in elements[0]:
elements = [e.strip() for e in elements[0].split(",")]
for element in elements:
if element not in VALID_ELEMENTS and len(element) != 3:
raise ValueError(
......@@ -117,6 +110,14 @@ class DataApiQuery(BaseModel):
)
return elements
@validator("sampling_period", pre=True, always=True)
def set_sampling_period(cls, sampling_period):
return sampling_period or SamplingPeriod.HOUR
@validator("format", pre=True, always=True)
def set_format(cls, format):
return format or OutputFormat.IAGA2002
@validator("id")
def validate_id(cls, id):
if id not in VALID_OBSERVATORIES:
......
from datetime import datetime
import enum
import os
from typing import Any, List, Union
from typing import Any, Dict, List, Union
from fastapi import Depends, FastAPI, Query, Request
from fastapi.exceptions import RequestValidationError
from fastapi.exception_handlers import request_validation_exception_handler
from fastapi.responses import JSONResponse
from obspy import UTCDateTime
from obspy import UTCDateTime, Stream
from starlette.responses import Response
from .DataApiQuery import DataApiQuery, DataType, OutputFormat, SamplingPeriod
from ...edge import EdgeFactory
from ...iaga2002 import IAGA2002Writer
from ...imfjson import IMFJSONWriter
from ...TimeseriesUtility import get_interval_from_delta
from .DataApiQuery import DataApiQuery, DataType, OutputFormat, SamplingPeriod
ERROR_CODE_MESSAGES = {
......@@ -32,7 +32,7 @@ VERSION = "version"
def format_error(
status_code: int, exception: str, format: str, request: Request
) -> str or dict:
) -> Response:
"""Assign error_body value based on error format."""
if format == "json":
......@@ -46,7 +46,7 @@ def format_error(
return error
def format_timeseries(timeseries, query: DataApiQuery) -> str:
def format_timeseries(timeseries, query: DataApiQuery) -> Stream:
"""Formats timeseries into JSON or IAGA data
Parameters
......@@ -145,7 +145,7 @@ Service Version:
return error_body
def json_error(code: int, error: Exception, url: str) -> dict:
def json_error(code: int, error: Exception, url: str) -> Dict:
"""Format json error message.
Returns
......@@ -172,15 +172,11 @@ def parse_query(
starttime: datetime = Query(None),
endtime: datetime = Query(None),
elements: List[str] = Query(None),
sampling_period: SamplingPeriod = Query(None),
data_type: Union[DataType, str] = Query(None),
format: OutputFormat = Query(None),
sampling_period: SamplingPeriod = Query(SamplingPeriod.HOUR),
data_type: Union[DataType, str] = Query(DataType.VARIATION),
format: OutputFormat = Query(OutputFormat.IAGA2002),
) -> DataApiQuery:
if elements != None:
if len(elements) == 1 and "," in elements[0]:
elements = [e.strip() for e in elements[0].split(",")]
if starttime == None:
now = datetime.now()
starttime = UTCDateTime(year=now.year, month=now.month, day=now.day)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment