From 7d8c33375fe947b811302cb9ed3176e313af0053 Mon Sep 17 00:00:00 2001
From: Nicholas Shavers <nshavers@contractor.usgs.gov>
Date: Mon, 30 Dec 2024 17:30:14 -0800
Subject: [PATCH] netcdf poc, works

---
 geomagio/Controller.py           |   8 +
 geomagio/netcdf/NetCDFFactory.py | 259 +++++++++++++++++++++++++++++++
 2 files changed, 267 insertions(+)
 create mode 100644 geomagio/netcdf/NetCDFFactory.py

diff --git a/geomagio/Controller.py b/geomagio/Controller.py
index dce6b919..6b7c336f 100644
--- a/geomagio/Controller.py
+++ b/geomagio/Controller.py
@@ -7,6 +7,8 @@ from typing import List, Optional, Tuple, Union
 
 from obspy.core import Stream, UTCDateTime
 
+from geomagio.netcdf.NetCDFFactory import NetCDFFactory
+
 from .algorithm import Algorithm, algorithms, AlgorithmException, FilterAlgorithm
 from .DerivedTimeseriesFactory import DerivedTimeseriesFactory
 from .PlotTimeseriesFactory import PlotTimeseriesFactory
@@ -547,6 +549,8 @@ def get_input_factory(args):
         # stream compatible factories
         if input_type == "iaga2002":
             input_factory = iaga2002.IAGA2002Factory(**input_factory_args)
+        if input_type == "netcdf":
+            input_factory = NetCDFFactory(**input_factory_args)
         elif input_type == "imfv122":
             input_factory = imfv122.IMFV122Factory(**input_factory_args)
         elif input_type == "imfv283":
@@ -632,6 +636,8 @@ def get_output_factory(args):
             output_factory = binlog.BinLogFactory(**output_factory_args)
         elif output_type == "iaga2002":
             output_factory = iaga2002.IAGA2002Factory(**output_factory_args)
+        elif output_type == "netcdf":
+            output_factory = NetCDFFactory(**output_factory_args)
         elif output_type == "imfjson":
             output_factory = imfjson.IMFJSONFactory(**output_factory_args)
         elif output_type == "covjson":
@@ -826,6 +832,7 @@ def parse_args(args):
             "pcdcp",
             "xml",
             "covjson",
+            "netcdf",
         ),
         default="edge",
         help='Input format (Default "edge")',
@@ -1032,6 +1039,7 @@ def parse_args(args):
             "vbf",
             "xml",
             "covjson",
+            "netcdf",
         ),
         # TODO: set default to 'iaga2002'
         help="Output format",
