From 03d7ade1ec920d72062097fd20dcf99006e9e0e7 Mon Sep 17 00:00:00 2001
From: Jeremy Fee <jmfee@usgs.gov>
Date: Thu, 28 Jan 2016 14:55:07 -0700
Subject: [PATCH] Add additional sqdist arguments, initial saved state support

---
 geomagio/algorithm/SqDistAlgorithm.py | 156 ++++++++++++++++++++++++--
 1 file changed, 145 insertions(+), 11 deletions(-)

diff --git a/geomagio/algorithm/SqDistAlgorithm.py b/geomagio/algorithm/SqDistAlgorithm.py
index 99e279728..a10b9991a 100644
--- a/geomagio/algorithm/SqDistAlgorithm.py
+++ b/geomagio/algorithm/SqDistAlgorithm.py
@@ -12,11 +12,14 @@
         https://gist.github.com/andrequeiroz/5888967
 """
 
+from .. import StreamConverter
 from Algorithm import Algorithm
 from AlgorithmException import AlgorithmException
+import json
 import numpy as np
-from obspy.core import Stream
+from obspy.core import Stream, UTCDateTime
 from scipy.optimize import fmin_l_bfgs_b
+import sys
 
 
 class SqDistAlgorithm(Algorithm):
@@ -24,21 +27,103 @@ class SqDistAlgorithm(Algorithm):
 
     def __init__(self, alpha=None, beta=None, gamma=None, phi=1, m=1,
                  yhat0=None, s0=None, l0=None, b0=None, sigma0=None,
-                 zthresh=6, fc=0, hstep=0):
+                 zthresh=6, fc=0, hstep=0, statefile=None, mag=False):
         Algorithm.__init__(self, inchannels=None, outchannels=None)
         self.alpha = alpha
         self.beta = beta
         self.gamma = gamma
         self.phi = phi
         self.m = m
+        self.zthresh = zthresh
+        self.fc = fc
+        self.hstep = hstep
+        self.statefile = statefile
+        self.mag = mag
+        # state variables
         self.yhat0 = yhat0
         self.s0 = s0
         self.l0 = l0
         self.b0 = b0
         self.sigma0 = sigma0
-        self.zthresh = zthresh
-        self.fc = fc
-        self.hstep = hstep
+        self.last_observatory = None
+        self.last_channel = None
+        self.next_starttime = None
+        self.load_state()
+
+    def get_input_interval(self, start, end, observatory=None, channels=None):
+        """Get Input Interval
+
+        start : UTCDateTime
+            start time of requested output.
+        end : UTCDateTime
+            end time of requested output.
+        observatory : string
+            observatory code.
+        channels : string
+            input channels.
+
+        Returns
+        -------
+        input_start : UTCDateTime
+            start of input required to generate requested output
+        input_end : UTCDateTime
+            end of input required to generate requested output.
+        """
+        if self.mag:
+            channels = ('H')
+        if observatory == self.last_observatory \
+                and len(channels) == 1 \
+                and channels[0] == self.last_channel \
+                and start == self.next_starttime:
+            # state is up to date, only need new data
+            return (start, end)
+        # state not up to date, need to prime
+        return (start - 3 * 30 * 24 * 60 * 60, end)
+
+    def load_state(self):
+        """Load algorithm state from a file.
+
+        File name is self.statefile.
+        """
+        if self.statefile is None:
+            return
+        data = None
+        try:
+            with open(self.statefile, 'r') as f:
+                data = f.read()
+                data = json.loads(data)
+        except Exception, e:
+            pass
+        if data is None or data == '':
+            return
+        self.yhat0 = data['yhat0']
+        self.s0 = data['s0']
+        self.l0 = data['l0']
+        self.b0 = data['b0']
+        self.sigma0 = data['sigma0']
+        self.last_observatory = data['last_observatory']
+        self.last_channel = data['last_channel']
+        self.next_starttime = UTCDateTime(data['next_starttime'])
+
+    def save_state(self):
+        """Save algorithm state to a file.
+
+        File name is self.statefile.
+        """
+        if self.statefile is None:
+            return
+        data = {
+            'yhat0': list(self.yhat0),
+            's0': list(self.s0),
+            'l0': self.l0,
+            'b0': self.b0,
+            'sigma0': list(self.sigma0),
+            'last_observatory': self.last_observatory,
+            'last_channel': self.last_channel,
+            'next_starttime': str(self.next_starttime)
+        }
+        with open(self.statefile, 'w') as f:
+            f.write(json.dumps(data))
 
     def process(self, stream):
         """Run algorithm for a stream.
