From f3583b6e050a85f4b40886ed9c6ad8d38fe3780a Mon Sep 17 00:00:00 2001
From: Travis Rivers <travrivers88@gmail.com>
Date: Thu, 27 Feb 2020 17:39:13 -0700
Subject: [PATCH] updates based upon Github comments

---
 geomagio/TimeseriesUtility.py |  12 +-
 geomagio/webservice/data.py   | 216 +++++++++++++++++++---------------
 2 files changed, 127 insertions(+), 101 deletions(-)

diff --git a/geomagio/TimeseriesUtility.py b/geomagio/TimeseriesUtility.py
index e80ac952..f9dc3cf3 100644
--- a/geomagio/TimeseriesUtility.py
+++ b/geomagio/TimeseriesUtility.py
@@ -95,18 +95,18 @@ def get_interval_from_delta(delta):
     interval : str
         interval length {day, hour, minute, second, tenhertz}
     """
-    if delta == 0.1:
+    if delta == "0.1":
         data_interval = "tenhertz"
     elif delta == 1:
-       data_interval = "second"
-    elif delta == 60:
+        data_interval = "second"
+    elif delta == "60":
         data_interval = "minute"
-    elif delta == 3600:
+    elif delta == "3600":
         data_interval = "hour"
-    elif delta == 86400:
+    elif delta == "86400":
         data_interval = "day"
     else:
-        data_interval = None
+        data_interval = delta
     return data_interval
 
 
diff --git a/geomagio/webservice/data.py b/geomagio/webservice/data.py
index 62545e4d..58be0346 100644
--- a/geomagio/webservice/data.py
+++ b/geomagio/webservice/data.py
@@ -5,12 +5,16 @@ from json import dumps
 from obspy import UTCDateTime
 import os
 
-from geomagio.edge import EdgeFactory
-from geomagio.iaga2002 import IAGA2002Writer
-from geomagio.imfjson import IMFJSONWriter
-from geomagio.TimeseriesUtility import get_interval_from_delta
+from ..edge import EdgeFactory
+from ..iaga2002 import IAGA2002Writer
+from ..imfjson import IMFJSONWriter
+from ..TimeseriesUtility import get_interval_from_delta
 
 
+DEFAULT_DATA_TYPE = 'variation'
+DEFAULT_ELEMENTS = ('X', 'Y', 'Z', 'F')
+DEFAULT_OUTPUT_FORMAT = 'iaga2002'
+DEFAULT_SAMPLING_PERIOD = '60'
 ERROR_CODE_MESSAGES = {
     204: "No Data",
     400: "Bad Request",
@@ -21,18 +25,23 @@ ERROR_CODE_MESSAGES = {
     503: "Service Unavailable"
 }
 VALID_DATA_TYPES = ["variation", "adjusted", "quasi-definitive", "definitive"]
+VALID_INTERVALS = ["tenhertz", "second", "minute", "hour", "day"]
 VALID_OBSERVATORIES = ["BRT", "BRW", "DED", "DHT", "CMO", "CMT", "SIT", "SHU", "NEW",
 "BDT", "BOU", "TST", "USGS", "FDT", "FRD", "FRN", "TUC", "BSL", "HON", "SJG",
 "GUA", "SJT"]
 VALID_OUTPUT_FORMATS = ["iaga2002", "json"]
-VALID_SAMPLING_PERIODS = [0.1, 1, 60, 3600, 86400]
+VALID_SAMPLING_PERIODS = ["0.1", "1", "60", "3600", "86400"]
 
 
 blueprint = Blueprint("data", __name__)
+input_factory = None
+VERSION = 'version'
 
 
 def init_app(app: Flask):
     global blueprint
+    global input_factory
+    input_factory = get_input_factory()
 
     app.register_blueprint(blueprint)
 
@@ -107,12 +116,17 @@ class WebServiceQuery(object):
         self.output_format = output_format
 
 
-def format_error(status_code, exception, query, url):
+def format_error(status_code, exception, parsed_query, request):
     """Assign error_body value based on error format."""
-    error_body = http_error(status_code, exception, query, url)
-    status = str(status_code) + ' ' + ERROR_CODE_MESSAGES[status_code]
 
-    return Response(error_body, mimetype="text/plain")
+    if parsed_query.output_format == 'json':
+        return Response(
+            json_error(status_code, exception, request.url),
+            mimetype="application/json")
+    else:
+        return Response(
+            iaga2002_error(status_code, exception, request.query_string),
+            mimetype="text/plain")
 
 
 def format_timeseries(timeseries, query):
@@ -149,14 +163,18 @@ def get_input_factory():
     input_factory
         Edge or miniseed factory object
     """
