Source code for gdgps_apps.client

import requests
from datetime import datetime
from base64 import b64decode
from httpsig.requests_auth import HTTPSignatureAuth
import re
import os
from requests_toolbelt import (
    MultipartEncoder,
    MultipartEncoderMonitor
)
import logging

import json
from .encode import PPPXJSONEncoder
from .exceptions import InvalidAuthenticationCredentials
from .defines import *
from urllib.parse import urlsplit

endpoint_to_defines = {
    'data': Data,
    'gipsydata': GIPSYData,
    'gipsyxdata': GIPSYXData
}

class PersistCredentialsAcrossRedirectsSession(requests.Session):
    """
    Requests library strips authentication information on redirected requests. This custom session allows us to persist
    auth credentials across redirects that fall within the TRUSTED_DOMAINS we define.

    I wouldn't be surprised if future updates to requests library breaks this. See:
        https://github.com/psf/requests/issues/2949
    """

    TRUSTED_DOMAINS = ['gdgps.net']

    def rebuild_auth(self, prepared_request, response):
        if response.is_redirect and prepared_request.url != response.url:
                if '.'.join(urlsplit(prepared_request.url).netloc.split(':')[0].split('.')[-2:]).lower() in self.TRUSTED_DOMAINS:
                    return
        return super().rebuild_auth(prepared_request, response)


