From 35dd0ae549c5bbd1e8fb94871f65a4154eb4e138 Mon Sep 17 00:00:00 2001
From: Jeremy Fee <jmfee@usgs.gov>
Date: Tue, 6 Nov 2018 13:35:56 -0700
Subject: [PATCH] Refactor utility methods from EdgeFactory to
 TimeseriesUtility

---
 geomagio/TimeseriesUtility.py      | 116 ++++++++++++++++++++++++++
 geomagio/edge/EdgeFactory.py       | 125 ++---------------------------
 test/TimeseriesUtility_test.py     |  81 ++++++++++++++++++-
 test/edge_test/EdgeFactory_test.py |  84 +------------------
 4 files changed, 204 insertions(+), 202 deletions(-)

diff --git a/geomagio/TimeseriesUtility.py b/geomagio/TimeseriesUtility.py
index 4f0adeb13..385bb0b29 100644
--- a/geomagio/TimeseriesUtility.py
+++ b/geomagio/TimeseriesUtility.py
@@ -1,9 +1,89 @@
 """Timeseries Utilities"""
 from builtins import range
+from datetime import datetime
 import numpy
 import obspy.core
 
 
+def create_empty_trace(starttime, endtime, observatory,
+            channel, type, interval, network, station, location):
+    """create an empty trace filled with nans.
+
+    Parameters
+    ----------
+    starttime: obspy.core.UTCDateTime
+        the starttime of the requested data
+    endtime: obspy.core.UTCDateTime
+        the endtime of the requested data
+    observatory : str
+        observatory code
+    channel : str
+        single character channel {H, E, D, Z, F}
+    type : str
+        data type {definitive, quasi-definitive, variation}
+    interval : str
+        interval length {minute, second}
+    network: str
+        the network code
+    station: str
+        the observatory station code
+    location: str
+        the location code
+    Returns
+    -------
+    obspy.core.Trace
+        trace for the requested channel
+    """
+    if interval == 'second':
+        delta = 1.
+    elif interval == 'minute':
+        delta = 60.
+    elif interval == 'hourly':
+        delta = 3600.
+    elif interval == 'daily':
+        delta = 86400.
+    stats = obspy.core.Stats()
+    stats.network = network
+    stats.station = station
+    stats.location = location
+    stats.channel = channel
+    # Calculate first valid sample time based on interval
+    trace_starttime = obspy.core.UTCDateTime(
+        numpy.ceil(starttime.timestamp / delta) * delta)
+    stats.starttime = trace_starttime
+    stats.delta = delta
+    # Calculate number of valid samples up to or before endtime
+    length = int((endtime - trace_starttime) / delta)
+    stats.npts = length + 1
+    data = numpy.full(stats.npts, numpy.nan, dtype=numpy.float64)
+    return obspy.core.Trace(data, stats)
+
+
+def get_stream_start_end_times(timeseries):
+    """get start and end times from a stream.
+            Traverses all traces, and find the earliest starttime, and
+            the latest endtime.
+    Parameters
+    ----------
+    timeseries: obspy.core.stream
+        The timeseries stream
+
+    Returns
+    -------
+    tuple: (starttime, endtime)
+        starttime: obspy.core.UTCDateTime
+        endtime: obspy.core.UTCDateTime
+    """
+    starttime = obspy.core.UTCDateTime(datetime.now())
+    endtime = obspy.core.UTCDateTime(0)
+    for trace in timeseries:
+        if trace.stats.starttime < starttime:
+            starttime = trace.stats.starttime
+        if trace.stats.endtime > endtime:
+            endtime = trace.stats.endtime
+    return (starttime, endtime)
+
+
 def get_stream_gaps(stream, channels=None):
     """Get gaps in a given stream
     Parameters
@@ -230,3 +310,39 @@ def merge_streams(*streams):
     # convert back to NaN filled array
     merged = unmask_stream(split)
     return merged
+
+
+def pad_timeseries(timeseries, starttime, endtime):
+    """Realigns timeseries data so the start and endtimes are the same
+        as what was originally asked for, even if the data was during
+        a gap.
+
+    Parameters
+    ----------
+    timeseries: obspy.core.stream
+        The timeseries stream as returned by the call to getWaveform
+    starttime: obspy.core.UTCDateTime
+        the starttime of the requested data
+    endtime: obspy.core.UTCDateTime
+        the endtime of the requested data
+
+    Notes: the original timeseries object is changed.
+    """
+    for trace in timeseries:
+        trace_starttime = obspy.core.UTCDateTime(trace.stats.starttime)
+        trace_endtime = obspy.core.UTCDateTime(trace.stats.endtime)
+        trace_delta = trace.stats.delta
+        if trace_starttime > starttime:
+            cnt = int((trace_starttime - starttime) / trace_delta)
+            if cnt > 0:
+                trace.data = numpy.concatenate([
+                        numpy.full(cnt, numpy.nan, dtype=numpy.float64),
+                        trace.data])
+                trace_starttime = trace_starttime - trace_delta * cnt
+                trace.stats.starttime = trace_starttime
+        if trace_endtime < endtime:
+            cnt = int((endtime - trace_endtime) / trace.stats.delta)
+            if cnt > 0:
+                trace.data = numpy.concatenate([
+                        trace.data,
+                        numpy.full(cnt, numpy.nan, dtype=numpy.float64)])
diff --git a/geomagio/edge/EdgeFactory.py b/geomagio/edge/EdgeFactory.py
index 4f3811f7d..dd8e13866 100644
--- a/geomagio/edge/EdgeFactory.py
+++ b/geomagio/edge/EdgeFactory.py
@@ -191,7 +191,8 @@ class EdgeFactory(TimeseriesFactory):
         interval = interval or self.interval or stats.data_interval
 
         if (starttime is None or endtime is None):
-            starttime, endtime = self._get_stream_start_end_times(timeseries)
+            starttime, endtime = TimeseriesUtility.get_stream_start_end_times(
+                    timeseries)
         for channel in channels:
             if timeseries.select(channel=channel).count() == 0:
                 raise TimeseriesFactoryException(
@@ -201,41 +202,6 @@ class EdgeFactory(TimeseriesFactory):
             self._put_channel(timeseries, observatory, channel, type,
                     interval, starttime, endtime)
 
-    def _clean_timeseries(self, timeseries, starttime, endtime):
-        """Realigns timeseries data so the start and endtimes are the same
-            as what was originally asked for, even if the data was during
-            a gap.
-
-        Parameters
-        ----------
-        timeseries: obspy.core.stream
-            The timeseries stream as returned by the call to getWaveform
-        starttime: obspy.core.UTCDateTime
-            the starttime of the requested data
-        endtime: obspy.core.UTCDateTime
-            the endtime of the requested data
-
-        Notes: the original timeseries object is changed.
-        """
-        for trace in timeseries:
-            trace_starttime = obspy.core.UTCDateTime(trace.stats.starttime)
-            trace_endtime = obspy.core.UTCDateTime(trace.stats.endtime)
-            trace_delta = trace.stats.delta
-            if trace_starttime > starttime:
-                cnt = int((trace_starttime - starttime) / trace_delta)
-                if cnt > 0:
-                    trace.data = numpy.concatenate([
-                            numpy.full(cnt, numpy.nan, dtype=numpy.float64),
-                            trace.data])
-                    trace_starttime = trace_starttime - trace_delta * cnt
-                    trace.stats.starttime = trace_starttime
-            if trace_endtime < endtime:
-                cnt = int((endtime - trace_endtime) / trace.stats.delta)
-                if cnt > 0:
-                    trace.data = numpy.concatenate([
-                            trace.data,
-                            numpy.full(cnt, numpy.nan, dtype=numpy.float64)])
-
     def _convert_timeseries_to_decimal(self, stream):
         """convert geomag edge timeseries data stored as ints, to decimal by
             dividing by 1000.00
