3131from codeflash .code_utils .env_utils import check_formatter_installed , get_codeflash_api_key
3232from codeflash .code_utils .git_utils import get_git_remotes , get_repo_owner_and_name
3333from codeflash .code_utils .github_utils import get_github_secrets_page_url
34+ from codeflash .code_utils .oauth_handler import perform_oauth_signin
3435from codeflash .code_utils .shell_utils import get_shell_rc_path , save_api_key_to_rc
3536from codeflash .either import is_successful
3637from codeflash .lsp .helpers import is_LSP_enabled
@@ -1149,25 +1150,14 @@ def convert(self, value: str, param: click.Parameter | None, ctx: click.Context
11491150
11501151# Returns True if the user entered a new API key, False if they used an existing one
11511152def prompt_api_key () -> bool :
1152- import threading
1153- import socket
1154- import http .server
1155- import urllib .parse
1156- import random
1157- import string
1158- import base64
1159- import hashlib
1160- import time
1161- import json
1162- import webbrowser
1163- import requests
1164-
1165- BASE_URL = "https://app.codeflash.ai/"
1153+ """Prompt user for API key via OAuth or manual entry"""
11661154
1155+ # Check for existing API key
11671156 try :
11681157 existing_api_key = get_codeflash_api_key ()
11691158 except OSError :
11701159 existing_api_key = None
1160+
11711161 if existing_api_key :
11721162 display_key = f"{ existing_api_key [:3 ]} ****{ existing_api_key [- 4 :]} "
11731163 api_key_panel = Panel (
@@ -1183,15 +1173,17 @@ def prompt_api_key() -> bool:
11831173 console .print (api_key_panel )
11841174 console .print ()
11851175 return False
1176+
1177+ # Prompt for authentication method
11861178 auth_choices = [
1187- "🔐 Sign in" ,
1188- "🔑 Enter Api key"
1179+ "🔐 Login in with Codeflash " ,
1180+ "🔑 Use Codeflash API key"
11891181 ]
1190- name = "auth_method"
1182+
11911183 questions = [
11921184 inquirer .List (
1193- name ,
1194- message = "How would you like to sign in ?" ,
1185+ "auth_method" ,
1186+ message = "How would you like to authenticate ?" ,
11951187 choices = auth_choices ,
11961188 default = auth_choices [0 ],
11971189 carousel = True ,
@@ -1201,133 +1193,37 @@ def prompt_api_key() -> bool:
12011193 answers = inquirer .prompt (questions , theme = CodeflashTheme ())
12021194 if not answers :
12031195 apologize_and_exit ()
1204- method = answers [name ]
1205- if method == "🔑 Enter Api key" :
1196+
1197+ method = answers ["auth_method" ]
1198+
1199+ if method == auth_choices [1 ]:
12061200 enter_api_key_and_save_to_rc ()
12071201 ph ("cli-new-api-key-entered" )
12081202 return True
1209- # OAuth PKCE Flow for "🔐 Sign in"
1210- # 1. Start a local server on available port
1211- class OAuthCallbackHandler (http .server .BaseHTTPRequestHandler ):
1212- server_version = "CFHTTP"
1213- code = None
1214- state = None
1215- error = None
1216- def do_GET (self ):
1217- parsed = urllib .parse .urlparse (self .path )
1218- if parsed .path != "/callback" :
1219- self .send_response (404 )
1220- self .end_headers ()
1221- return
1222- params = urllib .parse .parse_qs (parsed .query )
1223- OAuthCallbackHandler .code = params .get ("code" , [None ])[0 ]
1224- OAuthCallbackHandler .state = params .get ("state" , [None ])[0 ]
1225- OAuthCallbackHandler .error = params .get ("error" , [None ])[0 ]
1226- self .send_response (200 )
1227- self .send_header ("Content-type" , "text/html" )
1228- self .end_headers ()
1229- if OAuthCallbackHandler .code :
1230- self .wfile .write (b"<html><body><h2>Sign-in successful!</h2>You may close this window.</body></html>" )
1231- elif OAuthCallbackHandler .error :
1232- self .wfile .write (b"<html><body><h2>Sign-in failed.</h2></body></html>" )
1233- else :
1234- self .wfile .write (b"<html><body><h2>Missing code.</h2></body></html>" )
1235-
1236- def log_message (self , format , * args ):
1237- # Silence HTTP logs
1238- pass
1239-
1240- # Find a free port
1241- def get_free_port ():
1242- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
1243- s .bind (("" , 0 ))
1244- return s .getsockname ()[1 ]
1245-
1246- port = get_free_port ()
1247- redirect_uri = f"http://localhost:{ port } /callback"
1248- # PKCE code_verifier and code_challenge
1249- def random_string (length = 64 ):
1250- return '' .join (random .choices (string .ascii_letters + string .digits + "-._~" , k = length ))
1251- code_verifier = random_string (64 )
1252- code_challenge = base64 .urlsafe_b64encode (
1253- hashlib .sha256 (code_verifier .encode ()).digest ()
1254- ).rstrip (b'=' ).decode ()
1255- state = random_string (16 )
1256-
1257- # Compose auth URL
1258- auth_url = (
1259- f"{ BASE_URL } codeflash/auth?"
1260- f"response_type=code"
1261- f"&client_id=cf_vscode_app"
1262- f"&redirect_uri={ urllib .parse .quote (redirect_uri )} "
1263- f"&code_challenge={ code_challenge } "
1264- f"&code_challenge_method=sha256"
1265- f"&state={ state } "
1266- )
12671203
1268- # Start HTTP server in thread
1269- handler_class = OAuthCallbackHandler
1270- httpd = http .server .HTTPServer (("localhost" , port ), handler_class )
1271- server_thread = threading .Thread (target = httpd .handle_request )
1272- server_thread .daemon = True
1273- server_thread .start ()
1274- click .echo (f"🌐 Opening browser to sign in to Codeflash…" )
1275- webbrowser .open (auth_url )
1276- click .echo (f"If your browser did not open, visit:\n { auth_url } " )
1277- # Wait for callback (with timeout)
1278- max_wait = 120 # seconds
1279- waited = 0
1280- while handler_class .code is None and handler_class .error is None and waited < max_wait :
1281- time .sleep (0.5 )
1282- waited += 0.5
1283- httpd .server_close ()
1284- if handler_class .error :
1285- click .echo (f"❌ Sign-in failed: { handler_class .error } " )
1286- apologize_and_exit ()
1287- if not handler_class .code or not handler_class .state :
1288- click .echo ("❌ Did not receive code from sign-in. Please try again." )
1289- apologize_and_exit ()
1290- if handler_class .state != state :
1291- click .echo ("❌ State mismatch in OAuth callback." )
1292- apologize_and_exit ()
1293- code = handler_class .code
1294- console .print (code )
1295- # Exchange code for token
1296- token_url = f"{ BASE_URL } codeflash/auth/oauth/token"
1297- data = {
1298- "grant_type" : "authorization_code" ,
1299- "code" : code ,
1300- "code_verifier" : code_verifier ,
1301- "redirect_uri" : redirect_uri ,
1302- "client_id" : "cf_vscode_app"
1303- }
1304- try :
1305- resp = requests .post (
1306- token_url ,
1307- headers = {"Content-Type" : "application/json" },
1308- data = json .dumps (data ),
1309- timeout = 10 ,
1310- )
1311- resp .raise_for_status ()
1312- token_json = resp .json ()
1313- api_key = token_json .get ("api_key" ) or token_json .get ("access_token" )
1314- if not api_key :
1315- click .echo ("❌ Could not retrieve API key from response." )
1316- apologize_and_exit ()
1317- result = save_api_key_to_rc (api_key )
1318- if is_successful (result ):
1319- click .echo (result .unwrap ())
1320- click .echo ("✅ Signed in successfully and API key saved!" )
1321- else :
1322- click .echo (result .failure ())
1323- click .pause ()
1324- os .environ ["CODEFLASH_API_KEY" ] = api_key
1325- ph ("cli-new-api-key-entered" )
1326- return True
1327- except Exception as e :
1328- click .echo (f"❌ Failed to exchange code for API key: { e } " )
1204+ # Perform OAuth sign-in
1205+ api_key = perform_oauth_signin ()
1206+
1207+ if not api_key :
13291208 apologize_and_exit ()
13301209
1210+ # Save API key
1211+ shell_rc_path = get_shell_rc_path ()
1212+ if not shell_rc_path .exists () and os .name == "nt" :
1213+ shell_rc_path .touch ()
1214+ click .echo (f"✅ Created { shell_rc_path } " )
1215+
1216+ result = save_api_key_to_rc (api_key )
1217+ if is_successful (result ):
1218+ click .echo (result .unwrap ())
1219+ click .echo ("✅ Signed in successfully and API key saved!" )
1220+ else :
1221+ click .echo (result .failure ())
1222+ click .pause ()
1223+
1224+ os .environ ["CODEFLASH_API_KEY" ] = api_key
1225+ ph ("cli-oauth-signin-completed" )
1226+ return True
13311227
13321228def enter_api_key_and_save_to_rc () -> None :
13331229 browser_launched = False
0 commit comments