Skip to content
Snippets Groups Projects
DataApiQuery.py 3.64 KiB
Newer Older
  • Learn to ignore specific revisions
  • import enum
    from typing import Any, List, Union
    
    from pydantic import BaseModel, root_validator, validator
    
    
    
    DEFAULT_ELEMENTS = ["X", "Y", "Z", "F"]
    
    REQUEST_LIMIT = 345600
    VALID_DATA_TYPES = ["variation", "adjusted", "quasi-definitive", "definitive"]
    VALID_ELEMENTS = [
        "D",
        "DIST",
        "DST",
        "E",
        "E-E",
        "E-N",
        "F",
        "G",
        "H",
        "SQ",
        "SV",
        "UK1",
        "UK2",
        "UK3",
        "UK4",
        "X",
        "Y",
        "Z",
    ]
    VALID_OBSERVATORIES = [
        "BDT",
        "BOU",
        "BRT",
        "BRW",
        "BSL",
        "CMO",
        "CMT",
        "DED",
        "DHT",
        "FDT",
        "FRD",
        "FRN",
        "GUA",
        "HON",
        "NEW",
        "SHU",
        "SIT",
        "SJG",
        "SJT",
        "TST",
        "TUC",
        "USGS",
    ]
    
    
    class DataType(str, enum.Enum):
        VARIATION = "variation"
        ADJUSTED = "adjusted"
        QUASI_DEFINITIVE = "quasi-definitive"
        DEFINITIVE = "definitive"
    
    
    class OutputFormat(str, enum.Enum):
        IAGA2002 = "iaga2002"
        JSON = "json"
    
    
    class SamplingPeriod(float, enum.Enum):
        TEN_HERTZ = 0.1
        SECOND = 1.0
        MINUTE = 60
        HOUR = 3600
        DAY = 86400
    
    
    class DataApiQuery(BaseModel):
        id: str
    
        elements: List[str]
        sampling_period: SamplingPeriod
        data_type: Union[DataType, str]
        format: OutputFormat
    
    
        @validator("data_type", pre=True, always=True)
    
        def set_and_validate_data_type(cls, data_type):
            if not data_type:
                return DataType.VARIATION
    
            if len(data_type) != 2 and data_type not in DataType:
                raise ValueError(
                    f"Bad data type value '{data_type}'. Valid values are: {', '.join(VALID_DATA_TYPES)}"
                )
            return data_type
    
    
        @validator("elements", pre=True, always=True)
        def set_and_validate_elements(cls, elements):
            if not elements:
                return DEFAULT_ELEMENTS
    
            if len(elements) == 1 and "," in elements[0]:
                elements = [e.strip() for e in elements[0].split(",")]
    
    
            for element in elements:
                if element not in VALID_ELEMENTS and len(element) != 3:
                    raise ValueError(
                        f"Bad element '{element}'."
                        f"Valid values are: {', '.join(VALID_ELEMENTS)}."
                    )
            return elements
    
    
        @validator("sampling_period", pre=True, always=True)
        def set_sampling_period(cls, sampling_period):
            return sampling_period or SamplingPeriod.HOUR
    
        @validator("format", pre=True, always=True)
        def set_format(cls, format):
            return format or OutputFormat.IAGA2002
    
    
        @validator("id")
        def validate_id(cls, id):
            if id not in VALID_OBSERVATORIES:
                raise ValueError(
                    f"Bad observatory id '{id}'."
                    f" Valid values are: {', '.join(VALID_OBSERVATORIES)}."
                )
            return id
    
        @root_validator
        def validate_times(cls, values):
            starttime, endtime, elements, format, sampling_period = (
                values.get("starttime"),
                values.get("endtime"),
                values.get("elements"),
                values.get("format"),
                values.get("sampling_period"),
            )
            if starttime > endtime:
                raise ValueError("Starttime must be before endtime.")
    
            if len(elements) > 4 and format == "iaga2002":
                raise ValueError("No more than four elements allowed for iaga2002 format.")
    
    
            samples = int(
                len(elements) * (endtime - starttime).total_seconds() / sampling_period
            )
    
            # check data volume
            if samples > REQUEST_LIMIT:
                raise ValueError(f"Query exceeds request limit ({samples} > 345600)")
    
            return values