From fca6f671b15a1b09241a5b53a5f4994d7af2f712 Mon Sep 17 00:00:00 2001
From: pcain-usgs <pcain@usgs.gov>
Date: Thu, 28 Jan 2021 09:15:49 -0700
Subject: [PATCH] Hold X, Y, Z, and dF mae/std

---
 geomagio/adjusted/AdjustedMatrix.py |  9 ++++++--
 geomagio/adjusted/Affine.py         | 35 ++++++++++++++++++++++++++++-
 geomagio/adjusted/Metric.py         |  7 ++++++
 test/adjusted_test/adjusted_test.py |  8 -------
 4 files changed, 48 insertions(+), 11 deletions(-)
 create mode 100644 geomagio/adjusted/Metric.py

diff --git a/geomagio/adjusted/AdjustedMatrix.py b/geomagio/adjusted/AdjustedMatrix.py
index 5913d2a5a..c8b8dcc64 100644
--- a/geomagio/adjusted/AdjustedMatrix.py
+++ b/geomagio/adjusted/AdjustedMatrix.py
@@ -1,8 +1,13 @@
 from obspy import UTCDateTime
 from pydantic import BaseModel
-from typing import Optional, Any, List
+from typing import (
+    Any,
+    List,
+    Optional,
+)
 
 from .. import pydantic_utcdatetime
+from .Metric import Metric
 
 
 class AdjustedMatrix(BaseModel):
@@ -20,6 +25,6 @@ class AdjustedMatrix(BaseModel):
 
     matrix: Any
     pier_correction: float
-    metrics: List[float] = [0.0, 0.0]
+    metrics: Optional[List[Metric]] = None
     starttime: Optional[UTCDateTime] = None
     endtime: Optional[UTCDateTime] = None
diff --git a/geomagio/adjusted/Affine.py b/geomagio/adjusted/Affine.py
index ffd395afd..7c53cf5e0 100644
--- a/geomagio/adjusted/Affine.py
+++ b/geomagio/adjusted/Affine.py
@@ -5,7 +5,9 @@ from pydantic import BaseModel
 from typing import List, Optional, Tuple
 
 from .AdjustedMatrix import AdjustedMatrix
+from .. import ChannelConverter
 from .. import pydantic_utcdatetime
+from .Metric import Metric
 from ..residual import Reading
 from .Transform import Transform, TranslateOrigins, RotationTranslationXY
 
@@ -273,13 +275,44 @@ class Affine(BaseModel):
 
         # compose affine transform matrices using reverse ordered matrices
         M_composed = reduce(np.dot, np.flipud(Ms))
+        absolutes = np.vstack((absolutes, np.ones_like(absolutes[0])))
+        ordinates = np.vstack((ordinates, np.ones_like(ordinates[0])))
         pier_correction = np.average(
             [reading.pier_correction for reading in readings], weights=weights
         )
+        std, mae, mae_df, std_df = self.compute_metrics(
+            absolutes=absolutes, ordinates=ordinates, matrix=M_composed
+        )
 
         return AdjustedMatrix(
-            matrix=M_composed, pier_correction=pier_correction, metrics=metrics
+            matrix=M_composed,
+            pier_correction=pier_correction,
+            metrics=[
+                Metric(element="X", mae=mae[0], std=std[0]),
+                Metric(element="Y", mae=mae[1], std=std[1]),
+                Metric(element="Z", mae=mae[2], std=std[2]),
+                Metric(element="dF", mae=mae_df, std=std_df),
+            ],
+        )
+
+    def compute_metrics(
+        self, absolutes: List[float], ordinates: List[float], matrix: List[float]
+    ) -> Tuple[List[float], List[float], float, float]:
+        # expected values are absolutes
+        predicted = matrix @ ordinates
+        # mean absolute erros and standard deviations ignore the 4th row comprison, which is trivial
+        std = np.nanstd(predicted - absolutes, axis=1)[0:3]
+        mae = abs(np.nanmean(predicted - absolutes, axis=1))[0:3]
+        expected_f = ChannelConverter.get_computed_f_using_squares(
+            absolutes[0], absolutes[1], absolutes[2]
+        )
+        predicted_f = ChannelConverter.get_computed_f_using_squares(
+            predicted[0], predicted[1], predicted[2]
         )
+        df = ChannelConverter.get_deltaf(expected_f, predicted_f)
+        std_df = abs(np.nanstd(df))
+        mae_df = abs(np.nanmean(df))
+        return list(std), list(mae), std_df, mae_df
 
     def get_weights(
         self,
diff --git a/geomagio/adjusted/Metric.py b/geomagio/adjusted/Metric.py
new file mode 100644
index 000000000..19ba74ada
--- /dev/null
+++ b/geomagio/adjusted/Metric.py
@@ -0,0 +1,7 @@
+from pydantic import BaseModel
+
+
+class Metric(BaseModel):
+    element: str
+    mae: float
+    std: float
diff --git a/test/adjusted_test/adjusted_test.py b/test/adjusted_test/adjusted_test.py
index eaebd4471..26047cccc 100644
--- a/test/adjusted_test/adjusted_test.py
+++ b/test/adjusted_test/adjusted_test.py
@@ -244,7 +244,6 @@ def test_BOU201911202001_short_causal():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
     assert_equal(len(matrices), ((endtime - starttime) // update_interval) + 1)
 
 
@@ -276,7 +275,6 @@ def test_BOU201911202001_short_acausal():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
     assert_equal(len(matrices), ((endtime - starttime) // update_interval) + 1)
 
 
@@ -312,7 +310,6 @@ def test_BOU201911202001_infinite_weekly():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
     assert_equal(len(matrices), ((endtime - starttime) // update_interval) + 1)
 
 
@@ -342,7 +339,6 @@ def test_BOU201911202001_infinite_one_interval():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
     assert_equal(len(matrices), 1)
 
 
@@ -379,7 +375,6 @@ def test_CMO2015_causal():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
     assert_equal(len(matrices), ((endtime - starttime) // update_interval) + 1)
 
 
@@ -417,7 +412,6 @@ def test_CMO2015_acausal():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
     assert_equal(len(matrices), ((endtime - starttime) // update_interval) + 1)
 
 
@@ -459,7 +453,6 @@ def test_CMO2015_infinite_weekly():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
     assert_equal(len(matrices), ((endtime - starttime) // update_interval) + 1)
 
 
@@ -497,6 +490,5 @@ def test_CMO2015_infinite_one_interval():
             decimal=3,
             err_msg=f"Matrix {i} not equal",
         )
-        assert_array_less(metrics[i], 5.0)
 
     assert_equal(len(matrices), 1)
-- 
GitLab