@@ -293,59 +259,6 @@ class EdgeFactory(TimeseriesFactory):
             trace.data = numpy.ma.masked_invalid(trace.data)
         return stream
 
-    def _create_missing_channel(self, starttime, endtime, observatory,
-                channel, type, interval, network, station, location):
-        """fill a missing channel with nans.
-
-        Parameters
-        ----------
-        starttime: obspy.core.UTCDateTime
-            the starttime of the requested data
-        endtime: obspy.core.UTCDateTime
-            the endtime of the requested data
-        observatory : str
-            observatory code
-        channel : str
-            single character channel {H, E, D, Z, F}
-        type : str
-            data type {definitive, quasi-definitive, variation}
-        interval : str
-            interval length {minute, second}
-        network: str
-            the network code
-        station: str
-            the observatory station code
-        location: str
-            the location code
-        Returns
-        -------
-        obspy.core.Stream
-            stream of the requested channel data
-        """
-        if interval == 'second':
-            delta = 1.
-        elif interval == 'minute':
-            delta = 60.
-        elif interval == 'hourly':
-            delta = 3600.
-        elif interval == 'daily':
-            delta = 86400.
-        stats = obspy.core.Stats()
-        stats.network = network
-        stats.station = station
-        stats.location = location
-        stats.channel = channel
-        # Calculate first valid sample time based on interval
-        trace_starttime = obspy.core.UTCDateTime(
-            numpy.ceil(starttime.timestamp / delta) * delta)
-        stats.starttime = trace_starttime
-        stats.delta = delta
-        # Calculate number of valid samples up to or before endtime
-        length = int((endtime - trace_starttime) / delta)
-        stats.npts = length + 1
-        data = numpy.full(stats.npts, numpy.nan, dtype=numpy.float64)
-        return obspy.core.Stream(obspy.core.Trace(data, stats))
-
     def _get_edge_channel(self, observatory, channel, type, interval):
         """get edge channel.
 
@@ -563,42 +476,18 @@ class EdgeFactory(TimeseriesFactory):
                 edge_channel, starttime, endtime)
         data.merge()
         if data.count() == 0:
-            data = self._create_missing_channel(starttime, endtime,
-                observatory, channel, type, interval, network, station,
-                location)
+            data += TimeseriesUtility.create_empty_trace(
+                starttime, endtime, observatory, channel, type,
+                interval, network, station, location)
         self._set_metadata(data,
                 observatory, channel, type, interval)
         return data
 
-    def _get_stream_start_end_times(self, timeseries):
-        """get start and end times from a stream.
-                Traverses all traces, and find the earliest starttime, and
-                the latest endtime.
-        Parameters
-        ----------
-        timeseries: obspy.core.stream
-            The timeseries stream
-
-        Returns
-        -------
-        tuple: (starttime, endtime)
-            starttime: obspy.core.UTCDateTime
-            endtime: obspy.core.UTCDateTime
-        """
-        starttime = obspy.core.UTCDateTime(datetime.now())
-        endtime = obspy.core.UTCDateTime(0)
-        for trace in timeseries:
-            if trace.stats.starttime < starttime:
-                starttime = trace.stats.starttime
-            if trace.stats.endtime > endtime:
-                endtime = trace.stats.endtime
-        return (starttime, endtime)
-
     def _post_process(self, timeseries, starttime, endtime, channels):
         """Post process a timeseries stream after the raw data is
                 is fetched from a waveserver. Specifically changes
                 any MaskedArray to a ndarray with nans representing gaps.
