From e3995ff6fbd8c71e17a038f54107fc912e989343 Mon Sep 17 00:00:00 2001 From: Andrea Orlandi Date: Fri, 26 Sep 2025 16:04:49 +0200 Subject: [PATCH 1/4] wip --- src/picterra/tracer_client.py | 4 + src/picterra/utils.py/oauth.py | 297 +++++++++++++++++++++++++++++++++ 2 files changed, 301 insertions(+) create mode 100644 src/picterra/utils.py/oauth.py diff --git a/src/picterra/tracer_client.py b/src/picterra/tracer_client.py index bfd7c2f..76d3465 100644 --- a/src/picterra/tracer_client.py +++ b/src/picterra/tracer_client.py @@ -40,6 +40,10 @@ def _return_results_page( url = self._full_url("%s/" % resource_endpoint, params=params) return ResultsPage(url, self.sess.get) + def login(self): + pass + + def list_methodologies( self, search: Optional[str] = None, diff --git a/src/picterra/utils.py/oauth.py b/src/picterra/utils.py/oauth.py new file mode 100644 index 0000000..efc224d --- /dev/null +++ b/src/picterra/utils.py/oauth.py @@ -0,0 +1,297 @@ +import base64 +import hashlib +import multiprocessing +import os +import random +import string +import time +import urllib.parse as urlparse +import webbrowser +from base64 import urlsafe_b64encode +from datetime import datetime +from hashlib import sha256 +from http.server import BaseHTTPRequestHandler, HTTPServer +from random import random +from typing import Any, Dict, Optional, Type, no_type_check +from urllib.parse import parse_qs, urljoin, urlparse + +import requests + + +class OAuthError(Exception): + pass + +CLIENT_ID = "ggshield_oauth" +SCOPE = "scan" + +# potential port range to be used to run local server +# to handle authorization code callback +# this is the largest band of not commonly occupied ports +# https://stackoverflow.com/questions/10476987/best-tcp-port-number-range-for-internal-applications +LOCAL_SERVER_USABLE_PORT_RANGE = (29170, 29998) +LOCAL_SERVER_MAX_WAIT_S = 60 +LOCAL_SERVER_ADDRESS = "127.0.0.1" + + +def long_running_function_mp(secs): + for _ in range(secs): # Simulate a long running task + time.sleep(1) + + +def _generate_pkce_pair(): + code_verifier = ''.join( + random.choice(string.ascii_uppercase + string.digits) for _ in range( + random.randint(43, 128) + ) + ) + code_challenge = hashlib.sha256(code_verifier.encode('utf-8')).digest() + code_challenge = base64.urlsafe_b64encode(code_challenge).decode('utf-8').replace('=', '') + return code_verifier, code_challenge + + +def _get_error_param(parsed_url: urlparse.ParseResult) -> Optional[str]: + """ + extract the value of the 'error' url param. If not present, return None. + """ + params = urlparse.parse_qs(parsed_url.query) + if "error" in params: + return params["error"][0] + return None + + +def extract_qs_parameter(url_string, key) -> Optional[str]: + parsed_url = urlparse(url_string) + query_parameters = parse_qs(parsed_url.query) + return query_parameters.get(key, [None])[0] + + +class OAuthClient: + _port: Optional[int] = None + local_server: Optional[HTTPServer] = None + picterra_uri: str + pkce_pair: tuple[str, str] + """ + Helper class to handle the OAuth authentication flow + the logic is divided in 2 steps: + - open the browser on GitGuardian login screen and run a local server to wait for callback + - handle the oauth callback to exchange an authorization code against a valid access token + """ + + def __init__(self, client_id: str, server_uri: str) -> None: + self._client_id = client_id + self._state = "" # use the `state` property instead + self._lifetime: Optional[int] = None + self._login_path = "auth/login" + self._handler_wrapper = RequestHandlerWrapper(oauth_client=self) + self._access_token: Optional[str] = None + self.picterra_uri = server_uri + self.pkce_pair = _generate_pkce_pair() + + @property + def redirect_uri(self) -> str: + return "http://" + LOCAL_SERVER_ADDRESS + ":" + self._port + + @property + def authorize_uri(self) -> str: + return urljoin(self.picterra_uri, "o/authorize/") + + def start(self): + webbrowser.open_new_tab(self.authorize_uri) + + def process_callback(self, callback_url: str) -> None: + """ + This function runs within the request handler do_GET method. + It takes the url of the callback request as argument and does + - Extract the authorization code + - Exchange the code against an access token with GitGuardian's api + - Validate the new token against GitGuardian's api + - Save the token in configuration + Any error during this process will raise a OAuthError + """ + authorization_code = self._get_code(callback_url) + self._claim_token(authorization_code) + token_data = self._validate_access_token() + self._save_token(token_data) + + def _redirect_to_login(self) -> None: + """ + Open the user's browser to authentication page + """ + requests.get( + self.authorize_uri, + params={ + "response_type": "code", + "redirect_uri": self.redirect_uri, + "scope": SCOPE, + "state": self.state, + "code_challenge": self.pkce_pair[1], + "code_challenge_method": "S256", + "client_id": self._client_id, + "utm_source": "picterra-python", + } + ) + + def _get_access_token(self, code: str): + requests.post( + "http://127.0.0.1:8000/o/token/", + json={ + "client_id": self._client_id, + "client_secret": self._client_secret, + "code": code, + "code_verifier": self.pkce_pair[0], + "redirect_uri": self.redirect_uri, + "grant_type": "authorization_code" + } + ) + + def _spawn_server(self) -> None: + for port in range(*LOCAL_SERVER_USABLE_PORT_RANGE): + try: + self.local_server = HTTPServer( + # only consider requests from localhost on the predetermined port + # When starting a local server for OAuth callbacks, binding to all + # interfaces (0.0.0.0) exposes your authentication server to other + # machines on the network. This could allow attackers to intercept + # OAuth codes or inject malicious responses. Always bind only to + # localhost (127.0.0.1) to ensure the callback server is accessible + # only from the local machine. + (LOCAL_SERVER_ADDRESS, port), + # attach the wrapped request handler + self._handler_wrapper.request_handler, + ) + self._port = port + p = multiprocessing.Process( + target=long_running_function_mp, + args=(LOCAL_SERVER_MAX_WAIT_S,) + ) + p.start() + p.join(LOCAL_SERVER_MAX_WAIT_S) + if p.is_alive(): + print("Function timed out, terminating process...") + p.terminate() + p.join() # Wait for the process to terminate + raise Exception("Function timed out!") + else: + print("Function finished within the time limit.") + break + except OSError: + continue + else: + raise OAuthError("Could not find unoccupied port.") + + def _wait_for_callback(self) -> None: + """ + Wait to receive and process the authorization callback on the local server. + This actually catches HTTP requests made on the previously opened server. + The callback processing logic is implemented in the request handler class + and the `process_callback` method + """ + try: + while not self._handler_wrapper.complete: + # Wait for callback on localserver including an authorization code + # any matching request will get processed by the request handler and + # the `process_callback` function + self.local_server.handle_request() + except KeyboardInterrupt: + raise OAuthError("User stopped login process.") + + if self._handler_wrapper.error_message is not None: + # if no error message is attached, the process is considered successful + raise OAuthError(self._handler_wrapper.error_message) + + def _get_code(self, uri: str) -> str: + """ + Extract the authorization from the incoming request uri and verify that the state from + the uri match the one stored internally. + if no code can be extracted or the state is invalid, raise an OAuthError + else return the extracted code + """ + authorization_code = extract_qs_parameter(uri, "code") + if authorization_code is None: + raise OAuthError("Invalid code or state received from the callback.") + return authorization_code + + @property + def default_token_lifetime(self) -> Optional[int]: + """ + return the default token lifetime saved in the instance config. + if None, this will be interpreted as no expiry. + """ + instance_lifetime = self.instance_config.default_token_lifetime + if instance_lifetime is not None: + return instance_lifetime + return self.config.auth_config.default_token_lifetime + + +class RequestHandlerWrapper: + """ + Utilitary class to link the server and the request handler. + This allows to kill the server from the request processing. + """ + + oauth_client: OAuthClient + # tells the server to stop listening to requests + complete: bool + # error encountered while processing the callback + # if None, the process is considered successful + error_message: Optional[str] = None + + def __init__(self, oauth_client: OAuthClient) -> None: + self.oauth_client = oauth_client + self.complete = False + self.error_message = None + + @property + def request_handler(self) -> Type[BaseHTTPRequestHandler]: + class RequestHandler(BaseHTTPRequestHandler): + def do_GET(self_) -> None: + """ + This function process every GET request received by the server. + Non-root request are skipped. + If an authorization code can be extracted from the URI, attach it to the handler + so it can be retrieved after the request is processed, then kill the server. + """ + callback_url: str = self_.path + parsed_url = urlparse.urlparse(callback_url) + if parsed_url.path == "/": + error_string = _get_error_param(parsed_url) + if error_string is not None: + self_._end_request(200) + self.error_message = self.oauth_client.get_server_error_message( + error_string + ) + else: + try: + self.oauth_client.process_callback(callback_url) + except Exception as error: + self_._end_request(400) + # attach error message to the handler wrapper instance + self.error_message = error.message + else: + self_._end_request( + 301, + urljoin( + self.oauth_client.dashboard_url, "authenticated" + ), + ) + + # indicate to the server to stop + self.complete = True + else: + self_._end_request(404) + + def _end_request( + self_, status_code: int, redirect_url: Optional[str] = None + ) -> None: + """ + End the current request. If a redirect url is provided, + the response will be a redirection to this url. + If not the response will be a user error 400 + """ + self_.send_response(status_code) + + if redirect_url is not None: + self_.send_header("Location", redirect_url) + self_.end_headers() + + return RequestHandler From 3fd26356497026d06ba990ae1e3d1fcf1b2e2db1 Mon Sep 17 00:00:00 2001 From: Andrea Orlandi Date: Fri, 26 Sep 2025 23:45:45 +0200 Subject: [PATCH 2/4] wip2 --- src/picterra/base_client.py | 27 ++- src/picterra/tracer_client.py | 4 - src/picterra/utils/__init__.py | 0 src/picterra/{utils.py => utils}/oauth.py | 267 ++++++++++++---------- 4 files changed, 164 insertions(+), 134 deletions(-) create mode 100644 src/picterra/utils/__init__.py rename src/picterra/{utils.py => utils}/oauth.py (58%) diff --git a/src/picterra/base_client.py b/src/picterra/base_client.py index 86e0883..f3ca406 100644 --- a/src/picterra/base_client.py +++ b/src/picterra/base_client.py @@ -19,10 +19,18 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry +from .utils.oauth import OAuthClient, OAuthError + logger = logging.getLogger() CHUNK_SIZE_BYTES = 8192 # 8 KiB +# ANSI escape codes for colors +GREEN = "\033[92m" +RED = "\033[91m" +RESET = "\033[0m" # Resets the color to default + +CLIENT_ID = "Eya1oJleyYoo35I17w5WWP2oTbTLr89LTJXWBxDs" # allow injecting an non-existing package name to test the fallback behavior # of _get_ua in tests (see test_headers_user_agent_version__fallback) @@ -205,8 +213,6 @@ def __init__( "PICTERRA_BASE_URL", "https://app.picterra.ch/" ) api_key = os.environ.get("PICTERRA_API_KEY", None) - if not api_key: - raise APIError("PICTERRA_API_KEY environment variable is not defined") logger.info( "Using base_url=%s, api_url=%s; %d max retries, %d backoff and %s timeout.", base_url, @@ -234,7 +240,8 @@ def __init__( self.sess.mount("https://", adapter) self.sess.mount("http://", adapter) # Authentication - self.sess.headers.update({"X-Api-Key": api_key}) + if api_key is not None: + self.sess.headers.update({"X-Api-Key": api_key}) def _full_url(self, path: str, params: dict[str, Any] | None = None): url = urljoin(self.base_url, path) @@ -295,3 +302,17 @@ def get_operation_results(self, operation_id: str) -> dict[str, Any]: self._full_url("operations/%s/" % operation_id), ) return resp.json()["results"] + + def login(self): + base_url = os.environ.get( + "PICTERRA_BASE_URL", "https://app.picterra.ch/" + ) + base_url = "http://100.81.123.76:8000" # TODO remove + cl = OAuthClient(CLIENT_ID, base_url) + try: + data = cl.start() + print(data) + print(f"{GREEN}Logged in at {base_url}.{RESET}") + except OAuthError as e: + print(f"{RED}Error during login: '{e}'{RESET}") + sys.exit(1) diff --git a/src/picterra/tracer_client.py b/src/picterra/tracer_client.py index 76d3465..bfd7c2f 100644 --- a/src/picterra/tracer_client.py +++ b/src/picterra/tracer_client.py @@ -40,10 +40,6 @@ def _return_results_page( url = self._full_url("%s/" % resource_endpoint, params=params) return ResultsPage(url, self.sess.get) - def login(self): - pass - - def list_methodologies( self, search: Optional[str] = None, diff --git a/src/picterra/utils/__init__.py b/src/picterra/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/picterra/utils.py/oauth.py b/src/picterra/utils/oauth.py similarity index 58% rename from src/picterra/utils.py/oauth.py rename to src/picterra/utils/oauth.py index efc224d..891ecba 100644 --- a/src/picterra/utils.py/oauth.py +++ b/src/picterra/utils/oauth.py @@ -1,27 +1,23 @@ import base64 import hashlib -import multiprocessing -import os +import logging import random import string import time -import urllib.parse as urlparse import webbrowser -from base64 import urlsafe_b64encode -from datetime import datetime -from hashlib import sha256 from http.server import BaseHTTPRequestHandler, HTTPServer -from random import random -from typing import Any, Dict, Optional, Type, no_type_check -from urllib.parse import parse_qs, urljoin, urlparse +from typing import Optional, Type +from urllib.parse import ParseResult, parse_qs, urlencode, urljoin, urlparse, urlunparse import requests +logger = logging.getLogger() + class OAuthError(Exception): pass -CLIENT_ID = "ggshield_oauth" + SCOPE = "scan" # potential port range to be used to run local server @@ -30,12 +26,21 @@ class OAuthError(Exception): # https://stackoverflow.com/questions/10476987/best-tcp-port-number-range-for-internal-applications LOCAL_SERVER_USABLE_PORT_RANGE = (29170, 29998) LOCAL_SERVER_MAX_WAIT_S = 60 +# only consider requests from localhost on the predetermined port +# When starting a local server for OAuth callbacks, binding to all +# interfaces (0.0.0.0) exposes your authentication server to other +# machines on the network. This could allow attackers to intercept +# OAuth codes or inject malicious responses. Always bind only to +# localhost (127.0.0.1) to ensure the callback server is accessible +# only from the local machine. LOCAL_SERVER_ADDRESS = "127.0.0.1" def long_running_function_mp(secs): - for _ in range(secs): # Simulate a long running task + print(f"Timout is {secs}s") + for _ in range(secs): # Simulate a long running task time.sleep(1) + print("Tick") def _generate_pkce_pair(): @@ -49,22 +54,33 @@ def _generate_pkce_pair(): return code_verifier, code_challenge -def _get_error_param(parsed_url: urlparse.ParseResult) -> Optional[str]: +def _get_error_param(parsed_url: ParseResult) -> Optional[str]: """ extract the value of the 'error' url param. If not present, return None. """ - params = urlparse.parse_qs(parsed_url.query) + params = parse_qs(parsed_url.query) if "error" in params: - return params["error"][0] + return params["error"][0] + ": " + params.get("error_description", [""])[0] return None -def extract_qs_parameter(url_string, key) -> Optional[str]: +def _extract_qs_parameter(url_string, key) -> Optional[str]: parsed_url = urlparse(url_string) query_parameters = parse_qs(parsed_url.query) return query_parameters.get(key, [None])[0] +def _add_query_params(url, params): + url_parts = urlparse(url) + query = url_parts.query + new_query_parts = urlencode(params) + if query: + new_query = f"{query}&{new_query_parts}" + else: + new_query = new_query_parts + return urlunparse(url_parts._replace(query=new_query)) + + class OAuthClient: _port: Optional[int] = None local_server: Optional[HTTPServer] = None @@ -72,16 +88,21 @@ class OAuthClient: pkce_pair: tuple[str, str] """ Helper class to handle the OAuth authentication flow - the logic is divided in 2 steps: - - open the browser on GitGuardian login screen and run a local server to wait for callback + + The logic is divided in 2 steps: + - open the browser on login screen and run a local server to wait for callback - handle the oauth callback to exchange an authorization code against a valid access token + + Some notes: + * OAuth application client type must be "public" because we cannot securely store secret + credentials, and we should NOT use the client secret anywhere; this is also the reason + of using the PKCE flow , designed for clients that cannot keep a secret """ def __init__(self, client_id: str, server_uri: str) -> None: self._client_id = client_id self._state = "" # use the `state` property instead - self._lifetime: Optional[int] = None - self._login_path = "auth/login" + self._lifetime_s: Optional[int] = None self._handler_wrapper = RequestHandlerWrapper(oauth_client=self) self._access_token: Optional[str] = None self.picterra_uri = server_uri @@ -89,41 +110,18 @@ def __init__(self, client_id: str, server_uri: str) -> None: @property def redirect_uri(self) -> str: - return "http://" + LOCAL_SERVER_ADDRESS + ":" + self._port + return f"http://{LOCAL_SERVER_ADDRESS}:{self._port}" @property def authorize_uri(self) -> str: - return urljoin(self.picterra_uri, "o/authorize/") - - def start(self): - webbrowser.open_new_tab(self.authorize_uri) - - def process_callback(self, callback_url: str) -> None: - """ - This function runs within the request handler do_GET method. - It takes the url of the callback request as argument and does - - Extract the authorization code - - Exchange the code against an access token with GitGuardian's api - - Validate the new token against GitGuardian's api - - Save the token in configuration - Any error during this process will raise a OAuthError - """ - authorization_code = self._get_code(callback_url) - self._claim_token(authorization_code) - token_data = self._validate_access_token() - self._save_token(token_data) - - def _redirect_to_login(self) -> None: - """ - Open the user's browser to authentication page - """ - requests.get( - self.authorize_uri, - params={ + base_url = urljoin(self.picterra_uri, "o/authorize/") + return _add_query_params( + base_url, + { "response_type": "code", "redirect_uri": self.redirect_uri, - "scope": SCOPE, - "state": self.state, + # "scope": SCOPE, + #"state": "foobar", "code_challenge": self.pkce_pair[1], "code_challenge_method": "S256", "client_id": self._client_id, @@ -131,48 +129,82 @@ def _redirect_to_login(self) -> None: } ) - def _get_access_token(self, code: str): - requests.post( - "http://127.0.0.1:8000/o/token/", - json={ + @property + def token_uri(self): + return urljoin(self.picterra_uri, "o/token/") + + def _get_access_token(self, code: str) -> None: + print(f"Getting access token with code {code[:4]} from {self.token_uri}...") + resp = requests.post( + self.token_uri, + data={ "client_id": self._client_id, - "client_secret": self._client_secret, + #"client_secret": "pbkdf2_sha256$870000$V13WjHqsdkeKiEVt1sEzWR$NolL9+bb5Mh7S/YyKxQhDbWZbRZw60w/6xdrOXqjO8Y=", "code": code, "code_verifier": self.pkce_pair[0], "redirect_uri": self.redirect_uri, "grant_type": "authorization_code" - } + }, + timeout=10, ) + if resp.ok is False: + raise OAuthError(f"Error getting access token: {resp.text}") + resp.raise_for_status() + self._access_token = resp.json()["access_token"] + self._lifetime_s = resp.json()["expires_in"] + + def _get_code(self, uri: str) -> str: + """ + Extract the authorization from the incoming request uri and verify that the state from + the uri match the one stored internally. + if no code can be extracted or the state is invalid, raise an OAuthError + else return the extracted code + """ + authorization_code = _extract_qs_parameter(uri, "code") + if authorization_code is None: + raise OAuthError("Invalid code or state received from the callback.") + return authorization_code + + def start(self): + self._spawn_server() + print(f"Login at {self.authorize_uri}") + webbrowser.open_new_tab(self.authorize_uri) + self._wait_for_callback() + # self.local_server.shutdown() + return { + "token": self._access_token, + "lifetime_s": self._lifetime_s, + } + + def stop(self, error_message: str) -> None: + #self.local_server.shutdown() + raise OAuthError(f"Error after login: {error_message}.") + + def process_callback(self, callback_url: str) -> None: + """ + This function runs within the request handler do_GET method. + It takes the url of the callback request as argument and does + - Extract the authorization code + - Exchange the code against an access token with GitGuardian's api + - Validate the new token against GitGuardian's api + - Save the token in configuration + Any error during this process will raise a OAuthError + """ + print(f"Getting token from {callback_url[:7]}...") + authorization_code = self._get_code(callback_url) + self._get_access_token(authorization_code) def _spawn_server(self) -> None: for port in range(*LOCAL_SERVER_USABLE_PORT_RANGE): try: self.local_server = HTTPServer( - # only consider requests from localhost on the predetermined port - # When starting a local server for OAuth callbacks, binding to all - # interfaces (0.0.0.0) exposes your authentication server to other - # machines on the network. This could allow attackers to intercept - # OAuth codes or inject malicious responses. Always bind only to - # localhost (127.0.0.1) to ensure the callback server is accessible - # only from the local machine. (LOCAL_SERVER_ADDRESS, port), # attach the wrapped request handler - self._handler_wrapper.request_handler, + self._handler_wrapper.request_handler, # TODO simplify ? ) + self.local_server.timeout = LOCAL_SERVER_MAX_WAIT_S self._port = port - p = multiprocessing.Process( - target=long_running_function_mp, - args=(LOCAL_SERVER_MAX_WAIT_S,) - ) - p.start() - p.join(LOCAL_SERVER_MAX_WAIT_S) - if p.is_alive(): - print("Function timed out, terminating process...") - p.terminate() - p.join() # Wait for the process to terminate - raise Exception("Function timed out!") - else: - print("Function finished within the time limit.") + print("Started local server on port %d" % port) break except OSError: continue @@ -186,46 +218,31 @@ def _wait_for_callback(self) -> None: The callback processing logic is implemented in the request handler class and the `process_callback` method """ + assert self.local_server is not None try: - while not self._handler_wrapper.complete: + print("Waiting for callback...") + start_time = time.time() + while self._handler_wrapper.complete is False: + print(449580593854096) # Wait for callback on localserver including an authorization code # any matching request will get processed by the request handler and # the `process_callback` function self.local_server.handle_request() + if time.time() - start_time > LOCAL_SERVER_MAX_WAIT_S: + raise OAuthError("Timeout waiting for callback.") + if self._handler_wrapper.error_message is not None: + raise OAuthError(self._handler_wrapper.error_message) + print("Callback received.") except KeyboardInterrupt: raise OAuthError("User stopped login process.") - if self._handler_wrapper.error_message is not None: # if no error message is attached, the process is considered successful raise OAuthError(self._handler_wrapper.error_message) - def _get_code(self, uri: str) -> str: - """ - Extract the authorization from the incoming request uri and verify that the state from - the uri match the one stored internally. - if no code can be extracted or the state is invalid, raise an OAuthError - else return the extracted code - """ - authorization_code = extract_qs_parameter(uri, "code") - if authorization_code is None: - raise OAuthError("Invalid code or state received from the callback.") - return authorization_code - - @property - def default_token_lifetime(self) -> Optional[int]: - """ - return the default token lifetime saved in the instance config. - if None, this will be interpreted as no expiry. - """ - instance_lifetime = self.instance_config.default_token_lifetime - if instance_lifetime is not None: - return instance_lifetime - return self.config.auth_config.default_token_lifetime - class RequestHandlerWrapper: """ - Utilitary class to link the server and the request handler. + Helper class to link the server and the request handler. This allows to kill the server from the request processing. """ @@ -252,46 +269,42 @@ def do_GET(self_) -> None: so it can be retrieved after the request is processed, then kill the server. """ callback_url: str = self_.path - parsed_url = urlparse.urlparse(callback_url) + parsed_url = urlparse(callback_url) if parsed_url.path == "/": error_string = _get_error_param(parsed_url) if error_string is not None: + print(455676778676) self_._end_request(200) - self.error_message = self.oauth_client.get_server_error_message( - error_string - ) + # self.oauth_client.local_server.shutdown() + self.error_message = error_string else: try: + print(4354456) self.oauth_client.process_callback(callback_url) + print(32222) + self_._end_request(200) except Exception as error: + print(77777, error) + self.error_message = str(error) self_._end_request(400) - # attach error message to the handler wrapper instance - self.error_message = error.message - else: - self_._end_request( - 301, - urljoin( - self.oauth_client.dashboard_url, "authenticated" - ), - ) - + # else: + # self_._end_request( # TODO ??? + # 301, + # urljoin( + # self.oauth_client.dashboard_url, "authenticated" + # ), + # ) + print(9999999999999999) # indicate to the server to stop self.complete = True + #self_._end_request(200) else: + print(657889) self_._end_request(404) - def _end_request( - self_, status_code: int, redirect_url: Optional[str] = None - ) -> None: - """ - End the current request. If a redirect url is provided, - the response will be a redirection to this url. - If not the response will be a user error 400 - """ + def _end_request(self_, status_code: int) -> None: + assert 100 <= status_code <= 599 self_.send_response(status_code) - - if redirect_url is not None: - self_.send_header("Location", redirect_url) self_.end_headers() return RequestHandler From ef611e7ed29f2c49965c54317f2b1f1c61a24b54 Mon Sep 17 00:00:00 2001 From: Andrea Orlandi Date: Mon, 29 Sep 2025 15:04:36 +0200 Subject: [PATCH 3/4] v3 --- src/picterra/base_client.py | 69 ++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/src/picterra/base_client.py b/src/picterra/base_client.py index f3ca406..8619a97 100644 --- a/src/picterra/base_client.py +++ b/src/picterra/base_client.py @@ -12,11 +12,12 @@ else: from typing_extensions import Literal, TypedDict -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Optional, TypeVar from urllib.parse import urlencode, urljoin import requests from requests.adapters import HTTPAdapter +from requests.auth import AuthBase from urllib3.util.retry import Retry from .utils.oauth import OAuthClient, OAuthError @@ -190,32 +191,74 @@ class FeatureCollection(TypedDict): features: list[Feature] +class AuthInitError(Exception): + pass + + +class ApiKeyAuth(AuthBase): + def __init__(self): + if os.environ.get("PICTERRA_API_KEY", None) is None: + raise AuthInitError("PICTERRA_API_KEY environment variable not set") + + def __call__(self, r): + r.headers['X-Api-Key'] = os.environ.get("PICTERRA_API_KEY", None) + return r + + +class Oauth2Auth(AuthBase): + oauth_token: dict + + def __init__(self, base_url: str): + base_url = "http://100.81.123.76:8000" # TODO remove + cl = OAuthClient(CLIENT_ID, base_url) + try: + data = cl.start() + self.oauth_token = data + print(333, self.oauth_token) + print(f"{GREEN}Logged in at {base_url}.{RESET}") + except OAuthError as e: + raise AuthInitError(f"{RED}Error during login: '{e}'{RESET}") + + def __call__(self, r): + r.headers['Authorization'] = "Bearer " + self.oauth_token["access_token"] + return r + + class BaseAPIClient: """ Base class for Picterra API clients. This is subclassed for the different products we have. """ + base_url: str + sess: _RequestsSession def __init__( - self, api_url: str, timeout: int = 30, max_retries: int = 3, backoff_factor: int = 10 + self, api_url: str, timeout: int = 30, max_retries: int = 3, backoff_factor: int = 10, auth: Literal["apikey", "oauth2"] = "apikey", ): """ Args: api_url: the api's base url. This is different based on the Picterra product used and is typically defined by implementations of this client timeout: number of seconds before the request times out - max_retries: max attempts when ecountering gateway issues or throttles; see + max_retries: max attempts when encountering gateway issues or throttles; see retry_strategy comment below backoff_factor: factor used nin the backoff algorithm; see retry_strategy comment below + auth: TODO """ base_url = os.environ.get( "PICTERRA_BASE_URL", "https://app.picterra.ch/" ) - api_key = os.environ.get("PICTERRA_API_KEY", None) + if auth == "apikey": + auth = ApiKeyAuth() + elif auth == "oauth2": + auth = Oauth2Auth(base_url) + else: + raise RuntimeError("Invalid authentication method. Must be 'apikey' or 'oauth2'.") logger.info( - "Using base_url=%s, api_url=%s; %d max retries, %d backoff and %s timeout.", + "Using base_url=%s, auth=%s; api_url=%s; %d max retries, %d backoff and %s timeout.", base_url, + auth, api_url, max_retries, backoff_factor, @@ -239,9 +282,6 @@ def __init__( adapter = HTTPAdapter(max_retries=retry_strategy) self.sess.mount("https://", adapter) self.sess.mount("http://", adapter) - # Authentication - if api_key is not None: - self.sess.headers.update({"X-Api-Key": api_key}) def _full_url(self, path: str, params: dict[str, Any] | None = None): url = urljoin(self.base_url, path) @@ -303,16 +343,3 @@ def get_operation_results(self, operation_id: str) -> dict[str, Any]: ) return resp.json()["results"] - def login(self): - base_url = os.environ.get( - "PICTERRA_BASE_URL", "https://app.picterra.ch/" - ) - base_url = "http://100.81.123.76:8000" # TODO remove - cl = OAuthClient(CLIENT_ID, base_url) - try: - data = cl.start() - print(data) - print(f"{GREEN}Logged in at {base_url}.{RESET}") - except OAuthError as e: - print(f"{RED}Error during login: '{e}'{RESET}") - sys.exit(1) From f84a0bda330b35325023160fb6bdb6b51a5df8fd Mon Sep 17 00:00:00 2001 From: Andrea Orlandi Date: Tue, 30 Sep 2025 09:33:41 +0200 Subject: [PATCH 4/4] v4 --- src/picterra/base_client.py | 18 +++++++++------ src/picterra/utils/oauth.py | 45 +++++++------------------------------ 2 files changed, 19 insertions(+), 44 deletions(-) diff --git a/src/picterra/base_client.py b/src/picterra/base_client.py index 8619a97..6c71922 100644 --- a/src/picterra/base_client.py +++ b/src/picterra/base_client.py @@ -73,6 +73,7 @@ class _RequestsSession(requests.Session): def __init__(self, *args, **kwargs): self.timeout = kwargs.pop("timeout") + #self.auth = kwargs.pop("auth") super().__init__(*args, **kwargs) self.headers.update( { @@ -82,6 +83,7 @@ def __init__(self, *args, **kwargs): def request(self, *args, **kwargs): kwargs.setdefault("timeout", self.timeout) + #kwargs.setdefault("auth", self.auth) return super().request(*args, **kwargs) @@ -196,12 +198,15 @@ class AuthInitError(Exception): class ApiKeyAuth(AuthBase): + api_key: str + def __init__(self): if os.environ.get("PICTERRA_API_KEY", None) is None: raise AuthInitError("PICTERRA_API_KEY environment variable not set") + self.api_key = os.environ.get("PICTERRA_API_KEY", None) def __call__(self, r): - r.headers['X-Api-Key'] = os.environ.get("PICTERRA_API_KEY", None) + r.headers['X-Api-Key'] = self.api_key return r @@ -209,18 +214,16 @@ class Oauth2Auth(AuthBase): oauth_token: dict def __init__(self, base_url: str): - base_url = "http://100.81.123.76:8000" # TODO remove cl = OAuthClient(CLIENT_ID, base_url) try: data = cl.start() self.oauth_token = data - print(333, self.oauth_token) print(f"{GREEN}Logged in at {base_url}.{RESET}") except OAuthError as e: - raise AuthInitError(f"{RED}Error during login: '{e}'{RESET}") + raise SystemExit(f"{RED}Error during login: '{e}'{RESET}") def __call__(self, r): - r.headers['Authorization'] = "Bearer " + self.oauth_token["access_token"] + r.headers['Authorization'] = "Bearer " + self.oauth_token["token"] return r @@ -250,9 +253,9 @@ def __init__( "PICTERRA_BASE_URL", "https://app.picterra.ch/" ) if auth == "apikey": - auth = ApiKeyAuth() + auth_class = ApiKeyAuth() elif auth == "oauth2": - auth = Oauth2Auth(base_url) + auth_class = Oauth2Auth(base_url) else: raise RuntimeError("Invalid authentication method. Must be 'apikey' or 'oauth2'.") logger.info( @@ -268,6 +271,7 @@ def __init__( # Create the session with a default timeout (30 sec), that we can then # override on a per-endpoint basis (will be disabled for file uploads and downloads) self.sess = _RequestsSession(timeout=timeout) + self.sess.auth = auth_class # Retry: we set the HTTP codes for our throttle (429) plus possible gateway problems (50*), # and for polling methods (GET), as non-idempotent ones should be addressed via idempotency # key mechanism; given the algorithm is { * (2 **}, and we diff --git a/src/picterra/utils/oauth.py b/src/picterra/utils/oauth.py index 891ecba..2770983 100644 --- a/src/picterra/utils/oauth.py +++ b/src/picterra/utils/oauth.py @@ -18,7 +18,7 @@ class OAuthError(Exception): pass -SCOPE = "scan" +SCOPE = "scan" # TODO based on oauth config on platform (needs change) # potential port range to be used to run local server # to handle authorization code callback @@ -36,13 +36,6 @@ class OAuthError(Exception): LOCAL_SERVER_ADDRESS = "127.0.0.1" -def long_running_function_mp(secs): - print(f"Timout is {secs}s") - for _ in range(secs): # Simulate a long running task - time.sleep(1) - print("Tick") - - def _generate_pkce_pair(): code_verifier = ''.join( random.choice(string.ascii_uppercase + string.digits) for _ in range( @@ -133,13 +126,12 @@ def authorize_uri(self) -> str: def token_uri(self): return urljoin(self.picterra_uri, "o/token/") - def _get_access_token(self, code: str) -> None: - print(f"Getting access token with code {code[:4]} from {self.token_uri}...") + def _get_access_token(self, code: str) -> None: + logger.debug(f"Getting access token with code {code[:4]} from {self.token_uri}...") resp = requests.post( self.token_uri, data={ "client_id": self._client_id, - #"client_secret": "pbkdf2_sha256$870000$V13WjHqsdkeKiEVt1sEzWR$NolL9+bb5Mh7S/YyKxQhDbWZbRZw60w/6xdrOXqjO8Y=", "code": code, "code_verifier": self.pkce_pair[0], "redirect_uri": self.redirect_uri, @@ -167,7 +159,7 @@ def _get_code(self, uri: str) -> str: def start(self): self._spawn_server() - print(f"Login at {self.authorize_uri}") + logger.debug(f"Login at {self.authorize_uri}") webbrowser.open_new_tab(self.authorize_uri) self._wait_for_callback() # self.local_server.shutdown() @@ -176,10 +168,6 @@ def start(self): "lifetime_s": self._lifetime_s, } - def stop(self, error_message: str) -> None: - #self.local_server.shutdown() - raise OAuthError(f"Error after login: {error_message}.") - def process_callback(self, callback_url: str) -> None: """ This function runs within the request handler do_GET method. @@ -190,7 +178,7 @@ def process_callback(self, callback_url: str) -> None: - Save the token in configuration Any error during this process will raise a OAuthError """ - print(f"Getting token from {callback_url[:7]}...") + logger.debug(f"Getting token from {callback_url[:7]}...") authorization_code = self._get_code(callback_url) self._get_access_token(authorization_code) @@ -204,7 +192,7 @@ def _spawn_server(self) -> None: ) self.local_server.timeout = LOCAL_SERVER_MAX_WAIT_S self._port = port - print("Started local server on port %d" % port) + logger.debug("Started local server on port %d" % port) break except OSError: continue @@ -220,10 +208,9 @@ def _wait_for_callback(self) -> None: """ assert self.local_server is not None try: - print("Waiting for callback...") + logger.debug("Waiting for callback...") start_time = time.time() while self._handler_wrapper.complete is False: - print(449580593854096) # Wait for callback on localserver including an authorization code # any matching request will get processed by the request handler and # the `process_callback` function @@ -232,7 +219,7 @@ def _wait_for_callback(self) -> None: raise OAuthError("Timeout waiting for callback.") if self._handler_wrapper.error_message is not None: raise OAuthError(self._handler_wrapper.error_message) - print("Callback received.") + logger.debug("Callback received.") except KeyboardInterrupt: raise OAuthError("User stopped login process.") if self._handler_wrapper.error_message is not None: @@ -273,33 +260,17 @@ def do_GET(self_) -> None: if parsed_url.path == "/": error_string = _get_error_param(parsed_url) if error_string is not None: - print(455676778676) self_._end_request(200) - # self.oauth_client.local_server.shutdown() self.error_message = error_string else: try: - print(4354456) self.oauth_client.process_callback(callback_url) - print(32222) self_._end_request(200) except Exception as error: - print(77777, error) self.error_message = str(error) self_._end_request(400) - # else: - # self_._end_request( # TODO ??? - # 301, - # urljoin( - # self.oauth_client.dashboard_url, "authenticated" - # ), - # ) - print(9999999999999999) - # indicate to the server to stop self.complete = True - #self_._end_request(200) else: - print(657889) self_._end_request(404) def _end_request(self_, status_code: int) -> None: