From 3577fbf906a9d556cbf8b67e081c09dade3697b5 Mon Sep 17 00:00:00 2001
From: Peter Powers <pmpowers@usgs.gov>
Date: Wed, 30 Mar 2022 14:39:29 -0600
Subject: [PATCH] parameterized tests

---
 .../nshmp/model/NshmTestsLarge.java           | 101 +++++++++---------
 1 file changed, 53 insertions(+), 48 deletions(-)

diff --git a/src/test/java/gov/usgs/earthquake/nshmp/model/NshmTestsLarge.java b/src/test/java/gov/usgs/earthquake/nshmp/model/NshmTestsLarge.java
index 4427652ac..380c553a8 100644
--- a/src/test/java/gov/usgs/earthquake/nshmp/model/NshmTestsLarge.java
+++ b/src/test/java/gov/usgs/earthquake/nshmp/model/NshmTestsLarge.java
@@ -1,5 +1,9 @@
 package gov.usgs.earthquake.nshmp.model;
 
+import static gov.usgs.earthquake.nshmp.gmm.Imt.PGA;
+import static gov.usgs.earthquake.nshmp.gmm.Imt.SA0P2;
+import static gov.usgs.earthquake.nshmp.gmm.Imt.SA1P0;
+import static gov.usgs.earthquake.nshmp.gmm.Imt.SA5P0;
 import static gov.usgs.earthquake.nshmp.site.NshmpSite.BOSTON_MA;
 import static gov.usgs.earthquake.nshmp.site.NshmpSite.CHICAGO_IL;
 import static gov.usgs.earthquake.nshmp.site.NshmpSite.LOS_ANGELES_CA;
@@ -11,7 +15,6 @@ import static gov.usgs.earthquake.nshmp.site.NshmpSite.SAN_FRANCISCO_CA;
 import static gov.usgs.earthquake.nshmp.site.NshmpSite.SEATTLE_WA;
 import static java.lang.Math.abs;
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 import java.io.BufferedReader;
@@ -20,16 +23,20 @@ import java.lang.reflect.Type;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
+import java.util.EnumSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
 
 import com.google.common.reflect.TypeToken;
 import com.google.gson.Gson;
@@ -38,10 +45,13 @@ import com.google.gson.JsonObject;
 import com.google.gson.JsonParser;
 
 import gov.usgs.earthquake.nshmp.NamedLocation;
+import gov.usgs.earthquake.nshmp.calc.CalcConfig;
 import gov.usgs.earthquake.nshmp.calc.Hazard;
 import gov.usgs.earthquake.nshmp.calc.HazardCalcs;
 import gov.usgs.earthquake.nshmp.calc.Site;
 import gov.usgs.earthquake.nshmp.data.XySequence;
+import gov.usgs.earthquake.nshmp.geo.Location;
+import gov.usgs.earthquake.nshmp.gmm.Imt;
 
 /**
  * Class for end-to-end tests of hazard calculations. These tests require
@@ -67,61 +77,59 @@ class NshmTestsLarge {
       NEW_YORK_NY,
       CHICAGO_IL);
 
-  /*
-   * These tests use project relative file paths to read/write directly to/from
-   * the source tree.
-   */
+  private static final Set<Imt> IMTS = EnumSet.of(PGA, SA0P2, SA1P0, SA5P0);
+
+  private static final String MODEL_NAME = "nshm-conus";
+  private static final int MODEL_YEAR = 2018;
+  private static final Path MODEL_PATH = Paths.get("../" + MODEL_NAME);
+  private static final Path DATA_PATH = Paths.get("src/test/resources/e2e");
+
   private static final Gson GSON = new GsonBuilder()
       .setPrettyPrinting()
       .create();
 
-  private static ExecutorService EXEC;
+  private static ExecutorService exec;
+  private static HazardModel model;
+  static Map<Location, Map<String, XySequence>> expecteds;
 
   @BeforeAll
   static void setUpBeforeClass() {
+    model = ModelLoader.load(MODEL_PATH);
     int cores = Runtime.getRuntime().availableProcessors();
-    EXEC = Executors.newFixedThreadPool(cores);
+    exec = Executors.newFixedThreadPool(cores);
   }
 
   @AfterAll
   static void tearDownAfterClass() {
-    EXEC.shutdown();
+    exec.shutdown();
   }
 
