diff --git a/geomagio/api/ws/DataApiQuery.py b/geomagio/api/ws/DataApiQuery.py index 5d7cb3b2eabbf3867681bff082c4a6ed7d09cd86..9c894c46f47864e5384254d71d22c2c3738c9bfc 100644 --- a/geomagio/api/ws/DataApiQuery.py +++ b/geomagio/api/ws/DataApiQuery.py @@ -65,6 +65,7 @@ class DataApiQuery(BaseModel): elements: List[str] = DEFAULT_ELEMENTS sampling_period: SamplingPeriod = SamplingPeriod.MINUTE data_type: Union[DataType, str] = DataType.VARIATION + dbdt: list = [] format: OutputFormat = OutputFormat.IAGA2002 @validator("data_type") @@ -123,6 +124,16 @@ class DataApiQuery(BaseModel): endtime = starttime + (86400 - 0.001) return endtime + @validator("dbdt", always=True) + def validate_dbdt(cls, dbdt: list,) -> list: + """Default dbdt based on valid elements. + """ + for channel in dbdt: + if channel not in DEFAULT_ELEMENTS: + raise ValueError("Specified channel not found in valid elements.") + + return dbdt + @root_validator def validate_combinations(cls, values): starttime, endtime, elements, format, sampling_period = ( diff --git a/geomagio/api/ws/Element.py b/geomagio/api/ws/Element.py index 6d0eb7fd6f6f86fdd1e812a24a045330b15fb50d..eff80cca91d247ee2cd5f3cc664c5ab648ca54f7 100644 --- a/geomagio/api/ws/Element.py +++ b/geomagio/api/ws/Element.py @@ -36,7 +36,6 @@ ELEMENTS = [ id="UK3", abbreviation="T-Fluxgate", name="Fluxgate Temperature", units="°C" ), Element(id="UK4", abbreviation="T-Outside", name="Outside Temperature", units="°C"), - Element(id="DDT", abbreviation="DbDt", name="Time Derivative", units="1/s"), ] ELEMENT_INDEX = {e.id: e for e in ELEMENTS} diff --git a/geomagio/api/ws/data.py b/geomagio/api/ws/data.py index 345c47633cf710306d46a252f57021fba6c5ed71..4faeb92de1303881ab4b9114e8777005b43c5291 100644 --- a/geomagio/api/ws/data.py +++ b/geomagio/api/ws/data.py @@ -5,7 +5,6 @@ from fastapi import APIRouter, Depends, Query from obspy import UTCDateTime, Stream from starlette.responses import Response -from ...algorithm.DbDtAlgorithm import DbDtAlgorithm from ... import TimeseriesFactory, TimeseriesUtility from ...edge import EdgeFactory from ...iaga2002 import IAGA2002Writer @@ -66,42 +65,18 @@ def get_timeseries(data_factory: TimeseriesFactory, query: DataApiQuery) -> Stre query: parameters for the data to read """ - # gather non-dbdt elements first - base_elements = [element for element in query.elements if element[1::] != "_DDT"] - - # gather interval - interval = TimeseriesUtility.get_interval_from_delta(query.sampling_period) # get data - base_timeseries = data_factory.get_timeseries( + timeseries = data_factory.get_timeseries( starttime=query.starttime, endtime=query.endtime, observatory=query.id, - channels=base_elements, + channels=query.elements, type=query.data_type, - interval=interval, + dbdt=query.dbdt, + interval=TimeseriesUtility.get_interval_from_delta(query.sampling_period), ) - if "*_DDT" in query.elements: - dbdt_elements = [ - element[0:1] for element in query.elements if element[1::] == "_DDT" - ] - - timeseries = data_factory.get_timeseries( - starttime=query.starttime, - endtime=query.endtime, - observatory=query.id, - channels=dbdt_elements, - type=query.data_type, - interval=interval, - ) - - dbdt_timeseries = DbDtAlgorithm( - inchannels=dbdt_elements, outchannels=dbdt_elements + "_DDT" - ).process(timeseries) - - base_timeseries += (trace for trace in dbdt_timeseries) - - return base_timeseries + return timeseries router = APIRouter() diff --git a/geomagio/edge/EdgeFactory.py b/geomagio/edge/EdgeFactory.py index 4602b0b9898c5c520666200e13ee6dc2f5c57012..1ab734a5568073f93c5273c6c3aad168fb1de5f7 100644 --- a/geomagio/edge/EdgeFactory.py +++ b/geomagio/edge/EdgeFactory.py @@ -22,6 +22,7 @@ from ..TimeseriesFactory import TimeseriesFactory from ..TimeseriesFactoryException import TimeseriesFactoryException from ..ObservatoryMetadata import ObservatoryMetadata from .RawInputClient import RawInputClient +from ..algorithm.DbDtAlgorithm import DbDtAlgorithm class EdgeFactory(TimeseriesFactory): @@ -110,6 +111,7 @@ class EdgeFactory(TimeseriesFactory): channels=None, type=None, interval=None, + dbdt: list = None, ): """Get timeseries data @@ -127,6 +129,8 @@ class EdgeFactory(TimeseriesFactory): data type. interval: {'day', 'hour', 'minute', 'second', 'tenhertz'} data interval. + dbdt: list + list of channels to receive as time derivatives Returns ------- @@ -164,7 +168,7 @@ class EdgeFactory(TimeseriesFactory): finally: # restore stdout sys.stdout = original_stdout - self._post_process(timeseries, starttime, endtime, channels) + self._post_process(timeseries, starttime, endtime, channels, dbdt) return timeseries @@ -511,7 +515,7 @@ class EdgeFactory(TimeseriesFactory): self._set_metadata(data, observatory, channel, type, interval) return data - def _post_process(self, timeseries, starttime, endtime, channels): + def _post_process(self, timeseries, starttime, endtime, channels, dbdt): """Post process a timeseries stream after the raw data is is fetched from a waveserver. Specifically changes any MaskedArray to a ndarray with nans representing gaps. @@ -528,6 +532,8 @@ class EdgeFactory(TimeseriesFactory): the endtime of the requested data channels: array_like list of channels to load + dbdt: list + list of channels to receive time derivative Notes: the original timeseries object is changed. """ @@ -541,10 +547,30 @@ class EdgeFactory(TimeseriesFactory): for trace in timeseries.select(channel="D"): trace.data = ChannelConverter.get_radians_from_minutes(trace.data) - TimeseriesUtility.pad_timeseries(timeseries, starttime, endtime) + if dbdt: + # find matching channels from dbdt list + dbdt_stream = obspy.core.Stream( + [ + timeseries.select(channel=channel)[0] + for channel in channels + if channel in dbdt + ] + ) + # process matching stream with DbDtAlgorithm + dbdt_stream = DbDtAlgorithm( + inchannels=dbdt, period=timeseries[0].stats.sampling_rate + ).process(dbdt_stream) + + # replace traces from stream in original timeseries + for i in range(len(timeseries)): + channel = timeseries[i].stats.channel + if channel in dbdt: + timeseries[i] = dbdt_stream.select(channel=channel + "_DDT")[0] + + TimeseriesUtility.pad_timeseries(timeseries, starttime, endtime) def _put_channel( - self, timeseries, observatory, channel, type, interval, starttime, endtime + self, timeseries, observatory, channel, type, interval, starttime, endtime, ): """Put a channel worth of data