From dd4896c0286b480930d678e9fcf4dad5a9fcdcc2 Mon Sep 17 00:00:00 2001
From: Nicholas Shavers <nshavers@contractor.usgs.gov>
Date: Mon, 30 Dec 2024 14:40:32 -0800
Subject: [PATCH] xmlfactory poc, simple and working

---
 geomagio/Controller.py         |   8 ++
 geomagio/api/xml/XMLFactory.py | 212 +++++++++++++++++++++++++++++++++
 2 files changed, 220 insertions(+)
 create mode 100644 geomagio/api/xml/XMLFactory.py

diff --git a/geomagio/Controller.py b/geomagio/Controller.py
index 1f28a68c..bfdffb04 100644
--- a/geomagio/Controller.py
+++ b/geomagio/Controller.py
@@ -7,6 +7,8 @@ from typing import List, Optional, Tuple, Union
 
 from obspy.core import Stream, UTCDateTime
 
+from geomagio.api.xml.XMLFactory import XMLFactory
+
 from .algorithm import Algorithm, algorithms, AlgorithmException, FilterAlgorithm
 from .DerivedTimeseriesFactory import DerivedTimeseriesFactory
 from .PlotTimeseriesFactory import PlotTimeseriesFactory
@@ -559,6 +561,8 @@ def get_input_factory(args):
                 convert_channels=args.convert_voltbin,
                 **input_factory_args,
             )
+        elif input_type == "xml":
+            input_factory = XMLFactory(**input_factory_args)
         # wrap stream
         if input_stream is not None:
             input_factory = StreamTimeseriesFactory(
@@ -644,6 +648,8 @@ def get_output_factory(args):
                 locationCode=locationcode,
                 **output_factory_args,
             )
+        elif output_type == "xml":
+            output_factory = XMLFactory(**output_factory_args)
         # wrap stream
         if output_stream is not None:
             output_factory = StreamTimeseriesFactory(
@@ -814,6 +820,7 @@ def parse_args(args):
             "fdsn",
             "miniseed",
             "pcdcp",
+            "xml",
         ),
         default="edge",
         help='Input format (Default "edge")',
@@ -1018,6 +1025,7 @@ def parse_args(args):
             "plot",
             "temperature",
             "vbf",
+            "xml",
         ),
         # TODO: set default to 'iaga2002'
         help="Output format",
