From 76953e5b2b954031936c501521c134374251170b Mon Sep 17 00:00:00 2001
From: pcain-usgs <pcain@usgs.gov>
Date: Thu, 4 Feb 2021 17:26:13 -0700
Subject: [PATCH] Add acausal to transform, create metric with matrix

---
 geomagio/adjusted/AdjustedMatrix.py      |  41 +-
 geomagio/adjusted/Affine.py              | 470 ++++++++++-------------
 geomagio/adjusted/Metric.py              |  13 +-
 geomagio/adjusted/Transform.py           |   0
 geomagio/adjusted/__init__.py            |   2 +
 geomagio/adjusted/transform/Transform.py |   7 +-
 test/adjusted_test/adjusted_test.py      |  46 ++-
 7 files changed, 277 insertions(+), 302 deletions(-)
 delete mode 100644 geomagio/adjusted/Transform.py

diff --git a/geomagio/adjusted/AdjustedMatrix.py b/geomagio/adjusted/AdjustedMatrix.py
index ff38990dd..4fe0ebce8 100644
--- a/geomagio/adjusted/AdjustedMatrix.py
+++ b/geomagio/adjusted/AdjustedMatrix.py
@@ -1,8 +1,9 @@
 import numpy as np
 from obspy import UTCDateTime
 from pydantic import BaseModel
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Tuple
 
+from .. import ChannelConverter
 from .. import pydantic_utcdatetime
 from .Metric import Metric
 
@@ -36,3 +37,41 @@ class AdjustedMatrix(BaseModel):
         else:
             adjusted = adjusted[0 : len(outchannels)]
         return adjusted
+
+    def set_metrics(
+        self,
+        ordinates: Tuple[List[float], List[float], List[float]],
+        absolutes: Tuple[List[float], List[float], List[float]],
+    ):
+        """Computes mean absolute error and standard deviation for X, Y, Z, and dF between expected and predicted values.
+
+        Attributes
+        ----------
+        absolutes: X, Y and Z absolutes
+        ordinates: H, E and Z ordinates
+        matrix: composed matrix
+
+        Outputs
+        -------
+        metrics: list of Metric objects
+        """
+        ordinates = np.vstack((ordinates, np.ones_like(ordinates[0])))
+        predicted = self.matrix @ ordinates
+        metrics = []
+        elements = ["X", "Y", "Z", "dF"]
+        expected = absolutes + tuple(
+            ChannelConverter.get_computed_f_using_squares(*absolutes)
+        )
+        predicted = predicted[0:3] + tuple(
+            ChannelConverter.get_computed_f_using_squares(*predicted[0:3])
+        )
+        for i in range(len(elements) - 1):
+            diff = expected[i] - predicted[i]
+            metrics.append(
+                Metric(
+                    element=elements[i],
+                    absmean=abs(np.nanmean(diff)),
+                    stddev=np.std(diff),
+                )
+            )
+        self.metrics = metrics
diff --git a/geomagio/adjusted/Affine.py b/geomagio/adjusted/Affine.py
index d99cf9acc..60de08f22 100644
--- a/geomagio/adjusted/Affine.py
+++ b/geomagio/adjusted/Affine.py
@@ -4,91 +4,11 @@ from obspy import UTCDateTime
 from pydantic import BaseModel
 from typing import List, Optional, Tuple
 
-from .. import ChannelConverter
 from .. import pydantic_utcdatetime
 from ..residual import Reading
 from .AdjustedMatrix import AdjustedMatrix
-from .Metric import Metric
-from .Transform import Transform, TranslateOrigins, RotationTranslationXY
 
