Skip to content
Snippets Groups Projects

Convert to NetCDF

2 files
+ 11
24
Compare changes
  • Side-by-side
  • Inline
Files
2
import json
import os
import shutil
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Union
import netCDF4 as netcdf
import numpy as np
from rich.progress import BarColumn, Progress, TimeElapsedColumn
from ..application_inputs import ApplicationInputs
from ..database.database_info import NetcdfInfo, NetcdfMetadata, ScienceBaseMetadata
from ..gmm.imt import Imt
from ..gmm.site_class import SiteClass
from ..nshm import Nshm
from ..utils.console import console
from ..utils.netcdf_dimensions import NetcdfDimensions
from ..utils.netcdf_keys import NetcdfKeys
from ..utils.netcdf_parameters import NetcdfParameters
from ..utils.netcdf_utils import NetcdfUtils
_ROOT_ATTRIBUTES: dict = {
"description": "NSHMP Probabilistic Seismic Hazard Curves.",
"history": f"ASCII data converted ({datetime.now()})",
"nshmpNetcdfFormatDescription": (
"hazard(siteClass,imt,lat,lon,iml); siteClass and imt are "
+ "defined as custom enum data types."
),
}
_NETCDF_PATH_ENV = "NSHMP_NETCDF_FILE_PATH"
class Convert:
def __init__(self, inputs: ApplicationInputs, metadata: NetcdfMetadata):
self.metadata = metadata
self.inputs = inputs
nshm = self.metadata.database_info.nshm
region = self.metadata.model_region
if nshm != Nshm.NSHM_2018A:
raise ValueError(f"NSHM [{self.metadata.database_info.nshm.value}] not supported")
console.print(f"\n[blue]Converting {nshm.label}")
self._imt_indices: dict[Imt, int] = self._set_imt_indices()
self._site_class_indices: dict[SiteClass, int] = self._set_site_class_indices()
self._latitude_indices: dict[float, int] = {}
self._longitude_indices: dict[float, int] = {}
if os.getenv(_NETCDF_PATH_ENV):
self._netcdf_filename = Path(os.getenv(_NETCDF_PATH_ENV))
elif inputs.netcdf_filename is not None:
self._netcdf_filename = self.inputs.netcdf_filename
else:
self._netcdf_filename = Path(
metadata.database_info.database_directory.joinpath(
f"{metadata.database_info.nshm.value}_{metadata.model_region.name}.nc"
)
)
self._check_file()
self._root_group = netcdf.Dataset(
filename=self._netcdf_filename, mode="w", format="NETCDF4", clobber=True
)
self._dataset_group: netcdf.Group = self._root_group.createGroup(
f"{nshm.value}/{region.name}"
)
self._dimensions = self._create_dimensions(group=self._dataset_group)
self._site_class_mask_array = np.zeros(
[self._dimensions.lat.size, self._dimensions.lon.size], int
)
self._progress = Progress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeElapsedColumn(),
)
self._write_netcdf_file()
self._root_group.close()
self._clean_ascii()
console.print(
(
f"NetCDF conversion for NSHM ({metadata.database_info.nshm.value}) completed\n"
+ f"NetCDF file: {self._netcdf_filename}"
),
style="blue",
)
@property
def netcdf_file(self):
return self._netcdf_filename
def _check_file(self):
if self._netcdf_filename.exists():
os.remove(self._netcdf_filename)
def _create_netcdf_variables(self):
iml_values, imt_values = self._set_imt_values()
imt_enum_type = self._imt_enum_type()
geojson = self._get_map_file()
grid_step = self.metadata.grid_step
latitudes = self.metadata.locations.latitudes
longitudes = self.metadata.locations.longitudes
parameters = NetcdfParameters()
site_class_values = self._set_site_class_values()
site_class_enum_type = self._site_class_enum_type()
if geojson is not None:
parameters.BoundsParameters().netcdf_variable(group=self._dataset_group, data=geojson)
grid_mask_netcdf_var = parameters.GridMaskParameters().netcdf_variable(
group=self._dataset_group
)
hazard_netcdf_var = parameters.HazardParameters().netcdf_variable(group=self._dataset_group)
parameters.ImlParameters().netcdf_variable(group=self._dataset_group, data=iml_values)
parameters.ImtParameters(datatype=imt_enum_type).netcdf_variable(
group=self._dataset_group, data=imt_values
)
NetcdfUtils.create_coordinate_latitude(
group=self._dataset_group, grid_step=grid_step, values=latitudes
)
NetcdfUtils.create_coordinate_longitude(
group=self._dataset_group, grid_step=grid_step, values=longitudes
)
parameters.SiteClassParameters(datatype=site_class_enum_type).netcdf_variable(
group=self._dataset_group, data=site_class_values
)
parameters.Vs30Parameters().netcdf_variable(
group=self._dataset_group, data=self.metadata.vs30s
)
return grid_mask_netcdf_var, hazard_netcdf_var
def _create_dimensions(self, group: netcdf.Group) -> NetcdfDimensions:
dimensions = NetcdfDimensions(
iml=group.createDimension(
dimname=NetcdfKeys.IML, size=len(self.metadata.imls.get(Imt.PGA))
),
imt=group.createDimension(dimname=NetcdfKeys.IMT, size=len(self.metadata.imts)),
lat=group.createDimension(
dimname=NetcdfKeys.LAT, size=len(self.metadata.locations.latitudes)
),
lon=group.createDimension(
dimname=NetcdfKeys.LON, size=len(self.metadata.locations.longitudes)
),
site_class=group.createDimension(
dimname=NetcdfKeys.SITE_CLASS, size=len(self.metadata.site_classes)
),
)
return dimensions
def _clean_ascii(self) -> None:
if self.inputs.clean_ascii is True:
path = self.metadata.database_info.data_path
console.print(f"\n Removing ASCII files in ({path})", style="yellow")
shutil.rmtree(path)
def _get_hazard_data(self, hazard_netcdf_var: netcdf.Variable):
futures: list[Future] = []
status_msg = f"[bold green]Converting {self.metadata.database_info.nshm.label} files"
with self._progress:
with ThreadPoolExecutor() as executor:
for _info in self.metadata.netcdf_info:
info: NetcdfInfo = _info
futures.append(
executor.submit(
self._read_curves_file,
hazard_netcdf_var=hazard_netcdf_var,
netcdf_info=info,
)
)
for future in self._progress.track(futures, description=status_msg):
future.result(timeout=120)
def _get_imt_index(self, imt: Imt):
return self._imt_indices.get(imt)
def _get_latitude_index(self, latitude: float):
return NetcdfUtils.calc_index(
start=self.metadata.locations.latitudes[0],
delta=self.metadata.grid_step,
target=latitude,
)
def _get_longitude_index(self, longitude: float):
return NetcdfUtils.calc_index(
start=self.metadata.locations.longitudes[0],
delta=self.metadata.grid_step,
target=longitude,
)
def _get_map_file(self) -> Union[str, None]:
if self.metadata.database_info.map_file is None:
return None
for root, dirs, files in os.walk(self.metadata.database_info.data_path):
for file in files:
if self.metadata.database_info.map_file in file:
map_file = Path(root, file)
with open(map_file, "r") as json_reader:
geojson = json.load(json_reader)
return json.JSONEncoder().encode(geojson)
return None
def _get_site_class_index(self, site_class: SiteClass):
return self._site_class_indices.get(site_class)
def _imt_enum_type(self) -> netcdf.EnumType:
imt_dict: dict[str, int] = {}
for index, imt in enumerate(Imt):
imt_dict.setdefault(imt.name, index)
return self._dataset_group.createEnumType(
datatype=np.uint8, datatype_name=NetcdfKeys.IMT_ENUM_TYPE, enum_dict=imt_dict
)
def _read_curves_file(
self,
netcdf_info: NetcdfInfo,
hazard_netcdf_var: netcdf.Variable,
):
curves_file = netcdf_info.curve_file
if not curves_file.exists:
raise Exception(f"File ({curves_file}) not found")
data_array = np.full(
[
self._dimensions.lat.size,
self._dimensions.lon.size,
self._dimensions.iml.size,
],
NetcdfParameters.HazardParameters().fill_value,
float,
)
imls = self.metadata.imls.get(netcdf_info.imt)
imt_dir = curves_file.parent
imt_index = self._get_imt_index(imt=netcdf_info.imt)
imt_mask_array = np.zeros([self._dimensions.lat.size, self._dimensions.lon.size], int)
site_class_index = self._get_site_class_index(site_class=netcdf_info.site_class)
print(f"\t Converting [{imt_dir.parent.name}/{imt_dir.name}/{curves_file.name}]")
with open(curves_file, "r") as curves_reader:
# Skip header
next(curves_reader)
for _line in curves_reader:
line: str = _line
data: list[str] = line.strip().rstrip(",").split(",")
if line.startswith(NetcdfKeys.FULL_GRID):
continue
# Remove name in first index, get values
longitude, latitude, *values = [float(x) for x in data[1:]]
if len(values) < len(imls):
# pad values with zeros if necessary
while len(values) < len(imls):
values.append(0.0)
latitude_index = self._latitude_indices.setdefault(
latitude, self._get_latitude_index(latitude=latitude)
)
longitude_index = self._longitude_indices.setdefault(
longitude, self._get_longitude_index(longitude=longitude)
)
data_array[latitude_index, longitude_index, :] = values
# self._data_array[
# site_class_index, imt_index, latitude_index, longitude_index, :
# ] = values
imt_mask_array[latitude_index, longitude_index] = 1
self._site_class_mask_array += imt_mask_array
hazard_netcdf_var[site_class_index, imt_index, :, :, :] = data_array
def _set_imt_indices(self) -> dict[Imt, int]:
imt_indices: dict[Imt, int] = dict()
data_index = 0
for index, _imt in enumerate(Imt):
imt: Imt = _imt
if imt in self.metadata.imts:
imt_indices.setdefault(imt, data_index)
data_index += 1
return imt_indices
def _set_imt_values(self):
imt_values: list[int] = []
iml_values = np.full([self._dimensions.imt.size, self._dimensions.iml.size], 0, float)
data_index = 0
for index, _imt in enumerate(Imt):
imt: Imt = _imt
if imt in self.metadata.imts:
imt_values.append(index)
iml_values[data_index, :] = self.metadata.imls.get(imt)
data_index += 1
return iml_values, imt_values
def _set_site_class_indices(self) -> dict[SiteClass, int]:
indices: dict[SiteClass, int] = {}
data_index = 0
for index, _site_class in enumerate(SiteClass):
site_class: SiteClass = _site_class
if site_class in self.metadata.site_classes:
indices.setdefault(site_class, data_index)
data_index += 1
return indices
def _set_site_class_values(self) -> list[int]:
values: list[int] = []
for index, _site_class in enumerate(SiteClass):
site_class: SiteClass = _site_class
if site_class in self.metadata.site_classes:
values.append(index)
return values
def _site_class_enum_type(self) -> netcdf.EnumType:
site_class_dict: dict[str, int] = {}
for index, site_class in enumerate(SiteClass):
site_class_dict.setdefault(site_class.name, index)
return self._dataset_group.createEnumType(
datatype=np.uint8,
datatype_name=NetcdfKeys.SITE_CLASS_ENUM_TYPE,
enum_dict=site_class_dict,
)
def _write_netcdf_file(self):
self._root_group.setncatts(_ROOT_ATTRIBUTES)
self._dataset_group.description = self.metadata.database_info.description
for index, _metadata in enumerate(self.metadata.database_info.science_base_metadata):
metadata: ScienceBaseMetadata = _metadata
self._dataset_group.setncattr(name=f"science_base_url_{index}", value=metadata.url)
self._dataset_group.setncattr(
name=f"science_base_version_{index}", value=metadata.science_base_version
)
grid_mask_netcdf_var, hazard_netcdf_var = self._create_netcdf_variables()
self._get_hazard_data(hazard_netcdf_var=hazard_netcdf_var)
grid_mask_netcdf_var[:, :] = self._site_class_mask_array
Loading