From 7c5440b6bcc105c03443882b50494681705bb731 Mon Sep 17 00:00:00 2001
From: pcain-usgs <pcain@usgs.gov>
Date: Fri, 30 Apr 2021 14:31:22 -0600
Subject: [PATCH] add residual entrypoint, raise errors for missing
 measurements

---
 geomagio/api/ws/algorithms.py    | 36 +++++++++++++++++++++++++++++++-
 geomagio/residual/Calculation.py |  9 +++++---
 geomagio/residual/Measurement.py |  3 ++-
 3 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/geomagio/api/ws/algorithms.py b/geomagio/api/ws/algorithms.py
index 572045ba8..af957ee43 100644
--- a/geomagio/api/ws/algorithms.py
+++ b/geomagio/api/ws/algorithms.py
@@ -1,8 +1,17 @@
-from fastapi import APIRouter, Depends
+from typing import List
+
+from fastapi import APIRouter, Depends, HTTPException
 from starlette.responses import Response
 
 from ... import TimeseriesFactory
 from ...algorithm import DbDtAlgorithm
+from ...residual import (
+    calculate,
+    Reading,
+    MARK_TYPES,
+    INCLINATION_TYPES,
+    DECLINATION_TYPES,
+)
 from .DataApiQuery import DataApiQuery
 from .data import format_timeseries, get_data_factory, get_data_query, get_timeseries
 
@@ -25,3 +34,28 @@ def get_dbdt(
     return format_timeseries(
         timeseries=timeseries, format=query.format, elements=elements
     )
+
+
+@router.post("/algorithms/residual", response_model=Reading)
+def calculate_residual(reading: Reading, adjust_reference: bool = True):
+    missing_types = get_missing_measurement_types(reading=reading)
+    if len(missing_types) != 0:
+        error_message = ", ".join(t.value for t in missing_types)
+        raise HTTPException(
+            status_code=400,
+            detail=f"Missing {error_message} measurements in input reading",
+        )
+    return calculate(reading=reading, adjust_reference=adjust_reference)
+
+
+def get_missing_measurement_types(reading: Reading) -> List[str]:
+    measurement_types = [m.measurement_type for m in reading.measurements]
+    missing_types = []
+    missing_types.extend(
+        [type for type in DECLINATION_TYPES if type not in measurement_types]
+    )
+    missing_types.extend(
+        [type for type in INCLINATION_TYPES if type not in measurement_types]
+    )
+    missing_types.extend([type for type in MARK_TYPES if type not in measurement_types])
+    return missing_types
diff --git a/geomagio/residual/Calculation.py b/geomagio/residual/Calculation.py
index ddc108a4d..c18fe228e 100644
--- a/geomagio/residual/Calculation.py
+++ b/geomagio/residual/Calculation.py
@@ -28,7 +28,10 @@ def calculate(reading: Reading, adjust_reference: bool = True) -> Reading:
     NOTE: rest of reading object is shallow copy.
     """
     # reference measurement, used to adjust absolutes
-    reference = reading[mt.WEST_DOWN][0]
+    try:
+        reference = adjust_reference and reading[mt.WEST_DOWN][0] or None
+    except:
+        raise ValueError(f"Missing {mt.WEST_DOWN.value} measurement")
     # calculate inclination
     inclination, f, i_mean = calculate_I(
         hemisphere=reading.hemisphere, measurements=reading.measurements
@@ -39,13 +42,13 @@ def calculate(reading: Reading, adjust_reference: bool = True) -> Reading:
         corrected_f=corrected_f,
         inclination=inclination,
         mean=i_mean,
-        reference=adjust_reference and reference or None,
+        reference=reference,
     )
     absoluteD, meridian = calculate_D_absolute(
         azimuth=reading.azimuth,
         h_baseline=absoluteH.baseline,
         measurements=reading.measurements,
-        reference=adjust_reference and reference or None,
+        reference=reference,
     )
     # populate diagnostics object with averaged measurements
     diagnostics = Diagnostics(
diff --git a/geomagio/residual/Measurement.py b/geomagio/residual/Measurement.py
index f8be19437..bbf2f420a 100644
--- a/geomagio/residual/Measurement.py
+++ b/geomagio/residual/Measurement.py
@@ -54,7 +54,8 @@ def average_measurement(
         measurements = [m for m in measurements if m.measurement_type in types]
     if len(measurements) == 0:
         # no measurements to average
-        return None
+        error_message = ", ".join(t.value for t in types)
+        raise ValueError(f"Missing {error_message} measurements")
     starttime = safe_min([m.time.timestamp for m in measurements if m.time])
     endtime = safe_max([m.time.timestamp for m in measurements if m.time])
     measurement = AverageMeasurement(
-- 
GitLab