-    DATA_TYPE = os.getenv('Type', 'edge')
+    data_type = os.getenv('DATA_TYPE', 'edge')
+    host = os.getenv('DATA_HOST', 'cwbpub.cr.usgs.gov')
+    port = os.getenv('DATA_PORT', 2060)
 
-    if DATA_TYPE == 'edge':
+    if data_type == 'edge':
         input_factory = EdgeFactory(
-        host=os.getenv('DATA_HOST', 'cwbpub.cr.usgs.gov'),
-        port=os.getenv('DATA_PORT', 2060)
-        )
+        host=host,
+        port=port,
+        type=data_type)
         return input_factory
+    else:
+        return None
 
 
 def get_timeseries(query):
@@ -172,9 +190,6 @@ def get_timeseries(query):
     obspy.core.Stream
         timeseries object with requested data
     """
-    query.sampling_period = get_interval_from_delta(query.sampling_period)
-    input_factory = get_input_factory()
-
     timeseries = input_factory.get_timeseries(
         query.starttime,
         query.endtime,
@@ -182,26 +197,9 @@ def get_timeseries(query):
         query.elements,
         query.data_type,
         query.sampling_period)
-
     return timeseries
 
 
-def http_error(code, message, query, request):
-    """Format http error message.
-
-    Returns
-    -------
-    http_error_body : str
-        body of http error message.
-    """
-    if query.output_format == 'json':
-        http_error_body = json_error(code, message, request.url)
-        return http_error_body
-    else:
-        http_error_body = iaga2002_error(code, message, request.query_string)
-        return http_error_body
-
-
 def iaga2002_error(code, message, request_args):
     """Format iaga2002 error message.
 
@@ -210,16 +208,22 @@ def iaga2002_error(code, message, request_args):
     error_body : str
         body of iaga2002 error message.
     """
-
     status_message = ERROR_CODE_MESSAGES[code]
-    error_body = 'Error ' + str(code) + ': ' \
-    + status_message + '\n\n' + message + '\n\n'\
-    + 'Usage details are available from '\
-    + 'http://geomag.usgs.gov/ws/edge/ \n\n'\
-    + 'Request:\n' + 'ws/edge/?' + str(request_args)[2:-1] + '\n\n' + 'Request Submitted:\n'\
-    + datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") + '\n'
-
-    error_body = Response(error_body, mimetype="text/plain")
+    error_body = f"""Error {code}: {status_message}
+
+{message}
+
+Usage details are available from {request.base_url}
+
+Request:
+{request.url}
+
+Request Submitted:
+{UTCDateTime().isoformat()}Z
+
+Service Version:
+{VERSION}
+"""
     return error_body
 
 
@@ -231,20 +235,22 @@ def json_error(code, message, url):
     error_body : str
         body of json error message.
     """
-    error_dict = OrderedDict()
-    error_dict['type'] = "Error"
-    error_dict['metadata'] = OrderedDict()
-    error_dict['metadata']['status'] = code
     date = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
-    error_dict['metadata']['generated'] = date
-    error_dict['metadata']['url'] = url
     status_message = ERROR_CODE_MESSAGES[code]
-    error_dict['metadata']['title'] = status_message
-    error_dict['metadata']['error'] = message
-    error_body = dumps(error_dict,
-    ensure_ascii=True).encode('utf8')
 
-    return error_body
+    error_dict = {
+        "type": "Error",
+        "metadata": {
+            "status": code,
+            "generated": date,
+            "url": url,
+            "title": status_message,
+            "error": message
+        }
+
+    }
+    return dumps(error_dict,
+    sort_keys=True).encode('utf8')
 
 
 def parse_query(query):