class BaseClient(object):
    __logger__ = logging.getLogger(__name__ + '.BaseClient')

    def __init__(
            self,
            base_url=None,
            label=None,
            api=None,
            api_key=None,
            api_secret=None,
            raise_for_status=True,
            trust_env=False
    ):
        self.base_url_ = base_url.rstrip('/')
        self.label_ = label
        self.api_ = api.rstrip('/')
        self.api_url_ = self.base_url_ + '/' + self.api_.lstrip('/')
        self.algorithm_ = 'hmac-sha256'
        self.headers_ = ['date', ]
        self.trust_env_ = trust_env
        self.api_key_ = None
        self.api_secret_ = None

        self.authorization_ = None
        if api_key is not None and api_secret is not None:
            self.set_credentials(api_key, api_secret)

        if self.authorization_ is None:
            raise InvalidAuthenticationCredentials(
                'Invalid or missing authentication credentials for client %s: ( %s, %s )' % (
                    self.base_url_,
                    api_key,
                    api_secret
                )
            )

        self.raise_for_status_ = raise_for_status
        self.session_ = None
        # self.establish_session( )

    @property
    def url(self):
        return self.base_url_

    @property
    def label(self):
        return self.label_

    def establish_session(self):
        meta = {}
        if self.session_ is None:
            self.session_ = PersistCredentialsAcrossRedirectsSession()
            self.session_.trust_env = self.trust_env_  # https://github.com/requests/requests/issues/2773
            pinged, meta = self.ping()
            if pinged:
                self.__logger__.info('Session established with %s', self.base_url_)
            else:
                self.session_ = None
                self.__logger__.warning('Unable to establish session with %s: %s', self.base_url_, meta)
                return False, meta
        return True, meta

    def ping(self):
        if self.session_ is None:
            return self.establish_session()
        try:
            resp = self.request(endpoint='load')
            resp.raise_for_status()
            return True, resp.json()
        except Exception as e:
            return False, e

    def set_credentials(self, api_key, api_secret):
        self.api_key_ = api_key
        self.api_secret_ = b64decode(api_secret)
        self.authorization_ = HTTPSignatureAuth(
            key_id=self.api_key_,
            secret=self.api_secret_,
            algorithm=self.algorithm_,
            headers=self.headers_
        )

    def request(
            self,
            verb='GET',
            endpoint=None,
            data=None,
            send_json=False,
            uri=None,
            files=None,
            content_type=None,
            stream=False,
            headers=None,
            params=None,
            allow_redirects=True
    ):

        established, reason = self.establish_session()
        if not established:
            raise reason

        if uri is None:
            uri = self.api_url_ + ('/' + endpoint) if endpoint is not None else ''

        if headers is None:
            headers = {}
        headers['X-Api-Key'] = self.api_key_
        headers['date'] = datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S UTC')

        if send_json:
            headers['Content-Type'] = 'application/json'
        if 'content_type' not in headers and content_type is not None:
            headers['Content-Type'] = content_type

        req = requests.Request(
            verb,
            uri,
            auth=self.authorization_,
            headers=headers,
            data=json.dumps(data, indent=1, cls=PPPXJSONEncoder) if send_json else data,
            files=files,
            params=params
        )

        prepared = req.prepare()

        if self.__logger__.isEnabledFor(logging.DEBUG):
            self.__logger__.debug(
                '\n####################\n{}\n{}\n\n{}####################\n'.format(
                    prepared.method + ' ' + prepared.url,
                    '\n'.join('{}: {}'.format(k, v) for k, v in prepared.headers.items()),
                    prepared.body,
                ))

        response = self.session_.send(prepared, allow_redirects=allow_redirects, stream=stream)

        if self.__logger__.isEnabledFor(logging.DEBUG):
            self.__logger__.debug(
                '\n####################\n{}\n{}\n\n{}####################\n'.format(
                    str(response.status_code) + ' ' + response.reason,
                    '\n'.join('{}: {}'.format(k, v) for k, v in response.headers.items()),
                    response.text if len(response.content) < 2048 else '<omitted>',
                ))

        if self.raise_for_status_:
            # throws HTTPError if status code >=400
            response.raise_for_status()

        return response

    def update_data(self, dataid, endpoint='data', stream=False, progress_hook=None, **kwargs):
        defines = endpoint_to_defines.get(endpoint, Data)
        endpoint = '/'.join([endpoint.rstrip('/').lstrip('/'), dataid]) + '/'
        updt = {}
        for opt, val in kwargs.items():
            if str(opt).lower() != 'sources':
                updt[str(opt)] = str(val) if val is not None else '' # multipart encoder can't handle floats or Nones?
        uploads = {}
        if 'sources' in kwargs:
            for src in kwargs.pop('sources', []):
                src_typ = src.get('source_type', None)
                file_path = src.get('file', None)
                label = 'file%s' % (len(uploads) if len(uploads) > 0 else '',)
                if src_typ:
                    label = defines.get_verbose_source_type(src_typ).lower()
                if file_path:
                    uploads[label] = (
                        os.path.basename(file_path),
                        open(file_path, 'rb'),
                        'application/octet-stream'
                    ) if stream else open(file_path, 'rb')
                elif src_typ and src.get('delete', False):
                    updt[label] = 'DELETE'

        if uploads:
            if not stream:
                return self.request('PATCH', endpoint, files=uploads, data=updt, stream=stream)

            updt.update(uploads)
            monitor = MultipartEncoder(fields=updt)
            if progress_hook:
                monitor = MultipartEncoderMonitor(monitor, self.CallbackWrapper(progress_hook))
            return self.request('PATCH', endpoint, data=monitor, stream=stream, content_type=monitor.content_type)
        else:
            return self.request(verb='PATCH', endpoint=endpoint, data=updt)

    def sources(self, query=None):
        return self.request(endpoint='sources', params=query).json()

    def upload_source(self, dataid, file, source_type=None, stream=False, progress_hook=None, **kwargs):
        if progress_hook and not stream:
            stream = True

        f = file
        if isinstance(file, str):
            f = open(file, 'rb')
        filename = os.path.basename(f.name)
        files = {'file': f}
        fields = {'file': (filename, f, 'application/octet-stream')}

        data = dict(**kwargs)
        data['data'] = dataid
        if source_type:
            data['source_type'] = source_type

        if not stream:
            return self.request(
                'POST',
                'sources/',
                files=files,
                data=data,
                stream=stream
            ).json()

        for opt, val in kwargs.items():
            fields[str(opt)] = str(val) if val is not None else '' # multipart encoder can't handle floats or Nones?
        fields['data'] = dataid
        if source_type:
            fields['source_type'] = source_type
        monitor = MultipartEncoder(fields=fields)
        if progress_hook:
            monitor = MultipartEncoderMonitor(monitor, self.CallbackWrapper(progress_hook))
        return self.request(
            'POST',
            'sources/',
            data=monitor,
            stream=stream,
            content_type=monitor.content_type
        )

    def get_source(self, sid, query=None):
        return self.request(endpoint='sources/' + sid + '/').json()

    def delete_source(self, srcid):
        return self.request('DELETE', endpoint='sources/' + srcid + '/')

    def results(self, query=None):
        return self.request(endpoint='results', params=query).json()

    def delete_result(self, resid):
        return self.request('DELETE', endpoint='results/' + resid + '/')

    def download(self, uri, stream=True, write_to_disk=True, dr=None, query=None):

        resp = self.request(uri=uri, stream=stream, params=query)

        if write_to_disk and resp and resp.ok:

            if dr is None:
                dr = os.getcwd()
            elif not (os.path.exists(dr) and os.path.isdir(dr)):
                raise ValueError('%s is not a valid download directory.' % dr)

            filename = re.findall('filename=(.+)', resp.headers['content-disposition'])

            dl_path = os.path.normpath(dr + '/' + filename[0].replace('"', ''))
            chunk_size = self.chunk_size
            with open(dl_path, 'wb') as dout:
                while True:
                    chunk = resp.raw.read(chunk_size)
                    if not chunk:
                        break
                    dout.write(chunk)

            return dl_path

        return resp


