11import uuid
22from datetime import datetime
3- from azure . cosmos . aio import CosmosClient
3+
44from azure .cosmos import exceptions
5-
5+ from azure .cosmos .aio import CosmosClient
6+
7+
68class CosmosConversationClient ():
7-
9+
810 def __init__ (self , cosmosdb_endpoint : str , credential : any , database_name : str , container_name : str , enable_message_feedback : bool = False ):
911 self .cosmosdb_endpoint = cosmosdb_endpoint
1012 self .credential = credential
1113 self .database_name = database_name
1214 self .container_name = container_name
1315 self .enable_message_feedback = enable_message_feedback
1416 try :
15- self .cosmosdb_client = CosmosClient (self .cosmosdb_endpoint , credential = credential )
17+ self .cosmosdb_client = CosmosClient (
18+ self .cosmosdb_endpoint , credential = credential )
1619 except exceptions .CosmosHttpResponseError as e :
1720 if e .status_code == 401 :
1821 raise ValueError ("Invalid credentials" ) from e
1922 else :
2023 raise ValueError ("Invalid CosmosDB endpoint" ) from e
2124
2225 try :
23- self .database_client = self .cosmosdb_client .get_database_client (database_name )
26+ self .database_client = self .cosmosdb_client .get_database_client (
27+ database_name )
2428 except exceptions .CosmosResourceNotFoundError :
25- raise ValueError ("Invalid CosmosDB database name" )
26-
29+ raise ValueError ("Invalid CosmosDB database name" )
30+
2731 try :
28- self .container_client = self .database_client .get_container_client (container_name )
32+ self .container_client = self .database_client .get_container_client (
33+ container_name )
2934 except exceptions .CosmosResourceNotFoundError :
30- raise ValueError ("Invalid CosmosDB container name" )
31-
35+ raise ValueError ("Invalid CosmosDB container name" )
3236
3337 async def ensure (self ):
3438 if not self .cosmosdb_client or not self .database_client or not self .container_client :
@@ -37,30 +41,30 @@ async def ensure(self):
3741 database_info = await self .database_client .read ()
3842 except :
3943 return False , f"CosmosDB database { self .database_name } on account { self .cosmosdb_endpoint } not found"
40-
44+
4145 try :
4246 container_info = await self .container_client .read ()
4347 except :
4448 return False , f"CosmosDB container { self .container_name } not found"
45-
49+
4650 return True , "CosmosDB client initialized successfully"
4751
48- async def create_conversation (self , user_id , title = '' ):
52+ async def create_conversation (self , user_id , title = '' ):
4953 conversation = {
50- 'id' : str (uuid .uuid4 ()),
54+ 'id' : str (uuid .uuid4 ()),
5155 'type' : 'conversation' ,
52- 'createdAt' : datetime .utcnow ().isoformat (),
53- 'updatedAt' : datetime .utcnow ().isoformat (),
56+ 'createdAt' : datetime .utcnow ().isoformat (),
57+ 'updatedAt' : datetime .utcnow ().isoformat (),
5458 'userId' : user_id ,
5559 'title' : title
5660 }
57- ## TODO: add some error handling based on the output of the upsert_item call
58- resp = await self .container_client .upsert_item (conversation )
61+ # TODO: add some error handling based on the output of the upsert_item call
62+ resp = await self .container_client .upsert_item (conversation )
5963 if resp :
6064 return resp
6165 else :
6266 return False
63-
67+
6468 async def upsert_conversation (self , conversation ):
6569 resp = await self .container_client .upsert_item (conversation )
6670 if resp :
@@ -69,16 +73,15 @@ async def upsert_conversation(self, conversation):
6973 return False
7074
7175 async def delete_conversation (self , user_id , conversation_id ):
72- conversation = await self .container_client .read_item (item = conversation_id , partition_key = user_id )
76+ conversation = await self .container_client .read_item (item = conversation_id , partition_key = user_id )
7377 if conversation :
7478 resp = await self .container_client .delete_item (item = conversation_id , partition_key = user_id )
7579 return resp
7680 else :
7781 return True
7882
79-
8083 async def delete_messages (self , conversation_id , user_id ):
81- ## get a list of all the messages in the conversation
84+ # get a list of all the messages in the conversation
8285 messages = await self .get_messages (user_id , conversation_id )
8386 response_list = []
8487 if messages :
@@ -87,8 +90,7 @@ async def delete_messages(self, conversation_id, user_id):
8790 response_list .append (resp )
8891 return response_list
8992
90-
91- async def get_conversations (self , user_id , limit , sort_order = 'DESC' , offset = 0 ):
93+ async def get_conversations (self , user_id , limit , sort_order = 'DESC' , offset = 0 ):
9294 parameters = [
9395 {
9496 'name' : '@userId' ,
@@ -97,12 +99,12 @@ async def get_conversations(self, user_id, limit, sort_order = 'DESC', offset =
9799 ]
98100 query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt { sort_order } "
99101 if limit is not None :
100- query += f" offset { offset } limit { limit } "
101-
102+ query += f" offset { offset } limit { limit } "
103+
102104 conversations = []
103105 async for item in self .container_client .query_items (query = query , parameters = parameters ):
104106 conversations .append (item )
105-
107+
106108 return conversations
107109
108110 async def get_conversation (self , user_id , conversation_id ):
@@ -121,30 +123,30 @@ async def get_conversation(self, user_id, conversation_id):
121123 async for item in self .container_client .query_items (query = query , parameters = parameters ):
122124 conversations .append (item )
123125
124- ## if no conversations are found, return None
126+ # if no conversations are found, return None
125127 if len (conversations ) == 0 :
126128 return None
127129 else :
128130 return conversations [0 ]
129-
131+
130132 async def create_message (self , uuid , conversation_id , user_id , input_message : dict ):
131133 message = {
132134 'id' : uuid ,
133135 'type' : 'message' ,
134- 'userId' : user_id ,
136+ 'userId' : user_id ,
135137 'createdAt' : datetime .utcnow ().isoformat (),
136138 'updatedAt' : datetime .utcnow ().isoformat (),
137- 'conversationId' : conversation_id ,
139+ 'conversationId' : conversation_id ,
138140 'role' : input_message ['role' ],
139141 'content' : input_message ['content' ]
140142 }
141143
142144 if self .enable_message_feedback :
143145 message ['feedback' ] = ''
144-
145- resp = await self .container_client .upsert_item (message )
146+
147+ resp = await self .container_client .upsert_item (message )
146148 if resp :
147- ## update the parent conversations's updatedAt field with the current message's createdAt datetime value
149+ # update the parent conversations's updatedAt field with the current message's createdAt datetime value
148150 conversation = await self .get_conversation (user_id , conversation_id )
149151 if not conversation :
150152 return "Conversation not found"
@@ -153,7 +155,7 @@ async def create_message(self, uuid, conversation_id, user_id, input_message: di
153155 return resp
154156 else :
155157 return False
156-
158+
157159 async def update_message_feedback (self , user_id , message_id , feedback ):
158160 message = await self .container_client .read_item (item = message_id , partition_key = user_id )
159161 if message :
@@ -180,4 +182,3 @@ async def get_messages(self, user_id, conversation_id):
180182 messages .append (item )
181183
182184 return messages
183-
0 commit comments