-
-def weighted_quartile(data: List[float], weights: List[float], quant: float) -> float:
-    """Get weighted quartile to determine statistically good/bad data
-
-    Attributes
-    ----------
-    data: filtered array of observations
-    weights: array of vector distances/metrics
-    quant: statistical percentile of input data
-    """
-    # sort data and weights
-    ind_sorted = np.argsort(data)
-    sorted_data = data[ind_sorted]
-    sorted_weights = weights[ind_sorted]
-    # compute auxiliary arrays
-    Sn = np.cumsum(sorted_weights)
-    Pn = (Sn - 0.5 * sorted_weights) / Sn[-1]
-    # interpolate to weighted quantile
-    return np.interp(quant, Pn, sorted_data)
-
-
-def filter_iqr(
-    series: List[float], threshold: int = 6, weights: List[int] = None
-) -> List[int]:
-    """
-    Identify "good" elements in series by calculating potentially weighted
-    25%, 50% (median), and 75% quantiles of series, the number of 25%-50%
-    quantile ranges below, or 50%-75% quantile ranges above each value of
-    series falls from the median, and finally, setting elements of good to
-    True that fall within these multiples of quantile ranges.
-
-    NOTE: NumPy has a percentile function, but it does not yet handle
-          weights. This algorithm was adapted from the PyPI
-          package wquantiles (https://pypi.org/project/wquantiles/)
-
-    Inputs:
-    series: array of observations to filter
-
-    Options:
-    threshold: threshold in fractional number of 25%-50% (50%-75%)
-                quantile ranges below (above) the median each element of
-                series may fall and still be considered "good"
-                Default set to 6.
-    weights: weights to assign to each element of series. Default set to 1.
-
-    Output:
-    good: Boolean array where True values correspond to "good" data
-
-    """
-
-    if weights is None:
-        weights = np.ones_like(series)
-
-    # initialize good as all True for weights greater than 0
-    good = (weights > 0).astype(bool)
-    if np.size(good) <= 1:
-        # if a singleton is passed, assume it is "good"
-        return good
-
-    good_old = ~good
-    while not (good_old == good).all():
-        good_old = good
-
-        wq25 = weighted_quartile(series[good], weights[good], 0.25)
-        wq50 = weighted_quartile(series[good], weights[good], 0.50)
-        wq75 = weighted_quartile(series[good], weights[good], 0.75)
-
-        # NOTE: it is necessary to include good on the RHS here
-        #       to prevent oscillation between two equally likely
-        #       "optimal" solutions; this is a common problem with
-        #       expectation maximization algorithms
-        good = (
-            good
-            & (series >= (wq50 - threshold * (wq50 - wq25)))
-            & (series <= (wq50 + threshold * (wq75 - wq50)))
-        )
-
-    return good
+from .transform import RotationTranslationXY, TranslateOrigins, Transform
 
 
 class Affine(BaseModel):
@@ -107,11 +27,10 @@ class Affine(BaseModel):
     observatory: str = None
     starttime: UTCDateTime = UTCDateTime() - (86400 * 7)
     endtime: UTCDateTime = UTCDateTime()
-    acausal: bool = False
     update_interval: Optional[int] = 86400 * 7
     transforms: List[Transform] = [
-        RotationTranslationXY(memory=(86400 * 100)),
-        TranslateOrigins(memory=(86400 * 10)),
+        RotationTranslationXY(memory=(86400 * 100), acausal=True),
+        TranslateOrigins(memory=(86400 * 10), acausal=True),
     ]
 
     class Config:
@@ -139,7 +58,7 @@ class Affine(BaseModel):
         epochs = [r.time for r in all_readings if r.get_absolute("H").absolute == 0]
         while time < self.endtime:
             # update epochs for current time
-            epoch_start, epoch_end = self.get_epochs(
+            epoch_start, epoch_end = get_epochs(
                 epoch_start=epoch_start, epoch_end=epoch_end, epochs=epochs, time=time
             )
             # utilize readings that occur after or before a bad reading
@@ -151,105 +70,18 @@ class Affine(BaseModel):
             ]
             M = self.calculate_matrix(time, readings)
             # if readings are trimmed by bad data, mark the vakid interval