-  private static final Path MODEL_PATH = Paths.get("../");
-  private static final Path DATA_PATH = Paths.get("src/test/resources/e2e");
-
-  @Test
-  public void testConus2018() {
-    testModel("nshm-conus", 2018, CONUS_SITES);
+  @ParameterizedTest
+  @MethodSource("siteStream")
+  final void testLocation(NamedLocation site) {
+    compareCurves(site);
   }
 
-  private static void testModel(
-      String modelName,
-      int year,
-      List<NamedLocation> locations) {
-
-    Path modelPath = MODEL_PATH.resolve(modelName);
-    HazardModel model = ModelLoader.load(modelPath);
-    for (NamedLocation location : locations) {
-      compareCurves(modelName, year, model, location);
-    }
+  private static Stream<NamedLocation> siteStream() {
+    return CONUS_SITES.stream();
   }
 
-  private static void compareCurves(
-      String modelName,
-      int year,
-      HazardModel model,
-      NamedLocation location) {
+  private static void compareCurves(NamedLocation location) {
 
+    System.out.println(location);
     // String actual = generateActual(model, location);
-    Map<String, XySequence> actual = generateActual(model, location);
+    Map<String, XySequence> actual = generateActual(location);
     // String expected = readExpected(modelName, year, location);
-    Map<String, XySequence> expected = readExpected(modelName, year, location);
+    Map<String, XySequence> expected = readExpected(location);
     // assertEquals(expected, actual);
 
-    assertEquals(expected.keySet(), actual.keySet());
-    for (String key : expected.keySet()) {
+    // assertEquals(expected.keySet(), actual.keySet());
+    for (String key : actual.keySet()) {
+      System.out.println(key);
+      System.out.println(actual.get(key));
+    }
+    for (String key : actual.keySet()) {
+      System.out.println(key);
       assertCurveEquals(expected.get(key), actual.get(key), TOLERANCE);
     }
   }
@@ -153,17 +161,19 @@ class NshmTestsLarge {
         Double.valueOf(expected).equals(Double.valueOf(actual));
   }
 
-  private static Map<String, XySequence> generateActual(
-      HazardModel model,
-      NamedLocation location) {
+  private static Map<String, XySequence> generateActual(NamedLocation location) {
 
     Site site = Site.builder().location(location.location()).build();
 
+    CalcConfig config = CalcConfig.copyOf(model.config())
+        .imts(IMTS)
+        .build();
+
     Hazard hazard = HazardCalcs.hazard(
         model,
-        model.config(),
+        config,
         site,
-        EXEC);
+        exec);
 
     Map<String, XySequence> xyMap = hazard.curves().entrySet().stream()
         .collect(Collectors.toMap(
@@ -171,7 +181,6 @@ class NshmTestsLarge {
             Entry::getValue));
 
     return xyMap;
-    // return GSON.toJson(hazard.curves());
   }
 
   private static String resultFilename(
@@ -182,12 +191,9 @@ class NshmTestsLarge {
     return modelName + "-" + year + "-" + loc.name() + ".json";
   }
 
-  private static Map<String, XySequence> readExpected(
-      String modelName,
-      int year,
-      NamedLocation loc) {
+  private static Map<String, XySequence> readExpected(NamedLocation loc) {
 
-    String filename = resultFilename(modelName, year, loc);
+    String filename = resultFilename(MODEL_NAME, MODEL_YEAR, loc);
     Path resultPath = DATA_PATH.resolve(filename);
 
     JsonObject obj = null;
@@ -210,6 +216,7 @@ class NshmTestsLarge {
     double[] xs;
     double[] ys;
 
+    @SuppressWarnings("unused")
     Curve(double[] xs, double[] ys) {
       this.xs = xs;
       this.ys = ys;
@@ -221,11 +228,9 @@ class NshmTestsLarge {
       int year,
       List<NamedLocation> locations) throws IOException {
 
-    Path modelPath = MODEL_PATH.resolve(modelName);
-    HazardModel model = ModelLoader.load(modelPath);
     for (NamedLocation location : locations) {
       // String json = generateActual(model, location);
-      Map<String, XySequence> xyMap = generateActual(model, location);
+      Map<String, XySequence> xyMap = generateActual(location);
       String json = GSON.toJson(xyMap);
       writeExpected(modelName, year, location, json);
     }
-- 
GitLab