Skip to content

Commit 6082650

Browse files
committed
add login flow
1 parent 93990f5 commit 6082650

File tree

2 files changed

+445
-140
lines changed

2 files changed

+445
-140
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 36 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from codeflash.code_utils.env_utils import check_formatter_installed, get_codeflash_api_key
3232
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
3333
from codeflash.code_utils.github_utils import get_github_secrets_page_url
34+
from codeflash.code_utils.oauth_handler import perform_oauth_signin
3435
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
3536
from codeflash.either import is_successful
3637
from codeflash.lsp.helpers import is_LSP_enabled
@@ -1149,25 +1150,14 @@ def convert(self, value: str, param: click.Parameter | None, ctx: click.Context
11491150

11501151
# Returns True if the user entered a new API key, False if they used an existing one
11511152
def prompt_api_key() -> bool:
1152-
import threading
1153-
import socket
1154-
import http.server
1155-
import urllib.parse
1156-
import random
1157-
import string
1158-
import base64
1159-
import hashlib
1160-
import time
1161-
import json
1162-
import webbrowser
1163-
import requests
1164-
1165-
BASE_URL = "https://app.codeflash.ai/"
1153+
"""Prompt user for API key via OAuth or manual entry"""
11661154

1155+
# Check for existing API key
11671156
try:
11681157
existing_api_key = get_codeflash_api_key()
11691158
except OSError:
11701159
existing_api_key = None
1160+
11711161
if existing_api_key:
11721162
display_key = f"{existing_api_key[:3]}****{existing_api_key[-4:]}"
11731163
api_key_panel = Panel(
@@ -1183,15 +1173,17 @@ def prompt_api_key() -> bool:
11831173
console.print(api_key_panel)
11841174
console.print()
11851175
return False
1176+
1177+
# Prompt for authentication method
11861178
auth_choices = [
1187-
"🔐 Sign in",
1188-
"🔑 Enter Api key"
1179+
"🔐 Login in with Codeflash",
1180+
"🔑 Use Codeflash API key"
11891181
]
1190-
name="auth_method"
1182+
11911183
questions = [
11921184
inquirer.List(
1193-
name,
1194-
message="How would you like to sign in?",
1185+
"auth_method",
1186+
message="How would you like to authenticate?",
11951187
choices=auth_choices,
11961188
default=auth_choices[0],
11971189
carousel=True,
@@ -1201,133 +1193,37 @@ def prompt_api_key() -> bool:
12011193
answers = inquirer.prompt(questions, theme=CodeflashTheme())
12021194
if not answers:
12031195
apologize_and_exit()
1204-
method = answers[name]
1205-
if method == "🔑 Enter Api key":
1196+
1197+
method = answers["auth_method"]
1198+
1199+
if method == auth_choices[1]:
12061200
enter_api_key_and_save_to_rc()
12071201
ph("cli-new-api-key-entered")
12081202
return True
1209-
# OAuth PKCE Flow for "🔐 Sign in"
1210-
# 1. Start a local server on available port
1211-
class OAuthCallbackHandler(http.server.BaseHTTPRequestHandler):
1212-
server_version = "CFHTTP"
1213-
code = None
1214-
state = None
1215-
error = None
1216-
def do_GET(self):
1217-
parsed = urllib.parse.urlparse(self.path)
1218-
if parsed.path != "/callback":
1219-
self.send_response(404)
1220-
self.end_headers()
1221-
return
1222-
params = urllib.parse.parse_qs(parsed.query)
1223-
OAuthCallbackHandler.code = params.get("code", [None])[0]
1224-
OAuthCallbackHandler.state = params.get("state", [None])[0]
1225-
OAuthCallbackHandler.error = params.get("error", [None])[0]
1226-
self.send_response(200)
1227-
self.send_header("Content-type", "text/html")
1228-
self.end_headers()
1229-
if OAuthCallbackHandler.code:
1230-
self.wfile.write(b"<html><body><h2>Sign-in successful!</h2>You may close this window.</body></html>")
1231-
elif OAuthCallbackHandler.error:
1232-
self.wfile.write(b"<html><body><h2>Sign-in failed.</h2></body></html>")
1233-
else:
1234-
self.wfile.write(b"<html><body><h2>Missing code.</h2></body></html>")
1235-
1236-
def log_message(self, format, *args):
1237-
# Silence HTTP logs
1238-
pass
1239-
1240-
# Find a free port
1241-
def get_free_port():
1242-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1243-
s.bind(("", 0))
1244-
return s.getsockname()[1]
1245-
1246-
port = get_free_port()
1247-
redirect_uri = f"http://localhost:{port}/callback"
1248-
# PKCE code_verifier and code_challenge
1249-
def random_string(length=64):
1250-
return ''.join(random.choices(string.ascii_letters + string.digits + "-._~", k=length))
1251-
code_verifier = random_string(64)
1252-
code_challenge = base64.urlsafe_b64encode(
1253-
hashlib.sha256(code_verifier.encode()).digest()
1254-
).rstrip(b'=').decode()
1255-
state = random_string(16)
1256-
1257-
# Compose auth URL
1258-
auth_url = (
1259-
f"{BASE_URL}codeflash/auth?"
1260-
f"response_type=code"
1261-
f"&client_id=cf_vscode_app"
1262-
f"&redirect_uri={urllib.parse.quote(redirect_uri)}"
1263-
f"&code_challenge={code_challenge}"
1264-
f"&code_challenge_method=sha256"
1265-
f"&state={state}"
1266-
)
12671203

1268-
# Start HTTP server in thread
1269-
handler_class = OAuthCallbackHandler
1270-
httpd = http.server.HTTPServer(("localhost", port), handler_class)
1271-
server_thread = threading.Thread(target=httpd.handle_request)
1272-
server_thread.daemon = True
1273-
server_thread.start()
1274-
click.echo(f"🌐 Opening browser to sign in to Codeflash…")
1275-
webbrowser.open(auth_url)
1276-
click.echo(f"If your browser did not open, visit:\n {auth_url}")
1277-
# Wait for callback (with timeout)
1278-
max_wait = 120 # seconds
1279-
waited = 0
1280-
while handler_class.code is None and handler_class.error is None and waited < max_wait:
1281-
time.sleep(0.5)
1282-
waited += 0.5
1283-
httpd.server_close()
1284-
if handler_class.error:
1285-
click.echo(f"❌ Sign-in failed: {handler_class.error}")
1286-
apologize_and_exit()
1287-
if not handler_class.code or not handler_class.state:
1288-
click.echo("❌ Did not receive code from sign-in. Please try again.")
1289-
apologize_and_exit()
1290-
if handler_class.state != state:
1291-
click.echo("❌ State mismatch in OAuth callback.")
1292-
apologize_and_exit()
1293-
code = handler_class.code
1294-
console.print(code)
1295-
# Exchange code for token
1296-
token_url = f"{BASE_URL}codeflash/auth/oauth/token"
1297-
data = {
1298-
"grant_type": "authorization_code",
1299-
"code": code,
1300-
"code_verifier": code_verifier,
1301-
"redirect_uri": redirect_uri,
1302-
"client_id": "cf_vscode_app"
1303-
}
1304-
try:
1305-
resp = requests.post(
1306-
token_url,
1307-
headers={"Content-Type": "application/json"},
1308-
data=json.dumps(data),
1309-
timeout=10,
1310-
)
1311-
resp.raise_for_status()
1312-
token_json = resp.json()
1313-
api_key = token_json.get("api_key") or token_json.get("access_token")
1314-
if not api_key:
1315-
click.echo("❌ Could not retrieve API key from response.")
1316-
apologize_and_exit()
1317-
result = save_api_key_to_rc(api_key)
1318-
if is_successful(result):
1319-
click.echo(result.unwrap())
1320-
click.echo("✅ Signed in successfully and API key saved!")
1321-
else:
1322-
click.echo(result.failure())
1323-
click.pause()
1324-
os.environ["CODEFLASH_API_KEY"] = api_key
1325-
ph("cli-new-api-key-entered")
1326-
return True
1327-
except Exception as e:
1328-
click.echo(f"❌ Failed to exchange code for API key: {e}")
1204+
# Perform OAuth sign-in
1205+
api_key = perform_oauth_signin()
1206+
1207+
if not api_key:
13291208
apologize_and_exit()
13301209

1210+
# Save API key
1211+
shell_rc_path = get_shell_rc_path()
1212+
if not shell_rc_path.exists() and os.name == "nt":
1213+
shell_rc_path.touch()
1214+
click.echo(f"✅ Created {shell_rc_path}")
1215+
1216+
result = save_api_key_to_rc(api_key)
1217+
if is_successful(result):
1218+
click.echo(result.unwrap())
1219+
click.echo("✅ Signed in successfully and API key saved!")
1220+
else:
1221+
click.echo(result.failure())
1222+
click.pause()
1223+
1224+
os.environ["CODEFLASH_API_KEY"] = api_key
1225+
ph("cli-oauth-signin-completed")
1226+
return True
13311227

13321228
def enter_api_key_and_save_to_rc() -> None:
13331229
browser_launched = False

0 commit comments

Comments
 (0)