-            M.starttime = epoch_start
-            M.endtime = epoch_end
+            if M:
+                M.starttime = epoch_start
+                M.endtime = epoch_end
             time += update_interval
 
             Ms.append(M)
 
         return Ms
 
-    def get_epochs(
-        self,
-        epoch_start: float,
-        epoch_end: float,
-        epochs: List[float],
-        time: UTCDateTime,
-    ) -> Tuple[float, float]:
-        """Updates valid start/end time for a given interval
-
-        Attributes
-        ----------
-        epoch_start: float value signifying start of last valid interval
-        epoch_end: float value signifying end of last valid interval
-        epochs: list of floats signifying bad data times
-        time: current time epoch is being evaluated at
-
-        Outputs
-        -------
-        epoch_start: float value signifying start of current valid interval
-        epoch_end: float value signifying end of current valid interval
-        """
-        for e in epochs:
-            if e > time:
-                if epoch_end is None or e < epoch_end:
-                    epoch_end = e
-            if e < time:
-                if epoch_start is None or e > epoch_start:
-                    epoch_start = e
-        return epoch_start, epoch_end
-
-    def get_times(self, readings: List[UTCDateTime]):
-        return np.array([reading.get_absolute("H").endtime for reading in readings])
-
-    def get_ordinates(
-        self, readings: List[Reading]
-    ) -> Tuple[List[float], List[float], List[float]]:
-        """Calculates ordinates from absolutes and baselines"""
-        h_abs, d_abs, z_abs = self.get_absolutes(readings)
-        h_bas, d_bas, z_bas = self.get_baselines(readings)
-
-        # recreate ordinate variometer measurements from absolutes and baselines
-        h_ord = h_abs - h_bas
-        d_ord = d_abs - d_bas
-        z_ord = z_abs - z_bas
-
-        # WebAbsolutes defines/generates h differently than USGS residual
-        # method spreadsheets. The following should ensure that ordinate
-        # values are converted back to their original raw measurements:
-        e_o = h_abs * d_ord * 60 / 3437.7468
-        # TODO: is this handled in residual package?
-        if self.observatory in ["DED", "CMO"]:
-            h_o = np.sqrt(h_ord ** 2 - e_o ** 2)
-        else:
-            h_o = h_ord
-        z_o = z_ord
-        return (h_o, e_o, z_o)
-
-    def get_baselines(
-        self, readings: List[Reading]
-    ) -> Tuple[List[float], List[float], List[float]]:
-        """Get H, D and Z baselines"""
-        h_bas = np.array([reading.get_absolute("H").baseline for reading in readings])
-        d_bas = np.array([reading.get_absolute("D").baseline for reading in readings])
-        z_bas = np.array([reading.get_absolute("Z").baseline for reading in readings])
-        return (h_bas, d_bas, z_bas)
-
-    def get_absolutes(
-        self, readings: List[Reading]
-    ) -> Tuple[List[float], List[float], List[float]]:
-        """Get H, D and Z absolutes"""
-        h_abs = np.array([reading.get_absolute("H").absolute for reading in readings])
-        d_abs = np.array([reading.get_absolute("D").absolute for reading in readings])
-        z_abs = np.array([reading.get_absolute("Z").absolute for reading in readings])
-
-        return (h_abs, d_abs, z_abs)
-
-    def get_absolutes_xyz(
-        self, readings: List[Reading]
-    ) -> Tuple[List[float], List[float], List[float]]:
-        """Get X, Y and Z absolutes from H, D and Z baselines"""
-        h_abs, d_abs, z_abs = self.get_absolutes(readings)
-
-        # convert from cylindrical to Cartesian coordinates
-        x_a = h_abs * np.cos(d_abs * np.pi / 180)
-        y_a = h_abs * np.sin(d_abs * np.pi / 180)
-        z_a = z_abs
-        return (x_a, y_a, z_a)
-
     def calculate_matrix(
         self, time: UTCDateTime, readings: List[Reading]
-    ) -> AdjustedMatrix:
+    ) -> Optional[AdjustedMatrix]:
         """Calculates affine matrix for a given time
 
         Attributes
@@ -261,28 +93,24 @@ class Affine(BaseModel):
         -------
         AdjustedMatrix object containing result
         """