@@ -265,43 +271,67 @@ def parse_query(query):
     WebServiceException
         if any parameters are not supported.
     """
-    # Create web service query object
-    params = WebServiceQuery()
-
-    # Get and assign values
-    if not query.get("starttime"):
-        start_time = UTCDateTime(query.get("endtime")) - (24 * 60 * 60 - 1)
+    # Get values
+    observatory_id = query.get("observatory")
+    start_time = UTCDateTime(query.get("starttime"))
+    end_time = UTCDateTime(query.get("endtime"))
+    elements = query.get("channels", DEFAULT_ELEMENTS)
+    sampling_period = query.get("sampling_period", DEFAULT_SAMPLING_PERIOD)
+    data_type = query.get("type", DEFAULT_DATA_TYPE)
+    output_format = query.get("format", DEFAULT_OUTPUT_FORMAT)
+    # Assign values or defaults
+    if not output_format:
+        output_format = DEFAULT_OUTPUT_FORMAT
     else:
-        start_time = UTCDateTime(query.get("starttime"))
-    params.starttime = start_time
+        output_format = output_format.lower()
 
-    if not query.get("endtime"):
+    observatory_id = observatory_id.upper()
+    if observatory_id not in VALID_OBSERVATORIES:
+        raise WebServiceException(
+             f"""Bad observatory id "{query.observatory_id}".  Valid values are:  {', '.join(VALID_OBSERVATORIES)}."""
+            )
+    if not start_time:
         now = datetime.now()
-        today = UTCDateTime(now.year, now.month, now.day, 0)
-        end_time = today
+        today = UTCDateTime(
+                year=now.year,
+                month=now.month,
+                day=now.day,
+                hour=0)
+        start_time = today
+    else:
+        try:
+            start_time = UTCDateTime(start_time)
+        except Exception:
+            raise WebServiceException(
+                    'Bad start_time value "%s".'
+                    ' Valid values are ISO-8601 timestamps.' % start_time)
+    if not end_time:
+        end_time = start_time + (24 * 60 * 60 - 1)
     else:
-        end_time = UTCDateTime(query.get("endtime"))
+        try:
+            end_time = UTCDateTime(end_time)
+        except Exception:
+            raise WebServiceException(
+                    'Bad end_time value "%s".'
+                    ' Valid values are ISO-8601 timestamps.' % end_time)
+
+    if not sampling_period:
+        sampling_period = DEFAULT_SAMPLING_PERIOD
+    else:
+        sampling_period = sampling_period
+    if not data_type:
+        data_type = DEFAULT_DATA_TYPE
+    else:
+        data_type = data_type.lower()
+    # Create WebServiceQuery object and set properties
+    params = WebServiceQuery()
+    params.observatory_id = observatory_id
+    params.starttime = start_time
     params.endtime = end_time
-
-    if query.get("sampling_period"):
-        sampling_period = int(query.get("sampling_period"))
-        params.sampling_period =  sampling_period
-
-    if query.get("format"):
-        format = query.get("format")
-        params.output_format = format
-
-    observatory = query.get("observatory")
-    params.observatory_id = observatory
-
-    if query.get("channels"):
-        channels = query.get("channels").split(",")
-        params.elements = channels
-
-    if query.get("type"):
-        type = query.get("type")
-        params.data_type = type
-
+    params.elements = elements
+    params.sampling_period = str(get_interval_from_delta(sampling_period))
+    params.data_type = data_type
+    params.output_format = output_format
     return params
 
 
@@ -324,23 +354,19 @@ def validate_query(query):
         )
     if query.observatory_id not in VALID_OBSERVATORIES:
         raise WebServiceException(
-            'Bad observatory ID "%s".'
-            " Valid values are: %s" % (query.observatory_id, ', '.join(VALID_OBSERVATORIES) + '.')
+             f"""Bad observatory id "{query.observatory_id}".  Valid values are:  {', '.join(VALID_OBSERVATORIES)}."""
             )
     if query.starttime > query.endtime:
         raise WebServiceException("Starttime must be before endtime.")
     if query.data_type not in VALID_DATA_TYPES:
         raise WebServiceException(
-            'Bad type value "%s".'
-            " Valid values are: %s" % (query.data_type, ', '.join(VALID_DATA_TYPES) + '.')
+             f"""Bad data type value "{query.data_type}". Valid values are:  {', '.join(VALID_DATA_TYPES)}."""
             )
-    if query.sampling_period not in VALID_SAMPLING_PERIODS:
+    if query.sampling_period not in VALID_INTERVALS:
         raise WebServiceException(
-            'Bad sampling_period value "%s".'
-            " Valid values are: %s" % (query.sampling_period, ', '.join(VALID_SAMPLING_PERIODS) + '.')
+            f"""Bad sampling_period value {query.sampling_period}. Valid values are:  {', '.join(VALID_SAMPLING_PERIODS)}."""
             )
     if query.output_format not in VALID_OUTPUT_FORMATS:
         raise WebServiceException(
-            'Bad format value "%s".'
-            " Valid values are: %s" % (query.output_format, ', '.join(VALID_OUTPUT_FORMATS) + '.')
+             f"""Bad format value "{query.output_format}".  Valid values are:  {', '.join(VALID_OUTPUT_FORMATS)}."""
             )
\ No newline at end of file
-- 
GitLab