-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualization.py
More file actions
50 lines (40 loc) · 1.57 KB
/
visualization.py
File metadata and controls
50 lines (40 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from google.cloud import aiplatform
from google.auth import default
# Initialize Vertex AI
def initialize_vertex_ai():
# Authenticate using the default service account or a specific key file
credentials, _ = default()
aiplatform.init(credentials=credentials)
# Function to call the Gemini model
def call_gemini_model(prompt: str):
# Define the model name or ID for the Gemini model
model_id = 'your-gemini-model-id' # Replace with your Gemini model ID
project_id = 'your-project-id' # Replace with your Google Cloud project ID
region = 'your-region' # Replace with your Google Cloud region (e.g., 'us-central1')
# Create a Vertex AI client
client = aiplatform.gapic.PredictionServiceClient()
# Define the endpoint
endpoint = f"projects/{project_id}/locations/{region}/endpoints/{model_id}"
# Create the prediction request
instance = {"content": prompt}
instances = [instance]
parameters = {}
try:
# Call the model
response = client.predict(
endpoint=endpoint,
instances=instances,
parameters=parameters
)
# Extract and print the result
predictions = response.predictions
print("Predictions:", predictions)
except Exception as e:
print("Error during prediction:", e)
if __name__ == "__main__":
# Initialize Vertex AI
initialize_vertex_ai()
# Define a user prompt
user_prompt = "What is the capital of France?"
# Call the Gemini model with the user prompt
call_gemini_model(user_prompt)