-        absolutes = self.get_absolutes_xyz(readings)
-        baselines = self.get_baselines(readings)
-        ordinates = self.get_ordinates(readings)
-        times = self.get_times(readings)
+        absolutes = get_absolutes_xyz(readings)
+        baselines = get_baselines(readings)
+        ordinates = get_ordinates(readings, self.observatory)
+        times = get_times(readings)
         Ms = []
         weights = []
         inputs = ordinates
 
         for transform in self.transforms:
-            weights = self.get_weights(
-                time=time,
+            weights = transform.get_weights(
+                time=time.timestamp,
                 times=times,
-                transform=transform,
             )
-            # return NaNs if no valid observations
-            if np.sum(weights) == 0:
-                return AdjustedMatrix(
-                    matrix=np.nan * np.ones((4, 4)),
-                    pier_correction=np.nan,
-                )
             # zero out statistically 'bad' baselines
-            weights = self.weight_baselines(baselines=baselines, weights=weights)
+            weights = filter_iqrs(multiseries=baselines, weights=weights)
+            # return None if no valid observations
+            if np.sum(weights) == 0:
+                return None
 
             M = transform.calculate(
                 ordinates=inputs, absolutes=absolutes, weights=weights
@@ -299,100 +127,196 @@ class Affine(BaseModel):
         pier_correction = np.average(
             [reading.pier_correction for reading in readings], weights=weights
         )
-        metrics = self.get_metrics(
-            absolutes=absolutes, ordinates=ordinates, matrix=M_composed
+        matrix = AdjustedMatrix(
+            matrix=M_composed,
+            pier_correction=pier_correction,
         )
+        matrix.set_metrics(absolutes=absolutes, ordinates=ordinates)
         return AdjustedMatrix(
             matrix=M_composed,
             pier_correction=pier_correction,
-            metrics=metrics,
         )
 
-    def get_metrics(
-        self, absolutes: List[float], ordinates: List[float], matrix: List[float]
-    ) -> Tuple[List[float], List[float], float, float]:
-        """Computes mean absolute error and standard deviation for X, Y, Z, and dF between expected and predicted values.
 
-        Attributes
-        ----------
-        absolutes: X, Y and Z absolutes
-        ordinates: H, E and Z ordinates
-        matrix: composed matrix
+def filter_iqr(
+    series: List[float], threshold: int = 3.0, weights: List[int] = None
+) -> List[int]:
+    """
+    Identify "good" elements in series by calculating potentially weighted
+    25%, 50% (median), and 75% quantiles of series, the number of 25%-50%
+    quantile ranges below, or 50%-75% quantile ranges above each value of
+    series falls from the median, and finally, setting elements of good to
+    True that fall within these multiples of quantile ranges.
 
-        Outputs
-        -------
-        metrics: list of Metric objects
-        """
-        ordinates = np.vstack((ordinates, np.ones_like(ordinates[0])))
-        predicted = matrix @ ordinates
-
-        channels = ["X", "Y", "Z"]
-        metrics = []
-        for i in range(len(channels)):
-            metric = Metric(element=channels[i])
-            metric.calculate(expected=absolutes[i], predicted=predicted[i])
-            metrics.append(metric)
-
-        expected_f = ChannelConverter.get_computed_f_using_squares(
-            absolutes[0], absolutes[1], absolutes[2]
-        )
-        predicted_f = ChannelConverter.get_computed_f_using_squares(
-            predicted[0], predicted[1], predicted[2]
-        )
-        df = ChannelConverter.get_deltaf(expected_f, predicted_f)
-        metrics.append(self.get_metrics_df(predicted=predicted, expected=absolutes))
-        return metrics
-
-    def get_metrics_df(self, predicted, expected):
-        """Computes mean absolute error and standard deviation for dF between expected and predicted values"""
-        expected_f = ChannelConverter.get_computed_f_using_squares(
-            expected[0], expected[1], expected[2]
-        )
-        predicted_f = ChannelConverter.get_computed_f_using_squares(
-            predicted[0], predicted[1], predicted[2]
-        )
-        df = ChannelConverter.get_deltaf(expected_f, predicted_f)
-        return Metric(
-            element="dF",
-            mae=abs(np.nanmean(df)),
-            std=np.nanstd(df),
-        )
+    NOTE: NumPy has a percentile function, but it does not yet handle
+          weights. This algorithm was adapted from the PyPI
+          package wquantiles (https://pypi.org/project/wquantiles/)
 
-    def get_weights(
-        self,
-        time: UTCDateTime,
-        times: List[UTCDateTime],
-        transform: Transform,
-    ) -> np.array:
-        """
+    Inputs:
+    series: array of observations to filter
 
-        Attributes
-        ----------
-        time: time within calculation interval
-        times: times of valid readings
-        transform: matrix calculation method
+    Options:
+    threshold: threshold in fractional number of 25%-50% (50%-75%)
+                quantile ranges below (above) the median each element of
+                series may fall and still be considered "good"
+                Default set to 6.
+    weights: weights to assign to each element of series. Default set to 1.
 
-        Outputs
-        -------
-        weights: array of weights to apply to absolutes/ordinates within calculations
-        """
+    Output:
+    good: Boolean array where True values correspond to "good" data
+
+    """
+
+    if weights is None:
+        weights = np.ones_like(series)
+
+    # initialize good as all True for weights greater than 0
+    good = (weights > 0).astype(bool)
+    if np.size(good) <= 1:
+        # if a singleton is passed, assume it is "good"
+        return good
+
+    good_old = ~good
+    while not (good_old == good).all():
+        good_old = good
 
-        weights = transform.get_weights(time=time.timestamp, times=times)
-        # set weights for future observations to zero if not acausal
-        if not self.acausal:
-            weights[times > time.timestamp] = 0.0
-        return weights
-
-    def weight_baselines(
-        self,
-        baselines: List[float],
-        weights: List[float],
-        threshold=3,
-    ) -> List[float]:
-        """Filters "bad" weights generated by unreliable readings"""
+        wq25 = weighted_quartile(series[good], weights[good], 0.25)
+        wq50 = weighted_quartile(series[good], weights[good], 0.50)
+        wq75 = weighted_quartile(series[good], weights[good], 0.75)
+
+        # NOTE: it is necessary to include good on the RHS here
+        #       to prevent oscillation between two equally likely
+        #       "optimal" solutions; this is a common problem with
+        #       expectation maximization algorithms
         good = (
-            filter_iqr(baselines[0], threshold=threshold, weights=weights)
-            & filter_iqr(baselines[1], threshold=threshold, weights=weights)
-            & filter_iqr(baselines[2], threshold=threshold, weights=weights)
+            good
+            & (series >= (wq50 - threshold * (wq50 - wq25)))
+            & (series <= (wq50 + threshold * (wq75 - wq50)))
         )
-        return weights * good
+
+    return good
+
+
+def filter_iqrs(
+    multiseries: List[List[float]],
+    weights: List[float],
+    threshold: float = 3.0,
+) -> List[float]:
+    """Filters "bad" weights generated by unreliable readings"""
+    good = (
+        filter_iqr(multiseries[0], threshold=threshold, weights=weights)
+        & filter_iqr(multiseries[1], threshold=threshold, weights=weights)
+        & filter_iqr(multiseries[2], threshold=threshold, weights=weights)
+    )
+    return weights * good
+
+
+def get_absolutes(
+    readings: List[Reading],
+) -> Tuple[List[float], List[float], List[float]]:
+    """Get H, D and Z absolutes"""
+    h_abs = np.array([reading.get_absolute("H").absolute for reading in readings])
+    d_abs = np.array([reading.get_absolute("D").absolute for reading in readings])
+    z_abs = np.array([reading.get_absolute("Z").absolute for reading in readings])
+
+    return (h_abs, d_abs, z_abs)
+
+
+def get_absolutes_xyz(
+    readings: List[Reading],
+) -> Tuple[List[float], List[float], List[float]]:
+    """Get X, Y and Z absolutes from H, D and Z baselines"""
+    h_abs, d_abs, z_abs = get_absolutes(readings)
+    # convert from cylindrical to Cartesian coordinates
+    x_a = h_abs * np.cos(np.radians(d_abs))
+    y_a = h_abs * np.sin(np.radians(d_abs))
+    z_a = z_abs
+    return (x_a, y_a, z_a)
+
+
+def get_baselines(
+    readings: List[Reading],
+) -> Tuple[List[float], List[float], List[float]]:
+    """Get H, D and Z baselines"""
+    h_bas = np.array([reading.get_absolute("H").baseline for reading in readings])
+    d_bas = np.array([reading.get_absolute("D").baseline for reading in readings])
+    z_bas = np.array([reading.get_absolute("Z").baseline for reading in readings])
+    return (h_bas, d_bas, z_bas)
+
+
+def get_epochs(
+    epoch_start: float,
+    epoch_end: float,
+    epochs: List[float],
+    time: UTCDateTime,
+) -> Tuple[float, float]:
+    """Updates valid start/end time for a given interval
+
+    Attributes
+    ----------
+    epoch_start: float value signifying start of last valid interval
+    epoch_end: float value signifying end of last valid interval
+    epochs: list of floats signifying bad data times
+    time: current time epoch is being evaluated at
+
+    Outputs
+    -------
+    epoch_start: float value signifying start of current valid interval
+    epoch_end: float value signifying end of current valid interval
+    """
+    for e in epochs:
+        if e > time:
+            if epoch_end is None or e < epoch_end:
+                epoch_end = e
+        if e < time:
+            if epoch_start is None or e > epoch_start:
+                epoch_start = e
+    return epoch_start, epoch_end
+
+
+def get_ordinates(
+    readings: List[Reading], observatory: str
+) -> Tuple[List[float], List[float], List[float]]:
+    """Calculates ordinates from absolutes and baselines"""
+    h_abs, d_abs, z_abs = get_absolutes(readings)
+    h_bas, d_bas, z_bas = get_baselines(readings)
+
+    # recreate ordinate variometer measurements from absolutes and baselines
+    h_ord = h_abs - h_bas
+    d_ord = d_abs - d_bas
+    z_ord = z_abs - z_bas
+
+    # WebAbsolutes defines/generates h differently than USGS residual
+    # method spreadsheets. The following should ensure that ordinate
+    # values are converted back to their original raw measurements:
+    e_o = h_abs * d_ord * 60 / 3437.7468
+    if observatory in ["DED", "CMO"]:
+        h_o = np.sqrt(h_ord ** 2 - e_o ** 2)
+    else:
+        h_o = h_ord
+    z_o = z_ord
+    return (h_o, e_o, z_o)
+
+
+def get_times(readings: List[UTCDateTime]):
+    return np.array([reading.get_absolute("H").endtime for reading in readings])
+
+
+def weighted_quartile(data: List[float], weights: List[float], quant: float) -> float:
+    """Get weighted quartile to determine statistically good/bad data
+
+    Attributes
+    ----------
+    data: filtered array of observations
+    weights: array of vector distances/metrics
+    quant: statistical percentile of input data
+    """
+    # sort data and weights
+    ind_sorted = np.argsort(data)
+    sorted_data = data[ind_sorted]
+    sorted_weights = weights[ind_sorted]
+    # compute auxiliary arrays
+    Sn = np.cumsum(sorted_weights)
+    Pn = (Sn - 0.5 * sorted_weights) / Sn[-1]
+    # interpolate to weighted quantile
+    return np.interp(quant, Pn, sorted_data)
diff --git a/geomagio/adjusted/Metric.py b/geomagio/adjusted/Metric.py
index ee946bb5c..f02e29791 100644
--- a/geomagio/adjusted/Metric.py
+++ b/geomagio/adjusted/Metric.py
@@ -8,15 +8,10 @@ class Metric(BaseModel):
     Attributes
     ----------
     element: Channel that metrics are representative of
-    mae: mean absolute error
-    std: standard deviation
+    absmean: mean absolute error
+    stddev: standard deviation
     """
 
     element: str
-    mae: float = None
-    std: float = None
-
-    def calculate(self, expected, predicted):
-        """Calculates mean absolute error and standard deviation between expected and predicted data"""
-        self.mae = abs(np.nanmean(expected - predicted))
-        self.std = np.nanstd(expected - predicted)
+    absmean: float = None
+    stddev: float = None
diff --git a/geomagio/adjusted/Transform.py b/geomagio/adjusted/Transform.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/geomagio/adjusted/__init__.py b/geomagio/adjusted/__init__.py
index d1ff5e331..e5a8c74a2 100644
--- a/geomagio/adjusted/__init__.py
+++ b/geomagio/adjusted/__init__.py
@@ -1,9 +1,11 @@
 from .AdjustedMatrix import AdjustedMatrix
 from .Affine import Affine
+from .Metric import Metric
 from .SpreadsheetSummaryFactory import SpreadsheetSummaryFactory
 
 __all__ = [
     "AdjustedMatrix",
     "Affine",
+    "Metric",
     "SpreadsheetSummaryFactory",
 ]
diff --git a/geomagio/adjusted/transform/Transform.py b/geomagio/adjusted/transform/Transform.py
index 305035e95..5fdbad7c6 100644
--- a/geomagio/adjusted/transform/Transform.py
+++ b/geomagio/adjusted/transform/Transform.py
@@ -9,10 +9,12 @@ class Transform(BaseModel):
 
     Attributes
     ----------
-    memory: Controls impacts of measurements from the past.
+    acausal: if true, future readings are used in calculations
+    memory: Controls impact of measurements from the past
     Defaults to infinite(equal weighting)
     """
 
+    acausal: bool = False
     memory: Optional[float] = None
 
     def get_weights(self, times: UTCDateTime, time: int = None) -> List[float]:
@@ -42,6 +44,9 @@ class Transform(BaseModel):
         weights[times <= time] = np.exp((times[times <= time] - time) / self.memory)
         weights[times >= time] = np.exp((time - times[times >= time]) / self.memory)
 
+        if not self.acausal:
+            weights[times > time] = 0.0
+
         return weights
 
     def calculate(
diff --git a/test/adjusted_test/adjusted_test.py b/test/adjusted_test/adjusted_test.py
index fb62d02e0..252ccc8ec 100644
--- a/test/adjusted_test/adjusted_test.py
+++ b/test/adjusted_test/adjusted_test.py
@@ -1,6 +1,6 @@
 import json
 import numpy as np
-from numpy.testing import assert_equal, assert_array_almost_equal, assert_array_equal
+from numpy.testing import assert_equal, assert_array_almost_equal
 from obspy.core import UTCDateTime
 
 from geomagio.adjusted import (
@@ -236,6 +236,10 @@ def test_BOU201911202001_short_causal():
         starttime=starttime,
         endtime=endtime,
         update_interval=update_interval,
+        transforms=[
+            RotationTranslationXY(memory=(86400 * 100), acausal=False),
+            TranslateOrigins(memory=(86400 * 10), acausal=False),
+        ],
     ).calculate(readings=readings)
 
     matrices = format_result([adjusted_matrix.matrix for adjusted_matrix in result])
@@ -263,7 +267,10 @@ def test_BOU201911202001_short_acausal():
         starttime=starttime,
         endtime=endtime,
         update_interval=update_interval,
-        acausal=True,
+        transforms=[
+            RotationTranslationXY(memory=(86400 * 100), acausal=True),
+            TranslateOrigins(memory=(86400 * 10), acausal=True),
+        ],
     ).calculate(
         readings=readings,
     )
@@ -293,10 +300,9 @@ def test_BOU201911202001_infinite_weekly():
         starttime=starttime,
         endtime=endtime,
         update_interval=update_interval,
-        acausal=True,
         transforms=[
-            RotationTranslationXY(memory=np.inf),
-            TranslateOrigins(memory=np.inf),
+            RotationTranslationXY(memory=np.inf, acausal=True),
+            TranslateOrigins(memory=np.inf, acausal=True),
         ],
     ).calculate(
         readings=readings,
@@ -320,10 +326,9 @@ def test_BOU201911202001_infinite_one_interval():
         observatory="BOU",
         starttime=UTCDateTime("2019-11-01T00:00:00Z"),
         endtime=UTCDateTime("2020-01-31T23:59:00Z"),
-        acausal=True,
         transforms=[
-            RotationTranslationXY(memory=np.inf),
-            TranslateOrigins(memory=np.inf),
+            RotationTranslationXY(memory=np.inf, acausal=True),
+            TranslateOrigins(memory=np.inf, acausal=True),
         ],
         update_interval=None,
     ).calculate(
@@ -348,15 +353,13 @@ def test_BOU201911202001_invalid_readings():
         observatory="BOU",
         starttime=UTCDateTime("2019-11-01T00:00:00Z"),
         endtime=UTCDateTime("2020-01-31T23:59:00Z"),
-        acausal=True,
         transforms=[
-            RotationTranslationXY(memory=np.inf),
-            TranslateOrigins(memory=np.inf),
+            RotationTranslationXY(memory=np.inf, acausal=True),
+            TranslateOrigins(memory=np.inf, acausal=True),
         ],
         update_interval=None,
     ).calculate(readings=readings,)[0]
-    assert_array_equal(result.matrix, np.nan * np.ones((4, 4)))
-    assert_equal(result.pier_correction, np.nan)
+    assert result is None
 
 
 def test_CMO2015_causal():
@@ -378,6 +381,10 @@ def test_CMO2015_causal():
         starttime=starttime,
         endtime=endtime,
         update_interval=update_interval,
+        transforms=[
+            RotationTranslationXY(memory=(86400 * 100), acausal=False),
+            TranslateOrigins(memory=(86400 * 10), acausal=False),
+        ],
     ).calculate(
         readings=readings,
     )
@@ -413,7 +420,10 @@ def test_CMO2015_acausal():
         starttime=starttime,
         endtime=endtime,
         update_interval=update_interval,
-        acausal=True,
+        transforms=[
+            RotationTranslationXY(memory=(86400 * 100), acausal=True),
+            TranslateOrigins(memory=(86400 * 10), acausal=True),
+        ],
     ).calculate(
         readings=readings,
     )
@@ -449,8 +459,8 @@ def test_CMO2015_infinite_weekly():
         starttime=starttime,
         endtime=endtime,
         transforms=[
-            RotationTranslationXY(memory=np.inf),
-            TranslateOrigins(memory=np.inf),
+            RotationTranslationXY(memory=np.inf, acausal=True),
+            TranslateOrigins(memory=np.inf, acausal=True),
         ],
         update_interval=update_interval,
         acausal=True,
@@ -485,8 +495,8 @@ def test_CMO2015_infinite_one_interval():
         starttime=UTCDateTime("2015-02-01T00:00:00Z"),
         endtime=UTCDateTime("2015-11-27T23:59:00Z"),
         transforms=[
-            RotationTranslationXY(memory=np.inf),
-            TranslateOrigins(memory=np.inf),
+            RotationTranslationXY(memory=np.inf, acausal=True),
+            TranslateOrigins(memory=np.inf, acausal=True),
         ],
         acausal=True,
         update_interval=None,
-- 
GitLab