From 3e427dcfb77dbd9b871a49ef2a5848389cfe9189 Mon Sep 17 00:00:00 2001
From: pcain-usgs <pcain@usgs.gov>
Date: Thu, 21 Jan 2021 18:15:00 -0700
Subject: [PATCH] Refactor adjusted algorithm for AdjustedMatrix, test

---
 geomagio/algorithm/AdjustedAlgorithm.py       |  43 +++---
 test/algorithm_test/AdjustedAlgorithm_test.py | 136 +++++++++++++-----
 2 files changed, 124 insertions(+), 55 deletions(-)

diff --git a/geomagio/algorithm/AdjustedAlgorithm.py b/geomagio/algorithm/AdjustedAlgorithm.py
index 601bad244..e21220c54 100644
--- a/geomagio/algorithm/AdjustedAlgorithm.py
+++ b/geomagio/algorithm/AdjustedAlgorithm.py
@@ -1,23 +1,22 @@
-"""Algorithm that converts from one geomagnetic coordinate system to a
-    related geographic coordinate system, by using transformations generated
-    from absolute, baseline measurements.
-"""
-from __future__ import absolute_import
+import sys
 
-from .Algorithm import Algorithm
 import json
 import numpy as np
 from obspy.core import Stream, Stats
-import sys
+
+from ..adjusted.AdjustedMatrix import AdjustedMatrix
+from .Algorithm import Algorithm
 
 
 class AdjustedAlgorithm(Algorithm):
-    """Adjusted Data Algorithm"""
+    """Algorithm that converts from one geomagnetic coordinate system to a
+    related geographic coordinate system, by using transformations generated
+    from absolute, baseline measurements.
+    """
 
     def __init__(
         self,
-        matrix=None,
-        pier_correction=None,
+        matrix: AdjustedMatrix = None,
         statefile=None,
         data_type=None,
         location=None,
@@ -33,7 +32,6 @@ class AdjustedAlgorithm(Algorithm):
         )
         # state variables
         self.matrix = matrix
-        self.pier_correction = pier_correction
         self.statefile = statefile
         self.data_type = data_type
         self.location = location
@@ -45,12 +43,12 @@ class AdjustedAlgorithm(Algorithm):
         """Load algorithm state from a file.
         File name is self.statefile.
         """
-        # Adjusted matrix defaults to identity matrix
-        matrix_size = len([c for c in self.get_input_channels() if c != "F"]) + 1
-        self.matrix = np.eye(matrix_size)
-        self.pier_correction = 0
+        pier_correction = 0
         if self.statefile is None:
             return
+        # Adjusted matrix defaults to identity matrix
+        matrix_size = len([c for c in self.get_input_channels() if c != "F"]) + 1
+        matrix = np.eye(matrix_size)
         data = None
         try:
             with open(self.statefile, "r") as f:
@@ -62,8 +60,9 @@ class AdjustedAlgorithm(Algorithm):
             return
         for row in range(matrix_size):
             for col in range(matrix_size):
-                self.matrix[row, col] = np.float64(data[f"M{row+1}{col+1}"])
-        self.pier_correction = np.float64(data["PC"])
+                matrix[row, col] = np.float64(data[f"M{row+1}{col+1}"])
+        pier_correction = np.float64(data["PC"])
+        self.matrix = AdjustedMatrix(matrix=matrix, pier_correction=pier_correction)
 
     def save_state(self):
         """Save algorithm state to a file.
@@ -71,12 +70,12 @@ class AdjustedAlgorithm(Algorithm):
         """
         if self.statefile is None:
             return
-        data = {"PC": self.pier_correction}
-        length = len(self.matrix[0, :])
+        data = {"PC": self.matrix.pier_correction}
+        length = len(self.matrix.matrix[0, :])
         for i in range(0, length):
             for j in range(0, length):
                 key = "M" + str(i + 1) + str(j + 1)
