Skip to content
Snippets Groups Projects
DataApiQuery.py 4.04 KiB
Newer Older
import datetime
import enum
from typing import Any, Dict, List, Optional, Union

from obspy import UTCDateTime
from pydantic import BaseModel, root_validator, validator

from .Element import ELEMENTS, ELEMENT_INDEX
from .Observatory import OBSERVATORY_INDEX


DEFAULT_ELEMENTS = ["X", "Y", "Z", "F"]
REQUEST_LIMIT = 345600
VALID_ELEMENTS = [e.id for e in ELEMENTS]


class DataType(str, enum.Enum):
    VARIATION = "variation"
    ADJUSTED = "adjusted"
    QUASI_DEFINITIVE = "quasi-definitive"
    DEFINITIVE = "definitive"

    @classmethod
    def values(cls) -> List[str]:
        return [t.value for t in cls]


class OutputFormat(str, enum.Enum):
    IAGA2002 = "iaga2002"
    JSON = "json"


class SamplingPeriod(float, enum.Enum):
    TEN_HERTZ = 0.1
    SECOND = 1.0
    MINUTE = 60.0
    HOUR = 3600.0
    DAY = 86400.0


class DataApiQuery(BaseModel):
    id: str
    starttime: UTCDateTime = None
    endtime: UTCDateTime = None
    elements: List[str] = DEFAULT_ELEMENTS
    sampling_period: SamplingPeriod = SamplingPeriod.MINUTE
    data_type: Union[DataType, str] = DataType.VARIATION
    format: OutputFormat = OutputFormat.IAGA2002

    @validator("data_type")
    def validate_data_type(
        cls, data_type: Union[DataType, str]
    ) -> Union[DataType, str]:
        if data_type not in DataType.values() and len(data_type) != 2:
            raise ValueError(
                f"Bad data type value '{data_type}'."
                f" Valid values are: {', '.join(DataType.values())}"
            )
        return data_type

    @validator("elements", pre=True, always=True)
    def validate_elements(cls, elements: List[str]) -> List[str]:
        if not elements:
            return DEFAULT_ELEMENTS
        if len(elements) == 1 and "," in elements[0]:
            elements = [e.strip() for e in elements[0].split(",")]
        for element in elements:
            if element not in VALID_ELEMENTS and len(element) != 3:
                raise ValueError(
                    f"Bad element '{element}'."
                    f" Valid values are: {', '.join(VALID_ELEMENTS)}."
                )
        return elements

    @validator("id")
    def validate_id(cls, id: str) -> str:
        if id not in OBSERVATORY_INDEX:
            raise ValueError(
                f"Bad observatory id '{id}'."
                f" Valid values are: {', '.join(sorted(OBSERVATORY_INDEX.keys()))}."
    @validator("starttime", always=True)
    def validate_starttime(cls, starttime: UTCDateTime) -> UTCDateTime:
        if not starttime:
            # default to start of current day
            now = datetime.datetime.now(tz=datetime.timezone.utc)
            return UTCDateTime(year=now.year, month=now.month, day=now.day)
        return starttime

    def validate_endtime(
        cls, endtime: UTCDateTime, *, values: Dict, **kwargs
    ) -> UTCDateTime:
        """Default endtime is based on starttime.

        This method needs to be after validate_starttime.
        """
        if not endtime:
            # endtime defaults to 1 day after startime
            starttime = values.get("starttime")
        return endtime

    @root_validator
    def validate_combinations(cls, values):
        starttime, endtime, elements, format, sampling_period = (
            values.get("starttime"),
            values.get("endtime"),
            values.get("elements"),
            values.get("format"),
            values.get("sampling_period"),
        )
        if len(elements) > 4 and format == "iaga2002":
            raise ValueError("No more than four elements allowed for iaga2002 format.")
        if starttime > endtime:
            raise ValueError("Starttime must be before endtime.")
        # check data volume
        samples = int(len(elements) * (endtime - starttime) / sampling_period)
        if samples > REQUEST_LIMIT:
            raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})")
        # otherwise okay
        return values