|
17 | 17 | "~/.oci/config", profile_name=profile |
18 | 18 | ) # replace with the location of your oci config file |
19 | 19 |
|
20 | | -model = os.environ.get("MODEL", "meta-llama/Llama-2-7b-chat-hf") |
| 20 | +model = os.environ.get("MODEL", "mistralai/Mistral-7B-Instruct-v0.1") |
21 | 21 | template_file = app_config["models"][model].get("template") |
22 | 22 | prompt_template = string.Template( |
23 | 23 | open(template_file).read() if template_file else "$prompt" |
@@ -94,11 +94,35 @@ def query(prompt, max_tokens=200, **kwargs): |
94 | 94 | }, |
95 | 95 | } |
96 | 96 |
|
| 97 | + if os.environ.get("VLLM"): |
| 98 | + if os.environ.get("API_SPEC")=="openai": |
| 99 | + temperature = kwargs.get('temperature',0.7) |
| 100 | + top_p = kwargs.get('top_p',0.8) |
| 101 | + body = { |
| 102 | + "prompt": prompt_template.substitute({"prompt": prompt}), |
| 103 | + "max_tokens": max_tokens, |
| 104 | + "model": model, |
| 105 | + "temperature": temperature, |
| 106 | + "top_p":top_p |
| 107 | + } |
| 108 | + else: |
| 109 | + body["parameters"].pop("watermark", None) |
| 110 | + body["parameters"].pop("seed", None) |
| 111 | + body["parameters"].pop("return_full_text", None) |
| 112 | + |
97 | 113 | # create auth using one of the oci signers |
98 | 114 | auth = create_default_signer() |
99 | 115 | data = requests.post(endpoint, json=body, auth=auth, headers=headers).json() |
100 | 116 | # return model generated response, or any error as a string |
101 | | - return str(data.get("generated_text", data)) |
| 117 | + if os.environ.get("VLLM"): |
| 118 | + if os.environ.get("API_SPEC")=="openai": |
| 119 | + response_text_key = "choices" |
| 120 | + response = data.get(response_text_key, data)[0] |
| 121 | + response = response.get("text", data) |
| 122 | + else: |
| 123 | + response_text_key = 'generated_text' |
| 124 | + response = data.get(response_text_key, data) |
| 125 | + return str(response) |
102 | 126 |
|
103 | 127 |
|
104 | 128 | if __name__ == "__main__": |
|
0 commit comments