From 280633fc6443624f2020e2cf3700efc065e3f4e0 Mon Sep 17 00:00:00 2001 From: arigdon-usgs <arigdon@usgs.gov> Date: Mon, 30 Jul 2018 14:19:03 -0600 Subject: [PATCH] Fixed some linting errors and made the merge_streams() code cleaner as well as added more unit testing --- geomagio/TimeseriesUtility.py | 33 ++++++++++++++++------------- test/TimeseriesUtility_test.py | 38 ++++++++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/geomagio/TimeseriesUtility.py b/geomagio/TimeseriesUtility.py index 4fc346895..7fe248863 100644 --- a/geomagio/TimeseriesUtility.py +++ b/geomagio/TimeseriesUtility.py @@ -195,31 +195,36 @@ def merge_streams(*streams): stream with contiguous traces merged, and gaps filled with numpy.nan """ merged = obspy.core.Stream() - masked = obspy.core.Stream() # sort out empty for stream in streams: - for trace in stream: - if numpy.isnan(trace.data).all(): - masked += trace - else: - merged += trace + merged += stream - merged = mask_stream(merged) + split = mask_stream(merged) # split traces that contain gaps - merged = merged.split() - - merged += masked + split = split.split() + + # Re-add any empty traces that were removed by split() + readd = obspy.core.Stream() + for trace in merged: + stats = trace.stats + split_stream = split.select( + channel=stats.channel, + station=stats.station, + network=stats.network, + location=stats.location) + if len(split_stream) == 0: + readd += trace + split += readd # merge data - merged.merge( + split.merge( # 1 = do not interpolate - interpolation_samples=1, - fill_value=numpy.NaN, + interpolation_samples=0, # 1 = when there is overlap, use data from trace with last endtime method=1) # convert back to NaN filled array - merged = unmask_stream(merged) + merged = unmask_stream(split) return merged diff --git a/test/TimeseriesUtility_test.py b/test/TimeseriesUtility_test.py index 4a7a25876..6d54432df 100644 --- a/test/TimeseriesUtility_test.py +++ b/test/TimeseriesUtility_test.py @@ -7,6 +7,8 @@ import numpy from geomagio import TimeseriesUtility from obspy.core import Stream, UTCDateTime +assert_almost_equal = numpy.testing.assert_almost_equal + def test_get_stream_gaps(): """TimeseriesUtility_test.test_get_stream_gaps() @@ -123,8 +125,8 @@ def test_merge_streams(): trace1 = __create_trace('H', [1, 1, 1, 1]) trace2 = __create_trace('E', [2, numpy.nan, numpy.nan, 2]) trace3 = __create_trace('F', [numpy.nan, numpy.nan, numpy.nan, numpy.nan]) - trace4 = __create_trace('H', [1, 1, 1, 1]) - trace5 = __create_trace('E', [2, numpy.nan, numpy.nan, 2]) + trace4 = __create_trace('H', [2, 2, 2, 2]) + trace5 = __create_trace('E', [3, numpy.nan, numpy.nan, 3]) trace6 = __create_trace('F', [numpy.nan, numpy.nan, numpy.nan, numpy.nan]) npts1 = len(trace1.data) npts2 = len(trace4.data) @@ -138,22 +140,40 @@ def test_merge_streams(): trace.stats.npts = npts2 merged_streams1 = TimeseriesUtility.merge_streams(timeseries1) # Make sure the empty 'F' was not removed from stream - assert_equals(1, len(merged_streams1.select(channel = 'F'))) + assert_equals(1, len(merged_streams1.select(channel='F'))) # Merge multiple streams with overlapping timestamps timeseries = timeseries1 + timeseries2 + merged_streams = TimeseriesUtility.merge_streams(timeseries) assert_equals(len(merged_streams), len(timeseries1)) - assert_equals(len(merged_streams[0].data), 6) + assert_equals(len(merged_streams[0]), 6) assert_equals(len(merged_streams[2]), 6) - - trace7 = __create_trace('H', [1,1,1,1]) + assert_almost_equal( + merged_streams.select(channel='H')[0].data, + [1, 1, 2, 2, 2, 2]) + assert_almost_equal( + merged_streams.select(channel='E')[0].data, + [2, numpy.nan, 3, 2, numpy.nan, 3]) + assert_almost_equal( + merged_streams.select(channel='F')[0].data, + [numpy.nan] * 6) + + trace7 = __create_trace('H', [1, 1, 1, 1]) trace8 = __create_trace('E', [numpy.nan, numpy.nan, numpy.nan, numpy.nan]) trace9 = __create_trace('F', [numpy.nan, numpy.nan, numpy.nan, numpy.nan]) - timeseries3 = Stream(traces=[trace7,trace8,trace9]) + timeseries3 = Stream(traces=[trace7, trace8, trace9]) npts3 = len(trace7.data) for trace in timeseries3: trace.stats.starttime = UTCDateTime('2018-01-01T00:00:00Z') trace.stats.npts = npts3 merged_streams3 = TimeseriesUtility.merge_streams(timeseries3) - assert_equals(len(timeseries3),len(merged_streams3)) - + assert_equals(len(timeseries3), len(merged_streams3)) + assert_almost_equal( + timeseries3.select(channel='H')[0].data, + [1, 1, 1, 1]) + assert_equals( + numpy.isnan(timeseries3.select(channel='E')[0].data).all(), + True) + assert_equals( + numpy.isnan(timeseries3.select(channel='F')[0].data).all(), + True) -- GitLab