Skip to content
Snippets Groups Projects
FDSNFactory_test.py 6.33 KiB
Newer Older
  • Learn to ignore specific revisions
  • import io
    from typing import List
    
    import numpy
    from numpy.testing import assert_equal, assert_array_equal
    import numpy as np
    from obspy.core import Stream, Trace, UTCDateTime
    from obspy.core.inventory import Inventory, Network, Station, Channel, Site
    import pytest
    
    from geomagio.edge import FDSNFactory
    from geomagio.metadata.instrument.InstrumentCalibrations import (
        get_instrument_calibrations,
    )
    from .mseed_FDSN_test_clients import MockFDSNSeedClient
    
    
    @pytest.fixture(scope="class")
    def FDSN_factory() -> FDSNFactory:
        """instance of FDSNFactory with MockFDSNClient"""
        factory = FDSNFactory()
        factory.client = MockFDSNSeedClient()
        yield factory
    
    
    @pytest.fixture()
    def anmo_u_metadata():
        metadata = get_instrument_calibrations(observatory="ANMO")
        instrument = metadata[0]["instrument"]
        channels = instrument["channels"]
        yield channels["X"]
    
    
    def test__get_timeseries_add_empty_channels(FDSN_factory: FDSNFactory):
        """test.edge_test.FDSNFactory_test.test__get_timeseries_add_empty_channels()"""
        FDSN_factory.client.return_empty = True
        starttime = UTCDateTime("2024-09-07T00:00:00Z")
        endtime = UTCDateTime("2024-09-07T00:10:00Z")
        trace = FDSN_factory._get_timeseries(
            starttime=starttime,
            endtime=endtime,
            observatory="ANMO",
            channel="X",
            type="variation",
            interval="second",
            add_empty_channels=True,
        )[0]
    
        assert_array_equal(trace.data, numpy.ones(trace.stats.npts) * numpy.nan)
        assert trace.stats.starttime == starttime
        assert trace.stats.endtime == endtime
    
        with pytest.raises(IndexError):
            trace = FDSN_factory._get_timeseries(
                starttime=starttime,
                endtime=endtime,
                observatory="ANMO",
                channel="X",
                type="variation",
                interval="second",
                add_empty_channels=False,
            )[0]
    
    
    def test__set_metadata():
        """edge_test.FDSNFactory_test.test__set_metadata()"""
        # Call _set_metadata with 2 traces,  and make certain the stats get
        # set for both traces.
        trace1 = Trace()
        trace2 = Trace()
        stream = Stream(traces=[trace1, trace2])
        FDSNFactory()._set_metadata(stream, "ANMO", "X", "variation", "second")
        assert_equal(stream[0].stats["channel"], "X")
        assert_equal(stream[1].stats["channel"], "X")
    
    
    def test_get_timeseries(FDSN_factory):
        """edge_test.FDSNFactory_test.test_get_timeseries()"""
        # Call get_timeseries, and test stats for comfirmation that it came back.
        # TODO, need to pass in host and port from a config file, or manually
        #   change for a single test.
        timeseries = FDSN_factory.get_timeseries(
            starttime=UTCDateTime(2024, 3, 1, 0, 0, 0),
            endtime=UTCDateTime(2024, 3, 1, 1, 0, 0),
            observatory="ANMO",
            channels=("X"),
            type="variation",
            interval="second",
        )
    
        assert_equal(
            timeseries.select(channel="X")[0].stats.station,
            "ANMO",
            "Expect timeseries to have stats",
        )
        assert_equal(
            timeseries.select(channel="X")[0].stats.channel,
            "X",
            "Expect timeseries stats channel to be equal to X",
        )
        assert_equal(
            timeseries.select(channel="X")[0].stats.data_type,
            "variation",
            "Expect timeseries stats data_type to be equal to variation",
        )
    
    
    def test_get_timeseries_by_location(FDSN_factory):
        """test.edge_test.FDSNFactory_test.test_get_timeseries_by_location()"""
        timeseries = FDSN_factory.get_timeseries(
            UTCDateTime(2024, 3, 1, 0, 0, 0),
            UTCDateTime(2024, 3, 1, 1, 0, 0),
            "ANMO",
            ("X"),
            "R0",
            "second",
        )
        assert_equal(
            timeseries.select(channel="X")[0].stats.data_type,
            "R0",
            "Expect timeseries stats data_type to be equal to R0",
        )
    
    
    def test_rotate_trace():
        # Initialize the factory
        factory = FDSNFactory(
            observatory="ANMO",
            channels=["X", "Y", "Z"],
            type="variation",
            interval="second",
        )
    
        # Simulate input traces for X, Y, Z channels
        starttime = UTCDateTime("2024-01-01T00:00:00")
        endtime = UTCDateTime(2024, 1, 1, 0, 10)
        data_x = Trace(
            data=np.array([1, 2, 3, 4, 5]),
            header={"channel": "X", "starttime": starttime, "delta": 60},
        )
        data_y = Trace(
            data=np.array([6, 7, 8, 9, 10]),
            header={"channel": "Y", "starttime": starttime, "delta": 60},
        )
        data_z = Trace(
            data=np.array([11, 12, 13, 14, 15]),
            header={"channel": "Z", "starttime": starttime, "delta": 60},
        )
        input_stream = Stream(traces=[data_x, data_y, data_z])
    
        # Mock the Client.get_waveforms method to return the simulated stream
        factory.Client.get_waveforms = lambda *args, **kwargs: input_stream
    
        # Create a mock inventory object for get stations
        mock_inventory = create_mock_inventory()
        # Mock the Client.get_stations method to return dummy inventory (if required for rotation)
        factory.Client.get_stations = lambda *args, **kwargs: mock_inventory
    
        # Call get_timeseries with channel "X" to trigger rotation
        rotated_stream = factory.get_timeseries(
            starttime=starttime,
            endtime=endtime,
            observatory="ANMO",
            channels=["X"],  # Requesting any channel in [X, Y, Z] should trigger rotation
        )
    
        # Assertions
        assert (
            len(rotated_stream) == 1
        ), "Expected only the requested channel (X, Y, Z) after rotation"
        assert rotated_stream[0].stats.channel in [
            "X",
        ], "Unexpected channel names after rotation"
        assert (
            rotated_stream[0].stats.starttime == starttime
        ), "Start time mismatch in rotated data"
    
    
    def create_mock_inventory():
        """Creates a mock inventory for testing purposes."""
        # Create a dummy channel
        channel = Channel(
            code="X",
            location_code="",
            latitude=0.0,
            longitude=0.0,
            elevation=0.0,
            depth=0.0,
            azimuth=0.0,
            dip=0.0,
            sample_rate=1.0,
        )
        # Create a dummy station
        station = Station(
            code="ANMO",
            latitude=0.0,
            longitude=0.0,
            elevation=0.0,
            site=Site(name="TestSite"),
            channels=[channel],
        )
        # Create a dummy network
        network = Network(code="XX", stations=[station])
        # Create an inventory
        inventory = Inventory(networks=[network], source="MockInventory")
        return inventory