Skip to content
Snippets Groups Projects
DataApiQuery.py 4.96 KiB
Newer Older
  • Learn to ignore specific revisions
  • import datetime
    import enum
    
    from typing import Dict, List, Union
    
    
    from obspy import UTCDateTime
    from pydantic import BaseModel, root_validator, validator
    
    
    from .Element import ELEMENTS
    
    from .Observatory import OBSERVATORY_INDEX, ASL_OBSERVATORY_INDEX
    
    
    DEFAULT_ELEMENTS = ["X", "Y", "Z", "F"]
    
    REQUEST_LIMIT = 3456000  # Increased the request limit by 10x what was decided by Jeremy
    
    VALID_ELEMENTS = [e.id for e in ELEMENTS]
    
    
    class DataType(str, enum.Enum):
        VARIATION = "variation"
        ADJUSTED = "adjusted"
        QUASI_DEFINITIVE = "quasi-definitive"
        DEFINITIVE = "definitive"
    
    
        @classmethod
        def values(cls) -> List[str]:
            return [t.value for t in cls]
    
    
    
    class OutputFormat(str, enum.Enum):
        IAGA2002 = "iaga2002"
        JSON = "json"
    
    
    class SamplingPeriod(float, enum.Enum):
        TEN_HERTZ = 0.1
        SECOND = 1.0
        MINUTE = 60.0
        HOUR = 3600.0
        DAY = 86400.0
    
    
    
    class DataHost(str, enum.Enum):
    
        # recognized public Edge data hosts, plus one user-specified
    
        DEFAULT = os.getenv("DATA_HOST", "edgecwb.usgs.gov")
        EDGECWB = "edgecwb.usgs.gov"
        CWBPUB = "cwbpub.cr.usgs.gov"
        CWBPUB2 = "cwbp2.cr.usgs.gov"
    
    
        @classmethod
        def values(cls) -> List[str]:
            return [t.value for t in cls]
    
    
    class DataApiQuery(BaseModel):
        id: str
    
        starttime: UTCDateTime = None
        endtime: UTCDateTime = None
    
        elements: List[str] = DEFAULT_ELEMENTS
        sampling_period: SamplingPeriod = SamplingPeriod.MINUTE
        data_type: Union[DataType, str] = DataType.VARIATION
    
        format: Union[OutputFormat, str] = OutputFormat.IAGA2002
        data_host: Union[DataHost, str] = DataHost.DEFAULT
    
    
        @validator("data_type")
        def validate_data_type(
            cls, data_type: Union[DataType, str]
        ) -> Union[DataType, str]:
    
            if data_type not in DataType.values() and len(data_type) != 2:
    
                raise ValueError(
                    f"Bad data type value '{data_type}'."
    
                    f" Valid values are: {', '.join(DataType.values())}"
    
                )
            return data_type
    
    
        @validator("data_host")
        def validate_data_host(
            cls, data_host: Union[DataHost, str]
        ) -> Union[DataHost, str]:
            if data_host not in DataHost.values():
                raise ValueError(
                    # don't advertise acceptable hosts
                    f"Bad data_host value '{data_host}'."
                )
            return data_host
    
    
        @validator("elements", pre=True, always=True)
        def validate_elements(cls, elements: List[str]) -> List[str]:
            if not elements:
                return DEFAULT_ELEMENTS
            if len(elements) == 1 and "," in elements[0]:
                elements = [e.strip() for e in elements[0].split(",")]
            for element in elements:
                if element not in VALID_ELEMENTS and len(element) != 3:
                    raise ValueError(
                        f"Bad element '{element}'."
                        f" Valid values are: {', '.join(VALID_ELEMENTS)}."
                    )
            return elements
    
        @validator("id")
        def validate_id(cls, id: str) -> str:
    
            complete_observatory_index = {**OBSERVATORY_INDEX, **ASL_OBSERVATORY_INDEX}
            if id not in complete_observatory_index:
    
                raise ValueError(
                    f"Bad observatory id '{id}'."
    
                    f" Valid values are: {', '.join(sorted(complete_observatory_index.keys()))}."
    
        @validator("starttime", always=True)
        def validate_starttime(cls, starttime: UTCDateTime) -> UTCDateTime:
    
            if not starttime:
                # default to start of current day
                now = datetime.datetime.now(tz=datetime.timezone.utc)
    
                return UTCDateTime(year=now.year, month=now.month, day=now.day)
    
            return starttime
    
    
        def validate_endtime(
    
            cls, endtime: UTCDateTime, *, values: Dict, **kwargs
        ) -> UTCDateTime:
    
            """Default endtime is based on starttime.
    
            This method needs to be after validate_starttime.
            """
            if not endtime:
                # endtime defaults to 1 day after startime
                starttime = values.get("starttime")
    
            return endtime
    
        @root_validator
        def validate_combinations(cls, values):
    
            starttime, endtime, elements, format, sampling_period = (
    
                values.get("starttime"),
                values.get("endtime"),
                values.get("elements"),
                values.get("format"),
                values.get("sampling_period"),
            )
            if len(elements) > 4 and format == "iaga2002":
                raise ValueError("No more than four elements allowed for iaga2002 format.")
            if starttime > endtime:
                raise ValueError("Starttime must be before endtime.")
            # check data volume
    
            samples = int(len(elements) * (endtime - starttime) / sampling_period)
    
            if samples > REQUEST_LIMIT:
                raise ValueError(f"Request exceeds limit ({samples} > {REQUEST_LIMIT})")
            # otherwise okay
            return values