Skip to content
Snippets Groups Projects
XMLFactory.py 7.26 KiB
Newer Older
  • Learn to ignore specific revisions
  • import xml.etree.ElementTree as ET
    import numpy as np
    from obspy import Stream, Trace, UTCDateTime
    
    from geomagio.TimeseriesFactory import TimeseriesFactory
    
    
    class XMLFactory(TimeseriesFactory):
        """Factory for reading/writing XML format Geomagnetic Data."""
    
        def __init__(
            self,
            **kwargs,
        ):
            super().__init__(**kwargs)
    
        def parse_string(self, data: str, **kwargs):
            """
            Parse an XML string into an ObsPy Stream.
    
            The XML should follow the structure:
            <GeomagneticTimeSeriesData>
                <Description>
                    ...
                </Description>
                <Data>
                    <Sample>
                        <H></H>
                        <D></D>
                        ...
                    </Sample>
                    ...
                </Data>
            </GeomagneticTimeSeriesData>
    
            Returns an empty Stream if no valid data is found.
            """
            try:
                root = ET.fromstring(data)
            except ET.ParseError:
                return Stream()
    
            # Parse Description
            description = root.find("Description")
            if description is None:
                return Stream()
    
            institute = description.findtext("InstituteName", default="")
            observatory = description.findtext("ObservatoryName", default="")
            observatory_code = description.findtext("ObservatoryCode", default="")
            latitude = float(description.findtext("SensorLatitude", default="0.0"))
            longitude = float(description.findtext("SensorLongitude", default="0.0"))
            elevation = float(description.findtext("SensorElevation", default="0.0"))
            sensor_orientation = description.findtext("SensorOrientation", default="")
            original_sample_rate = description.findtext("OriginalSampleRate", default="")
            data_type = description.findtext("DataType", default="")
            start_time_str = description.findtext("StartTime", default="")
            sample_period = float(description.findtext("SamplePeriod", default="60000"))
    
    
            try:
                starttime = UTCDateTime(start_time_str)
            except Exception:
                starttime = UTCDateTime()
    
    
            # Parse Data
            data_element = root.find("Data")
            if data_element is None:
                return Stream()
    
            samples = data_element.findall("Sample")
            if not samples:
                return Stream()
    
            channel_data = {}
            for sample in samples:
                for channel in sample:
                    ch_name = channel.tag
                    value = channel.text.strip()
                    if ch_name not in channel_data:
                        channel_data[ch_name] = []
                    try:
                        channel_data[ch_name].append(float(value) if value else 0.0)
                    except ValueError:
                        channel_data[ch_name].append(0.0)
    
            # Create Stream and Traces
            stream = Stream()
            npts = len(samples)
            delta = sample_period / 1000.0  # Convert milliseconds to seconds
    
            for ch_name, values in channel_data.items():
                data_array = np.array(values, dtype=float)
                stats = {
                    "network": "",
                    "station": observatory_code,
                    "channel": ch_name,
                    "starttime": starttime,
                    "npts": npts,
                    "sampling_rate": 1.0 / delta if delta != 0 else 1.0,
                    "geodetic_longitude": longitude,
                    "geodetic_latitude": latitude,
                    "elevation": elevation,
                    "station_name": observatory,
                    "data_type": data_type,
                    "data_interval_type": original_sample_rate,
                }
    
                trace = Trace(data=data_array, header=stats)
                stream += trace
    
            return stream
    
        def write_file(self, fh, timeseries: Stream, channels):
            """
            Write an ObsPy Stream to an XML coverage.
    
            The XML structure follows the example:
            <GeomagneticTimeSeriesData>
                <Description>
                    ...
                </Description>
                <Data>
                    <Sample>
                        <H></H>
                        <D></D>
                        ...
                    </Sample>
                    ...
                </Data>
            </GeomagneticTimeSeriesData>
            """
            if not timeseries or len(timeseries) == 0:
                fh.write(b"<GeomagneticTimeSeriesData></GeomagneticTimeSeriesData>")
                return
    
            timeseries.merge()
    
            # TODO: CONSIDER PADDING TIMESERIES: use TimeseriesUtility.pad_timeseries(...)
    
    
            # Filter channels
            if channels:
                new_stream = Stream()
                for ch in channels:
                    new_stream += timeseries.select(channel=ch)
                timeseries = new_stream
    
            timeseries.sort(keys=["starttime"])
            tr = timeseries[0]
            stats = tr.stats
    
            station = stats.station or ""
            lat = float(getattr(stats, "geodetic_latitude", 0.0))
            lon = float(getattr(stats, "geodetic_longitude", 0.0))
            alt = float(getattr(stats, "elevation", 0.0))
            data_type = getattr(stats, "data_type", str(self.type))
            data_interval_type = getattr(stats, "data_interval_type", str(self.interval))
            station_name = getattr(stats, "station_name", station)
            sensor_orientation = getattr(stats, "sensor_orientation", "")
            original_sample_rate = getattr(stats, "digital_sampling_rate", "")
            sample_period = getattr(stats, "sampling_period", 60000.0)
    
            npts = tr.stats.npts
            delta = tr.stats.delta
            starttime = tr.stats.starttime
    
            # Create root element
            root = ET.Element("GeomagneticTimeSeriesData")
    
            # Create Description element
            description = ET.SubElement(root, "Description")
            ET.SubElement(description, "InstituteName").text = getattr(
                stats, "agency_name", ""
            )
            ET.SubElement(description, "ObservatoryName").text = station_name
            ET.SubElement(description, "ObservatoryCode").text = station
            ET.SubElement(description, "SensorLatitude").text = f"{lat}"
            ET.SubElement(description, "SensorLongitude").text = f"{lon}"
            ET.SubElement(description, "SensorElevation").text = f"{alt}"
            ET.SubElement(description, "SensorOrientation").text = sensor_orientation
            ET.SubElement(description, "OriginalSampleRate").text = str(
                original_sample_rate
            )
            ET.SubElement(description, "DataType").text = data_type
            ET.SubElement(description, "StartTime").text = starttime.isoformat()
            ET.SubElement(description, "SamplePeriod").text = f"{sample_period}"
    
            # Create Data element
            data_elem = ET.SubElement(root, "Data")
    
    
            # For each sample, create and add as sub element of <data>...</data>
    
            for i in range(npts):
                sample = ET.SubElement(data_elem, "Sample")
                for trace in timeseries:
                    ch_name = trace.stats.channel
                    value = trace.data[i] if i < len(trace.data) else ""
    
                    # For each channels trace, get data point, add as a sub element of <sample>...</sample>
    
                    ET.SubElement(sample, ch_name).text = f"{value}"
    
            # Generate the XML string
            tree = ET.ElementTree(root)
            # Write to file handle with proper encoding
            tree.write(fh, encoding="utf-8", xml_declaration=True)