-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
64 lines (48 loc) · 1.99 KB
/
main.py
File metadata and controls
64 lines (48 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging
from fastapi import FastAPI, Depends, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import ngram
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GeneratedResponse(BaseModel):
generatedText: str
success: bool
def load_model():
model = ngram.load_model("model/3gram_model_bytes_ascii.cdb")
# model = ngram.load_pickle_model("model/3gram_model.pkl")
yield model
# Shutdown the model if needed
model.close()
app = FastAPI()
@app.get("/rest/get-text")
def get_text(input_text: str, num_words: int = 50, model=Depends(load_model)) -> GeneratedResponse:
current_gram = ngram.get_gram(input_text, n=3, append=True)
generated_text = ngram.generate_text(current_gram, model, num_tokens=num_words)
success = True
if not generated_text:
generated_text = "[ERROR] I didn't understand that."
success = False
return GeneratedResponse(generatedText=generated_text, success=success)
@app.websocket("ws/get-text")
async def get_text_ws(websocket: WebSocket):
await websocket.accept()
model = ngram.load_model("model/3gram_model_bytes.cdb")
try:
while True:
data = await websocket.receive_text()
current_gram = ngram.get_gram(data, n=3, append=True)
logger.info("Current gram: %s", current_gram)
logger.info("model type: %s, length: %d", type(model), len(model))
for next_word in ngram.token_generator(current_gram, model):
if next_word is None:
await websocket.send_text("[ERROR] I didn't understand that.")
break
await websocket.send_text(next_word + " ")
else:
await websocket.send_text("[DONE]")
except WebSocketDisconnect:
logger.info("Client disconnected")
finally:
model.close()
app.mount("/", StaticFiles(directory="./client/dist/", html=True), name="client")