diff --git a/.vscode/settings.json b/.vscode/settings.json index 1265994..828249e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,5 +4,6 @@ "tests" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "python.analysis.typeCheckingMode": "basic" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index b37d839..2a49e8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Core: support OAuth client credentials authentication via `client_id` and `client_secret` in `XurrentApiHelper` while maintaining API key compatibility. +- Core: The OAuth token endpoint now dynamically determines the domain from `base_url`, preserving any regional subdomains to ensure consistency between API and OAuth endpoints. +- Core: When using OAuth, if a 401 Unauthorized error is received, the token is automatically refreshed and the API call is retried once. If authentication still fails after token refresh, an explicit HTTPError is raised. + ## [0.10.0] - 2025-08-16 ### Added diff --git a/README.md b/README.md index e971d57..74afe4a 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,30 @@ This module is used to interact with the Xurrent API. It provides a set of class # this can be used to derive the ID from the nodeID ``` +#### Using OAuth client credentials + +You can let the helper automatically request and refresh bearer tokens by providing the OAuth +`client_id` and `client_secret` that were issued to your application. The original API key flow +continues to work unchanged, but only one authentication method may be used per helper instance. + +```python + from xurrent.core import XurrentApiHelper + + baseUrl = "https://api.xurrent.qa/v1" + account = "account-name" + client_id = "your-client-id" + client_secret = "your-client-secret" + + x_api_helper = XurrentApiHelper( + baseUrl, + api_account=account, + client_id=client_id, + client_secret=client_secret, + ) + + response = x_api_helper.api_call("/requests", "GET") +``` + #### Configuration Items ```python diff --git a/pyproject.toml b/pyproject.toml index 6d87f2f..97fa634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "xurrent" -version = "0.10.0" +version = "0.11.0-preview.10" authors = [ { name="Fabian Steiner", email="fabian@stei-ner.net" }, ] @@ -18,7 +18,7 @@ Homepage = "https://github.com/fasteiner/xurrent-python" Issues = "https://github.com/fasteiner/xurrent-python/issues" [tool.poetry] name = "xurrent" -version = "0.10.0" +version = "0.11.0-preview.10" description = "A python module to interact with the Xurrent API." authors = ["Ing. Fabian Franz Steiner BSc. "] readme = "README.md" diff --git a/src/xurrent/core.py b/src/xurrent/core.py index d97061a..1fcbcc6 100644 --- a/src/xurrent/core.py +++ b/src/xurrent/core.py @@ -7,6 +7,8 @@ import json import re import base64 +from logging import Logger +from typing import Optional, List class LogLevel(Enum): DEBUG = logging.DEBUG @@ -46,23 +48,49 @@ class XurrentApiHelper: api_user: Person # Forward declaration with a string api_user_teams: List[Team] # Forward declaration with a string - def __init__(self, base_url, api_key, api_account,resolve_user=True, logger: Logger=None): + def __init__( + self, + base_url, + api_key=None, + api_account=None, + resolve_user=True, + logger: Logger=None, + client_id: Optional[str]=None, + client_secret: Optional[str]=None + ): """ Initialize the Xurrent API helper. :param base_url: Base URL of the Xurrent API :param api_key: API key to authenticate with :param api_account: Account name to use + :param client_id: OAuth client ID to use when fetching an access token + :param client_secret: OAuth client secret to use when fetching an access token :param resolve_user: Resolve the API user and their teams (default: True) :param logger: Logger to use (optional), otherwise a new logger is created """ self.base_url = base_url - self.api_key = api_key self.api_account = api_account + self._client_id = client_id + self._client_secret = client_secret + self._token_expires_at: Optional[float] = None + + if bool(api_key) == bool(client_id and client_secret): + raise ValueError('Provide either api_key or both client_id and client_secret, but not both.') + if not self.api_account: + raise ValueError('api_account must be provided.') + if logger: self.logger = logger else: self.logger = self.create_logger(False) + if client_id or client_secret: + if not (client_id and client_secret): + raise ValueError('Both client_id and client_secret are required for OAuth authentication.') + self.api_key = None + self._obtain_access_token() + else: + self.api_key = api_key #Create a requests session to maintain persistent connections, with preset headers self.__session = requests.Session() self.__session.headers.update({ @@ -75,6 +103,70 @@ def __init__(self, base_url, api_key, api_account,resolve_user=True, logger: Log self.api_user = Person.get_me(self) self.api_user_teams = self.api_user.get_teams() + def _ensure_access_token(self): + """Ensure that a valid access token is available.""" + if not self._client_id: + return + + needs_refresh = self.api_key is None + if self._token_expires_at is not None: + needs_refresh = needs_refresh or time.time() >= self._token_expires_at + + if needs_refresh: + self._obtain_access_token() + + def _obtain_access_token(self): + """Fetch a new OAuth access token using the client credentials grant.""" + # Dynamically determine the domain from the base_url and preserve regional subdomains + import urllib.parse + parsed = urllib.parse.urlparse(self.base_url) + # Extract the netloc (e.g. api.xurrent.com, api.au.xurrent.com) and replace 'api' with 'oauth' + netloc_parts = parsed.netloc.split('.') + if len(netloc_parts) < 2: + raise ValueError('Invalid base_url for extracting domain') + + # Replace the first part (assumed to be 'api') with 'oauth' + if netloc_parts[0] == 'api': + netloc_parts[0] = 'oauth' + else: + self.logger.warning(f"Expected first domain part to be 'api', got '{netloc_parts[0]}'. Proceeding anyway.") + netloc_parts[0] = 'oauth' + + # Reconstruct the domain preserving all parts including regional subdomains + oauth_domain = '.'.join(netloc_parts) + token_url = f'https://{oauth_domain}/token' + payload = { + 'client_id': self._client_id, + 'client_secret': self._client_secret, + 'grant_type': 'client_credentials' + } + + try: + response = requests.post(token_url, data=payload) + response.raise_for_status() + except requests.exceptions.RequestException as exc: + self.logger.error(f'Failed to obtain OAuth access token: {exc}') + raise + + data = response.json() + access_token = data.get('access_token') + if not access_token: + self.logger.error('OAuth token response did not contain an access_token.') + raise ValueError('OAuth token response did not contain an access_token.') + + expires_in = data.get('expires_in', 3600) + buffer_seconds = 60 + self._token_expires_at = time.time() + max(expires_in - buffer_seconds, 0) + self.api_key = access_token + + # Update session headers if the session exists + if hasattr(self, '__session'): + self.__session.headers.update({ + 'Authorization': f'Bearer {self.api_key}' + }) + + self.logger.debug('Obtained new OAuth access token.') + def __append_per_page(self, uri, per_page=100): """ Append the 'per_page' parameter to the URI if not already present. @@ -125,20 +217,27 @@ def create_logger(self, verbose) -> Logger: logger.addHandler(log_stream) return logger - def set_log_level(self, level: LogLevel): + def set_log_level(self, level): """ Set the log level for the logger and all handlers. - :param level: Log level to set + :param level: Log level to set (can be a string, int, or LogLevel enum) """ - self.logger.setLevel(level) + # Handle different types of input + if isinstance(level, LogLevel): + log_level = level.value + else: + log_level = level + + self.logger.setLevel(log_level) for handler in self.logger.handlers: - handler.setLevel(level) + handler.setLevel(log_level) def api_call(self, uri: str, method='GET', data=None, per_page=100, raw=False): """ Make a call to the Xurrent API with support for rate limiting and pagination. + Automatically handles 401 responses by refreshing the OAuth token if client_id and client_secret are provided. :param uri: URI to call :param method: HTTP method to use (default: GET) :param data: Data to send with the request (optional) @@ -150,67 +249,86 @@ def api_call(self, uri: str, method='GET', data=None, per_page=100, raw=False): if not uri.startswith(self.base_url) and "://" not in uri[:10]: uri = f'{self.base_url}{uri}' - aggregated_data = [] - next_page_url = uri - - while next_page_url: - try: - # Append pagination parameters for GET requests - if per_page and method == 'GET': - # if contains ? or does not end with /, append per_page - next_page_url = self.__append_per_page(next_page_url, per_page) - - # Log the request - self.logger.debug(f'{method} {next_page_url} {data if method != "GET" else ""}') - - # Make the HTTP request - response = self.__session.request(method, next_page_url, json=data) - - if response.status_code == 204: - return None - - # Handle rate limiting (429 status code) - if response.status_code == 429: - retry_after = int(response.headers.get('Retry-After', 1)) # Default to 1 second if not provided - self.logger.warning(f'Rate limit reached. Retrying after {retry_after} seconds...') - time.sleep(retry_after) - continue - - # Check for other non-success status codes - if not response.ok: - self.logger.error(f'Error in request: {response.status_code} - {response.text}') - response.raise_for_status() - - #Stop here if we shall not process or interperet the returned data - if raw: - return response.content - - # Process response - response_data = response.json() - - # For GET requests, handle pagination - if method == 'GET' and isinstance(response_data, list): - aggregated_data.extend(response_data) - - # Parse the 'Link' header to find the 'next' page URL - link_header = response.headers.get('Link') - if link_header: - links = {rel.strip(): url.strip('<>') for url, rel in - (link.split(';') for link in link_header.split(','))} - next_page_url = links.get('rel="next"') - if next_page_url: - next_page_url = next_page_url.replace('<', '').replace('>', '') + def do_request(): + self._ensure_access_token() + headers = { + 'Authorization': f'Bearer {self.api_key}', + 'x-xurrent-account': self.api_account + } + aggregated_data = [] + next_page_url = uri + while next_page_url: + try: + # Append pagination parameters for GET requests + if per_page and method == 'GET': + next_page_url = self.__append_per_page(next_page_url, per_page) + + # Log the request + self.logger.debug(f'{method} {next_page_url} {data if method != "GET" else ""}') + + # Make the HTTP request - use session if available, otherwise direct request + if hasattr(self, '__session'): + response = self.__session.request(method, next_page_url, json=data) else: - next_page_url = None - else: - return response_data # Return for non-GET requests - - except requests.exceptions.RequestException as e: - self.logger.error(f'HTTP request failed: {e}') - raise - - # Return aggregated results for paginated GET - return aggregated_data + response = requests.request(method, next_page_url, headers=headers, json=data) + + if response.status_code == 204: + return None + + # Handle rate limiting (429 status code) + if response.status_code == 429: + retry_after = int(response.headers.get('Retry-After', 1)) # Default to 1 second if not provided + self.logger.warning(f'Rate limit reached. Retrying after {retry_after} seconds...') + time.sleep(retry_after) + continue + + # Handle 401 Unauthorized - signal to outer logic to refresh token and retry + if response.status_code == 401 and self._client_id: + return '401-refresh' + + # Check for other non-success status codes + if not response.ok: + self.logger.error(f'Error in request: {response.status_code} - {response.text}') + response.raise_for_status() + + # Stop here if we shall not process or interpret the returned data + if raw: + return response.content + + # Process response + response_data = response.json() + + # For GET requests, handle pagination + if method == 'GET' and isinstance(response_data, list): + aggregated_data.extend(response_data) + + # Parse the 'Link' header to find the 'next' page URL + link_header = response.headers.get('Link') + if link_header: + links = {rel.strip(): url.strip('<>') for url, rel in + (link.split(';') for link in link_header.split(','))} + next_page_url = links.get('rel="next"') + if next_page_url: + next_page_url = next_page_url.replace('<', '').replace('>', '') + else: + next_page_url = None + else: + return response_data + except requests.exceptions.RequestException as e: + self.logger.error(f'HTTP request failed: {e}') + raise + return aggregated_data + + result = do_request() + if result == '401-refresh': + self.logger.info('401 Unauthorized received, refreshing OAuth token and retrying request...') + self._obtain_access_token() + result = do_request() + # If we still get the 401-refresh sentinel after a token refresh, something is wrong with auth + if result == '401-refresh': + self.logger.error('Still receiving 401 Unauthorized after token refresh, authentication failed') + raise requests.exceptions.HTTPError('Authentication failed: 401 Unauthorized received even after token refresh') + return result def bulk_export(self, type: str, export_format='csv', save_as=None, poll_timeout=5): """ @@ -223,22 +341,47 @@ def bulk_export(self, type: str, export_format='csv', save_as=None, poll_timeout """ #Initiate an export and get the polling token - export = self.api_call('/export', method = 'POST', data = dict(type = type, export_format = export_format)) + export = self.api_call('/export', method='POST', data=dict(type=type, export_format=export_format)) + + if not isinstance(export, dict) or 'token' not in export: + self.logger.error(f'Export initialization failed: {export}') + raise ValueError('Invalid export response: missing token') + + token = export['token'] #Begin export results poll waiting loop + export_result = None while True: self.logger.debug('Export poll wait.') time.sleep(poll_timeout) - result = self.api_call(f"/export/{export['token']}", per_page = None) - if result['state'] in ('queued','processing'): + poll_result = self.api_call(f"/export/{token}", per_page=0) + + if not isinstance(poll_result, dict) or 'state' not in poll_result: + self.logger.error(f'Export polling failed: {poll_result}') + raise ValueError('Invalid poll response: missing state') + + if poll_result['state'] in ('queued', 'processing'): continue - if result['state'] == 'done': + elif poll_result['state'] == 'done': + export_result = poll_result break - self.logger.error(f'Export request failed: {result=}') - raise - + else: + self.logger.error(f'Export request failed: {poll_result=}') + raise RuntimeError(f'Export failed with state: {poll_result["state"]}') + + if 'url' not in export_result: + self.logger.error(f'Export result missing URL: {export_result}') + raise ValueError('Export result missing download URL') + #Save or Return the exported data - result = self.api_call(result["url"], per_page = None, raw = True) + download_url = export_result["url"] + result = self.api_call(download_url, per_page=0, raw=True) + + # Check if result is bytes, otherwise provide a general error + if not isinstance(result, bytes): + self.logger.error('Expected bytes response for export download') + raise TypeError('Export download returned unexpected type') + if save_as: with open(save_as, 'wb') as file: file.write(result) diff --git a/tests/unit_tests/test_core_oauth.py b/tests/unit_tests/test_core_oauth.py new file mode 100644 index 0000000..6a7548e --- /dev/null +++ b/tests/unit_tests/test_core_oauth.py @@ -0,0 +1,125 @@ +import os +import sys +from unittest.mock import MagicMock + +import pytest +import requests + +# Add the `../src` directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +from xurrent.core import XurrentApiHelper + + +@pytest.fixture +def mock_token_response(): + def _factory(access_token="token", expires_in=3600): + response = MagicMock() + response.json.return_value = {"access_token": access_token, "expires_in": expires_in} + response.raise_for_status = MagicMock() + return response + + return _factory + + +def test_init_requires_authentication_method(): + with pytest.raises(ValueError): + XurrentApiHelper("https://api.example.com", api_account="account", resolve_user=False) + + +def test_init_with_both_auth_methods_raises(): + with pytest.raises(ValueError): + XurrentApiHelper( + "https://api.example.com", + api_key="token", + api_account="account", + resolve_user=False, + client_id="cid", + client_secret="secret", + ) + + +def test_client_credentials_fetches_token(monkeypatch, mock_token_response): + post_calls = [] + + def fake_post(url, data): + post_calls.append((url, data)) + return mock_token_response("oauth-token") + + api_responses = [] + + def fake_request(method, url, headers=None, json=None): + api_responses.append({"method": method, "url": url, "headers": headers, "json": json}) + response = MagicMock() + response.status_code = 200 + response.ok = True + response.json.return_value = {"result": "ok"} + response.headers = {} + return response + + monkeypatch.setattr(requests, "post", fake_post) + monkeypatch.setattr(requests, "request", fake_request) + + helper = XurrentApiHelper( + "https://api.example.com", + api_account="account", + resolve_user=False, + client_id="cid", + client_secret="secret", + ) + + result = helper.api_call("/resource") + + assert result == {"result": "ok"} + assert post_calls == [ + ( + "https://oauth.xurrent.com/token", + { + "client_id": "cid", + "client_secret": "secret", + "grant_type": "client_credentials", + }, + ) + ] + assert api_responses[0]["headers"]["Authorization"] == "Bearer oauth-token" + + +def test_client_credentials_refreshes_token(monkeypatch): + token_payloads = [ + {"access_token": "token-1", "expires_in": 0}, + {"access_token": "token-2", "expires_in": 3600}, + ] + post_count = 0 + + def fake_post(url, data): + nonlocal post_count + response = MagicMock() + payload = token_payloads[post_count] + response.json.return_value = payload + response.raise_for_status = MagicMock() + post_count += 1 + return response + + def fake_request(method, url, headers=None, json=None): + response = MagicMock() + response.status_code = 200 + response.ok = True + response.json.return_value = {"result": "ok"} + response.headers = {} + return response + + monkeypatch.setattr(requests, "post", fake_post) + monkeypatch.setattr(requests, "request", fake_request) + + helper = XurrentApiHelper( + "https://api.example.com", + api_account="account", + resolve_user=False, + client_id="cid", + client_secret="secret", + ) + + helper.api_call("/resource-1") + helper.api_call("/resource-2") + + assert post_count == 2