diff --git a/geomagio/api/xml/XMLFactory.py b/geomagio/api/xml/XMLFactory.py
new file mode 100644
index 00000000..4e6850de
--- /dev/null
+++ b/geomagio/api/xml/XMLFactory.py
@@ -0,0 +1,212 @@
+import xml.etree.ElementTree as ET
+import numpy as np
+from obspy import Stream, Trace, UTCDateTime
+
+from geomagio.TimeseriesFactory import TimeseriesFactory
+
+
+class XMLFactory(TimeseriesFactory):
+    """Factory for reading/writing XML format Geomagnetic Data."""
+
+    def __init__(
+        self,
+        time_format="iso8601",  # could be "iso8601" or "numeric"
+        **kwargs,
+    ):
+        """
+        Parameters
+        ----------
+        time_format : {"iso8601", "numeric"}
+            - "iso8601": each time value is an ISO8601 string
+            - "numeric": each time value is numeric (epoch milliseconds)
+
+        """
+        super().__init__(**kwargs)
+        self.time_format = time_format
+
+    def parse_string(self, data: str, **kwargs):
+        """
+        Parse an XML string into an ObsPy Stream.
+
+        The XML should follow the structure:
+        <GeomagneticTimeSeriesData>
+            <Description>
+                ...
+            </Description>
+            <Data>
+                <Sample>
+                    <H></H>
+                    <D></D>
+                    ...
+                </Sample>
+                ...
+            </Data>
+        </GeomagneticTimeSeriesData>
+
+        Returns an empty Stream if no valid data is found.
+        """
+        try:
+            root = ET.fromstring(data)
+        except ET.ParseError:
+            return Stream()
+
+        # Parse Description
+        description = root.find("Description")
+        if description is None:
+            return Stream()
+
+        institute = description.findtext("InstituteName", default="")
+        observatory = description.findtext("ObservatoryName", default="")
+        observatory_code = description.findtext("ObservatoryCode", default="")
+        latitude = float(description.findtext("SensorLatitude", default="0.0"))
+        longitude = float(description.findtext("SensorLongitude", default="0.0"))
+        elevation = float(description.findtext("SensorElevation", default="0.0"))
+        sensor_orientation = description.findtext("SensorOrientation", default="")
+        original_sample_rate = description.findtext("OriginalSampleRate", default="")
+        data_type = description.findtext("DataType", default="")
+        start_time_str = description.findtext("StartTime", default="")
+        sample_period = float(description.findtext("SamplePeriod", default="60000"))
+
+        if self.time_format == "numeric":
+            starttime = UTCDateTime(float(start_time_str) / 1000.0)
+        else:
+            try:
+                starttime = UTCDateTime(start_time_str)
+            except Exception:
+                starttime = UTCDateTime()
+
+        # Parse Data
+        data_element = root.find("Data")
+        if data_element is None:
+            return Stream()
+
+        samples = data_element.findall("Sample")
+        if not samples:
+            return Stream()
+
+        channel_data = {}
+        for sample in samples:
+            for channel in sample:
+                ch_name = channel.tag
+                value = channel.text.strip()
+                if ch_name not in channel_data:
+                    channel_data[ch_name] = []
+                try:
+                    channel_data[ch_name].append(float(value) if value else 0.0)
+                except ValueError:
+                    channel_data[ch_name].append(0.0)
+
+        # Create Stream and Traces
+        stream = Stream()
+        npts = len(samples)
+        delta = sample_period / 1000.0  # Convert milliseconds to seconds
+
+        for ch_name, values in channel_data.items():
+            data_array = np.array(values, dtype=float)
+            stats = {
+                "network": "",
+                "station": observatory_code,
+                "channel": ch_name,
+                "starttime": starttime,
+                "npts": npts,
+                "sampling_rate": 1.0 / delta if delta != 0 else 1.0,
+                "geodetic_longitude": longitude,
+                "geodetic_latitude": latitude,
+                "elevation": elevation,
+                "station_name": observatory,
+                "data_type": data_type,
+                "data_interval_type": original_sample_rate,
+            }
+
+            trace = Trace(data=data_array, header=stats)
+            stream += trace
+
+        return stream
+
+    def write_file(self, fh, timeseries: Stream, channels):
+        """
+        Write an ObsPy Stream to an XML coverage.
+
+        The XML structure follows the example:
+        <GeomagneticTimeSeriesData>
+            <Description>
+                ...
+            </Description>
+            <Data>
+                <Sample>
+                    <H></H>
+                    <D></D>
+                    ...
+                </Sample>
+                ...
+            </Data>
+        </GeomagneticTimeSeriesData>
+        """
+        if not timeseries or len(timeseries) == 0:
+            fh.write(b"<GeomagneticTimeSeriesData></GeomagneticTimeSeriesData>")
+            return
+
+        timeseries.merge()
+
+        # Filter channels
+        if channels:
+            new_stream = Stream()
+            for ch in channels:
+                new_stream += timeseries.select(channel=ch)
+            timeseries = new_stream
+
+        timeseries.sort(keys=["starttime"])
+        tr = timeseries[0]
+        stats = tr.stats
+
+        station = stats.station or ""
+        lat = float(getattr(stats, "geodetic_latitude", 0.0))
+        lon = float(getattr(stats, "geodetic_longitude", 0.0))
+        alt = float(getattr(stats, "elevation", 0.0))
+        data_type = getattr(stats, "data_type", str(self.type))
+        data_interval_type = getattr(stats, "data_interval_type", str(self.interval))
+        station_name = getattr(stats, "station_name", station)
+        sensor_orientation = getattr(stats, "sensor_orientation", "")
+        original_sample_rate = getattr(stats, "digital_sampling_rate", "")
+        sample_period = getattr(stats, "sampling_period", 60000.0)
+
+        npts = tr.stats.npts
+        delta = tr.stats.delta
+        starttime = tr.stats.starttime
+
+        # Create root element
+        root = ET.Element("GeomagneticTimeSeriesData")
+
+        # Create Description element
+        description = ET.SubElement(root, "Description")
+        ET.SubElement(description, "InstituteName").text = getattr(
+            stats, "agency_name", ""
+        )
+        ET.SubElement(description, "ObservatoryName").text = station_name
+        ET.SubElement(description, "ObservatoryCode").text = station
+        ET.SubElement(description, "SensorLatitude").text = f"{lat}"
+        ET.SubElement(description, "SensorLongitude").text = f"{lon}"
+        ET.SubElement(description, "SensorElevation").text = f"{alt}"
+        ET.SubElement(description, "SensorOrientation").text = sensor_orientation
+        ET.SubElement(description, "OriginalSampleRate").text = str(
+            original_sample_rate
+        )
+        ET.SubElement(description, "DataType").text = data_type
+        ET.SubElement(description, "StartTime").text = starttime.isoformat()
+        ET.SubElement(description, "SamplePeriod").text = f"{sample_period}"
+
+        # Create Data element
+        data_elem = ET.SubElement(root, "Data")
+
+        # All traces should have the same starttime and delta per convention
+        for i in range(npts):
+            sample = ET.SubElement(data_elem, "Sample")
+            for trace in timeseries:
+                ch_name = trace.stats.channel
+                value = trace.data[i] if i < len(trace.data) else ""
+                ET.SubElement(sample, ch_name).text = f"{value}"
+
+        # Generate the XML string
+        tree = ET.ElementTree(root)
+        # Write to file handle with proper encoding
+        tree.write(fh, encoding="utf-8", xml_declaration=True)
-- 
GitLab