diff --git a/geomagio/netcdf/NetCDFFactory.py b/geomagio/netcdf/NetCDFFactory.py
new file mode 100644
index 00000000..4c877e7b
--- /dev/null
+++ b/geomagio/netcdf/NetCDFFactory.py
@@ -0,0 +1,259 @@
+import netCDF4
+import numpy as np
+from obspy import Stream, Trace, UTCDateTime
+from datetime import datetime, timezone
+from io import BytesIO
+
+from geomagio.TimeseriesFactory import TimeseriesFactory
+
+
+class NetCDFFactory(TimeseriesFactory):
+    """Factory for reading/writing NetCDF format Geomagnetic Data using numeric epoch times."""
+
+    def __init__(self, **kwargs):
+        """
+        Initializes the NetCDFFactory.
+
+        Parameters
+        ----------
+        **kwargs : dict
+            Additional keyword arguments for the base TimeseriesFactory.
+        """
+        super().__init__(**kwargs)
+        self.time_format = "numeric"  # Fixed to numeric epoch times
+
+    def parse_string(self, data: bytes, **kwargs):
+        """
+        Parse a NetCDF byte string into an ObsPy Stream.
+
+        The NetCDF file should follow a structure with global attributes for metadata
+        and variables for each geomagnetic component, including a 'time' variable
+        representing epoch times in seconds.
+
+        Parameters
+        ----------
+        data : bytes
+            Byte content of the NetCDF file.
+
+        Returns
+        -------
+        Stream
+            An ObsPy Stream object containing the parsed data. Returns an empty Stream if parsing fails.
+        """
+        try:
+            # Create a NetCDF dataset from the byte data
+            nc_dataset = netCDF4.Dataset("inmemory", mode="r", memory=data)
+        except Exception as e:
+            print(f"Failed to parse NetCDF data: {e}")
+            return Stream()
+
+        try:
+            # Extract global attributes
+            institute = getattr(nc_dataset, "institute_name", "")
+            observatory = getattr(nc_dataset, "observatory_name", "")
+            observatory_code = getattr(nc_dataset, "observatory_code", "")
+            latitude = getattr(nc_dataset, "sensor_latitude", 0.0)
+            longitude = getattr(nc_dataset, "sensor_longitude", 0.0)
+            elevation = getattr(nc_dataset, "sensor_elevation", 0.0)
+            sensor_orientation = getattr(nc_dataset, "sensor_orientation", "")
+            original_sample_rate = getattr(nc_dataset, "original_sample_rate", "")
+            data_type = getattr(nc_dataset, "data_type", "")
+            start_time_num = getattr(
+                nc_dataset, "start_time", 0.0
+            )  # Seconds since epoch
+            sample_period = getattr(nc_dataset, "sample_period", 1.0)  # In milliseconds
+
+            # Convert numeric start time to UTCDateTime
+            try:
+                starttime = UTCDateTime(start_time_num)
+            except Exception:
+                starttime = UTCDateTime()
+
+            # Extract time variable
+            if "time" not in nc_dataset.variables:
+                print("No 'time' variable found in NetCDF data.")
+                return Stream()
+
+            time_var = nc_dataset.variables["time"]
+            times_epoch = time_var[:]  # Numeric epoch times in seconds
+
+            # Extract channel data
+            channel_names = [var for var in nc_dataset.variables if var != "time"]
+            channel_data = {}
+            for ch in channel_names:
+                data_array = nc_dataset.variables[ch][:]
+                # Ensure data is a 1D array
+                channel_data[ch] = data_array.flatten()
+
+            nc_dataset.close()
+        except Exception as e:
+            print(f"Error while extracting data from NetCDF: {e}")
+            return Stream()
+
+        # Create Stream and Traces
+        stream = Stream()
+        npts = len(times_epoch)
+        if npts == 0:
+            print("No time points found in NetCDF data.")
+            return stream
+
+        # Calculate delta as the average sampling interval
+        if npts > 1:
+            delta = (times_epoch[-1] - times_epoch[0]) / (npts - 1)
+        else:
+            delta = 1.0  # Default to 1 second if only one point
+
+        for ch_name, values in channel_data.items():
+            if len(values) != npts:
+                print(
+                    f"Channel '{ch_name}' has mismatched data length. Expected {npts}, got {len(values)}. Skipping."
+                )
+                continue
+            data_array = np.array(values, dtype=float)
+            try:
+                trace_starttime = UTCDateTime(times_epoch[0])
+            except Exception:
+                trace_starttime = UTCDateTime()
+
+            stats = {
+                "network": "",
+                "station": observatory_code,
+                "channel": ch_name,
+                "starttime": trace_starttime,
+                "npts": npts,
+                "sampling_rate": 1.0 / delta if delta != 0 else 1.0,
+                "geodetic_longitude": longitude,
+                "geodetic_latitude": latitude,
+                "elevation": elevation,
+                "station_name": observatory,
+                "data_type": data_type,
+                "data_interval_type": original_sample_rate,
+            }
+
+            trace = Trace(data=data_array, header=stats)
+            stream += trace
+
+        return stream
+
+    def write_file_from_buffer(self, fh, timeseries: Stream, channels=None):
+        """
+        Write an ObsPy Stream to a NetCDF file.
+
+        The NetCDF file will contain global attributes for metadata and variables
+        for each geomagnetic component, including a 'time' variable representing
+        epoch times in seconds.
+
+        Parameters
+        ----------
+        fh : file-like object
+            File handle to write the NetCDF data.
+        timeseries : Stream
+            ObsPy Stream object containing the data to write.
+        channels : list, optional
+            List of channel names to include. If None, all channels are included.
+        """
+        if not timeseries or len(timeseries) == 0:
+            # Create an empty NetCDF structure
+            with netCDF4.Dataset(fh, "w", format="NETCDF4") as nc_dataset:
+                nc_dataset.description = "Empty Geomagnetic Time Series Data"
+            return
+
+        timeseries.merge()
+
+        # Filter channels if specified
+        if channels:
+            # Ensure channel names are uppercase for consistency
+            channels = [ch.upper() for ch in channels]
+            # Manually filter the Stream to include only desired channels
+            timeseries = Stream(
+                [tr for tr in timeseries if tr.stats.channel.upper() in channels]
+            )
+
+        if len(timeseries) == 0:
+            print("No matching channels found after filtering.")
+            with netCDF4.Dataset(fh, "w", format="NETCDF4") as nc_dataset:
+                nc_dataset.description = (
+                    "Empty Geomagnetic Time Series Data after channel filtering"
+                )
+            return
+
+        timeseries.sort(keys=["starttime"])
+        tr = timeseries[0]
+        stats = tr.stats
+
+        station = stats.station or ""
+        lat = float(getattr(stats, "geodetic_latitude", 0.0))
+        lon = float(getattr(stats, "geodetic_longitude", 0.0))
+        alt = float(getattr(stats, "elevation", 0.0))
+        data_type = getattr(stats, "data_type", "unknown")
+        data_interval_type = getattr(stats, "data_interval_type", "unknown")
+        station_name = getattr(stats, "station_name", station)
+        sensor_orientation = getattr(stats, "sensor_orientation", "")
+        original_sample_rate = getattr(stats, "original_sample_rate", "")
+        sample_period = getattr(stats, "sampling_period", 1.0)  # In milliseconds
+
+        npts = tr.stats.npts
+        delta = tr.stats.delta
+        starttime = tr.stats.starttime
+
+        # Generate time values as epoch seconds
+        times_epoch = np.array([starttime.timestamp + i * delta for i in range(npts)])
+
+        # Create NetCDF dataset
+        with netCDF4.Dataset(fh, "w", format="NETCDF4") as nc_dataset:
+            # Define the time dimension
+            nc_dataset.createDimension("time", npts)
+
+            # Create the time variable
+            time_var = nc_dataset.createVariable("time", "f8", ("time",))
+            time_var.units = "seconds since 1970-01-01T00:00:00Z"
+            time_var.calendar = "standard"
+            time_var[:] = times_epoch
+
+            # Create channel variables
+            for trace in timeseries:
+                ch_name = trace.stats.channel.upper()
+                # Ensure channel name is a valid NetCDF variable name
+                ch_name_nc = ch_name.replace("-", "_").replace(" ", "_")
+                var = nc_dataset.createVariable(ch_name_nc, "f8", ("time",))
+                var.units = (
+                    "nT"  # non standard elements such as temperature not considered.
+                )
+                var.long_name = f"Geomagnetic component {ch_name}"
+                var[:] = trace.data
+
+            # Set global attributes
+            nc_dataset.institute_name = getattr(stats, "agency_name", "")
+            nc_dataset.observatory_name = station_name
+            nc_dataset.observatory_code = station
+            nc_dataset.sensor_latitude = lat
+            nc_dataset.sensor_longitude = lon
+            nc_dataset.sensor_elevation = alt
+            nc_dataset.sensor_orientation = sensor_orientation
+            nc_dataset.original_sample_rate = original_sample_rate
+            nc_dataset.data_type = data_type
+            nc_dataset.start_time = starttime.timestamp  # Numeric epoch time
+            nc_dataset.sample_period = sample_period  # In milliseconds
+
+            # Optional global attributes
+            nc_dataset.history = f"Created on {datetime.now(timezone.utc).isoformat()}"
+
+    def write_file(self, fh, timeseries: Stream, channels=None):
+        import tempfile
+        import shutil
+        import os
+
+        # Create a temporary file
+        with tempfile.NamedTemporaryFile(delete=False) as tmp:
+            tmp_filepath = tmp.name
+
+        try:
+            # Write to the temporary file using the existing write_file method
+            self.write_file_from_buffer(tmp_filepath, timeseries, channels)
+
+            # Read the temporary file content
+            with open(tmp_filepath, "rb") as tmp:
+                shutil.copyfileobj(tmp, fh)
+        finally:
+            # Clean up the temporary file
+            os.remove(tmp_filepath)
-- 
GitLab