11import uuid
22from datetime import datetime
3- from azure . cosmos . aio import CosmosClient
3+
44from azure .cosmos import exceptions
5-
6- class CosmosConversationClient ():
7-
8- def __init__ (self , cosmosdb_endpoint : str , credential : any , database_name : str , container_name : str , enable_message_feedback : bool = False ):
5+ from azure .cosmos .aio import CosmosClient
6+
7+
8+ class CosmosConversationClient :
9+ def __init__ (
10+ self ,
11+ cosmosdb_endpoint : str ,
12+ credential : any ,
13+ database_name : str ,
14+ container_name : str ,
15+ enable_message_feedback : bool = False ,
16+ ):
917 self .cosmosdb_endpoint = cosmosdb_endpoint
1018 self .credential = credential
1119 self .database_name = database_name
1220 self .container_name = container_name
1321 self .enable_message_feedback = enable_message_feedback
1422 try :
15- self .cosmosdb_client = CosmosClient (self .cosmosdb_endpoint , credential = credential )
23+ self .cosmosdb_client = CosmosClient (
24+ self .cosmosdb_endpoint , credential = credential
25+ )
1626 except exceptions .CosmosHttpResponseError as e :
1727 if e .status_code == 401 :
1828 raise ValueError ("Invalid credentials" ) from e
1929 else :
2030 raise ValueError ("Invalid CosmosDB endpoint" ) from e
2131
2232 try :
23- self .database_client = self .cosmosdb_client .get_database_client (database_name )
33+ self .database_client = self .cosmosdb_client .get_database_client (
34+ database_name
35+ )
2436 except exceptions .CosmosResourceNotFoundError :
25- raise ValueError ("Invalid CosmosDB database name" )
26-
37+ raise ValueError ("Invalid CosmosDB database name" )
38+
2739 try :
28- self .container_client = self .database_client .get_container_client (container_name )
40+ self .container_client = self .database_client .get_container_client (
41+ container_name
42+ )
2943 except exceptions .CosmosResourceNotFoundError :
30- raise ValueError ("Invalid CosmosDB container name" )
31-
44+ raise ValueError ("Invalid CosmosDB container name" )
3245
3346 async def ensure (self ):
34- if not self .cosmosdb_client or not self .database_client or not self .container_client :
47+ if (
48+ not self .cosmosdb_client
49+ or not self .database_client
50+ or not self .container_client
51+ ):
3552 return False , "CosmosDB client not initialized correctly"
3653 try :
3754 database_info = await self .database_client .read ()
38- except :
39- return False , f"CosmosDB database { self .database_name } on account { self .cosmosdb_endpoint } not found"
40-
55+ except Exception :
56+ return (
57+ False ,
58+ f"CosmosDB database { self .database_name } on account { self .cosmosdb_endpoint } not found" ,
59+ )
60+
4161 try :
4262 container_info = await self .container_client .read ()
43- except :
63+ except Exception :
4464 return False , f"CosmosDB container { self .container_name } not found"
45-
65+
4666 return True , "CosmosDB client initialized successfully"
4767
48- async def create_conversation (self , user_id , title = '' ):
68+ async def create_conversation (self , user_id , title = "" ):
4969 conversation = {
50- 'id' : str (uuid .uuid4 ()),
51- ' type' : ' conversation' ,
52- ' createdAt' : datetime .utcnow ().isoformat (),
53- ' updatedAt' : datetime .utcnow ().isoformat (),
54- ' userId' : user_id ,
55- ' title' : title
70+ "id" : str (uuid .uuid4 ()),
71+ " type" : " conversation" ,
72+ " createdAt" : datetime .utcnow ().isoformat (),
73+ " updatedAt" : datetime .utcnow ().isoformat (),
74+ " userId" : user_id ,
75+ " title" : title ,
5676 }
57- ## TODO: add some error handling based on the output of the upsert_item call
58- resp = await self .container_client .upsert_item (conversation )
77+ # TODO: add some error handling based on the output of the upsert_item call
78+ resp = await self .container_client .upsert_item (conversation )
5979 if resp :
6080 return resp
6181 else :
6282 return False
63-
83+
6484 async def upsert_conversation (self , conversation ):
6585 resp = await self .container_client .upsert_item (conversation )
6686 if resp :
@@ -69,115 +89,109 @@ async def upsert_conversation(self, conversation):
6989 return False
7090
7191 async def delete_conversation (self , user_id , conversation_id ):
72- conversation = await self .container_client .read_item (item = conversation_id , partition_key = user_id )
92+ conversation = await self .container_client .read_item (
93+ item = conversation_id , partition_key = user_id
94+ )
7395 if conversation :
74- resp = await self .container_client .delete_item (item = conversation_id , partition_key = user_id )
96+ resp = await self .container_client .delete_item (
97+ item = conversation_id , partition_key = user_id
98+ )
7599 return resp
76100 else :
77101 return True
78102
79-
80103 async def delete_messages (self , conversation_id , user_id ):
81- ## get a list of all the messages in the conversation
104+ # get a list of all the messages in the conversation
82105 messages = await self .get_messages (user_id , conversation_id )
83106 response_list = []
84107 if messages :
85108 for message in messages :
86- resp = await self .container_client .delete_item (item = message ['id' ], partition_key = user_id )
109+ resp = await self .container_client .delete_item (
110+ item = message ["id" ], partition_key = user_id
111+ )
87112 response_list .append (resp )
88113 return response_list
89114
90-
91- async def get_conversations (self , user_id , limit , sort_order = 'DESC' , offset = 0 ):
92- parameters = [
93- {
94- 'name' : '@userId' ,
95- 'value' : user_id
96- }
97- ]
115+ async def get_conversations (self , user_id , limit , sort_order = "DESC" , offset = 0 ):
116+ parameters = [{"name" : "@userId" , "value" : user_id }]
98117 query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt { sort_order } "
99118 if limit is not None :
100- query += f" offset { offset } limit { limit } "
101-
119+ query += f" offset { offset } limit { limit } "
120+
102121 conversations = []
103- async for item in self .container_client .query_items (query = query , parameters = parameters ):
122+ async for item in self .container_client .query_items (
123+ query = query , parameters = parameters
124+ ):
104125 conversations .append (item )
105-
126+
106127 return conversations
107128
108129 async def get_conversation (self , user_id , conversation_id ):
109130 parameters = [
110- {
111- 'name' : '@conversationId' ,
112- 'value' : conversation_id
113- },
114- {
115- 'name' : '@userId' ,
116- 'value' : user_id
117- }
131+ {"name" : "@conversationId" , "value" : conversation_id },
132+ {"name" : "@userId" , "value" : user_id },
118133 ]
119- query = f "SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
134+ query = "SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
120135 conversations = []
121- async for item in self .container_client .query_items (query = query , parameters = parameters ):
136+ async for item in self .container_client .query_items (
137+ query = query , parameters = parameters
138+ ):
122139 conversations .append (item )
123140
124- ## if no conversations are found, return None
141+ # if no conversations are found, return None
125142 if len (conversations ) == 0 :
126143 return None
127144 else :
128145 return conversations [0 ]
129-
146+
130147 async def create_message (self , uuid , conversation_id , user_id , input_message : dict ):
131148 message = {
132- 'id' : uuid ,
133- ' type' : ' message' ,
134- ' userId' : user_id ,
135- ' createdAt' : datetime .utcnow ().isoformat (),
136- ' updatedAt' : datetime .utcnow ().isoformat (),
137- ' conversationId' : conversation_id ,
138- ' role' : input_message [' role' ],
139- ' content' : input_message [' content' ]
149+ "id" : uuid ,
150+ " type" : " message" ,
151+ " userId" : user_id ,
152+ " createdAt" : datetime .utcnow ().isoformat (),
153+ " updatedAt" : datetime .utcnow ().isoformat (),
154+ " conversationId" : conversation_id ,
155+ " role" : input_message [" role" ],
156+ " content" : input_message [" content" ],
140157 }
141158
142159 if self .enable_message_feedback :
143- message [' feedback' ] = ''
144-
145- resp = await self .container_client .upsert_item (message )
160+ message [" feedback" ] = ""
161+
162+ resp = await self .container_client .upsert_item (message )
146163 if resp :
147- ## update the parent conversations's updatedAt field with the current message's createdAt datetime value
164+ # update the parent conversations's updatedAt field with the current message's createdAt datetime value
148165 conversation = await self .get_conversation (user_id , conversation_id )
149166 if not conversation :
150167 return "Conversation not found"
151- conversation [' updatedAt' ] = message [' createdAt' ]
168+ conversation [" updatedAt" ] = message [" createdAt" ]
152169 await self .upsert_conversation (conversation )
153170 return resp
154171 else :
155172 return False
156-
173+
157174 async def update_message_feedback (self , user_id , message_id , feedback ):
158- message = await self .container_client .read_item (item = message_id , partition_key = user_id )
175+ message = await self .container_client .read_item (
176+ item = message_id , partition_key = user_id
177+ )
159178 if message :
160- message [' feedback' ] = feedback
179+ message [" feedback" ] = feedback
161180 resp = await self .container_client .upsert_item (message )
162181 return resp
163182 else :
164183 return False
165184
166185 async def get_messages (self , user_id , conversation_id ):
167186 parameters = [
168- {
169- 'name' : '@conversationId' ,
170- 'value' : conversation_id
171- },
172- {
173- 'name' : '@userId' ,
174- 'value' : user_id
175- }
187+ {"name" : "@conversationId" , "value" : conversation_id },
188+ {"name" : "@userId" , "value" : user_id },
176189 ]
177- query = f "SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
190+ query = "SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
178191 messages = []
179- async for item in self .container_client .query_items (query = query , parameters = parameters ):
192+ async for item in self .container_client .query_items (
193+ query = query , parameters = parameters
194+ ):
180195 messages .append (item )
181196
182197 return messages
183-
0 commit comments