From 4ccf0508f63b46d3388935b2ff79d1755ea30a3f Mon Sep 17 00:00:00 2001
From: Peter Powers <pmpowers@usgs.gov>
Date: Fri, 17 Sep 2021 13:05:30 -0600
Subject: [PATCH] model data serialization and source tree refactors

---
 .../earthquake/nshmp/model/HazardModel.java   |  26 +--
 .../earthquake/nshmp/model/ModelLoader.java   |   4 +-
 .../usgs/earthquake/nshmp/model/Models.java   | 106 +++++++++++--
 .../earthquake/nshmp/model/SourceTree.java    | 150 ++++++++++--------
 .../usgs/earthquake/nshmp/model/Trees.java    |  12 ++
 5 files changed, 205 insertions(+), 93 deletions(-)

diff --git a/src/main/java/gov/usgs/earthquake/nshmp/model/HazardModel.java b/src/main/java/gov/usgs/earthquake/nshmp/model/HazardModel.java
index 4f8c98fc..b9d5a390 100644
--- a/src/main/java/gov/usgs/earthquake/nshmp/model/HazardModel.java
+++ b/src/main/java/gov/usgs/earthquake/nshmp/model/HazardModel.java
@@ -340,9 +340,9 @@ public final class HazardModel implements Iterable<SourceSet<? extends Source>>
 
       /* SYSTEM: each leaf is a SystemSourceSet */
       if (tree.type() == SourceType.FAULT_SYSTEM) {
-        for (Leaf leaf : tree.branchMap().keySet()) {
+        for (Leaf leaf : tree.branches().keySet()) {
           RuptureSet rs = leaf.ruptureSet;
-          double branchWeight = tree.branchWeights().get(leaf);
+          double branchWeight = tree.leaves().get(leaf);
           SystemRuptureSet srs = (SystemRuptureSet) rs;
           SystemSourceSet sss = srs.createSourceSet(branchWeight);
           addSourceSet(sss);
@@ -363,14 +363,14 @@ public final class HazardModel implements Iterable<SourceSet<? extends Source>>
           .weight(1.0)
           .gmms(tree.gmms());
 
-      for (Leaf leaf : tree.branchMap().keySet()) {
+      for (Leaf leaf : tree.branches().keySet()) {
 
         RuptureSet rs = leaf.ruptureSet;
-        double branchWeight = tree.branchWeights().get(leaf);
+        double branchWeight = tree.leaves().get(leaf);
 
         if (leaf.ruptureSet.type() == SourceType.FAULT) {
 
-          double leafWeight = tree.branchWeights().get(leaf);
+          double leafWeight = tree.leaves().get(leaf);
           FaultRuptureSet frs = (FaultRuptureSet) rs;
 
           /*
@@ -463,8 +463,8 @@ public final class HazardModel implements Iterable<SourceSet<? extends Source>>
           .weight(1.0)
           .gmms(tree.gmms());
 
-      for (Leaf leaf : tree.branchMap().keySet()) {
-        double leafWeight = tree.branchWeights().get(leaf);
+      for (Leaf leaf : tree.branches().keySet()) {
+        double leafWeight = tree.leaves().get(leaf);
         InterfaceRuptureSet irs = (InterfaceRuptureSet) leaf.ruptureSet;
         InterfaceSource is = interfaceRuptureSetToSource(irs, leafWeight);
         builder.source(is, leafWeight);
@@ -486,24 +486,24 @@ public final class HazardModel implements Iterable<SourceSet<? extends Source>>
     }
 
     private void zoneSourceSetsFromTree(SourceTree tree) {
-      for (Leaf leaf : tree.branchMap().keySet()) {
-        double leafWeight = tree.branchWeights().get(leaf);
+      for (Leaf leaf : tree.branches().keySet()) {
+        double leafWeight = tree.leaves().get(leaf);
         ZoneRuptureSet zrs = (ZoneRuptureSet) leaf.ruptureSet;
         addSourceSet(zrs.sourceSet(leafWeight));
       }
     }
 
     private void slabSourceSetsFromTree(SourceTree tree) {
-      for (Leaf leaf : tree.branchMap().keySet()) {
-        double leafWeight = tree.branchWeights().get(leaf);
+      for (Leaf leaf : tree.branches().keySet()) {
+        double leafWeight = tree.leaves().get(leaf);
         SlabRuptureSet srs = (SlabRuptureSet) leaf.ruptureSet;
         addSourceSet(srs.sourceSet(leafWeight));
       }
     }
 
     private void gridSourceSetsFromTree(SourceTree tree) {
-      for (Leaf leaf : tree.branchMap().keySet()) {
-        double leafWeight = tree.branchWeights().get(leaf);
+      for (Leaf leaf : tree.branches().keySet()) {
+        double leafWeight = tree.leaves().get(leaf);
         GridRuptureSet grs = (GridRuptureSet) leaf.ruptureSet;
         addSourceSet(grs.sourceSet(leafWeight));
       }
diff --git a/src/main/java/gov/usgs/earthquake/nshmp/model/ModelLoader.java b/src/main/java/gov/usgs/earthquake/nshmp/model/ModelLoader.java
index 7434e3ac..b3146f9c 100644
--- a/src/main/java/gov/usgs/earthquake/nshmp/model/ModelLoader.java
+++ b/src/main/java/gov/usgs/earthquake/nshmp/model/ModelLoader.java
@@ -101,8 +101,8 @@ abstract class ModelLoader {
    */
 
   public static void main(String[] args) {
-    // Path testModel = Paths.get("../nshm-conus-2018");
-    Path testModel = Paths.get("../nshm-hawaii");
+    Path testModel = Paths.get("../nshm-conus");
+    // Path testModel = Paths.get("../nshm-hawaii");
     HazardModel model = ModelLoader.load(testModel);
     System.out.println();
     System.out.println(model);
diff --git a/src/main/java/gov/usgs/earthquake/nshmp/model/Models.java b/src/main/java/gov/usgs/earthquake/nshmp/model/Models.java
index 4ee9ba8f..bf8dc998 100644
--- a/src/main/java/gov/usgs/earthquake/nshmp/model/Models.java
+++ b/src/main/java/gov/usgs/earthquake/nshmp/model/Models.java
@@ -8,11 +8,16 @@ import java.nio.file.Path;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map.Entry;
+import java.util.NoSuchElementException;
 import java.util.stream.Collectors;
 
 import com.google.gson.Gson;
 import com.google.gson.GsonBuilder;
 
+import gov.usgs.earthquake.nshmp.mfd.Mfd;
+import gov.usgs.earthquake.nshmp.tree.Branch;
+import gov.usgs.earthquake.nshmp.tree.LogicTree;
+
 /**
  * Factory class for querying source models.
  *
@@ -25,30 +30,52 @@ public class Models {
       .create();
 
   public static void main(String[] args) {
-    HazardModel model = ModelLoader.load(Path.of("../nshm-conus-2018"));
+    HazardModel model = ModelLoader.load(Path.of("../nshm-conus"));
     // HazardModel model = ModelLoader.load(Path.of("../nshm-hawaii"));
-    List<?> trees = trees(model);
-    String out = GSON.toJson(trees);
-    System.out.println(out);
+
+    // JSON of all tree info for a model
+    System.out.println(GSON.toJson(trees(model)));
+    System.out.println();
+
+    // Cascadia interface tree
+    SourceTree tree = model.tree(3199).orElseThrow();
+
+    // see what the list of nodes for each branch looks like
+    // we don't use this but it's good for illustrating
+    // internal structure of a SourceTree
+    tree.nodes().stream().forEach(System.out::println);
+    System.out.println();
+
+    // Serialized form of Object returned by tree() in this class
+    System.out.println(GSON.toJson(tree(model, 3199)));
+    System.out.println();
+
+    // 2799 is huge wasatch tree
   }
 
   /**
-   * Returns a serializable list of source logic tree groups.
+   * Returns an object for JSON serialization with the name and ID of all source
+   * logic tree groups in the supplied model organized by tectonic setting and
+   * source type.
    *
    * @param model to extract logic tree data from
    */
-  public static List<?> trees(HazardModel model) {
+  public static Object trees(HazardModel model) {
     return model.trees().asMap().entrySet().stream()
         .map(Models::toSettingGroup)
         .collect(toList());
   }
 
+  /**
+   * Returns an object for JSON serialization the details of a logic tree in the
+   * supplied model.
+   *
+   * @param model to query
+   * @param id of desired source tree
+   * @throws NoSuchElementException if the id is not present in the model
+   */
   public static Object tree(HazardModel model, int id) {
-    return model.tree(id).orElseThrow();
-  }
-
-  public static Object mfds(HazardModel model, int id) {
-    return null;
+    return model.tree(id).map(Tree::new).orElseThrow();
   }
 
   private static SettingGroup toSettingGroup(Entry<TectonicSetting, Collection<SourceTree>> entry) {
@@ -62,7 +89,7 @@ public class Models {
         .collect(groupingBy(
             SourceTree::type,
             mapping(
-                Tree::new,
+                TreeInfo::new,
                 Collectors.toList())))
         .entrySet().stream()
         .map(e -> new SourceGroup(e.getKey(), e.getValue()))
@@ -83,22 +110,73 @@ public class Models {
   static final class SourceGroup {
 
     final SourceType type;
-    final List<Tree> data;
+    final List<TreeInfo> data;
 
-    SourceGroup(SourceType type, List<Tree> data) {
+    SourceGroup(SourceType type, List<TreeInfo> data) {
       this.type = type;
       this.data = data;
     }
   }
 
+  static final class TreeInfo {
+
+    final int id;
+    final String name;
+
+    TreeInfo(SourceTree tree) {
+      this.id = tree.id();
+      this.name = tree.name();
+    }
+  }
+
   static final class Tree {
 
     final int id;
     final String name;
+    final TectonicSetting setting;
+    final SourceType type;
+    final List<SourceBranch> branches;
 
     Tree(SourceTree tree) {
       this.id = tree.id();
       this.name = tree.name();
+      this.setting = tree.setting();
+      this.type = tree.type();
+      this.branches = tree.branches().entrySet().stream()
+          .map(e -> new SourceBranch(
+              e.getValue().toString(),
+              tree.leaves().get(e.getKey()),
+              e.getKey().ruptureSet().mfdTree()))
+          .collect(toList());
+    }
+  }
+
+  static final class SourceBranch {
+
+    final String name;
+    final String path;
+    final double weight;
+    final List<MfdBranch> mfds;
+
+    SourceBranch(String path, double weight, LogicTree<Mfd> mfds) {
+      this.name = mfds.name();
+      this.path = path;
+      this.weight = weight;
+      this.mfds = mfds.stream()
+          .map(MfdBranch::new)
+          .collect(toList());
+    }
+  }
+
+  static final class MfdBranch {
+    final String id;
+    final double weight;
+    final Mfd mfd;
+
+    MfdBranch(Branch<Mfd> branch) {
+      this.id = branch.id();
+      this.weight = branch.weight();
+      this.mfd = branch.value();
     }
   }
 }
diff --git a/src/main/java/gov/usgs/earthquake/nshmp/model/SourceTree.java b/src/main/java/gov/usgs/earthquake/nshmp/model/SourceTree.java
index ceecbbed..1b94d8ae 100644
--- a/src/main/java/gov/usgs/earthquake/nshmp/model/SourceTree.java
+++ b/src/main/java/gov/usgs/earthquake/nshmp/model/SourceTree.java
@@ -4,14 +4,16 @@ import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 import static gov.usgs.earthquake.nshmp.Text.checkName;
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.Collectors.toUnmodifiableMap;
 
 import java.nio.file.Path;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Optional;
-import java.util.stream.Collectors;
 
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
@@ -55,8 +57,9 @@ public class SourceTree {
   private final GmmSet gmms;
 
   private final Network<Node, Branch<Path>> tree;
-  private final Map<Leaf, List<Branch<Path>>> leafBranches;
-  private final Map<Leaf, Double> leafWeights;
+  private final List<List<Node>> nodes;
+  private final Map<Leaf, List<Branch<Path>>> branches;
+  private final Map<Leaf, Double> leaves;
 
   SourceTree(Builder builder) {
     this.name = builder.name;
@@ -64,15 +67,17 @@ public class SourceTree {
     this.setting = builder.setting;
     this.type = builder.type;
     this.gmms = builder.gmms;
-    this.tree = ImmutableNetwork.copyOf(builder.tree);
-
-    this.leafBranches = builder.leafBranches;
-    this.leafWeights = builder.leafWeights;
 
-    // TODO clean - we do need this kind of output though (verbose logging?)
-    // for (Leaf leaf : leafBranches.keySet()) {
-    // System.out.println(leafBranches.get(leaf) + " " + leafWeights.get(leaf));
-    // }
+    this.tree = ImmutableNetwork.copyOf(builder.tree);
+    this.nodes = builder.nodes;
+    this.branches = builder.branches;
+    this.leaves = builder.leaves;
+
+    // List<String> branchStrs = branches.keySet().stream()
+    // .map(this::branchString)
+    // .sorted()
+    // .collect(toList());
+    // branchStrs.forEach(System.out::println);
   }
 
   /** The name of this tree. */
@@ -100,12 +105,24 @@ public class SourceTree {
     return gmms;
   }
 
-  public Map<Leaf, List<Branch<Path>>> branchMap() {
-    return leafBranches;
+  /** The leaf nodes of this tree mapped to their branches. */
+  public Map<Leaf, List<Branch<Path>>> branches() {
+    return branches;
+  }
+
+  /** The leaf nodes of this tree mapped to the total branch weight. */
+  public Map<Leaf, Double> leaves() {
+    return leaves;
   }
 
-  public Map<Leaf, Double> branchWeights() {
-    return leafWeights;
+  /**
+   * The list of {@code root --> node --> ... --> node --> leaf} node mappings
+   * that reflects the internal structure of this tree.
+   */
+  public List<String> nodes() {
+    return nodes.stream()
+        .map(Object::toString)
+        .collect(toList());
   }
 
   static Builder builder() {
@@ -128,11 +145,12 @@ public class SourceTree {
     private GmmSet gmms;
 
     private MutableNetwork<Node, Branch<Path>> tree;
-    private List<Leaf> leaves = new ArrayList<>();
+    private List<Leaf> leafList = new ArrayList<>();
 
     /* Created on build. */
-    private Map<Leaf, List<Branch<Path>>> leafBranches;
-    private Map<Leaf, Double> leafWeights;
+    private List<List<Node>> nodes;
+    private Map<Leaf, List<Branch<Path>>> branches;
+    private Map<Leaf, Double> leaves;
 
     private Builder() {}
 
@@ -210,8 +228,7 @@ public class SourceTree {
       tree.removeNode(target);
       Leaf leaf = new Leaf(target.index, ruptureSet); // transfer assigned index
       tree.addEdge(source, leaf, parent);
-
-      leaves.add(leaf);
+      leafList.add(leaf);
       return this;
     }
 
@@ -223,60 +240,65 @@ public class SourceTree {
       checkNotNull(type);
       checkNotNull(gmms);
       checkState(tree.nodes().size() > 0, "Empty tree");
-      List<Leaf> leafList = List.copyOf(leaves);
+      // List<Leaf> leafList = List.copyOf(leaves);
       checkState(leafList.size() > 0, "Leafless tree");
-      leafBranches = branchLists(tree, leafList);
-      leafWeights = leafWeights(leafBranches);
+      buildBranchesAndNodes();
+      // leafBranches = branchLists(tree, leaves);
+      // leafWeights = leafWeights(branches);
+      buildLeafWeights();
       checkState(!built, "Single use builder");
       built = true;
       return new SourceTree(this);
     }
-  }
 
-  /*
-   * On build: Generate lists of logic tree branches (branch paths) from the
-   * root of a tree to each of the supplied leaf nodes. Because the tree is a
-   * directed graph, traversing a transposed view starting with each leaf yields
-   * the required node path.
-   */
-  private static Map<Leaf, List<Branch<Path>>> branchLists(
-      Network<Node, Branch<Path>> tree,
-      List<Leaf> leaves) {
-
-    Traverser<Node> traverser = Traverser.forTree(Graphs.transpose(tree));
-    Map<Leaf, List<Branch<Path>>> branchListMap = new HashMap<>();
-    for (Leaf leaf : leaves) {
-      List<Node> nodeList = Lists.newArrayList(traverser.depthFirstPostOrder(leaf));
-      checkState(nodeList.size() > 1); // 2 nodes minimum [root -> leaf]
-      List<Branch<Path>> branchList = new ArrayList<>();
-      Node source = nodeList.get(0);
-      for (Node target : Iterables.skip(nodeList, 1)) {
-        branchList.add(tree.edgeConnecting(source, target).orElseThrow());
-        source = target;
+    /*
+     * On build: Generate lists of logic tree branches (branch paths) from the
+     * root of a tree to each of the supplied leaf nodes. Because the tree is a
+     * directed graph, traversing a transposed view starting with each leaf
+     * yields the required node path.
+     */
+    private void buildBranchesAndNodes() {
+
+      Traverser<Node> traverser = Traverser.forTree(Graphs.transpose(tree));
+      Map<Leaf, List<Branch<Path>>> branchListsMap = new HashMap<>();
+      List<List<Node>> nodeLists = new ArrayList<>();
+
+      for (Leaf leaf : leafList) {
+        List<Node> nodeList = Lists.newArrayList(traverser.depthFirstPostOrder(leaf));
+        checkState(nodeList.size() > 1); // 2 nodes minimum [root -> leaf]
+        nodeLists.add(nodeList);
+
+        List<Branch<Path>> branchList = new ArrayList<>();
+        Node source = nodeList.get(0);
+        for (Node target : Iterables.skip(nodeList, 1)) {
+          branchList.add(tree.edgeConnecting(source, target).orElseThrow());
+          source = target;
+        }
+        branchListsMap.put(leaf, List.copyOf(branchList));
       }
-      branchListMap.put(leaf, List.copyOf(branchList));
+      nodes = List.copyOf(nodeLists);
+      branches = Map.copyOf(branchListsMap);
     }
-    return Map.copyOf(branchListMap);
-  }
 
-  /*
-   * On build: Create map of leaves and their weights. Note that the values of
-   * the returned map will not sum to one when the source tree contains one or
-   * more LogicGroups.
-   */
-  private static Map<Leaf, Double> leafWeights(Map<Leaf, List<Branch<Path>>> branchLists) {
-    return branchLists.entrySet().stream()
-        .collect(Collectors.toUnmodifiableMap(
-            e -> e.getKey(),
-            e -> leafWeight(e.getValue())));
-  }
+    /*
+     * On build: Create map of leaves and their weights. Note that the values of
+     * the returned map will not sum to one when the source tree contains one or
+     * more LogicGroups.
+     */
+    private void buildLeafWeights() {
+      leaves = branches.entrySet().stream()
+          .collect(toUnmodifiableMap(
+              Entry::getKey,
+              e -> leafWeight(e.getValue())));
+    }
 
-  /* Compute cumulative weight of a source branch from root to leaf. */
-  private static double leafWeight(List<Branch<Path>> branchList) {
-    double weight = branchList.stream()
-        .mapToDouble(Branch::weight)
-        .reduce(1, (a, b) -> a * b);
-    return Maths.round(weight, 8);
+    /* Compute cumulative weight of a source branch from root to leaf. */
+    private static double leafWeight(List<Branch<Path>> branchList) {
+      double weight = branchList.stream()
+          .mapToDouble(Branch::weight)
+          .reduce(1, (a, b) -> a * b);
+      return Maths.round(weight, 8);
+    }
   }
 
   static class Node {
diff --git a/src/main/java/gov/usgs/earthquake/nshmp/model/Trees.java b/src/main/java/gov/usgs/earthquake/nshmp/model/Trees.java
index 53929950..3876a361 100644
--- a/src/main/java/gov/usgs/earthquake/nshmp/model/Trees.java
+++ b/src/main/java/gov/usgs/earthquake/nshmp/model/Trees.java
@@ -203,6 +203,18 @@ class Trees {
     return Mfds.combine(scaledMfdList(tree));
   }
 
+  static LogicTree<Mfd> reduceMfdListTree(LogicTree<List<Mfd>> mfdsTree) {
+    // TODO what is metadata for combined, e.g. grid, mfd
+    // is a value sum of recomputed a; we really want a
+    // to be the value from the rate tree
+    LogicTree.Builder<Mfd> mfdTree = LogicTree.builder(mfdsTree.name());
+    mfdsTree.forEach(branch -> mfdTree.addBranch(
+        branch.id(),
+        Mfds.combine(branch.value()),
+        branch.weight()));
+    return mfdTree.build();
+  }
+
   /* LogicTree<MFD> --> List<MFD * branchWeight> */
   static List<Mfd> scaledMfdList(LogicTree<Mfd> tree) {
     return tree.stream()
-- 
GitLab