@@ -56,6 +141,19 @@ class SqDistAlgorithm(Algorithm):
             stream containing 3 traces per original trace.
         """
         out = Stream()
+
+        if self.mag:
+            # convert stream to mag
+            if stream.select(channel='H').count() > 0 \
+                    and stream.select(channel='E').count() > 0:
+                stream = StreamConverter.get_mag_from_obs(stream)
+            elif stream.select(channel='X').count() > 0 \
+                    and stream.select(channel='Y').count() > 0:
+                stream = StreamConverter.get_mag_from_geo(stream)
+            else:
+                raise AlgorithmException('Unable to convert to magnetic H')
+            stream = stream.select(channel='H')
+
         for trace in stream.traces:
             out += self.process_one(trace)
         return out
@@ -82,6 +180,20 @@ class SqDistAlgorithm(Algorithm):
                 channel_SV
         """
         out = Stream()
+        # check state
+        if self.last_observatory != None \
+                and self.last_channel != None \
+                and self.next_starttime != None:
+            # have state, verify okay to proceed
+            if trace.stats.station != self.last_observatory \
+                    or trace.stats.channel != self.last_channel \
+                    or trace.stats.starttime != self.next_starttime:
+                # state not correct, clear to be safe
+                self.yhat0 = None
+                self.s0 = None
+                self.l0 = None
+                self.b0 = None
+                self.sigma0 = None
         # process
         yhat, shat, sigmahat, yhat0, s0, l0, b0, sigma0 = self.additive(
                 yobs=trace.data,
@@ -104,6 +216,10 @@ class SqDistAlgorithm(Algorithm):
         self.l0 = l0
         self.b0 = b0
         self.sigma0 = sigma0
+        self.last_observatory = trace.stats.station
+        self.last_channel = trace.stats.channel
+        self.next_starttime = trace.stats.endtime + trace.stats.delta
+        self.save_state()
         # create updated traces
         channel = trace.stats.channel
         raw = trace.data
@@ -469,14 +585,28 @@ class SqDistAlgorithm(Algorithm):
             command line argument parser
         """
         parser.add_argument('--sqdist-alpha',
-                type=float,
-                help='SqDist alpha parameter')
+                default=1.0 / 1440.0 / 30,
+                help='SqDist alpha parameter',
+                type=float)
         parser.add_argument('--sqdist-beta',
-                type=float,
-                help='SqDist alpha parameter')
+                default=1.0 / 1440.0 / 30,
+                help='SqDist beta parameter',
+                type=float)
         parser.add_argument('--sqdist-gamma',
-                type=float,
-                help='SqDist alpha parameter')
+                default=1.0 / 30,
+                help='SqDist gamma parameter',
+                type=float)
+        parser.add_argument('--sqdist-m',
+                default=1440,
+                help='SqDist m parameter',
+                type=int)
+        parser.add_argument('--sqdist-mag',
+                action='store_true',
+                default=False,
+                help='Generate sqdist based on magnetic H component')
+        parser.add_argument('--sqdist-statefile',
+                default=None,
+                help='File to store state between calls to algorithm');
 
     def configure(self, arguments):
         """Configure algorithm using comand line arguments.
@@ -490,3 +620,7 @@ class SqDistAlgorithm(Algorithm):
         self.alpha = arguments.sqdist_alpha
         self.beta = arguments.sqdist_beta
         self.gamma = arguments.sqdist_gamma
+        self.m = arguments.sqdist_m
+        self.mag = arguments.sqdist_mag
+        self.statefile = arguments.sqdist_statefile
+        self.load_state()
-- 
GitLab