diff --git a/.env.example b/.env.example index e856b21..58e5515 100644 --- a/.env.example +++ b/.env.example @@ -98,6 +98,34 @@ QWEN_CODE_OAUTH_1="" # Path to your iFlow credential file (e.g., ~/.iflow/oauth_creds.json). IFLOW_OAUTH_1="" +# --- GitHub Copilot --- +# GitHub Copilot uses Device Flow OAuth authentication. +# After first-time setup, the proxy stores credentials in 'oauth_creds/' directory. +# You can also pre-configure credentials via environment variables: +# +# COPILOT_GITHUB_TOKEN - Your GitHub OAuth token (long-lived, from Device Flow) +# COPILOT_ENTERPRISE_URL - Optional: GitHub Enterprise URL (e.g., company.ghe.com) +# +# For multiple Copilot accounts, use numbered variables: +# COPILOT_1_GITHUB_TOKEN="ghp_..." +# COPILOT_2_GITHUB_TOKEN="ghp_..." +COPILOT_GITHUB_TOKEN="" +COPILOT_ENTERPRISE_URL="" + +# --- Copilot X-Initiator Header Control --- +# Controls the X-Initiator header behavior (affects Copilot's response style): +# - COPILOT_FORCE_AGENT_HEADER: Always use "agent" mode (default: false) +# - COPILOT_AGENT_PERCENTAGE: For first messages, % chance of "agent" (0-100, default: 100) +# Set to 0 for always "user", 100 for always "agent", or a value in between for random. +# Based on: https://github.com/Tarquinen/dotfiles/tree/main/.config/opencode/plugin/copilot-force-agent-header +COPILOT_FORCE_AGENT_HEADER=false +COPILOT_AGENT_PERCENTAGE=100 + +# --- Copilot Available Models --- +# Comma-separated list of Copilot models to expose. Leave empty for defaults. +# Default models: gpt-4o, gpt-4.1, gpt-4.1-mini, claude-3.5-sonnet, claude-sonnet-4, o3-mini, o1, gemini-2.0-flash-001 +COPILOT_MODELS="" + # ------------------------------------------------------------------------------ # | [ADVANCED] Provider-Specific Settings | diff --git a/src/rotator_library/credential_manager.py b/src/rotator_library/credential_manager.py index 16be41c..c30cf92 100644 --- a/src/rotator_library/credential_manager.py +++ b/src/rotator_library/credential_manager.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Dict, List, Optional, Set -lib_logger = logging.getLogger('rotator_library') +lib_logger = logging.getLogger("rotator_library") OAUTH_BASE_DIR = Path.cwd() / "oauth_creds" OAUTH_BASE_DIR.mkdir(exist_ok=True) @@ -16,6 +16,7 @@ "qwen_code": Path.home() / ".qwen", "iflow": Path.home() / ".iflow", "antigravity": Path.home() / ".antigravity", + "copilot": Path.home() / ".copilot", # Add other providers like 'claude' here if they have a standard CLI path } @@ -26,6 +27,7 @@ "antigravity": "ANTIGRAVITY", "qwen_code": "QWEN_CODE", "iflow": "IFLOW", + "copilot": "COPILOT", } @@ -33,38 +35,39 @@ class CredentialManager: """ Discovers OAuth credential files from standard locations, copies them locally, and updates the configuration to use the local paths. - + Also discovers environment variable-based OAuth credentials for stateless deployments. Supports two env var formats: - + 1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN 2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc. - + When env-based credentials are detected, virtual paths like "env://provider/1" are created. """ + def __init__(self, env_vars: Dict[str, str]): self.env_vars = env_vars def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]: """ Discover OAuth credentials defined via environment variables. - + Supports two formats: 1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN 2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc. - + Returns: Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1") """ env_credentials: Dict[str, Set[str]] = {} - + for provider, env_prefix in ENV_OAUTH_PROVIDERS.items(): found_indices: Set[str] = set() - + # Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern) # Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc. numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$") - + for key in self.env_vars.keys(): match = numbered_pattern.match(key) if match: @@ -73,28 +76,34 @@ def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]: refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN" if refresh_key in self.env_vars and self.env_vars[refresh_key]: found_indices.add(index) - + # Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern) # Only use this if no numbered credentials exist if not found_indices: access_key = f"{env_prefix}_ACCESS_TOKEN" refresh_key = f"{env_prefix}_REFRESH_TOKEN" - if (access_key in self.env_vars and self.env_vars[access_key] and - refresh_key in self.env_vars and self.env_vars[refresh_key]): + if ( + access_key in self.env_vars + and self.env_vars[access_key] + and refresh_key in self.env_vars + and self.env_vars[refresh_key] + ): # Use "0" as the index for legacy single credential found_indices.add("0") - + if found_indices: env_credentials[provider] = found_indices - lib_logger.info(f"Found {len(found_indices)} env-based credential(s) for {provider}") - + lib_logger.info( + f"Found {len(found_indices)} env-based credential(s) for {provider}" + ) + # Convert to virtual paths result: Dict[str, List[str]] = {} for provider, indices in env_credentials.items(): # Sort indices numerically for consistent ordering sorted_indices = sorted(indices, key=lambda x: int(x)) result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices] - + return result def discover_and_prepare(self) -> Dict[str, List[str]]: @@ -105,7 +114,9 @@ def discover_and_prepare(self) -> Dict[str, List[str]]: # These take priority for stateless deployments env_oauth_creds = self._discover_env_oauth_credentials() for provider, virtual_paths in env_oauth_creds.items(): - lib_logger.info(f"Using {len(virtual_paths)} env-based credential(s) for {provider}") + lib_logger.info( + f"Using {len(virtual_paths)} env-based credential(s) for {provider}" + ) final_config[provider] = virtual_paths # Extract OAuth file paths from environment variables @@ -115,21 +126,29 @@ def discover_and_prepare(self) -> Dict[str, List[str]]: provider = key.split("_OAUTH_")[0].lower() if provider not in env_oauth_paths: env_oauth_paths[provider] = [] - if value: # Only consider non-empty values + if value: # Only consider non-empty values env_oauth_paths[provider].append(value) # PHASE 2: Discover file-based OAuth credentials for provider, default_dir in DEFAULT_OAUTH_DIRS.items(): # Skip if already discovered from environment variables if provider in final_config: - lib_logger.debug(f"Skipping file discovery for {provider} - using env-based credentials") + lib_logger.debug( + f"Skipping file discovery for {provider} - using env-based credentials" + ) continue - + # Check for existing local credentials first. If found, use them and skip discovery. - local_provider_creds = sorted(list(OAUTH_BASE_DIR.glob(f"{provider}_oauth_*.json"))) + local_provider_creds = sorted( + list(OAUTH_BASE_DIR.glob(f"{provider}_oauth_*.json")) + ) if local_provider_creds: - lib_logger.info(f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery.") - final_config[provider] = [str(p.resolve()) for p in local_provider_creds] + lib_logger.info( + f"Found {len(local_provider_creds)} existing local credential(s) for {provider}. Skipping discovery." + ) + final_config[provider] = [ + str(p.resolve()) for p in local_provider_creds + ] continue # If no local credentials exist, proceed with a one-time discovery and copy. @@ -140,13 +159,13 @@ def discover_and_prepare(self) -> Dict[str, List[str]]: path = Path(path_str).expanduser() if path.exists(): discovered_paths.add(path) - + # 2. If no overrides are provided via .env, scan the default directory # [MODIFIED] This logic is now disabled to prefer local-first credential management. # if not discovered_paths and default_dir.exists(): # for json_file in default_dir.glob('*.json'): # discovered_paths.add(json_file) - + if not discovered_paths: lib_logger.debug(f"No credential files found for provider: {provider}") continue @@ -161,13 +180,19 @@ def discover_and_prepare(self) -> Dict[str, List[str]]: try: # Since we've established no local files exist, we can copy directly. shutil.copy(source_path, local_path) - lib_logger.info(f"Copied '{source_path.name}' to local pool at '{local_path}'.") + lib_logger.info( + f"Copied '{source_path.name}' to local pool at '{local_path}'." + ) prepared_paths.append(str(local_path.resolve())) except Exception as e: - lib_logger.error(f"Failed to process OAuth file from '{source_path}': {e}") - + lib_logger.error( + f"Failed to process OAuth file from '{source_path}': {e}" + ) + if prepared_paths: - lib_logger.info(f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}") + lib_logger.info( + f"Discovered and prepared {len(prepared_paths)} credential(s) for provider: {provider}" + ) final_config[provider] = prepared_paths lib_logger.info("OAuth credential discovery complete.") diff --git a/src/rotator_library/provider_factory.py b/src/rotator_library/provider_factory.py index f13d16a..30ea06a 100644 --- a/src/rotator_library/provider_factory.py +++ b/src/rotator_library/provider_factory.py @@ -4,14 +4,17 @@ from .providers.qwen_auth_base import QwenAuthBase from .providers.iflow_auth_base import IFlowAuthBase from .providers.antigravity_auth_base import AntigravityAuthBase +from .providers.copilot_auth_base import CopilotAuthBase PROVIDER_MAP = { "gemini_cli": GeminiAuthBase, "qwen_code": QwenAuthBase, "iflow": IFlowAuthBase, "antigravity": AntigravityAuthBase, + "copilot": CopilotAuthBase, } + def get_provider_auth_class(provider_name: str): """ Returns the authentication class for a given provider. @@ -21,8 +24,9 @@ def get_provider_auth_class(provider_name: str): raise ValueError(f"Unknown provider: {provider_name}") return provider_class + def get_available_providers(): """ Returns a list of available provider names. """ - return list(PROVIDER_MAP.keys()) \ No newline at end of file + return list(PROVIDER_MAP.keys()) diff --git a/src/rotator_library/providers/__init__.py b/src/rotator_library/providers/__init__.py index c6bee07..ec0679a 100644 --- a/src/rotator_library/providers/__init__.py +++ b/src/rotator_library/providers/__init__.py @@ -89,7 +89,10 @@ def _register_providers(): provider_name = "nvidia_nim" PROVIDER_PLUGINS[provider_name] = attribute import logging - logging.getLogger('rotator_library').debug(f"Registered provider: {provider_name}") + + logging.getLogger("rotator_library").debug( + f"Registered provider: {provider_name}" + ) # Then, create dynamic plugins for custom OpenAI-compatible providers # Use environment variables directly (load_dotenv already called in main.py) @@ -114,6 +117,7 @@ def _register_providers(): "qwen_code", "gemini_cli", "antigravity", + "copilot", ]: continue @@ -129,7 +133,10 @@ def __init__(self): plugin_class = create_plugin_class(provider_name) PROVIDER_PLUGINS[provider_name] = plugin_class import logging - logging.getLogger('rotator_library').debug(f"Registered dynamic provider: {provider_name}") + + logging.getLogger("rotator_library").debug( + f"Registered dynamic provider: {provider_name}" + ) # Discover and register providers when the package is imported diff --git a/src/rotator_library/providers/copilot_auth_base.py b/src/rotator_library/providers/copilot_auth_base.py new file mode 100644 index 0000000..0b73926 --- /dev/null +++ b/src/rotator_library/providers/copilot_auth_base.py @@ -0,0 +1,631 @@ +# src/rotator_library/providers/copilot_auth_base.py +""" +GitHub Copilot OAuth2 authentication implementation using Device Flow. + +This is fundamentally different from Google OAuth providers: +- Uses GitHub's Device Flow instead of Authorization Code Flow +- Requires two-step token exchange: + 1. GitHub OAuth token (long-lived, used as "refresh token") + 2. Copilot API token (short-lived, ~30 min, used as "access token") + +Based on: https://github.com/sst/opencode-copilot-auth +""" + +import os +import json +import time +import asyncio +import logging +from pathlib import Path +from typing import Dict, Any, Optional, Union +import tempfile +import shutil + +import httpx +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + +from ..utils.headless_detection import is_headless_environment + +lib_logger = logging.getLogger("rotator_library") +console = Console() + + +class CopilotAuthBase: + """ + GitHub Copilot OAuth2 authentication using Device Flow. + + This provider uses GitHub's Device Authorization Grant flow, which is + more suitable for CLI applications than the web-based Authorization Code flow. + + Key differences from GoogleOAuthBase: + - Uses GitHub Device Flow (polls for authorization) + - Two-token system: GitHub OAuth token + Copilot API token + - Copilot API tokens expire quickly (~30 min) and need frequent refresh + + Subclasses may override: + - ENV_PREFIX: Prefix for environment variables (default: "COPILOT") + - REFRESH_EXPIRY_BUFFER_SECONDS: Time buffer before token expiry + + Supports both github.com and GitHub Enterprise deployments. + """ + + # GitHub Copilot OAuth Client ID (from VS Code Copilot extension) + CLIENT_ID = "Iv1.b507a08c87ecfe98" + + # Headers that mimic the official Copilot client + COPILOT_HEADERS = { + "User-Agent": "GitHubCopilotChat/0.32.4", + "Editor-Version": "vscode/1.105.1", + "Editor-Plugin-Version": "copilot-chat/0.32.4", + "Copilot-Integration-Id": "vscode-chat", + } + + # Environment variable prefix + ENV_PREFIX = "COPILOT" + + # Token refresh buffer (default: 5 minutes for short-lived Copilot tokens) + REFRESH_EXPIRY_BUFFER_SECONDS = 5 * 60 + + def __init__(self): + self._credentials_cache: Dict[str, Dict[str, Any]] = {} + self._refresh_locks: Dict[str, asyncio.Lock] = {} + self._locks_lock = asyncio.Lock() + + # [BACKOFF TRACKING] Track consecutive failures per credential + self._refresh_failures: Dict[str, int] = {} + self._next_refresh_after: Dict[str, float] = {} + + # [QUEUE SYSTEM] Sequential refresh processing + self._refresh_queue: asyncio.Queue = asyncio.Queue() + self._queued_credentials: set = set() + self._unavailable_credentials: set = set() + self._queue_tracking_lock = asyncio.Lock() + self._queue_processor_task: Optional[asyncio.Task] = None + + def _normalize_domain(self, url: str) -> str: + """Normalize GitHub domain from URL.""" + return url.replace("https://", "").replace("http://", "").rstrip("/") + + def _get_urls(self, domain: str = "github.com") -> Dict[str, str]: + """Get GitHub OAuth URLs for the specified domain.""" + return { + "DEVICE_CODE_URL": f"https://{domain}/login/device/code", + "ACCESS_TOKEN_URL": f"https://{domain}/login/oauth/access_token", + "COPILOT_API_KEY_URL": f"https://api.{domain}/copilot_internal/v2/token", + } + + def _parse_env_credential_path(self, path: str) -> Optional[str]: + """ + Parse a virtual env:// path and return the credential index. + + Supported formats: + - "env://provider/0" - Legacy single credential + - "env://provider/1" - First numbered credential + """ + if not path.startswith("env://"): + return None + + parts = path[6:].split("/") + if len(parts) >= 2: + return parts[1] + return "0" + + def _load_from_env( + self, credential_index: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Load OAuth credentials from environment variables. + + For Copilot, we need: + - {PREFIX}_GITHUB_TOKEN or {PREFIX}_{N}_GITHUB_TOKEN (long-lived GitHub OAuth token) + - Optionally: {PREFIX}_ENTERPRISE_URL for GitHub Enterprise + + The Copilot API token is fetched dynamically and cached. + """ + if credential_index and credential_index != "0": + prefix = f"{self.ENV_PREFIX}_{credential_index}" + default_email = f"env-user-{credential_index}" + else: + prefix = self.ENV_PREFIX + default_email = "env-user" + + # For Copilot, the "refresh_token" is the GitHub OAuth token + github_token = os.getenv(f"{prefix}_GITHUB_TOKEN") + if not github_token: + # Also check legacy naming + github_token = os.getenv(f"{prefix}_REFRESH_TOKEN") + + if not github_token: + return None + + lib_logger.debug(f"Loading {prefix} credentials from environment variables") + + # Check for enterprise URL + enterprise_url = os.getenv(f"{prefix}_ENTERPRISE_URL", "") + + creds = { + "refresh_token": github_token, # GitHub OAuth token used as refresh token + "access_token": "", # Copilot API token (fetched on demand) + "expiry_date": 0, # Will be set when Copilot token is fetched + "enterprise_url": enterprise_url, + "_proxy_metadata": { + "email": os.getenv(f"{prefix}_EMAIL", default_email), + "last_check_timestamp": time.time(), + "loaded_from_env": True, + "env_credential_index": credential_index or "0", + }, + } + + return creds + + async def _load_credentials(self, path: str) -> Dict[str, Any]: + """Load credentials from cache, environment, or file.""" + if path in self._credentials_cache: + return self._credentials_cache[path] + + async with await self._get_lock(path): + if path in self._credentials_cache: + return self._credentials_cache[path] + + # Check for virtual env:// path + credential_index = self._parse_env_credential_path(path) + if credential_index is not None: + env_creds = self._load_from_env(credential_index) + if env_creds: + lib_logger.info( + f"Using {self.ENV_PREFIX} credentials from environment variables (index: {credential_index})" + ) + self._credentials_cache[path] = env_creds + return env_creds + else: + raise IOError( + f"Environment variables for {self.ENV_PREFIX} credential index {credential_index} not found" + ) + + # Try loading from legacy env vars + env_creds = self._load_from_env() + if env_creds: + lib_logger.info( + f"Using {self.ENV_PREFIX} credentials from environment variables" + ) + self._credentials_cache[path] = env_creds + return env_creds + + # Fall back to file-based loading + try: + lib_logger.debug( + f"Loading {self.ENV_PREFIX} credentials from file: {path}" + ) + with open(path, "r") as f: + creds = json.load(f) + self._credentials_cache[path] = creds + return creds + except FileNotFoundError: + raise IOError( + f"{self.ENV_PREFIX} OAuth credential file not found at '{path}'" + ) + except Exception as e: + raise IOError( + f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}" + ) + + async def _save_credentials(self, path: str, creds: Dict[str, Any]): + """Save credentials to file with atomic write.""" + if creds.get("_proxy_metadata", {}).get("loaded_from_env"): + lib_logger.debug("Credentials loaded from env, skipping file save") + self._credentials_cache[path] = creds + return + + parent_dir = os.path.dirname(os.path.abspath(path)) + os.makedirs(parent_dir, exist_ok=True) + + tmp_fd = None + tmp_path = None + try: + tmp_fd, tmp_path = tempfile.mkstemp( + dir=parent_dir, prefix=".tmp_", suffix=".json", text=True + ) + + with os.fdopen(tmp_fd, "w") as f: + json.dump(creds, f, indent=2) + tmp_fd = None + + try: + os.chmod(tmp_path, 0o600) + except (OSError, AttributeError): + pass + + shutil.move(tmp_path, path) + tmp_path = None + + self._credentials_cache[path] = creds + lib_logger.debug( + f"Saved updated {self.ENV_PREFIX} OAuth credentials to '{path}' (atomic write)." + ) + + except Exception as e: + lib_logger.error( + f"Failed to save updated {self.ENV_PREFIX} OAuth credentials to '{path}': {e}" + ) + if tmp_fd is not None: + try: + os.close(tmp_fd) + except: + pass + if tmp_path and os.path.exists(tmp_path): + try: + os.unlink(tmp_path) + except: + pass + raise + + def _is_token_expired(self, creds: Dict[str, Any]) -> bool: + """Check if the Copilot API token is expired.""" + expiry_timestamp = creds.get("expiry_date", 0) + if isinstance(expiry_timestamp, (int, float)) and expiry_timestamp > 0: + # expiry_date is stored in milliseconds (like gemini-cli format) + return (expiry_timestamp / 1000) < ( + time.time() + self.REFRESH_EXPIRY_BUFFER_SECONDS + ) + # If no expiry or zero, token is expired + return True + + async def _refresh_copilot_token( + self, path: str, creds: Dict[str, Any], force: bool = False + ) -> Dict[str, Any]: + """ + Refresh the Copilot API token using the GitHub OAuth token. + + The GitHub OAuth token (refresh_token) is long-lived. + The Copilot API token (access_token) expires after ~30 minutes. + """ + async with await self._get_lock(path): + # Skip if token is still valid (unless forced) + cached_creds = self._credentials_cache.get(path, creds) + if not force and not self._is_token_expired(cached_creds): + return cached_creds + + github_token = creds.get("refresh_token") + if not github_token: + raise ValueError( + "No GitHub OAuth token (refresh_token) found in credentials." + ) + + enterprise_url = creds.get("enterprise_url", "") + domain = ( + self._normalize_domain(enterprise_url) + if enterprise_url + else "github.com" + ) + urls = self._get_urls(domain) + + lib_logger.debug( + f"Refreshing {self.ENV_PREFIX} Copilot API token for '{Path(path).name}' (forced: {force})..." + ) + + async with httpx.AsyncClient() as client: + try: + response = await client.get( + urls["COPILOT_API_KEY_URL"], + headers={ + "Accept": "application/json", + "Authorization": f"Bearer {github_token}", + **self.COPILOT_HEADERS, + }, + timeout=30.0, + ) + + if response.status_code == 401: + lib_logger.warning( + f"GitHub token invalid for '{Path(path).name}' (HTTP 401). " + f"Token may have been revoked. Starting re-authentication..." + ) + return await self.initialize_token(path) + + response.raise_for_status() + token_data = response.json() + + # Update credentials with new Copilot API token + creds["access_token"] = token_data.get("token", "") + # expires_at is Unix timestamp in seconds + expires_at = token_data.get("expires_at", 0) + creds["expiry_date"] = expires_at * 1000 # Convert to milliseconds + + # Update metadata + if "_proxy_metadata" not in creds: + creds["_proxy_metadata"] = {} + creds["_proxy_metadata"]["last_check_timestamp"] = time.time() + + await self._save_credentials(path, creds) + lib_logger.debug( + f"Successfully refreshed {self.ENV_PREFIX} Copilot API token for '{Path(path).name}'." + ) + return creds + + except httpx.HTTPStatusError as e: + lib_logger.error( + f"Failed to refresh Copilot token (HTTP {e.response.status_code}): {e}" + ) + raise + except httpx.RequestError as e: + lib_logger.error(f"Network error refreshing Copilot token: {e}") + raise + + async def _get_lock(self, path: str) -> asyncio.Lock: + """Get or create a lock for the given credential path.""" + async with self._locks_lock: + if path not in self._refresh_locks: + self._refresh_locks[path] = asyncio.Lock() + return self._refresh_locks[path] + + async def proactively_refresh(self, credential_path: str): + """Proactively refresh a credential if it's nearing expiry.""" + creds = await self._load_credentials(credential_path) + if self._is_token_expired(creds): + await self._refresh_copilot_token(credential_path, creds) + + def is_credential_available(self, path: str) -> bool: + """Check if a credential is available for rotation.""" + return path not in self._unavailable_credentials + + async def initialize_token( + self, creds_or_path: Union[Dict[str, Any], str] + ) -> Dict[str, Any]: + """ + Initialize or re-authenticate GitHub Copilot credentials using Device Flow. + + Device Flow steps: + 1. Request device code from GitHub + 2. Display user code and verification URL + 3. Poll for authorization completion + 4. Exchange device code for access token + """ + path = creds_or_path if isinstance(creds_or_path, str) else None + + if isinstance(creds_or_path, dict): + display_name = creds_or_path.get("_proxy_metadata", {}).get( + "display_name", "in-memory object" + ) + else: + display_name = Path(path).name if path else "in-memory object" + + lib_logger.debug( + f"Initializing {self.ENV_PREFIX} token for '{display_name}'..." + ) + + try: + creds = ( + await self._load_credentials(creds_or_path) if path else creds_or_path + ) + needs_auth = False + reason = "" + + if not creds.get("refresh_token"): + needs_auth = True + reason = "GitHub OAuth token is missing" + elif self._is_token_expired(creds): + # Try to refresh the Copilot API token + try: + return await self._refresh_copilot_token(path, creds) + except Exception as e: + lib_logger.warning( + f"Automatic token refresh for '{display_name}' failed: {e}. " + f"Proceeding to interactive login." + ) + needs_auth = True + reason = "Token refresh failed" + + if not needs_auth: + lib_logger.info( + f"{self.ENV_PREFIX} OAuth token at '{display_name}' is valid." + ) + return creds + + lib_logger.warning( + f"{self.ENV_PREFIX} OAuth token for '{display_name}' needs setup: {reason}." + ) + + # Check for enterprise URL in existing creds or environment + enterprise_url = creds.get("enterprise_url", "") + if not enterprise_url: + enterprise_url = os.getenv(f"{self.ENV_PREFIX}_ENTERPRISE_URL", "") + + domain = ( + self._normalize_domain(enterprise_url) + if enterprise_url + else "github.com" + ) + urls = self._get_urls(domain) + + is_headless = is_headless_environment() + + # Step 1: Request device code + async with httpx.AsyncClient() as client: + device_response = await client.post( + urls["DEVICE_CODE_URL"], + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + "User-Agent": "GitHubCopilotChat/0.35.0", + }, + json={ + "client_id": self.CLIENT_ID, + "scope": "read:user", + }, + timeout=30.0, + ) + + if not device_response.is_success: + raise Exception( + f"Failed to initiate device authorization: {device_response.text}" + ) + + device_data = device_response.json() + user_code = device_data.get("user_code", "") + verification_uri = device_data.get("verification_uri", "") + device_code = device_data.get("device_code", "") + interval = device_data.get("interval", 5) + expires_in = device_data.get("expires_in", 900) + + # Display instructions + if is_headless: + auth_panel_text = Text.from_markup( + "Running in headless environment (no GUI detected).\n" + "Please open the URL below in a browser on another machine to authorize:\n" + ) + else: + auth_panel_text = Text.from_markup( + "Please visit the URL below and enter the code to authorize:\n" + ) + + console.print( + Panel( + auth_panel_text, + title=f"{self.ENV_PREFIX} OAuth Setup for [bold yellow]{display_name}[/bold yellow]", + style="bold blue", + ) + ) + console.print(f"[bold]URL:[/bold] {verification_uri}") + console.print( + f"[bold]Code:[/bold] [bold green]{user_code}[/bold green]\n" + ) + + # Step 2: Poll for authorization + max_polls = expires_in // interval + with console.status( + f"[bold green]Waiting for you to complete authentication (code: {user_code})...[/bold green]", + spinner="dots", + ): + for _ in range(max_polls): + await asyncio.sleep(interval) + + token_response = await client.post( + urls["ACCESS_TOKEN_URL"], + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + "User-Agent": "GitHubCopilotChat/0.35.0", + }, + json={ + "client_id": self.CLIENT_ID, + "device_code": device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + }, + timeout=30.0, + ) + + if not token_response.is_success: + continue + + token_data = token_response.json() + + if "access_token" in token_data: + # Success! Store the GitHub OAuth token + github_token = token_data["access_token"] + + # Build new credentials + new_creds = { + "refresh_token": github_token, + "access_token": "", # Will be filled by first API call + "expiry_date": 0, + "enterprise_url": enterprise_url, + "_proxy_metadata": { + "last_check_timestamp": time.time(), + }, + } + + # Fetch user info + try: + user_response = await client.get( + f"https://api.{domain}/user", + headers={"Authorization": f"Bearer {github_token}"}, + timeout=10.0, + ) + if user_response.is_success: + user_info = user_response.json() + new_creds["_proxy_metadata"]["email"] = ( + user_info.get( + "email", user_info.get("login", "unknown") + ) + ) + except Exception as e: + lib_logger.warning(f"Failed to fetch user info: {e}") + new_creds["_proxy_metadata"]["email"] = "unknown" + + if path: + await self._save_credentials(path, new_creds) + + lib_logger.info( + f"{self.ENV_PREFIX} OAuth initialized successfully for '{display_name}'." + ) + + # Now fetch the Copilot API token + return await self._refresh_copilot_token( + path, new_creds, force=True + ) + + if token_data.get("error") == "authorization_pending": + continue + + if token_data.get("error"): + raise Exception(f"OAuth failed: {token_data.get('error')}") + + raise Exception("OAuth flow timed out. Please try again.") + + except Exception as e: + raise ValueError( + f"Failed to initialize {self.ENV_PREFIX} OAuth for '{path}': {e}" + ) + + async def get_auth_header(self, credential_path: str) -> Dict[str, str]: + """Get Authorization header with fresh Copilot API token.""" + creds = await self._load_credentials(credential_path) + if self._is_token_expired(creds): + creds = await self._refresh_copilot_token(credential_path, creds) + return {"Authorization": f"Bearer {creds['access_token']}"} + + async def get_user_info( + self, creds_or_path: Union[Dict[str, Any], str] + ) -> Dict[str, Any]: + """Get user info from cached metadata or API.""" + path = creds_or_path if isinstance(creds_or_path, str) else None + creds = await self._load_credentials(creds_or_path) if path else creds_or_path + + if creds.get("_proxy_metadata", {}).get("email"): + return {"email": creds["_proxy_metadata"]["email"]} + + # Fetch from GitHub API + github_token = creds.get("refresh_token") + if github_token: + enterprise_url = creds.get("enterprise_url", "") + domain = ( + self._normalize_domain(enterprise_url) + if enterprise_url + else "github.com" + ) + + async with httpx.AsyncClient() as client: + try: + response = await client.get( + f"https://api.{domain}/user", + headers={"Authorization": f"Bearer {github_token}"}, + timeout=10.0, + ) + if response.is_success: + user_info = response.json() + email = user_info.get( + "email", user_info.get("login", "unknown") + ) + creds["_proxy_metadata"] = { + "email": email, + "last_check_timestamp": time.time(), + } + if path: + await self._save_credentials(path, creds) + return {"email": email} + except Exception as e: + lib_logger.warning(f"Failed to fetch user info: {e}") + + return {"email": "unknown"} diff --git a/src/rotator_library/providers/copilot_provider.py b/src/rotator_library/providers/copilot_provider.py new file mode 100644 index 0000000..f35cdac --- /dev/null +++ b/src/rotator_library/providers/copilot_provider.py @@ -0,0 +1,484 @@ +# src/rotator_library/providers/copilot_provider.py +""" +GitHub Copilot Provider - Custom API integration for Copilot Chat. + +This provider implements the full Copilot Chat API integration including: +- Custom OAuth authentication (Device Flow) +- Direct API calls bypassing LiteLLM +- X-Initiator header control (user vs agent mode) +- Vision request support +- Both streaming and non-streaming responses + +Based on: +- https://github.com/sst/opencode-copilot-auth +- https://github.com/Tarquinen/dotfiles/tree/main/.config/opencode/plugin/copilot-force-agent-header +""" + +from __future__ import annotations + +import copy +import json +import logging +import os +import random +import time +import uuid +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Union + +import httpx +import litellm + +from .provider_interface import ProviderInterface +from .copilot_auth_base import CopilotAuthBase + +lib_logger = logging.getLogger("rotator_library") + + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +# Default Copilot API base URLs +COPILOT_API_BASE = "https://api.githubcopilot.com" + +# Available Copilot models (these may vary based on subscription) +DEFAULT_COPILOT_MODELS = [ + "gpt-4o", + "gpt-4.1", + "gpt-4.1-mini", + "claude-3.5-sonnet", + "claude-sonnet-4", + "o3-mini", + "o1", + "gemini-2.0-flash-001", +] + +# Responses API alternate input types for agent detection +RESPONSES_API_ALTERNATE_INPUT_TYPES = [ + "file_search_call", + "computer_call", + "computer_call_output", + "web_search_call", + "function_call", + "function_call_output", + "image_generation_call", + "code_interpreter_call", + "local_shell_call", + "local_shell_call_output", + "mcp_list_tools", + "mcp_approval_request", + "mcp_approval_response", + "mcp_call", + "reasoning", +] + + +def _env_bool(key: str, default: bool = False) -> bool: + """Get boolean from environment variable.""" + return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes") + + +def _env_int(key: str, default: int) -> int: + """Get integer from environment variable.""" + try: + return int(os.getenv(key, str(default))) + except ValueError: + return default + + +def _env_float(key: str, default: float) -> float: + """Get float from environment variable.""" + try: + return float(os.getenv(key, str(default))) + except ValueError: + return default + + +# ============================================================================= +# MAIN PROVIDER CLASS +# ============================================================================= + + +class CopilotProvider(CopilotAuthBase, ProviderInterface): + """ + GitHub Copilot provider with custom API integration. + + Features: + - Device Flow OAuth authentication + - Direct Copilot Chat API calls + - Configurable X-Initiator header (user vs agent mode) + - Configurable agent header percentage for first messages + - Vision request support + - Both streaming and non-streaming responses + + Environment Variables: + - COPILOT_FORCE_AGENT_HEADER: Always use "agent" initiator (default: false) + - COPILOT_AGENT_PERCENTAGE: For first messages, % chance of "agent" (0-100, default: 100) + - COPILOT_MODELS: Comma-separated list of available models + - COPILOT_ENTERPRISE_URL: GitHub Enterprise URL (optional) + """ + + skip_cost_calculation = True # Copilot uses subscription, not token billing + + def __init__(self): + super().__init__() + + # X-Initiator header configuration + # Based on https://github.com/Tarquinen/dotfiles/tree/main/.config/opencode/plugin/copilot-force-agent-header + self._force_agent_header = _env_bool("COPILOT_FORCE_AGENT_HEADER", False) + self._agent_percentage = _env_int("COPILOT_AGENT_PERCENTAGE", 100) + + # Model configuration + models_env = os.getenv("COPILOT_MODELS", "") + if models_env: + self._available_models = [ + m.strip() for m in models_env.split(",") if m.strip() + ] + else: + self._available_models = DEFAULT_COPILOT_MODELS + + lib_logger.debug( + f"CopilotProvider initialized: force_agent={self._force_agent_header}, " + f"agent_percentage={self._agent_percentage}%, models={len(self._available_models)}" + ) + + # ========================================================================= + # PROVIDER INTERFACE IMPLEMENTATION + # ========================================================================= + + def has_custom_logic(self) -> bool: + """Returns True - Copilot uses custom API calls, not LiteLLM.""" + return True + + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """ + Return available Copilot models. + + The api_key here is actually the credential path for OAuth providers. + For Copilot, models are configured via environment or defaults. + """ + return self._available_models + + def get_credential_priority(self, credential: str) -> Optional[int]: + """ + Returns priority for credential. Copilot doesn't have tiers. + All Copilot credentials are treated equally. + """ + return 1 # All credentials have same priority + + def get_model_tier_requirement(self, model: str) -> Optional[int]: + """ + Returns model tier requirement. Copilot doesn't restrict by tier. + """ + return None + + # ========================================================================= + # X-INITIATOR HEADER LOGIC + # ========================================================================= + + def _determine_initiator( + self, messages: List[Dict[str, Any]], is_responses_api: bool = False + ) -> str: + """ + Determine the X-Initiator header value based on conversation context. + + Logic (based on opencode-copilot-auth): + 1. If message contains tool/assistant roles → "agent" (ongoing conversation) + 2. If using Responses API with certain types → "agent" + 3. For first messages: + - If COPILOT_FORCE_AGENT_HEADER=true → "agent" + - Else: COPILOT_AGENT_PERCENTAGE% chance of "agent", else "user" + + Returns: + "agent" or "user" + """ + # Check for ongoing agent conversation (has assistant/tool messages) + if messages: + for msg in messages: + role = msg.get("role", "") + if role in ["tool", "assistant"]: + return "agent" + + # Check for vision content in messages + content = msg.get("content", []) + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image_url": + pass # Vision doesn't affect initiator + + # Check for Responses API alternate input types + if is_responses_api and messages: + last_input = messages[-1] if messages else {} + input_type = last_input.get("type", "") + if input_type in RESPONSES_API_ALTERNATE_INPUT_TYPES: + return "agent" + if last_input.get("role") == "assistant": + return "agent" + + # First message logic + if self._force_agent_header: + return "agent" + + if self._agent_percentage >= 100: + return "agent" + elif self._agent_percentage <= 0: + return "user" + else: + # Random based on percentage + return "agent" if random.random() * 100 < self._agent_percentage else "user" + + def _is_vision_request(self, messages: List[Dict[str, Any]]) -> bool: + """Check if request contains vision/image content.""" + for msg in messages: + content = msg.get("content", []) + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") in [ + "image_url", + "input_image", + ]: + return True + return False + + # ========================================================================= + # API COMPLETION + # ========================================================================= + + async def acompletion( + self, client: httpx.AsyncClient, **kwargs + ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: + """ + Handle completion requests to Copilot API. + + This method: + 1. Gets fresh Copilot API token + 2. Builds request with proper headers (X-Initiator, etc.) + 3. Makes direct API call to Copilot + 4. Parses response into LiteLLM format + """ + credential_path = kwargs.get("api_key", "") + model = kwargs.get("model", "gpt-4o") + messages = kwargs.get("messages", []) + stream = kwargs.get("stream", False) + + # Strip provider prefix if present + if "/" in model: + model = model.split("/")[-1] + + # Get fresh credentials and token + creds = await self._load_credentials(credential_path) + if self._is_token_expired(creds): + creds = await self._refresh_copilot_token(credential_path, creds) + + access_token = creds.get("access_token", "") + enterprise_url = creds.get("enterprise_url", "") + + # Determine base URL + if enterprise_url: + base_url = f"https://copilot-api.{self._normalize_domain(enterprise_url)}" + else: + base_url = COPILOT_API_BASE + + # Determine headers + initiator = self._determine_initiator(messages) + is_vision = self._is_vision_request(messages) + + headers = { + **self.COPILOT_HEADERS, + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + "Openai-Intent": "conversation-edits", + "X-Initiator": initiator, + } + + if is_vision: + headers["Copilot-Vision-Request"] = "true" + + # Build request body (OpenAI-compatible format) + body = { + "model": model, + "messages": messages, + "stream": stream, + } + + # Add optional parameters + if kwargs.get("temperature") is not None: + body["temperature"] = kwargs["temperature"] + if kwargs.get("max_tokens") is not None: + body["max_tokens"] = kwargs["max_tokens"] + if kwargs.get("top_p") is not None: + body["top_p"] = kwargs["top_p"] + if kwargs.get("stop") is not None: + body["stop"] = kwargs["stop"] + if kwargs.get("tools"): + body["tools"] = kwargs["tools"] + if kwargs.get("tool_choice"): + body["tool_choice"] = kwargs["tool_choice"] + + lib_logger.debug( + f"Copilot request: model={model}, initiator={initiator}, " + f"stream={stream}, vision={is_vision}" + ) + + if stream: + return self._handle_streaming_response( + client, base_url, headers, body, model + ) + else: + return await self._handle_non_streaming_response( + client, base_url, headers, body, model + ) + + async def _handle_non_streaming_response( + self, + client: httpx.AsyncClient, + base_url: str, + headers: Dict[str, str], + body: Dict[str, Any], + model: str, + ) -> litellm.ModelResponse: + """Handle non-streaming Copilot API response.""" + url = f"{base_url}/chat/completions" + + try: + response = await client.post( + url, + headers=headers, + json=body, + timeout=300.0, # 5 minute timeout for long completions + ) + response.raise_for_status() + data = response.json() + + # Convert to LiteLLM format + return self._convert_to_litellm_response(data, model) + + except httpx.HTTPStatusError as e: + lib_logger.error( + f"Copilot API error (HTTP {e.response.status_code}): {e.response.text}" + ) + raise + except Exception as e: + lib_logger.error(f"Copilot request failed: {e}") + raise + + async def _handle_streaming_response( + self, + client: httpx.AsyncClient, + base_url: str, + headers: Dict[str, str], + body: Dict[str, Any], + model: str, + ) -> AsyncGenerator[litellm.ModelResponse, None]: + """Handle streaming Copilot API response.""" + url = f"{base_url}/chat/completions" + + async def stream_generator(): + try: + async with client.stream( + "POST", + url, + headers=headers, + json=body, + timeout=300.0, + ) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): + continue + + data_str = line[6:] # Remove "data: " prefix + if data_str == "[DONE]": + break + + try: + chunk_data = json.loads(data_str) + yield self._convert_to_litellm_chunk(chunk_data, model) + except json.JSONDecodeError: + continue + + except httpx.HTTPStatusError as e: + lib_logger.error( + f"Copilot streaming error (HTTP {e.response.status_code}): {e.response.text}" + ) + raise + except Exception as e: + lib_logger.error(f"Copilot streaming failed: {e}") + raise + + return stream_generator() + + def _convert_to_litellm_response( + self, data: Dict[str, Any], model: str + ) -> litellm.ModelResponse: + """Convert Copilot API response to LiteLLM ModelResponse format.""" + choices = [] + for choice in data.get("choices", []): + message = choice.get("message", {}) + litellm_choice = litellm.Choices( + index=choice.get("index", 0), + message=litellm.Message( + role=message.get("role", "assistant"), + content=message.get("content", ""), + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + + # Handle tool calls + if message.get("tool_calls"): + litellm_choice.message.tool_calls = message["tool_calls"] + + choices.append(litellm_choice) + + usage = data.get("usage", {}) + return litellm.ModelResponse( + id=data.get("id", f"copilot-{uuid.uuid4()}"), + choices=choices, + created=data.get("created", int(time.time())), + model=f"copilot/{model}", + usage=litellm.Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + ) + + def _convert_to_litellm_chunk( + self, chunk_data: Dict[str, Any], model: str + ) -> litellm.ModelResponse: + """Convert Copilot streaming chunk to LiteLLM format.""" + choices = [] + for choice in chunk_data.get("choices", []): + delta = choice.get("delta", {}) + litellm_choice = litellm.Choices( + index=choice.get("index", 0), + delta=litellm.Delta( + role=delta.get("role"), + content=delta.get("content"), + ), + finish_reason=choice.get("finish_reason"), + ) + + # Handle tool call deltas + if delta.get("tool_calls"): + litellm_choice.delta.tool_calls = delta["tool_calls"] + + choices.append(litellm_choice) + + return litellm.ModelResponse( + id=chunk_data.get("id", f"copilot-{uuid.uuid4()}"), + choices=choices, + created=chunk_data.get("created", int(time.time())), + model=f"copilot/{model}", + ) + + async def aembedding( + self, client: httpx.AsyncClient, **kwargs + ) -> litellm.EmbeddingResponse: + """ + Copilot doesn't support embeddings API. + Raises NotImplementedError. + """ + raise NotImplementedError("Copilot does not support embeddings API")