-                data[key] = self.matrix[i, j]
+                data[key] = self.matrix.matrix[i, j]
         with open(self.statefile, "w") as f:
             f.write(json.dumps(data))
 
@@ -134,7 +133,7 @@ class AdjustedAlgorithm(Algorithm):
             ]
             + [np.ones_like(stream[0].data)]
         )
-        adjusted = np.matmul(self.matrix, raws)
+        adjusted = np.matmul(self.matrix.matrix, raws)
         out = Stream(
             [
                 self.create_trace(
@@ -147,7 +146,7 @@ class AdjustedAlgorithm(Algorithm):
         )
         if "F" in inchannels and "F" in outchannels:
             f = stream.select(channel="F")[0]
-            out += self.create_trace("F", f.stats, f.data + self.pier_correction)
+            out += self.create_trace("F", f.stats, f.data + self.matrix.pier_correction)
         return out
 
     def can_produce_data(self, starttime, endtime, stream):
diff --git a/test/algorithm_test/AdjustedAlgorithm_test.py b/test/algorithm_test/AdjustedAlgorithm_test.py
index f37f93f78..d44ae09e6 100644
--- a/test/algorithm_test/AdjustedAlgorithm_test.py
+++ b/test/algorithm_test/AdjustedAlgorithm_test.py
@@ -1,3 +1,4 @@
+from geomagio.adjusted.AdjustedMatrix import AdjustedMatrix
 from geomagio.algorithm import AdjustedAlgorithm as adj
 import geomagio.iaga2002 as i2
 from numpy.testing import assert_almost_equal, assert_equal
@@ -8,19 +9,53 @@ def test_construct():
     # load adjusted data transform matrix and pier correction
     a = adj(statefile="etc/adjusted/adjbou_state_.json")
 
-    assert_almost_equal(actual=a.matrix[0, 0], desired=9.83427577e-01, decimal=6)
+    assert_almost_equal(actual=a.matrix.matrix[0, 0], desired=9.83427577e-01, decimal=6)
 
-    assert_equal(actual=a.pier_correction, desired=-22)
+    assert_equal(actual=a.matrix.pier_correction, desired=-22)
 
 
-def test_process_XYZF():
+def assert_streams_almost_equal(adjusted, expected, channels):
+    for channel in channels:
+        assert_almost_equal(
+            actual=adjusted.select(channel=channel)[0].data,
+            desired=expected.select(channel=channel)[0].data,
+            decimal=2,
+        )
+
+
+def test_process_XYZF_AdjustedMatrix():
     """algorithm_test.AdjustedAlgorithm_test.test_process()
 
     Check adjusted data processing versus files generated from
     original script
     """
-    # load adjusted data transform matrix and pier correction
-    a = adj(statefile="etc/adjusted/adjbou_state_.json")
+    # Initiate algorithm with AdjustedMatrix object
+    a = adj(
+        matrix=AdjustedMatrix(
+            matrix=[
+                [
+                    0.9834275767090617,
+                    -0.15473074200902157,
+                    0.027384986324932026,
+                    -1276.164681191976,
+                ],
+                [
+                    0.16680172992706568,
+                    0.987916201012128,
+                    -0.0049868332295851525,
+                    -0.8458192581350419,
+                ],
+                [
+                    -0.006725053082782385,
+                    -0.011809351484171948,
+                    0.9961869012493976,
+                    905.3800885796844,
+                ],
+                [0, 0, 0, 1],
+            ],
+            pier_correction=-22,
+        )
+    )
 
     # load boulder Jan 16 files from /etc/ directory
     with open("etc/adjusted/BOU201601vmin.min") as f:
@@ -31,34 +66,77 @@ def test_process_XYZF():
     # process hezf (raw) channels with loaded transform
     adjusted = a.process(raw)
 
-    # compare channels from adjusted and expected streams
-    assert_almost_equal(
-        actual=adjusted.select(channel="X")[0].data,
-        desired=expected.select(channel="X")[0].data,
-        decimal=2,
+    assert_streams_almost_equal(
+        adjusted=adjusted, expected=expected, channels=["X", "Y", "Z", "F"]
     )
-    assert_almost_equal(
-        actual=adjusted.select(channel="Y")[0].data,
-        desired=expected.select(channel="Y")[0].data,
-        decimal=2,
+
+
+def test_process_reverse_polarity_AdjustedMatrix():
+    """algorithm_test.AdjustedAlgorithm_test.test_process()
+
+    Check adjusted data processing versus files generated from
+    original script. Tests reverse polarity martix.
+    """
+    # Initiate algorithm with AdjustedMatrix object
+    a = adj(
+        matrix=AdjustedMatrix(
+            matrix=[
+                [-1, 0, 0],
+                [0, -1, 0],
+                [0, 0, 1],
+            ],
+            pier_correction=-22,
+        ),
+        inchannels=["H", "E"],
+        outchannels=["H", "E"],
     )
-    assert_almost_equal(
-        actual=adjusted.select(channel="Z")[0].data,
-        desired=expected.select(channel="Z")[0].data,
-        decimal=2,
+
+    # load boulder May 20 files from /etc/ directory
+    with open("etc/adjusted/BOU202005vmin.min") as f:
+        raw = i2.IAGA2002Factory().parse_string(f.read())
+    with open("etc/adjusted/BOU202005adj.min") as f:
+        expected = i2.IAGA2002Factory().parse_string(f.read())
+
+    # process he(raw) channels with loaded transform
+    adjusted = a.process(raw)
+
+    assert_streams_almost_equal(
+        adjusted=adjusted, expected=expected, channels=["H", "E"]
     )
-    assert_almost_equal(
-        actual=adjusted.select(channel="F")[0].data,
-        desired=expected.select(channel="F")[0].data,
-        decimal=2,
+
+
+def test_process_XYZF_statefile():
+    """algorithm_test.AdjustedAlgorithm_test.test_process()
+
+    Check adjusted data processing versus files generated from
+    original script
+
+    Uses statefile to generate AdjustedMatrix
+    """
+    # load adjusted data transform matrix and pier correction
+    a = adj(statefile="etc/adjusted/adjbou_state_.json")
+
+    # load boulder Jan 16 files from /etc/ directory
+    with open("etc/adjusted/BOU201601vmin.min") as f:
+        raw = i2.IAGA2002Factory().parse_string(f.read())
+    with open("etc/adjusted/BOU201601adj.min") as f:
+        expected = i2.IAGA2002Factory().parse_string(f.read())
+
+    # process hezf (raw) channels with loaded transform
+    adjusted = a.process(raw)
+
+    assert_streams_almost_equal(
+        adjusted=adjusted, expected=expected, channels=["X", "Y", "Z", "F"]
     )
 
 
-def test_process_reverse_polarity():
+def test_process_reverse_polarity_statefile():
     """algorithm_test.AdjustedAlgorithm_test.test_process()
 
     Check adjusted data processing versus files generated from
     original script. Tests reverse polarity martix.
+
+    Uses statefile to generate AdjustedMatrix
     """
     # load adjusted data transform matrix and pier correction
     a = adj(
@@ -76,14 +154,6 @@ def test_process_reverse_polarity():
     # process he(raw) channels with loaded transform
     adjusted = a.process(raw)
 
-    # compare channels from adjusted and expected streams
-    assert_almost_equal(
-        actual=adjusted.select(channel="H")[0].data,
-        desired=expected.select(channel="H")[0].data,
-        decimal=2,
-    )
-    assert_almost_equal(
-        actual=adjusted.select(channel="E")[0].data,
-        desired=expected.select(channel="E")[0].data,
-        decimal=2,
+    assert_streams_almost_equal(
+        adjusted=adjusted, expected=expected, channels=["H", "E"]
     )
-- 
GitLab