diff --git a/geomagio/api/db/MetadataDatabaseFactory.py b/geomagio/api/db/MetadataDatabaseFactory.py index b9fa85cc208dccf732e6b5442a30d2a69c7911a0..c7676bc7352d0e3db0cb9a96f0b95cbf38ed9ae1 100644 --- a/geomagio/api/db/MetadataDatabaseFactory.py +++ b/geomagio/api/db/MetadataDatabaseFactory.py @@ -3,9 +3,9 @@ from typing import List, Optional from databases import Database from obspy import UTCDateTime -from sqlalchemy import or_, Table +from sqlalchemy import or_ -from ...metadata import Metadata, MetadataCategory +from ...metadata import Metadata, MetadataQuery from .metadata_history_table import metadata_history from .metadata_table import metadata @@ -24,102 +24,73 @@ class MetadataDatabaseFactory(object): async def get_metadata( self, - *, # make all params keyword - id: int = None, - network: str = None, - station: str = None, - channel: str = None, - location: str = None, - category: MetadataCategory = None, - starttime: datetime = None, - endtime: datetime = None, - created_after: datetime = None, - created_before: datetime = None, - data_valid: bool = None, - metadata: Table = metadata, - metadata_valid: bool = None, - status: List[str] = None, + params: MetadataQuery, + history: bool = False, ) -> List[Metadata]: - query = metadata.select() + table = metadata + if history: + table = metadata_history + query = table.select() + ( + id, + category, + network, + station, + channel, + location, + starttime, + endtime, + created_after, + created_before, + data_valid, + metadata_valid, + status, + ) = params.dict().values() if id: - query = query.where(metadata.c.id == id) + query = query.where(table.c.id == id) if category: - query = query.where(metadata.c.category == category) + query = query.where(table.c.category == category) if network: - query = query.where(metadata.c.network == network) + query = query.where(table.c.network == network) if station: - query = query.where(metadata.c.station == station) + query = query.where(table.c.station == station) if channel: - query = query.where(metadata.c.channel.like(channel)) + query = query.where(table.c.channel.like(channel)) if location: - query = query.where(metadata.c.location.like(location)) + query = query.where(table.c.location.like(location)) if starttime: query = query.where( or_( - metadata.c.endtime == None, - metadata.c.endtime > starttime, + table.c.endtime == None, + table.c.endtime > starttime, ) ) if endtime: query = query.where( or_( - metadata.c.starttime == None, - metadata.c.starttime < endtime, + table.c.starttime == None, + table.c.starttime < endtime, ) ) if created_after: - query = query.where(metadata.c.created_time > created_after) + query = query.where(table.c.created_time > created_after) if created_before: - query = query.where(metadata.c.created_time < created_before) + query = query.where(table.c.created_time < created_before) if data_valid is not None: - query = query.where(metadata.c.data_valid == data_valid) + query = query.where(table.c.data_valid == data_valid) if metadata_valid is not None: - query = query.where(metadata.c.metadata_valid == metadata_valid) + query = query.where(table.c.metadata_valid == metadata_valid) if status is not None: - query = query.where(metadata.c.status.in_(status)) + query = query.where(table.c.status.in_(status)) rows = await self.database.fetch_all(query) return [Metadata(**row) for row in rows] async def get_metadata_by_id(self, id: int): - meta = await self.get_metadata(id=id) + meta = await self.get_metadata(MetadataQuery(id=id)) if len(meta) != 1: raise ValueError(f"{len(meta)} records found") return meta[0] - async def get_metadata_history( - self, - *, # make all params keyword - id: int = None, - network: str = None, - station: str = None, - channel: str = None, - location: str = None, - category: MetadataCategory = None, - starttime: datetime = None, - endtime: datetime = None, - created_after: datetime = None, - created_before: datetime = None, - data_valid: bool = None, - metadata_valid: bool = None, - status: List[str] = None, - ) -> List[Metadata]: - return await self.get_metadata( - id=id, - network=network, - station=station, - channel=channel, - location=location, - category=category, - starttime=starttime, - endtime=endtime, - created_after=created_after, - created_before=created_before, - data_valid=data_valid, - metadata=metadata_history, - metadata_valid=metadata_valid, - status=status, - ) - async def get_metadata_history_by_id(self, id: int) -> Optional[Metadata]: query = metadata_history.select() query = query.where(metadata_history.c.id == id) diff --git a/geomagio/api/secure/metadata.py b/geomagio/api/secure/metadata.py index 4821c83bebf124dbed108a8a52f44f2b974660eb..755948d12afa590d377ee034ebd0b77bf3bddc11 100644 --- a/geomagio/api/secure/metadata.py +++ b/geomagio/api/secure/metadata.py @@ -73,16 +73,14 @@ async def create_metadata( @router.get("/metadata", response_model=List[Metadata]) async def get_metadata(query: MetadataQuery = Depends(get_metadata_query)): - metas = await MetadataDatabaseFactory(database=database).get_metadata( - **query.datetime_dict(exclude={"id", "metadata_id"}) - ) + metas = await MetadataDatabaseFactory(database=database).get_metadata(params=query) return metas @router.get("/metadata/history", response_model=List[Metadata]) async def get_metadata_history(query: MetadataQuery = Depends(get_metadata_query)): - metas = await MetadataDatabaseFactory(database=database).get_metadata_history( - **query.datetime_dict(exclude={"id", "metadata_id"}) + metas = await MetadataDatabaseFactory(database=database).get_metadata( + params=query, history=True ) return metas