Source code for wolframclient.evaluation.cloud.asyncoauth

from __future__ import absolute_import, print_function, unicode_literals

import logging

from wolframclient.evaluation.cloud.base import OAuthAsyncSessionBase, UserIDPassword
from wolframclient.exception import AuthenticationException
from wolframclient.utils import six
from wolframclient.utils.api import aiohttp, oauth, ssl

logger = logging.getLogger(__name__)


[docs]class OAuthAIOHttpAsyncSessionBase(OAuthAsyncSessionBase): """ Asynchronous OAuth authentication class using aiohttp library for requests. """ def __init__( self, http_session, server, consumer_key, consumer_secret, signature_method=None, client_class=None, ssl_context_class=None, ): super().__init__( server, consumer_key, consumer_secret, signature_method=signature_method, client_class=client_class or oauth.Client, ) self.http_session = http_session self.ssl_context_class = ssl_context_class or ssl.SSLContext if self.server.certificate is not None: self._ssl_context = self.ssl_context_class() self._ssl_context.load_verify_locations(self.server.certificate) # self._ssl_context = ssl.create_default_context(cafile=self.server.certificate) else: self._ssl_context = None
[docs] async def signed_request(self, uri, headers={}, data=None, method="POST"): """ Construct a signed request and send it.""" if not self.authorized(): await self.authenticate() req_headers = {} for k, v in headers.items(): req_headers[k] = v sign_body = False # Payload Instances are not encoded (e.g: octet stream). Only FormData are. form_encoded = isinstance(data, aiohttp.FormData) and not data.is_multipart multipart = isinstance(data, aiohttp.FormData) and data.is_multipart # only form encoded body are signed. # Non multipart FormData are url encoded: need signed request. We need to get back the body # as a string. if form_encoded: buffer = _AsyncBytesIO() await data().write(buffer) body = buffer.getvalue() req_headers["Content-Type"] = "application/x-www-form-urlencoded" uri, req_headers, signed_body = self._client.sign( uri, method, body=body if form_encoded else None, headers=req_headers, realm=self.server.cloudbase, ) if logger.isEnabledFor(logging.DEBUG): logger.debug("Signed uri: %s", uri) logger.debug("Signed header: %s", req_headers) logger.debug("Is body signed: %s", form_encoded) if multipart or not form_encoded: body = data else: body = aiohttp.StringPayload(signed_body) return await self.http_session.request( method, uri, data=body, headers=req_headers, ssl=self._ssl_context )
async def _ensure_success_response(self, response): msg = None if response.status == 200: return try: as_json = await response.json() msg = as_json.get("message", None) # msg is None if response is not JSON, but it's fine. except: raise AuthenticationException( response, "Request failed with status %i" % response.status ) raise AuthenticationException(response, msg)
[docs]class OAuth1AIOHttpAsyncSession(OAuthAIOHttpAsyncSessionBase): """ OAuth1 using aiohttp."""
[docs] async def set_oauth_request_token(self): if logger.isEnabledFor(logging.DEBUG): logger.debug( "Fetching oauth request token from: %s", self.server.request_token_endpoint ) logging.disable(logging.DEBUG) token_client = self.client_class(self.consumer_key, client_secret=self.consumer_secret) uri, headers, body = token_client.sign(self.server.request_token_endpoint, "POST") logging.disable(logging.NOTSET) async with self.http_session.post( uri, headers=headers, data=body, ssl=self._ssl_context ) as response: await self._ensure_success_response(response) self._update_token_from_request_body(await response.read())
[docs] async def set_oauth_access_token(self): if logger.isEnabledFor(logging.DEBUG): logger.debug( "Fetching oauth access token from %s", self.server.access_token_endpoint ) access_client = self.client_class( self.consumer_key, client_secret=self.consumer_secret, resource_owner_key=self._oauth_token, resource_owner_secret=self._oauth_token_secret, ) uri, headers, body = access_client.sign(self.server.access_token_endpoint, "POST") async with self.http_session.post( uri, headers=headers, data=body, ssl=self._ssl_context ) as response: await self._ensure_success_response(response) self._update_token_from_request_body(await response.read())
[docs] async def authenticate(self): await self.set_oauth_request_token() await self.set_oauth_access_token() self._update_client()
[docs]class XAuthAIOHttpAsyncSession(OAuthAIOHttpAsyncSessionBase): """ XAuth using aiohttp.""" def __init__( self, userid_password, http_session, server, signature_method=None, client_class=oauth.Client, ): super().__init__( http_session, server, server.xauth_consumer_key, server.xauth_consumer_secret, signature_method=signature_method, client_class=client_class, ) if not self.server.is_xauth(): raise AuthenticationException( "XAuth is not configured for this server. Missing xauth consumer key and/or secret." ) if isinstance(userid_password, tuple) and len(userid_password) == 2: self.xauth_credentials = UserIDPassword(*userid_password) elif isinstance(userid_password, UserIDPassword): self.xauth_credentials = userid_password else: raise ValueError( "User ID and password must be specified as a tuple or a UserIDPassword instance." )
[docs] async def authenticate(self): if logger.isEnabledFor(logging.DEBUG): logger.debug("xauth authentication of user %s", self.xauth_credentials.user) client = self.client_class(self.consumer_key, self.consumer_secret) # avoid dumping password in log files. logging.disable(logging.DEBUG) uri, headers, body = client.sign( self.server.access_token_endpoint, "POST", headers=self.DEFAULT_CONTENT_TYPE, body={ "x_auth_username": self.xauth_credentials.user, "x_auth_password": self.xauth_credentials.password, "x_auth_mode": "client_auth", }, ) logging.disable(logging.NOTSET) async with self.http_session.post( uri, headers=headers, data=body, ssl=self._ssl_context ) as response: await self._ensure_success_response(response) self._update_token_from_request_body(await response.read()) self._update_client()
class _AsyncBytesIO(object): def __init__(self, initial_bytes=None): self.buffer = six.BytesIO(initial_bytes) async def write(self, value): self.buffer.write(value) def getvalue(self): self.buffer.flush() return self.buffer.getvalue()