-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi_server_https.py
More file actions
426 lines (357 loc) · 15.5 KB
/
Copy pathapi_server_https.py
File metadata and controls
426 lines (357 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
import os
import uuid
import joblib
import numpy as np
import uvicorn
import sys
import logging
import json
from datetime import datetime
from pathlib import Path
import ssl
from src.audio.feature_extraction import extract_features, AudioFeatureExtractor
from config.config import DEFAULT_CONFIG
# JSON 직렬화 문제 해결을 위한 클래스 정의
class NumpyEncoder(json.JSONEncoder):
"""NumPy 데이터 타입을 JSON으로 직렬화하기 위한 인코더"""
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
# 로깅 설정 - 더 상세한 포맷
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('feature_extraction.log', encoding='utf-8')
]
)
logger = logging.getLogger(__name__)
app = FastAPI(title="수박 당도 예측 서버", description="오디오 파일을 분석하여 수박의 당도를 예측합니다")
# HTTPS 강제 리다이렉트 미들웨어
@app.middleware("http")
async def force_https(request: Request, call_next):
"""HTTP 요청을 HTTPS로 리다이렉트"""
if request.url.scheme == "http":
url = request.url.replace(scheme="https")
return RedirectResponse(url=str(url), status_code=301)
return await call_next(request)
# CORS 설정
app.add_middleware(
CORSMiddleware,
allow_origins=["https://*"], # HTTPS만 허용
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 모델 로드
MODEL_PATH = os.path.join(DEFAULT_CONFIG.model_output_dir, "pickle", "random_forest_model.pkl")
try:
model_bundle = joblib.load(MODEL_PATH)
if isinstance(model_bundle, dict):
model = model_bundle["model"]
logger.info(f"모델 로드 성공: {MODEL_PATH} (딕셔너리에서 'model' 키 사용)")
else:
model = model_bundle
logger.info(f"모델 로드 성공: {MODEL_PATH} (직접 모델 객체 사용)")
# 모델 정보 로깅
if hasattr(model, 'n_features_in_'):
logger.info(f"모델 특성 수: {model.n_features_in_}")
if hasattr(model, 'classes_'):
logger.info(f"모델 클래스: {model.classes_}")
if hasattr(model, 'n_estimators'):
logger.info(f"Random Forest 트리 수: {model.n_estimators}")
except Exception as e:
logger.error(f"모델 로드 실패: {e}")
model = None
def log_audio_info(y, sr, file_path):
"""오디오 파일의 기본 정보를 로깅"""
duration = len(y) / sr
max_amplitude = np.max(np.abs(y))
rms_energy = np.sqrt(np.mean(y ** 2))
logger.info(f"📁 파일 정보: {os.path.basename(file_path)}")
logger.info(f" - 샘플링 레이트: {sr} Hz")
logger.info(f" - 길이: {duration:.2f}초 ({len(y)} 샘플)")
logger.info(f" - 최대 진폭: {max_amplitude:.6f}")
logger.info(f" - RMS 에너지: {rms_energy:.6f}")
logger.info(f" - 다이나믹 레인지: {20 * np.log10(max_amplitude / (rms_energy + 1e-8)):.2f} dB")
def log_feature_details(feature_vector):
"""추출된 특성의 상세 정보를 로깅"""
feature_array = feature_vector.to_array()
feature_names = feature_vector.feature_names
logger.info(f"🔍 특성 분석:")
logger.info(f" - 특성 개수: {len(feature_array)}")
logger.info(f" - 값 범위: [{np.min(feature_array):.6f}, {np.max(feature_array):.6f}]")
logger.info(f" - 평균값: {np.mean(feature_array):.6f}")
logger.info(f" - 표준편차: {np.std(feature_array):.6f}")
# MFCC 상세 정보
logger.info(f" - MFCC 계수 상세:")
for i, val in enumerate(feature_vector.mfcc):
logger.info(f" [{i + 1:2d}] {val:10.6f}")
# Chroma 상세 정보
logger.info(f" - Chroma 계수 상세:")
for i, val in enumerate(feature_vector.chroma):
logger.info(f" [{i + 1:2d}] {val:10.6f}")
def save_features_to_json(feature_vector, file_path):
"""특성을 JSON 파일로 저장"""
try:
feature_array = feature_vector.to_array()
feature_names = feature_vector.feature_names
feature_data = {
"timestamp": datetime.now().isoformat(),
"source_file": os.path.basename(file_path),
"feature_count": len(feature_array),
"features": {
name: float(value) for name, value in zip(feature_names, feature_array)
},
"statistics": {
"min": float(np.min(feature_array)),
"max": float(np.max(feature_array)),
"mean": float(np.mean(feature_array)),
"std": float(np.std(feature_array)),
"median": float(np.median(feature_array))
}
}
json_filename = f"features_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(json_filename, 'w', encoding='utf-8') as f:
json.dump(feature_data, f, indent=2, ensure_ascii=False)
logger.info(f"💾 특성 데이터 저장됨: {json_filename}")
return json_filename
except Exception as e:
logger.warning(f"특성 JSON 저장 실패: {e}")
return None
# 서버 상태 확인
@app.get("/health")
def health_check():
return {
"status": "healthy",
"message": "수박 당도 예측 서버가 실행 중입니다",
"model_loaded": model is not None,
"timestamp": datetime.now().isoformat()
}
# 지원되는 파일 형식 정보
@app.get("/supported-formats")
def get_supported_formats():
return {
"formats": [".wav", ".m4a", ".mp3"],
"description": "지원되는 오디오 파일 형식"
}
# 특성 정보 API (디버깅용)
@app.get("/feature-info")
def get_feature_info():
return {
"total_features": DEFAULT_CONFIG.n_mfcc + 5 + DEFAULT_CONFIG.n_chroma, # 13 + 5 + 12 = 30
"feature_groups": {
"mfcc": {"count": DEFAULT_CONFIG.n_mfcc, "description": "Mel-frequency cepstral coefficients"},
"mel_spectrogram": {"count": 2, "description": "멜 스펙트로그램의 통계 (평균, 표준편차)"},
"spectral": {"count": 2, "description": "스펙트럼 특징 (중심, 롤오프)"},
"zcr": {"count": 1, "description": "Zero crossing rate"},
"chroma": {"count": DEFAULT_CONFIG.n_chroma, "description": "Chroma features"}
},
"model_type": model.__class__.__name__ if model else "모델 로드 실패",
"classes": [int(c) for c in model.classes_] if hasattr(model, 'classes_') else ["알 수 없음"]
}
# 디버깅용 특성 분석 API
@app.post("/debug-features")
async def debug_features(file: UploadFile = File(...)):
"""특성 추출 디버깅 전용 API"""
temp_path = None
try:
# 파일 확장자 검사
ext = os.path.splitext(file.filename)[-1].lower()
if ext not in [".wav", ".m4a", ".mp3"]:
return JSONResponse(
status_code=400,
content={"error": "지원하지 않는 파일 형식입니다. .wav, .m4a, .mp3 파일만 업로드 가능합니다."}
)
# 임시 파일 저장
temp_path = f"debug_{uuid.uuid4()}{ext}"
with open(temp_path, "wb") as f:
f.write(await file.read())
logger.info(f"디버깅용 파일 저장: {file.filename} -> {temp_path}")
# 특성 추출
feature_vector = extract_features(temp_path)
if feature_vector is None:
return JSONResponse(
status_code=400,
content={"error": "오디오 파일에서 특성 추출에 실패했습니다."}
)
# 특성 상세 정보 로깅
log_feature_details(feature_vector)
# 특성 저장
json_filename = save_features_to_json(feature_vector, temp_path)
# 결과 반환
feature_array = feature_vector.to_array()
feature_names = feature_vector.feature_names
result = {
"success": True,
"filename": file.filename,
"feature_count": len(feature_array),
"features": {
name: float(value) for name, value in zip(feature_names, feature_array)
},
"statistics": {
"min": float(np.min(feature_array)),
"max": float(np.max(feature_array)),
"mean": float(np.mean(feature_array)),
"std": float(np.std(feature_array)),
"median": float(np.median(feature_array))
},
"quality_check": {
"nan_count": int(np.sum(np.isnan(feature_array))),
"inf_count": int(np.sum(np.isinf(feature_array))),
"is_valid": bool(np.all(np.isfinite(feature_array)))
}
}
if json_filename:
result["saved_to"] = json_filename
return JSONResponse(content=result)
except Exception as e:
logger.error(f"디버깅 중 오류 발생: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": f"디버깅 중 오류가 발생했습니다: {str(e)}"}
)
finally:
# 임시 파일 삭제
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
logger.info(f"임시 파일 삭제: {temp_path}")
except Exception as e:
logger.warning(f"임시 파일 삭제 실패: {e}")
# 예측 API
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""오디오 파일을 분석하여 수박 당도를 예측"""
if model is None:
return JSONResponse(
status_code=500,
content={"success": False, "error": "모델이 로드되지 않았습니다."}
)
temp_path = None
try:
# 파일 확장자 검사
ext = os.path.splitext(file.filename)[-1].lower()
if ext not in [".wav", ".m4a", ".mp3"]:
return JSONResponse(
status_code=400,
content={
"success": False,
"error": "지원하지 않는 파일 형식입니다. .wav, .m4a, .mp3 파일만 업로드 가능합니다."
}
)
# 임시 파일 저장
temp_path = f"temp_{uuid.uuid4()}{ext}"
with open(temp_path, "wb") as f:
f.write(await file.read())
logger.info(f"파일 업로드 완료: {file.filename} -> {temp_path}")
# 특성 추출
feature_vector = extract_features(temp_path)
if feature_vector is None:
return JSONResponse(
status_code=400,
content={"success": False, "error": "오디오 파일에서 특성 추출에 실패했습니다."}
)
feature_array = feature_vector.to_array().reshape(1, -1)
# 디버깅: 모델이 기대하는 feature 개수, 실제 feature vector 정보 출력
logger.info(f"[predict] model.n_features_in_: {getattr(model, 'n_features_in_', None)}")
logger.info(f"[predict] feature_array.shape: {feature_array.shape}")
logger.info(f"[predict] feature_array dtype: {feature_array.dtype}")
# 예측
prediction = model.predict(feature_array)[0]
probability = model.predict_proba(feature_array)[0] if hasattr(model, 'predict_proba') else None
# 예측 클래스 이름 (모델 클래스에서 가져오기)
if hasattr(model, 'classes_'):
predicted_class = int(model.classes_[prediction])
else:
predicted_class = f"class_{prediction}"
logger.info(f"🎯 예측 결과: {prediction} (클래스: {predicted_class})")
if probability is not None:
logger.info(f" - 확률 분포: {probability}")
logger.info(f" - 최대 확률: {max(probability):.4f}")
# 결과 반환
result = {
"success": True,
"filename": file.filename,
"prediction": int(prediction),
"predicted_class": predicted_class,
"confidence": float(max(probability)) if probability is not None else None
}
# 확률 분포 정보 추가 (클래스별)
if probability is not None and hasattr(model, 'classes_'):
result["probabilities"] = {
str(int(cls)): float(prob) for cls, prob in zip(model.classes_, probability)
}
# NumPy 인코더를 사용하여 JSON 직렬화
return JSONResponse(content=json.loads(json.dumps(result, cls=NumpyEncoder)))
except Exception as e:
logger.error(f"예측 중 오류 발생: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"success": False,
"error": f"예측 중 오류가 발생했습니다: {str(e)}"
}
)
finally:
# 임시 파일 삭제
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
logger.info(f"임시 파일 삭제: {temp_path}")
except Exception as e:
logger.warning(f"임시 파일 삭제 실패: {e}")
if __name__ == "__main__":
import socket
def get_local_ip():
"""로컬 네트워크 IP 주소를 반환"""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except:
return "127.0.0.1"
# SSL 인증서 경로 설정
SSL_KEYFILE = "/etc/letsencrypt/live/www.singyupark.cloud/privkey.pem" # 개인 키 파일 경로
SSL_CERTFILE = "/etc/letsencrypt/live/www.singyupark.cloud/cert.pem" # 인증서 파일 경로
# SSL 인증서 파일이 없으면 생성 스크립트 실행 안내
if not os.path.exists(SSL_KEYFILE) or not os.path.exists(SSL_CERTFILE):
print("⚠️ SSL 인증서가 없습니다. 다음 명령을 실행하여 생성하세요:")
print(" ./generate_ssl_cert.sh")
print()
print("또는 직접 생성:")
print(" openssl req -x509 -newkey rsa:4096 -keyout ssl/key.pem -out ssl/cert.pem -days 365 -nodes")
sys.exit(1)
# 실행 안내 정보 출력
local_ip = get_local_ip()
print(f"🍉 수박 당도 예측 서버 시작 (HTTPS)")
print(f" - 로컬: https://localhost:9001")
print(f" - 네트워크: https://{local_ip}:9001")
print(f" - 상태 확인: https://{local_ip}:9001/health")
print(f" - 특성 디버깅: https://{local_ip}:9001/debug-features")
print(f" - 예측 API: https://{local_ip}:9001/predict")
print()
print("⚠️ 자체 서명 인증서 사용 시 브라우저에서 보안 경고가 나타날 수 있습니다.")
# HTTPS로 서버 실행
uvicorn.run(
app,
host="0.0.0.0",
port=9001,
ssl_keyfile=SSL_KEYFILE,
ssl_certfile=SSL_CERTFILE,
ssl_version=ssl.PROTOCOL_TLS_SERVER,
ssl_cert_reqs=ssl.CERT_NONE,
ssl_ciphers="TLSv1.2:!aNULL:!eNULL:!EXPORT:!DES:!MD5:!PSK:!RC4"
)