-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathweb_app.py
More file actions
235 lines (183 loc) · 6.85 KB
/
web_app.py
File metadata and controls
235 lines (183 loc) · 6.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
CodeSprite v2 Web 服务 — 框架无关 IR 架构版
安全加固:
- 安全响应头(CSP/HSTS/X-Frame-Options等)
- 请求速率限制
- 输入校验与长度限制
- 审计日志记录
API端点:
- POST /api/generate - 代码分析
- GET /api/info - 引擎信息
- GET /api/health - 健康检查
- POST /api/feedback - 用户反馈
- GET /api/learning-status - 学习状态
"""
import os
import sys
import time
import json
import logging
from datetime import datetime
from collections import defaultdict
from flask import Flask, request, jsonify, render_template, make_response
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from ir.config import ModelConfig
from ir.transformer import TransformerModel
from inference.engine import InferenceEngine
from src.tokenizer import SimpleTokenizer
from src.compliance import SecurityHeaders, RateLimiter, AuditLogger
from backends.pytorch import PyTorchBackend, init_model_weights
from src.device import resolve_device, print_device_info
# --- 初始化 ---
app = Flask(__name__)
# 安全配置: Flask secret_key(用于 session 签名等)
# 生产环境部署前请替换为随机生成的强密钥:
# python -c "import secrets; print(secrets.token_hex(32))"
app.secret_key = os.environ.get('FLASK_SECRET_KEY', 'codesprite-dev-key-change-in-production')
# 合规组件
security_headers = SecurityHeaders()
rate_limiter = RateLimiter(max_requests=20, window_seconds=60)
audit_logger = AuditLogger(log_dir="logs")
# 加载配置
def load_config():
import yaml
with open('config/config.yaml', 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
config_dict = load_config()
model_config = ModelConfig.from_yaml(config_dict)
# 加载模型
tokenizer = SimpleTokenizer(vocab_size=model_config.vocab_size)
model = TransformerModel(model_config)
# 设备选择:Web 服务默认 CPU(轻量级推理,不占用 GPU 显存)
# 如需 GPU 推理:设置环境变量 CODESPRITE_WEB_DEVICE=cuda
web_device = os.environ.get("CODESPRITE_WEB_DEVICE", "cpu")
cpu_threads = config_dict['system'].get('cpu_threads', None)
resolved_device = resolve_device(web_device, cpu_threads=cpu_threads)
print_device_info(resolved_device)
# 创建后端并初始化权重(即使没有预训练模型)
backend = PyTorchBackend(device=resolved_device)
init_model_weights(model, backend)
print("Model weights initialized (random initialization)")
# 创建推理引擎
engine = InferenceEngine(
model,
backend=backend,
checkpoint_path=None, # 不加载预训练
tokenizer=tokenizer,
device=resolved_device
)
engine.temperature = 0.8
engine.top_k = 50
engine.top_p = 0.9
print(f"\nCodeSprite v2 Web App ready")
print(f" Backend: {engine.backend.name} ({resolved_device})")
print(f" Parameters: {model.get_param_count():,}")
print(f" Note: Using randomly initialized weights (not pre-trained)\n")
# --- 安全装饰器 ---
@app.before_request
def before_request():
"""请求前处理:速率限制"""
client_ip = request.remote_addr
if not rate_limiter.check(client_ip):
audit_logger.log(client_ip, "RATE_LIMITED", request.path)
response = jsonify({"error": "请求过于频繁,请稍后再试"})
response.status_code = 429
return response
@app.after_request
def after_request(response):
"""响应后处理:添加安全头"""
return security_headers.apply(response)
# --- API 端点 ---
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/generate', methods=['POST'])
def generate():
"""代码分析 API"""
data = request.get_json(silent=True) or {}
prompt = data.get('prompt', '').strip()
max_tokens = min(data.get('max_tokens', 100), 500)
temperature = data.get('temperature', engine.temperature)
top_k = data.get('top_k', engine.top_k)
top_p = data.get('top_p', engine.top_p)
if not prompt:
return jsonify({"error": "prompt 不能为空"}), 400
if len(prompt) > 2000:
return jsonify({"error": "prompt 长度不能超过 2000 字符"}), 400
start_time = time.time()
client_ip = request.remote_addr
try:
# 临时设置采样参数
orig_temp = engine.temperature
orig_topk = engine.top_k
orig_topp = engine.top_p
engine.temperature = temperature
engine.top_k = top_k
engine.top_p = top_p
generated = engine.generate(prompt, max_new_tokens=max_tokens)
# 恢复
engine.temperature = orig_temp
engine.top_k = orig_topk
engine.top_p = orig_topp
elapsed = time.time() - start_time
audit_logger.log(client_ip, "GENERATE", f"len={len(prompt)}, tokens={max_tokens}, time={elapsed:.2f}s")
return jsonify({
"text": generated,
"tokens_generated": len(tokenizer.encode(generated)) - len(tokenizer.encode(prompt)),
"elapsed_seconds": round(elapsed, 2),
"backend": engine.backend.name,
})
except Exception as e:
logging.error(f"Generation error: {e}")
return jsonify({"error": f"生成失败: {str(e)}"}), 500
@app.route('/api/info', methods=['GET'])
def model_info():
"""引擎信息"""
return jsonify(engine.info())
@app.route('/api/health', methods=['GET'])
def health():
"""健康检查"""
return jsonify({
"status": "ok",
"timestamp": datetime.now().isoformat(),
"backend": engine.backend.name,
"parameters": model.get_param_count(),
})
@app.route('/api/feedback', methods=['POST'])
def feedback():
"""用户反馈"""
data = request.get_json(silent=True) or {}
prompt = data.get('prompt', '')
response_text = data.get('response', '')
rating = data.get('rating', '') # 'up' or 'down'
if not prompt or not rating:
return jsonify({"error": "缺少必要字段"}), 400
# 记录反馈
audit_logger.log(
request.remote_addr, "FEEDBACK",
f"rating={rating}, prompt_len={len(prompt)}, response_len={len(response_text)}"
)
return jsonify({"status": "ok", "message": "反馈已记录"})
@app.route('/api/learning-status', methods=['GET'])
def learning_status():
"""学习状态"""
return jsonify({
"auto_learning_enabled": False,
"total_feedback": 0,
"message": "自动学习功能在 v2 架构中暂时禁用"
})
@app.errorhandler(404)
def not_found(e):
return jsonify({"error": "Not found"}), 404
@app.errorhandler(500)
def server_error(e):
return jsonify({"error": "Internal server error"}), 500
# --- 启动 ---
if __name__ == '__main__':
print("\n" + "="*50)
print(" CodeSprite v2 Web Server")
print(" Architecture: Framework-Agnostic IR")
print(f" Backend: {engine.backend.name}")
print(" http://localhost:5000")
print("="*50 + "\n")
app.run(host='0.0.0.0', port=5000, debug=False)