[docs]class PortalClient(BaseClient): __logger__ = logging.getLogger(__name__ + '.PortalClient') def __init__( self, portal_url=None, label=None, api_key=None, api_secret=None, api=DEFAULT_USER_API_PATH, raise_for_status=True, trust_env=False ): super(PortalClient, self).__init__( base_url=portal_url, label=label, api=api, api_key=api_key, api_secret=api_secret, raise_for_status=raise_for_status, trust_env=trust_env )
[docs] def list_data(self, query=None): data = self.request(endpoint='data', params=query) return data.json()
[docs] def alerts(self, query=None): return self.request(endpoint='alerts', params=query).json()
[docs] def delete_alert(self, alertid): return self.request(verb='DELETE', endpoint='alerts/' + alertid + '/')
[docs] def processors(self, query=None): return self.request(endpoint='processors', params=query).json()
[docs] def approve(self, dataid): return self.update_data(dataid, state=Data.APPROVED)
[docs] def detail(self, uuid, dtype='gipsyx', query=None): return self.request(endpoint=dtype + '/' + uuid, params=query).json()
[docs] def delete_data(self, dataid): return self.request(verb='DELETE', endpoint='data/' + dataid)
[docs] def profile(self, query=None): return self.request(endpoint='profile', params=query).json()
[docs] def update_profile(self, userid, **kwargs): return self.request(verb='PATCH', endpoint='profile/' + userid + '/', data=dict(**kwargs))
[docs] def flags(self, query=None): return self.request(endpoint='dataflags', params=query).json()
[docs] def delete_flag(self, flagid): return self.request(verb='DELETE', endpoint='dataflags/' + flagid + '/')
[docs]class ProcessorClient(BaseClient): __logger__ = logging.getLogger(__name__ + '.ProcessorClient') chunk_size = 32 * 1024
[docs] class CallbackWrapper(object): bytes_read = 0 def __init__(self, callback): self.callback = callback def __call__(self, monitor): self.callback(monitor.bytes_read - self.bytes_read) self.bytes_read = monitor.bytes_read
def __init__( self, processor_url, api_key, api_secret, api=DEFAULT_USER_API_PATH, label=None, raise_for_status=True, trust_env=False ): super(ProcessorClient, self).__init__( base_url=processor_url, label=label, api=api, api_key=api_key, api_secret=api_secret, raise_for_status=raise_for_status, trust_env=trust_env )
[docs] def approve(self, dataid): return self.update_data(dataid, state=Data.APPROVED)
[docs] def upload_gipsyx(self, file, antCal=None, pressure=None, attitude=None, stream=True, progress_hook=None, **kwargs): if progress_hook and not stream: stream = True f = file if isinstance(file, str): f = open(file, 'rb') filename = os.path.basename(f.name) files = {'file': f} fields = {'file': (filename, f, 'application/octet-stream')} p = pressure if pressure: if isinstance(pressure, str): p = open(pressure, 'rb') files['pressure'] = p fields['pressure'] = (os.path.basename(p.name), p, 'application/octet-stream') a = attitude if attitude: if isinstance(attitude, str): a = open(attitude, 'rb') files['attitude'] = a fields['attitude'] = (os.path.basename(a.name), a, 'application/octet-stream') c = antCal if antCal: if isinstance(antCal, str): c = open(antCal, 'rb') files['antenna_calibration'] = c fields['antenna_calibration'] = (os.path.basename(c.name), c, 'application/octet-stream') if not stream: return self.request( 'POST', 'gipsyxdata/', files=files, data=dict(**kwargs), stream=stream ).json() for opt, val in kwargs.items(): fields[str(opt)] = str(val) if val is not None else '' # multipart encoder can't handle floats or Nones? monitor = MultipartEncoder(fields=fields) if progress_hook: monitor = MultipartEncoderMonitor(monitor, self.CallbackWrapper(progress_hook)) return self.request( 'POST', 'gipsyxdata/', data=monitor, stream=stream, content_type=monitor.content_type )