Skip to content
Snippets Groups Projects
stac_helpers.py 5.39 KiB
Newer Older
  • Learn to ignore specific revisions
  • Snyder, Amelia Marie's avatar
    Snyder, Amelia Marie committed
    import numpy as np
    import cfunits
    
    
    def license_picker(license_text):
        print(f'license in dataset attrs: "{license_text}"')
        print('\nFor USGS data, we can use "\033[1mCC0-1.0\033[0m" as the license. For all other data we can use "\033[1mUnlicense\033[0m".')
        print('Ref: https://spdx.org/licenses/')
        license_mapper = {
    
            'Public domain': 'CC0-1.0',
    
            'Creative Commons CC0 1.0 Universal Dedication(http://creativecommons.org/publicdomain/zero/1.0/legalcode)': 'CC0-1.0',
    
            'Freely available': 'Unlicense',
            'Freely Available: Oregon State University retains rights to ownership of the data and information.': 'Unlicense',
    
            'No restrictions': 'Unlicense',
            'Creative Commons Attribution-ShareAlike 4.0 International License (http://creativecommons.org/licenses/by-sa/4.0/)': 'CC-BY-SA-4.0',
            'This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License (https://creativecommons.org/licenses/by-sa/4.0/).': 'CC-BY-SA-4.0'
    
        }
        try:
            license = license_mapper[license_text]
            print(f'\nlicense automatically chosen: \033[1m{license}')
    
        except:
            license = str(input("What license would you like to use for this dataset?"))
            print(f'\nlicense input by user: \033[1m{license}')
    
    Snyder, Amelia Marie's avatar
    Snyder, Amelia Marie committed
        return license
    
    
    def print_attr(ds, attr_name):
    
    Snyder, Amelia Marie's avatar
    Snyder, Amelia Marie committed
        # 'time_coverage_resolution',
        # 'time_coverage_start', 'time_coverage_end',
        # 'resolution', 'geospatial_lon_resolution', 'geospatial_lat_resolution',
        # 'geospatial_lon_min','geospatial_lon_max', 'geospatial_lat_min', 'geospatial_lat_max'
    
        try:
            attr = ds.attrs[attr_name]
            print(f'dataset attribute \033[1m{attr_name}\033[0m: {attr}')
        except:
            pass
    
    
    Snyder, Amelia Marie's avatar
    Snyder, Amelia Marie committed
    def extract_dim(ds, d):
        try:
            dim_list = ds.cf.axes[d]
            assert len(dim_list)==1, f'There are too many {d} dimensions in this dataset.'
            dim = dim_list[0]
        except KeyError:
            print(f"Could not auto-extract {d} dimension name.")
            print("Look at the xarray output above showing the dataset dimensions.")
            dim = str(input(f"What is the name of the {d} dimension of this dataset?"))
        assert dim in ds.dims, "That is not a valid dimension name for this dataset"
        print(f"name of {d} dimension: {dim}\n")
        return dim
    
    def get_step(ds, dim_name, time_dim=False, debug=False, step_ix=0, round_dec=None):
        dim_vals = ds[dim_name].values
        diffs = [d2 - d1 for d1, d2 in zip(dim_vals, dim_vals[1:])]
        # option to round number of decimals
        # sometimes there are different steps calculated due to small rounding errors coming out of the diff
        # calculation, rounding these can correct for that
        if round_dec:
            unique_steps = np.unique(np.array(diffs).round(decimals=round_dec), return_counts=True)
        else:
            unique_steps = np.unique(diffs, return_counts=True)
        step_list = unique_steps[0]
        # optional - for invesitgating uneven steps
        if debug:
            print(f'step_list: {step_list}')
            print(f'step_count: {unique_steps[1]}')
            indices = [i for i, x in enumerate(diffs) if x == step_list[step_ix]]
            print(f'index locations of step index {step_ix} in step_list: {indices}')
        # set step - if all steps are the same length
        # datacube spec specifies to use null for irregularly spaced steps
        if len(step_list)==1:
            if time_dim:
                # make sure time deltas are in np timedelta format
                step_list = [np.array([step], dtype="timedelta64[ns]")[0] for step in step_list]
            step = step_list[0].astype(float).item()
        else:
            step = None
        return(step)
    
    def get_long_name(ds, v):
        # try to get long_name attribute from variable
        try:
            long_name = ds[v].attrs['long_name']
        # otherwise, leave empty
        except:
            long_name = None
        return long_name
    
    def get_unit(ds, v):
        # check if unit is defined for variable
        try:
            unit = ds[v].attrs['units']
        except:
            unit = None
        # check if unit comes from https://docs.unidata.ucar.edu/udunits/current/#Database
        # datacube extension specifies: The unit of measurement for the data, preferably compliant to UDUNITS-2 units (singular).
        # gdptools expects this format as well
        try:
            cfunits.Units(unit).isvalid
        except:
            print("Unit is not valid as a UD unit.")
            unit = str(input("Please enter a valid unit for {v} from here: https://docs.unidata.ucar.edu/udunits/current/#Database"))
            assert cfunits.Units(unit).isvalid
        return unit
    
    
    def get_var_type(ds, v, crs_var):
        if v in ds.coords or v==crs_var:
    
    Snyder, Amelia Marie's avatar
    Snyder, Amelia Marie committed
            # type = auxiliary for a variable that contains coordinate data, but isn't a dimension in cube:dimensions.
            # For example, the values of the datacube might be provided in the projected coordinate reference system, 
            # but the datacube could have a variable lon with dimensions (y, x), giving the longitude at each point.
            var_type = 'auxiliary'
        # type = data for a variable indicating some measured value, for example "precipitation", "temperature", etc.
        else:
            var_type = 'data'
        return var_type
    
    def find_paths(nested_dict, prepath=()):
        for k, v in nested_dict.items():
            try:
                path = prepath + (k,)
                if type(v) is np.float64: # found value
                    yield path
                elif hasattr(v, 'items'): # v is a dict
                    yield from find_paths(v, path) 
            except:
                print(prepath)