From 659044d1d3329ffd03fea1e59a4e837eafd25ea9 Mon Sep 17 00:00:00 2001
From: Brandon Clayton <bclayton@usgs.gov>
Date: Wed, 10 Mar 2021 13:37:14 -0700
Subject: [PATCH] read in map file

---
 .../nshmp/netcdf/converters/convert_2018a.py  | 57 +++++++++++--------
 1 file changed, 33 insertions(+), 24 deletions(-)

diff --git a/src/main/python/gov/usgs/earthquake/nshmp/netcdf/converters/convert_2018a.py b/src/main/python/gov/usgs/earthquake/nshmp/netcdf/converters/convert_2018a.py
index 47034f1..f5e0ffd 100644
--- a/src/main/python/gov/usgs/earthquake/nshmp/netcdf/converters/convert_2018a.py
+++ b/src/main/python/gov/usgs/earthquake/nshmp/netcdf/converters/convert_2018a.py
@@ -1,3 +1,4 @@
+import json
 import os
 import shutil
 
@@ -5,6 +6,7 @@ from concurrent.futures import Future, ThreadPoolExecutor
 from dataclasses import dataclass
 from datetime import datetime
 from pathlib import Path
+from typing import Union
 
 import netCDF4 as netcdf
 import numpy as np
@@ -45,6 +47,8 @@ class Convert2018A:
 
         self._imt_indices: dict[Imt, int] = self._set_imt_indices()
         self._site_class_indices: dict[SiteClass, int] = self._set_site_class_indices()
+        self._latitude_indices: dict[float, int] = {}
+        self._longitude_indices: dict[float, int] = {}
 
         if os.getenv(_NETCDF_PATH_ENV):
             self._netcdf_filename = Path(os.getenv(_NETCDF_PATH_ENV))
@@ -69,8 +73,10 @@ class Convert2018A:
             [self._dimensions.lat.size, self._dimensions.lon.size], int
         )
         self._imt_mask_array = np.zeros([self._dimensions.lat.size, self._dimensions.lon.size], int)
-        self._site_class_data_array = np.full(
+
+        self._data_array = np.full(
             [
+                self._dimensions.site_class.size,
                 self._dimensions.imt.size,
                 self._dimensions.lat.size,
                 self._dimensions.lon.size,
@@ -103,13 +109,16 @@ class Convert2018A:
     def _create_netcdf_variables(self):
         iml_values, imt_values = self._set_imt_values()
         imt_enum_type = self._imt_enum_type()
+        geojson = self._get_map_file()
         grid_step = self.metadata.grid_step
         latitudes = self.metadata.locations.latitudes
         longitudes = self.metadata.locations.longitudes
+        parameters = NetcdfParameters()
         site_class_values = self._set_site_class_values()
         site_class_enum_type = self._site_class_enum_type()
 
-        parameters = NetcdfParameters()
+        if geojson is not None:
+            parameters.BoundsParameters().netcdf_variable(group=self._dataset_group, data=geojson)
         grid_mask_netcdf_var = parameters.GridMaskParameters().netcdf_variable(
             group=self._dataset_group
         )
@@ -161,8 +170,6 @@ class Convert2018A:
     def _get_hazard_data(
         self,
         hazard_netcdf_var: netcdf.Variable,
-        latitude_indices: dict[float, int],
-        longitude_indices: dict[float, int],
         netcdf_info: list[NetcdfInfo],
     ):
         futures: list[Future] = []
@@ -174,8 +181,6 @@ class Convert2018A:
                     executor.submit(
                         self._read_curves_file,
                         hazard_netcdf_var=hazard_netcdf_var,
-                        latitude_indices=latitude_indices,
-                        longitude_indices=longitude_indices,
                         netcdf_info=info,
                     )
                 )
@@ -200,6 +205,20 @@ class Convert2018A:
             target=longitude,
         )
 
+    def _get_map_file(self) -> Union[str, None]:
+        if self.metadata.database_info.map_file is None:
+            return None
+
+        for root, dirs, files in os.walk(self.metadata.database_info.data_path):
+            for file in files:
+                if self.metadata.database_info.map_file in file:
+                    map_file = Path(root, file)
+                    with open(map_file, "r") as json_reader:
+                        geojson = json.load(json_reader)
+                        return json.JSONEncoder().encode(geojson)
+
+        return None
+
     def _get_site_class_index(self, site_class: SiteClass):
         return self._site_class_indices.get(site_class)
 
@@ -217,8 +236,6 @@ class Convert2018A:
         self,
         netcdf_info: NetcdfInfo,
         hazard_netcdf_var: netcdf.Variable,
-        latitude_indices: dict[float, int],
-        longitude_indices: dict[float, int],
     ):
         curves_file = netcdf_info.curve_file
 
@@ -226,11 +243,12 @@ class Convert2018A:
             raise Exception(f"File ({curves_file}) not found")
 
         imt_dir = curves_file.parent
+        imt_index = self._get_imt_index(imt=netcdf_info.imt)
+        site_class_index = self._get_site_class_index(site_class=netcdf_info.site_class)
+        imls = self.metadata.imls.get(netcdf_info.imt)
         print(f"\t Converting [{imt_dir.parent.name}/{imt_dir.name}/{curves_file.name}]")
 
         with open(curves_file, "r") as curves_reader:
-            imls = self.metadata.imls.get(netcdf_info.imt)
-
             # Skip header
             next(curves_reader)
 
@@ -249,24 +267,19 @@ class Convert2018A:
                     while len(values) < len(imls):
                         values.append(0.0)
 
-                latitude_index = latitude_indices.setdefault(
+                latitude_index = self._latitude_indices.setdefault(
                     latitude, self._get_latitude_index(latitude=latitude)
                 )
 
-                longitude_index = longitude_indices.setdefault(
+                longitude_index = self._longitude_indices.setdefault(
                     longitude, self._get_longitude_index(longitude=longitude)
                 )
 
-                self._site_class_data_array[
-                    self._get_imt_index(imt=netcdf_info.imt), latitude_index, longitude_index, :
+                self._data_array[
+                    site_class_index, imt_index, latitude_index, longitude_index, :
                 ] = values
-
                 self._imt_mask_array[latitude_index, longitude_index] = 1
-
             self._site_class_mask_array += self._imt_mask_array
-            hazard_netcdf_var[
-                self._get_site_class_index(site_class=netcdf_info.site_class), :, :, :, :
-            ] = self._site_class_data_array
 
     def _set_imt_indices(self) -> dict[Imt, int]:
         imt_indices: dict[Imt, int] = dict()
@@ -333,16 +346,11 @@ class Convert2018A:
         )
 
     def _write_hazard_data(self, hazard_netcdf_var: netcdf.Variable):
-        latitude_indices: dict[float, int] = {}
-        longitude_indices: dict[float, int] = {}
-
         status_msg = f"[bold green]Converting {self.metadata.database_info.nshm.label} files ..."
         with console.status(status_msg, spinner="pong") as status:
             self._get_hazard_data(
                 netcdf_info=self.metadata.netcdf_info,
                 hazard_netcdf_var=hazard_netcdf_var,
-                latitude_indices=latitude_indices,
-                longitude_indices=longitude_indices,
             )
         status.stop()
 
@@ -359,4 +367,5 @@ class Convert2018A:
 
         grid_mask_netcdf_var, hazard_netcdf_var = self._create_netcdf_variables()
         self._write_hazard_data(hazard_netcdf_var=hazard_netcdf_var)
+        hazard_netcdf_var[:, :, :, :, :] = self._data_array
         grid_mask_netcdf_var[:, :] = self._site_class_mask_array
-- 
GitLab