From 5526a41c781d4a04758567d3a0e6e09e295f4702 Mon Sep 17 00:00:00 2001
From: Jeremy Fee <jmfee@usgs.gov>
Date: Wed, 17 Jun 2020 19:51:45 -0600
Subject: [PATCH] Move dbdt webservice processing back to webservice

---
 geomagio/api/ws/data.py      | 30 +++++++++++++++++++++++++++---
 geomagio/edge/EdgeFactory.py | 30 ++----------------------------
 2 files changed, 29 insertions(+), 31 deletions(-)

diff --git a/geomagio/api/ws/data.py b/geomagio/api/ws/data.py
index 4faeb92de..88b0883e3 100644
--- a/geomagio/api/ws/data.py
+++ b/geomagio/api/ws/data.py
@@ -6,6 +6,7 @@ from obspy import UTCDateTime, Stream
 from starlette.responses import Response
 
 from ... import TimeseriesFactory, TimeseriesUtility
+from ...algorithm import DbDtAlgorithm
 from ...edge import EdgeFactory
 from ...iaga2002 import IAGA2002Writer
 from ...imfjson import IMFJSONWriter
@@ -64,7 +65,6 @@ def get_timeseries(data_factory: TimeseriesFactory, query: DataApiQuery) -> Stre
     data_factory: where to read data
     query: parameters for the data to read
     """
-
     # get data
     timeseries = data_factory.get_timeseries(
         starttime=query.starttime,
@@ -72,11 +72,33 @@ def get_timeseries(data_factory: TimeseriesFactory, query: DataApiQuery) -> Stre
         observatory=query.id,
         channels=query.elements,
         type=query.data_type,
-        dbdt=query.dbdt,
         interval=TimeseriesUtility.get_interval_from_delta(query.sampling_period),
     )
+    return post_process(query, timeseries)
+
 
-    return timeseries
+def post_process(query: DataApiQuery, timeseries: Stream) -> Stream:
+    """Process timeseries data before it is returned.
+
+    Parameters
+    ----------
+    query: parameters for the data to read
+    timeseries: data that was read
+    """
+    out = timeseries
+    if query.dbdt:
+        out = Stream()
+        dbdt = Stream()
+        for trace in timeseries:
+            if trace.stats.channel in query.dbdt:
+                dbdt += trace
+            else:
+                out += trace
+        out += DbDtAlgorithm().process(dbdt)
+        query.elements = [
+            el in query.dbdt and f"{el}_DT" or el for el in query.elements
+        ]
+    return out
 
 
 router = APIRouter()
@@ -91,6 +113,7 @@ def get_data(
     sampling_period: Union[SamplingPeriod, float] = Query(SamplingPeriod.MINUTE),
     data_type: Union[DataType, str] = Query(DataType.ADJUSTED, alias="type"),
     format: OutputFormat = Query(OutputFormat.IAGA2002),
+    dbdt: List[str] = Query([]),
     data_factory: TimeseriesFactory = Depends(get_data_factory),
 ) -> Response:
     # parse query
@@ -102,6 +125,7 @@ def get_data(
         sampling_period=sampling_period,
         data_type=data_type,
         format=format,
+        dbdt=dbdt,
     )
     # read data
     timeseries = get_timeseries(data_factory, query)
diff --git a/geomagio/edge/EdgeFactory.py b/geomagio/edge/EdgeFactory.py
index 3c2078664..83c278a6e 100644
--- a/geomagio/edge/EdgeFactory.py
+++ b/geomagio/edge/EdgeFactory.py
@@ -22,7 +22,6 @@ from ..TimeseriesFactory import TimeseriesFactory
 from ..TimeseriesFactoryException import TimeseriesFactoryException
 from ..ObservatoryMetadata import ObservatoryMetadata
 from .RawInputClient import RawInputClient
-from ..algorithm.DbDtAlgorithm import DbDtAlgorithm
 
 
 class EdgeFactory(TimeseriesFactory):
@@ -111,7 +110,6 @@ class EdgeFactory(TimeseriesFactory):
         channels=None,
         type=None,
         interval=None,
-        dbdt: list = None,
     ):
         """Get timeseries data
 
@@ -129,8 +127,6 @@ class EdgeFactory(TimeseriesFactory):
             data type.
         interval: {'day', 'hour', 'minute', 'second', 'tenhertz'}
             data interval.
-        dbdt: list
-            list of channels to receive as time derivatives
 
         Returns
         -------
@@ -168,7 +164,7 @@ class EdgeFactory(TimeseriesFactory):
         finally:
             # restore stdout
             sys.stdout = original_stdout
-        self._post_process(timeseries, starttime, endtime, channels, dbdt)
+        self._post_process(timeseries, starttime, endtime, channels)
 
         return timeseries
 
@@ -515,7 +511,7 @@ class EdgeFactory(TimeseriesFactory):
         self._set_metadata(data, observatory, channel, type, interval)
         return data
 
-    def _post_process(self, timeseries, starttime, endtime, channels, dbdt):
+    def _post_process(self, timeseries, starttime, endtime, channels):
         """Post process a timeseries stream after the raw data is
                 is fetched from a waveserver. Specifically changes
                 any MaskedArray to a ndarray with nans representing gaps.
@@ -532,8 +528,6 @@ class EdgeFactory(TimeseriesFactory):
             the endtime of the requested data
         channels: array_like
             list of channels to load
-        dbdt: list
-            list of channels to receive time derivative
 
         Notes: the original timeseries object is changed.
         """
@@ -547,26 +541,6 @@ class EdgeFactory(TimeseriesFactory):
             for trace in timeseries.select(channel="D"):
                 trace.data = ChannelConverter.get_radians_from_minutes(trace.data)
 
-        if dbdt:
-            # find matching channels from dbdt list
-            dbdt_stream = obspy.core.Stream(
-                [
-                    timeseries.select(channel=channel)[0]
-                    for channel in channels
-                    if channel in dbdt
-                ]
-            )
-            # process matching stream with DbDtAlgorithm
-            dbdt_stream = DbDtAlgorithm(
-                inchannels=dbdt, period=timeseries[0].stats.sampling_rate
-            ).process(dbdt_stream)
-
-            # replace traces from stream in original timeseries
-            for i in range(len(timeseries)):
-                channel = timeseries[i].stats.channel
-                if channel in dbdt:
-                    timeseries[i] = dbdt_stream.select(channel=channel + "_DDT")[0]
-
         TimeseriesUtility.pad_timeseries(timeseries, starttime, endtime)
 
     def _put_channel(
-- 
GitLab