-                Then calls _clean_timeseries to deal with gaps at the
+                Then calls pad_timeseries to deal with gaps at the
                 beggining or end of the streams.
 
         Parameters
@@ -625,7 +514,7 @@ class EdgeFactory(TimeseriesFactory):
                 trace.data = ChannelConverter.get_radians_from_minutes(
                     trace.data)
 
-        self._clean_timeseries(timeseries, starttime, endtime)
+        TimeseriesUtility.pad_timeseries(timeseries, starttime, endtime)
 
     def _put_channel(self, timeseries, observatory, channel, type, interval,
                 starttime, endtime):
diff --git a/test/TimeseriesUtility_test.py b/test/TimeseriesUtility_test.py
index 04483dd9d..5fa48f1cb 100644
--- a/test/TimeseriesUtility_test.py
+++ b/test/TimeseriesUtility_test.py
@@ -5,11 +5,56 @@ from nose.tools import assert_equals
 from .StreamConverter_test import __create_trace
 import numpy
 from geomagio import TimeseriesUtility
-from obspy.core import Stream, UTCDateTime
+from obspy.core import Stream, Stats, Trace, UTCDateTime
 
 assert_almost_equal = numpy.testing.assert_almost_equal
 
 
+def test_create_empty_trace():
+    """TimeseriesUtility_test.test_create_empty_trace()
+    """
+    trace1 = _create_trace([1, 1, 1, 1, 1], 'H', UTCDateTime("2018-01-01"))
+    trace2 = _create_trace([2, 2], 'E', UTCDateTime("2018-01-01"))
+    observatory = 'Test'
+    interval = 'minute'
+    network = 'NT'
+    location = 'R0'
+    trace3 = TimeseriesUtility.create_empty_trace(
+            starttime=trace1.stats.starttime,
+            endtime=trace1.stats.endtime,
+            observatory=observatory,
+            channel='F',
+            type='variation',
+            interval=interval,
+            network=network,
+            station=trace1.stats.station,
+            location=location)
+    timeseries = Stream(traces=[trace1, trace2])
+    # For continuity set stats to be same for all traces
+    for trace in timeseries:
+        trace.stats.observatory = observatory
+        trace.stats.type = 'variation'
+        trace.stats.interval = interval
+        trace.stats.network = network
+        trace.stats.station = trace1.stats.station
+        trace.stats.location = location
+    timeseries += trace3
+    assert_equals(len(trace3.data), trace3.stats.npts)
+    assert_equals(timeseries[0].stats.starttime, timeseries[2].stats.starttime)
+    TimeseriesUtility.pad_timeseries(
+        timeseries=timeseries,
+        starttime=trace1.stats.starttime,
+        endtime=trace1.stats.endtime)
+    assert_equals(len(trace3.data), trace3.stats.npts)
+    assert_equals(timeseries[0].stats.starttime, timeseries[2].stats.starttime)
+    # Change starttime by more than 1 delta
+    starttime = trace1.stats.starttime
+    endtime = trace1.stats.endtime
+    TimeseriesUtility.pad_timeseries(timeseries, starttime - 90, endtime + 90)
+    assert_equals(len(trace3.data), trace3.stats.npts)
+    assert_equals(timeseries[0].stats.starttime, timeseries[2].stats.starttime)
+
+
 def test_get_stream_gaps():
     """TimeseriesUtility_test.test_get_stream_gaps()
 
@@ -188,3 +233,37 @@ def test_merge_streams():
     assert_almost_equal(
         merged4.select(channel='H')[0].data,
         [1, 2, 2, 2, 1, 1])
+
+
+def test_pad_timeseries():
+    """TimeseriesUtility_test.test_pad_timeseries()
+    """
+    trace1 = _create_trace([1, 1, 1, 1, 1], 'H', UTCDateTime("2018-01-01"))
+    trace2 = _create_trace([2, 2], 'E', UTCDateTime("2018-01-01"))
+    timeseries = Stream(traces=[trace1, trace2])
+    TimeseriesUtility.pad_timeseries(
+        timeseries=timeseries,
+        starttime=trace1.stats.starttime,
+        endtime=trace1.stats.endtime)
+    assert_equals(len(trace1.data), len(trace2.data))
+    assert_equals(trace1.stats.starttime, trace2.stats.starttime)
+    assert_equals(trace1.stats.endtime, trace2.stats.endtime)
+    # change starttime by less than 1 delta
+    starttime = trace1.stats.starttime
+    endtime = trace1.stats.endtime
+    TimeseriesUtility.pad_timeseries(timeseries, starttime - 30, endtime + 30)
+    assert_equals(trace1.stats.starttime, starttime)
+    # Change starttime by more than 1 delta
+    TimeseriesUtility.pad_timeseries(timeseries, starttime - 90, endtime + 90)
+    assert_equals(trace1.stats.starttime, starttime - 60)
+    assert_equals(numpy.isnan(trace1.data[0]), numpy.isnan(numpy.NaN))
+
+
+def _create_trace(data, channel, starttime, delta=60.):
+    stats = Stats()
+    stats.channel = channel
+    stats.delta = delta
+    stats.starttime = starttime
+    stats.npts = len(data)
+    data = numpy.array(data, dtype=numpy.float64)
+    return Trace(data, stats)
diff --git a/test/edge_test/EdgeFactory_test.py b/test/edge_test/EdgeFactory_test.py
index d99052305..e0b14e0e8 100644
--- a/test/edge_test/EdgeFactory_test.py
+++ b/test/edge_test/EdgeFactory_test.py
@@ -1,9 +1,8 @@
 """Tests for EdgeFactory.py"""
 
-from obspy.core import Stats, Stream, Trace, UTCDateTime
+from obspy.core import Stream, Trace, UTCDateTime
 from geomagio.edge import EdgeFactory
 from nose.tools import assert_equals
-import numpy as np
 
 
 def test__get_edge_network():
@@ -92,84 +91,3 @@ def dont_get_timeseries():
         'BOU', 'Expect timeseries to have stats')
     assert_equals(timeseries.select(channel='H')[0].stats.channel,
         'H', 'Expect timeseries stats channel to be equal to H')
-
-
-def test_clean_timeseries():
-    """edge_test.EdgeFactory_test.test_clean_timeseries()
-    """
-    edge_factory = EdgeFactory()
-    trace1 = _create_trace([1, 1, 1, 1, 1], 'H', UTCDateTime("2018-01-01"))
-    trace2 = _create_trace([2, 2], 'E', UTCDateTime("2018-01-01"))
-    timeseries = Stream(traces=[trace1, trace2])
-    edge_factory._clean_timeseries(
-        timeseries=timeseries,
-        starttime=trace1.stats.starttime,
-        endtime=trace1.stats.endtime)
-    assert_equals(len(trace1.data), len(trace2.data))
-    assert_equals(trace1.stats.starttime, trace2.stats.starttime)
-    assert_equals(trace1.stats.endtime, trace2.stats.endtime)
-    # change starttime by less than 1 delta
-    starttime = trace1.stats.starttime
-    endtime = trace1.stats.endtime
-    edge_factory._clean_timeseries(timeseries, starttime - 30, endtime + 30)
-    assert_equals(trace1.stats.starttime, starttime)
-    # Change starttime by more than 1 delta
-    edge_factory._clean_timeseries(timeseries, starttime - 90, endtime + 90)
-    assert_equals(trace1.stats.starttime, starttime - 60)
-    assert_equals(np.isnan(trace1.data[0]), np.isnan(np.NaN))
-
-
-def test_create_missing_channel():
-    """edge_test.EdgeFactory_test.test_create_missing_channel()
-    """
-    edge_factory = EdgeFactory()
-    trace1 = _create_trace([1, 1, 1, 1, 1], 'H', UTCDateTime("2018-01-01"))
-    trace2 = _create_trace([2, 2], 'E', UTCDateTime("2018-01-01"))
-    observatory = 'Test'
-    interval = 'minute'
-    network = 'NT'
-    location = 'R0'
-    trace3 = edge_factory._create_missing_channel(
-        starttime=trace1.stats.starttime,
-        endtime=trace1.stats.endtime,
-        observatory=observatory,
-        channel='F',
-        type='variation',
-        interval=interval,
-        network=network,
-        station=trace1.stats.station,
-        location=location)
-    timeseries = Stream(traces=[trace1, trace2])
-    # For continuity set stats to be same for all traces
-    for trace in timeseries:
-        trace.stats.observatory = observatory
-        trace.stats.type = 'variation'
-        trace.stats.interval = interval
-        trace.stats.network = network
-        trace.stats.station = trace1.stats.station
-        trace.stats.location = location
-    timeseries += trace3
-    assert_equals(len(trace3[0].data), trace3[0].stats.npts)
-    assert_equals(timeseries[0].stats.starttime, timeseries[2].stats.starttime)
-    edge_factory._clean_timeseries(
-        timeseries=timeseries,
-        starttime=trace1.stats.starttime,
-        endtime=trace1.stats.endtime)
-    assert_equals(len(trace3[0].data), trace3[0].stats.npts)
-    assert_equals(timeseries[0].stats.starttime, timeseries[2].stats.starttime)
-    # Change starttime by more than 1 delta
-    starttime = trace1.stats.starttime
-    endtime = trace1.stats.endtime
-    edge_factory._clean_timeseries(timeseries, starttime - 90, endtime + 90)
-    assert_equals(len(trace3[0].data), trace3[0].stats.npts)
-    assert_equals(timeseries[0].stats.starttime, timeseries[2].stats.starttime)
-
-
-def _create_trace(data, channel, starttime, delta=60.):
-    stats = Stats()
-    stats.channel = channel
-    stats.delta = delta
-    stats.starttime = starttime
-    stats.npts = len(data)
-    data = np.array(data, dtype=np.float64)
-    return Trace(data, stats)
-- 
GitLab