From 9d5dc516176344d99a35969180b9c8753ded41fe Mon Sep 17 00:00:00 2001
From: Peter Powers <pmpowers@usgs.gov>
Date: Mon, 30 Sep 2024 13:22:58 -0600
Subject: [PATCH] updated model loader tests to cover decomposed curves

---
 .../earthquake/nshmp/model/LoaderTests.java   | 225 ++++++++++++------
 1 file changed, 158 insertions(+), 67 deletions(-)

diff --git a/src/test/java/gov/usgs/earthquake/nshmp/model/LoaderTests.java b/src/test/java/gov/usgs/earthquake/nshmp/model/LoaderTests.java
index 0a91d25b..9b8ab8bb 100644
--- a/src/test/java/gov/usgs/earthquake/nshmp/model/LoaderTests.java
+++ b/src/test/java/gov/usgs/earthquake/nshmp/model/LoaderTests.java
@@ -1,5 +1,9 @@
 package gov.usgs.earthquake.nshmp.model;
 
+import static gov.usgs.earthquake.nshmp.calc.HazardExport.CURVE_FILE;
+import static gov.usgs.earthquake.nshmp.calc.HazardExport.GMM_DIR;
+import static gov.usgs.earthquake.nshmp.calc.HazardExport.MAG_DIR;
+import static gov.usgs.earthquake.nshmp.calc.HazardExport.TYPE_DIR;
 import static java.lang.Math.abs;
 import static java.util.stream.Collectors.toMap;
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
@@ -11,10 +15,8 @@ import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.Arrays;
 import java.util.EnumMap;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Map.Entry;
 import java.util.OptionalDouble;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -31,8 +33,10 @@ import gov.usgs.earthquake.nshmp.calc.HazardCalcs;
 import gov.usgs.earthquake.nshmp.calc.HazardExport;
 import gov.usgs.earthquake.nshmp.calc.Site;
 import gov.usgs.earthquake.nshmp.calc.Sites;
+import gov.usgs.earthquake.nshmp.data.MutableXySequence;
 import gov.usgs.earthquake.nshmp.data.XySequence;
 import gov.usgs.earthquake.nshmp.geo.Location;
+import gov.usgs.earthquake.nshmp.gmm.Gmm;
 import gov.usgs.earthquake.nshmp.gmm.Imt;
 
 class LoaderTests {
@@ -51,14 +55,14 @@ class LoaderTests {
 
   static HazardModel model;
   static List<Site> sites;
-  static Map<Location, Map<Imt, XySequence>> expecteds;
+  static Expecteds expecteds;
   static ExecutorService exec;
 
   @BeforeAll
   static void setUpBeforeClass() throws IOException {
     model = ModelLoader.load(MODEL_PATH);
     sites = Sites.fromCsv(SITES_PATH, model.siteData(), OptionalDouble.empty());
-    expecteds = loadExpecteds();
+    expecteds = new Expecteds();
     int cores = Runtime.getRuntime().availableProcessors();
     exec = Executors.newFixedThreadPool(cores);
   }
@@ -72,21 +76,67 @@ class LoaderTests {
   @MethodSource("siteStream")
   final void testLocation(Site site) {
     Location loc = site.location();
-    Map<Imt, XySequence> expected = expecteds.get(loc);
+    // Map<Imt, XySequence> totalExpected = expecteds.totalCurves.get(loc);
     Hazard actual = HazardCalcs.hazard(model, model.config(), site, exec);
-    assertCurvesEqual(expected, actual, TOLERANCE);
+    assertCurvesEqual(loc, expecteds, actual, TOLERANCE);
   }
 
   private static Stream<Site> siteStream() {
     return sites.stream();
   }
 
-  private static void assertCurvesEqual(Map<Imt, XySequence> expected, Hazard actual, double tol) {
-    expected.entrySet().forEach(
-        e -> assertCurveEquals(
-            e.getValue(),
-            actual.curves().get(e.getKey()),
-            tol));
+  private static void assertCurvesEqual(
+      Location loc,
+      Expecteds expecteds,
+      Hazard actuals,
+      double tol) {
+
+    var gmmActualsImtMap = HazardExport.curvesByGmm(actuals);
+    var typeActualsImtMap = HazardExport.curvesBySource(actuals);
+
+    for (Imt imt : expecteds.totalCurves.keySet()) {
+
+      // total curve
+      XySequence totalExpected = expecteds.totalCurves.get(imt).get(loc);
+      XySequence totalActual = actuals.curves().get(imt);
+      assertCurveEquals(totalExpected, totalActual, tol);
+
+      /*
+       * For decomposed curves we need to loop actuals because expecteds will
+       * have additional zero-values curves for missing GMMs, source types and
+       * magnitudes.
+       */
+
+      // GMM curves
+      Map<Gmm, Map<Location, XySequence>> gmmExpecteds = expecteds.gmmCurves.get(imt);
+      Map<Gmm, ? extends XySequence> gmmActuals = gmmActualsImtMap.get(imt);
+      for (Gmm gmm : gmmActuals.keySet()) {
+        XySequence gmmExpected = gmmExpecteds.get(gmm).get(loc);
+        XySequence gmmActual = gmmActuals.get(gmm);
+        assertCurveEquals(gmmExpected, gmmActual, tol);
+      }
+
+      // Source type curves
+      Map<SourceType, Map<Location, XySequence>> typeExpecteds = expecteds.typeCurves.get(imt);
+      Map<SourceType, ? extends XySequence> typeActuals = typeActualsImtMap.get(imt);
+      for (SourceType type : typeActuals.keySet()) {
+        XySequence typeExpected = typeExpecteds.get(type).get(loc);
+        XySequence typeActual = typeActuals.get(type);
+        assertCurveEquals(typeExpected, typeActual, tol);
+      }
+
+      // Magnitude curves
+      Map<Double, XySequence> magExpecteds = expecteds.magCurves.get(imt).get(loc);
+      Map<Double, MutableXySequence> magActuals = actuals.magCurves().get(imt);
+      for (Double m : magActuals.keySet()) {
+        // if mag bin not used actuals will be null
+        // and expecteds will be array of zeros
+        if (magActuals.get(m) == null) {
+          continue;
+        }
+        assertCurveEquals(magExpecteds.get(m), magActuals.get(m), tol);
+      }
+    }
   }
 
   private static void assertCurveEquals(XySequence expected, XySequence actual, double tol) {
@@ -103,7 +153,7 @@ class LoaderTests {
     // y-value difference relative to tolerance
     assertArrayEquals(expectedYs, actualYs, tol);
 
-    // y-value difference relative to tolerance
+    // y-value ratio relative to tolerance
     for (int i = 0; i < expectedYs.length; i++) {
       String message = String.format(
           "arrays differ at [%s] expected:<[%s]> but was:<[%s]>",
@@ -117,84 +167,125 @@ class LoaderTests {
         Double.valueOf(expected).equals(Double.valueOf(actual));
   }
 
-  /* read curves then transpose */
-  private static Map<Location, Map<Imt, XySequence>> loadExpecteds() throws IOException {
-    // consider centralized curve file processing
-    Map<Imt, Map<Location, XySequence>> curves = Files.walk(RESULTS_PATH)
-        .filter(LoaderTests::isCurveFile)
-        .collect(toMap(
-            LoaderTests::imtFromPath,
-            LoaderTests::readCurves));
-    return transpose(curves);
-  }
+  private static class Expecteds {
+
+    Map<Imt, Map<Location, XySequence>> totalCurves = new EnumMap<>(Imt.class);
+    Map<Imt, Map<Gmm, Map<Location, XySequence>>> gmmCurves = new EnumMap<>(Imt.class);
+    Map<Imt, Map<SourceType, Map<Location, XySequence>>> typeCurves = new EnumMap<>(Imt.class);
+    Map<Imt, Map<Location, Map<Double, XySequence>>> magCurves = new EnumMap<>(Imt.class);
+
+    Expecteds() {
+      try {
+        Map<Imt, Path> imtDirs = Files.list(RESULTS_PATH)
+            .filter(Files::isDirectory)
+            .collect(toMap(
+                p -> Imt.valueOf(p.getFileName().toString()),
+                p -> p));
+        imtDirs.forEach(this::loadImtDir);
+      } catch (IOException ioe) {
+        throw new RuntimeException(ioe);
+      }
+    }
 
-  private static boolean isCurveFile(Path path) {
-    return path.getFileName().toString().equals("curves.csv");
-  }
+    void loadImtDir(Imt imt, Path path) {
+      try {
+
+        totalCurves.put(imt, readLocationCurves(path.resolve(CURVE_FILE)));
+
+        var imtGmmCurves = Files.list(path.resolve(GMM_DIR))
+            .filter(Files::isDirectory)
+            .collect(toMap(
+                p -> Gmm.valueOf(p.getFileName().toString()),
+                p -> readLocationCurves(p.resolve(CURVE_FILE))));
+        gmmCurves.put(imt, imtGmmCurves);
+
+        var imtTypeCurves = Files.list(path.resolve(TYPE_DIR))
+            .filter(Files::isDirectory)
+            .collect(toMap(
+                p -> SourceType.valueOf(p.getFileName().toString()),
+                p -> readLocationCurves(p.resolve(CURVE_FILE))));
+        typeCurves.put(imt, imtTypeCurves);
+
+        var imtMagCurves = Files.list(path.resolve(MAG_DIR))
+            .filter(p -> p.getFileName().toString().endsWith(".csv"))
+            .collect(toMap(
+                p -> readLocation(p),
+                p -> readMagnitudeCurves(p)));
 
-  private static Imt imtFromPath(Path path) {
-    return Imt.valueOf(path.getParent().getFileName().toString());
+        magCurves.put(imt, imtMagCurves);
+
+      } catch (IOException ioe) {
+        throw new RuntimeException(ioe);
+      }
+    }
   }
 
-  private static Map<Location, XySequence> readCurves(Path path) {
+  private static Map<Location, XySequence> readLocationCurves(Path path) {
+    int offset = 3;
     try {
       List<String> lines = Files.readAllLines(path);
-      double[] imls = toValues(lines.get(0));
+      double[] imls = readValues(lines.get(0), offset);
       return lines.stream()
           .skip(1)
-          .map(line -> toCurve(line, imls))
           .collect(toMap(
-              Entry::getKey,
-              Entry::getValue));
+              line -> readLocation(line),
+              line -> XySequence.create(imls, readValues(line, 3))));
     } catch (IOException ioe) {
       throw new RuntimeException(ioe);
     }
   }
 
-  private static Entry<Location, XySequence> toCurve(String line, double[] xs) {
-    Site site = toSite(line);
-    double[] ys = toValues(line);
-    return Map.entry(site.location(), XySequence.create(xs, ys));
+  /* Read a location from a line string. */
+  private static Location readLocation(String line) {
+    String[] s = line.split(",");
+    return Location.create(
+        Double.parseDouble(s[1]),
+        Double.valueOf(s[2]));
+  }
+
+  /* Read a location from a CSV filename. */
+  private static Location readLocation(Path path) {
+    String f = path.getFileName().toString();
+    String[] s = f.substring(0, f.length() - 4).split(",");
+    return Location.create(
+        Double.parseDouble(s[1]),
+        Double.parseDouble(s[2]));
   }
 
-  private static Site toSite(String line) {
-    String[] s = Arrays.stream(line.split(","))
-        .map(String::trim)
-        .limit(3)
-        .toArray(String[]::new);
-    String name = s[0];
-    Location loc = Location.create(
-        Double.valueOf(s[1]),
-        Double.valueOf(s[2]));
-    return Site.builder()
-        .name(name)
-        .location(loc)
-        .build();
+  private static Map<Double, XySequence> readMagnitudeCurves(Path path) {
+    int offset = 1;
+    try {
+      List<String> lines = Files.readAllLines(path);
+      double[] imls = readValues(lines.get(0), offset);
+      Map<Double, XySequence> pp = lines.stream()
+          .skip(1)
+          .collect(toMap(
+              line -> Double.valueOf(line.substring(0, line.indexOf(","))),
+              line -> readMagnitudeCurve(line, imls, offset)));
+      return pp;
+    } catch (IOException ioe) {
+      throw new RuntimeException(ioe);
+    }
   }
 
-  private static double[] toValues(String line) {
+  private static XySequence readMagnitudeCurve(
+      String line,
+      double[] xs,
+      int offset) {
+
+    double[] ys = readValues(line, offset);
+    return XySequence.create(xs, ys);
+  }
+
+  /* Read values from a comma-delimited string. */
+  private static double[] readValues(String line, int offset) {
     return Arrays.stream(line.split(","))
         .map(String::trim)
-        .skip(3)
+        .skip(offset)
         .mapToDouble(Double::parseDouble)
         .toArray();
   }
 
-  private static Map<Location, Map<Imt, XySequence>> transpose(
-      Map<Imt, Map<Location, XySequence>> mapIn) {
-
-    Map<Location, Map<Imt, XySequence>> mapOut = new HashMap<>();
-    for (Entry<Imt, Map<Location, XySequence>> imtEntry : mapIn.entrySet()) {
-      Imt imt = imtEntry.getKey();
-      for (Entry<Location, XySequence> locEntry : imtEntry.getValue().entrySet()) {
-        Location loc = locEntry.getKey();
-        XySequence xy = locEntry.getValue();
-        mapOut.computeIfAbsent(loc, k -> new EnumMap<>(Imt.class)).put(imt, xy);
-      }
-    }
-    return mapOut;
-  }
-
   public static void main(String[] args) throws IOException {
     model = ModelLoader.load(MODEL_PATH);
     List<Site> sites = Sites.fromCsv(SITES_PATH, model.siteData(), OptionalDouble.empty());
-- 
GitLab