66package org .springframework .ai .openai .samples .helloworld ;
77
88import org .springframework .ai .chat .client .ChatClient ;
9- import org .springframework .ai .chat .model . ChatResponse ;
9+ import org .springframework .ai .chat .client . ChatClient . ChatClientRequestSpec ;
1010import org .springframework .ai .chat .prompt .Prompt ;
1111import org .springframework .ai .chat .prompt .PromptTemplate ;
1212import org .springframework .ai .document .Document ;
1313import org .springframework .ai .embedding .EmbeddingModel ;
14+ //import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
1415import org .springframework .ai .reader .ExtractedTextFormatter ;
1516import org .springframework .ai .reader .pdf .PagePdfDocumentReader ;
1617import org .springframework .ai .reader .pdf .config .PdfDocumentReaderConfig ;
2526import org .springframework .web .bind .annotation .RequestBody ;
2627import org .springframework .web .bind .annotation .RequestParam ;
2728import org .springframework .web .bind .annotation .RestController ;
29+ import org .springframework .web .servlet .mvc .method .annotation .ResponseBodyEmitter ;
30+
31+ import com .fasterxml .jackson .databind .ObjectMapper ;
32+
2833import org .springframework .ai .vectorstore .oracle .OracleVectorStore ;
2934
3035import jakarta .annotation .PostConstruct ;
3136
3237import org .springframework .core .io .Resource ;
38+ import org .springframework .http .MediaType ;
3339import org .springframework .jdbc .core .JdbcTemplate ;
3440
3541import java .io .IOException ;
3844import java .util .ArrayList ;
3945import java .util .Map ;
4046import java .util .HashMap ;
47+ import java .security .SecureRandom ;
48+ import java .time .Instant ;
49+
4150
4251import java .util .Iterator ;
4352import org .slf4j .Logger ;
4453import org .slf4j .LoggerFactory ;
4554
55+
56+ import org .springframework .model .*;
57+
4658@ RestController
4759class AIController {
4860
61+ @ Value ("${spring.ai.openai.chat.options.model}" )
62+ private String modelOpenAI ;
63+
64+ @ Value ("${spring.ai.ollama.chat.options.model}" )
65+ private String modelOllamaAI ;
66+
4967 @ Autowired
5068 private final OracleVectorStore vectorStore ;
5169
@@ -71,7 +89,8 @@ class AIController {
7189 private JdbcTemplate jdbcTemplate ;
7290
7391 private static final Logger logger = LoggerFactory .getLogger (AIController .class );
74-
92+ private static final int SLEEP = 50 ; // Wait in streaming between chunks
93+ private static final int STREAM_SIZE = 5 ; // chars in each chunk
7594 AIController (ChatClient chatClient , EmbeddingModel embeddingModel , OracleVectorStore vectorStore ) {
7695
7796 this .chatClient = chatClient ;
@@ -169,14 +188,16 @@ public Prompt promptEngineering(String message, String contextInstr) {
169188 INSTRUCTIONS:""" ;
170189
171190 String default_Instr = """
172- Answer the users question using the DOCUMENTS text above.
191+ Answer the users question using the DOCUMENTS text above.
173192 Keep your answer ground in the facts of the DOCUMENTS.
174193 If the DOCUMENTS doesn’t contain the facts to answer the QUESTION, return:
175194 I'm sorry but I haven't enough information to answer.
176195 """ ;
177196
178- //This template doesn't work with agent pattern, but only via RAG
179- //The contextInstr coming from AI Optimizer can't be used here: default only
197+ //This template doesn't work with re-phrasing/grading pattern, but only via RAG
198+ //The contextInstr coming from Oracle ai optimizer and toolkit can't be used here: default only
199+ //Modifiy it to include re-phrasing/grading if you wish.
200+
180201 template = template + "\n " + default_Instr ;
181202
182203 List <Document > similarDocuments = this .vectorStore .similaritySearch (
@@ -208,25 +229,70 @@ StringBuilder createContext(List<Document> similarDocuments) {
208229 return context ;
209230 }
210231
211- @ PostMapping ("/chat/completions" )
212- Map <String , Object > completionRag (@ RequestBody Map <String , String > requestBody ) {
213-
214- String message = requestBody .getOrDefault ("message" , "Tell me a joke" );
215- Prompt prompt = promptEngineering (message , contextInstr );
216- logger .info (prompt .getContents ());
217- try {
218- String content = chatClient .prompt (prompt ).call ().content ();
219- Map <String , Object > messageMap = Map .of ("content" , content );
220- Map <String , Object > choicesMap = Map .of ("message" , messageMap );
221- List <Map <String , Object >> choicesList = List .of (choicesMap );
222232
223- return Map .of ("choices" , choicesList );
233+ @ PostMapping (value = "/chat/completions" , produces = MediaType .TEXT_EVENT_STREAM_VALUE )
234+ public ResponseBodyEmitter streamCompletions (@ RequestBody ChatRequest request ) {
235+ ResponseBodyEmitter bodyEmitter = new ResponseBodyEmitter ();
236+ String userMessageContent ;
224237
225- } catch (Exception e ) {
226- logger .error ("Error while fetching completion" , e );
227- return Map .of ("error" , "Failed to fetch completion" );
238+ for (Map <String , String > message : request .getMessages ()) {
239+ if ("user" .equals (message .get ("role" ))) {
240+
241+ String content = message .get ("content" );
242+ if (content != null && !content .trim ().isEmpty ()) {
243+ userMessageContent = content ;
244+ logger .info ("user message: " +userMessageContent );
245+ Prompt prompt = promptEngineering (userMessageContent , contextInstr );
246+ logger .info ("prompt message: " +prompt .getContents ());
247+ String contentResponse = chatClient .prompt (prompt ).call ().content ();
248+ logger .info ("-------------------------------------------------------" );
249+ logger .info ("- RAG RETURN -" );
250+ logger .info ("-------------------------------------------------------" );
251+ logger .info (contentResponse );
252+ new Thread (() -> {
253+ try {
254+ ObjectMapper mapper = new ObjectMapper ();
255+
256+ if (request .isStream ()) {
257+ logger .info ("Request is a Stream" );
258+ List <String > chunks = chunkString (contentResponse );
259+ for (String token : chunks ) {
260+
261+ ChatMessage messageAnswer = new ChatMessage ("assistant" , token );
262+ ChatChoice choice = new ChatChoice (messageAnswer );
263+ ChatStreamResponse chunk = new ChatStreamResponse ("chat.completion.chunk" , new ChatChoice []{choice });
264+
265+ bodyEmitter .send ("data: " + mapper .writeValueAsString (chunk ) + "\n \n " );
266+ Thread .sleep (SLEEP );
267+ }
268+
269+ bodyEmitter .send ("data: [DONE]\n \n " );
270+ } else {
271+ logger .info ("Request isn't a Stream" );
272+ String id ="chatcmpl-" +generateRandomToken (28 );
273+ String object ="chat.completion" ;
274+ String created =String .valueOf (Instant .now ().getEpochSecond ());
275+ String model =getModel ();
276+ ChatMessage messageAnswer = new ChatMessage ("assistant" , contentResponse );
277+ List <ChatChoice > choices = List .of (new ChatChoice (messageAnswer ));
278+ bodyEmitter .send (new ChatResponse (id , object ,created , model , choices ));
279+ }
280+ bodyEmitter .complete ();
281+ } catch (Exception e ) {
282+ bodyEmitter .completeWithError (e );
283+ }
284+ }).start ();
285+
286+ return bodyEmitter ;
287+
288+ }
289+ break ;
228290 }
229291 }
292+
293+
294+ return bodyEmitter ;
295+ }
230296
231297 @ GetMapping ("/service/search" )
232298 List <Map <String , Object >> search (@ RequestParam (value = "message" , defaultValue = "Tell me a joke" ) String query ,
@@ -247,4 +313,77 @@ List<Map<String, Object>> search(@RequestParam(value = "message", defaultValue =
247313 ;
248314 return resultList ;
249315 }
316+
317+ @ GetMapping ("/models" )
318+ Map <String , Object > models (@ RequestBody (required = false ) Map <String , String > requestBody ) {
319+ String modelId = "custom" ;
320+ logger .info ("models request" );
321+ if (!"" .equals (modelOpenAI )) {
322+ modelId = modelOpenAI ;
323+ } else if (!"" .equals (modelOllamaAI )) {
324+ modelId = modelOllamaAI ;
325+ }
326+ logger .info ("model" );
327+
328+
329+ logger .info (chatClient .prompt ().toString ());
330+ try {
331+ Map <String , Object > model = new HashMap <>();
332+ model .put ("id" , modelId );
333+ model .put ("object" , "model" );
334+ model .put ("created" , 0000000000L );
335+ model .put ("owned_by" , "no-info" );
336+
337+ List <Map <String , Object >> dataList = new ArrayList <>();
338+ dataList .add (model );
339+
340+ Map <String , Object > response = new HashMap <>();
341+ response .put ("object" , "list" );
342+ response .put ("data" , dataList );
343+
344+ return response ;
345+
346+ } catch (Exception e ) {
347+ logger .error ("Error while fetching completion" , e );
348+ return Map .of ("error" , "Failed to fetch completion" );
349+ }
350+ }
351+
352+
353+ public List <String > chunkString (String input ) {
354+ List <String > chunks = new ArrayList <>();
355+ int chunkSize = STREAM_SIZE ;
356+
357+ for (int i = 0 ; i < input .length (); i += chunkSize ) {
358+ int end = Math .min (input .length (), i + chunkSize );
359+ chunks .add (input .substring (i , end ));
360+ }
361+
362+ return chunks ;
363+ }
364+
365+ public String generateRandomToken (int length ) {
366+ String CHARACTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" ;
367+ SecureRandom random = new SecureRandom ();
368+ StringBuilder sb = new StringBuilder (length );
369+ for (int i = 0 ; i < length ; i ++) {
370+ int index = random .nextInt (CHARACTERS .length ());
371+ sb .append (CHARACTERS .charAt (index ));
372+ }
373+ return sb .toString ();
374+ }
375+
376+ public String getModel (){
377+ String modelId ="custom" ;
378+ if (!"" .equals (modelOpenAI )) {
379+ modelId = modelOpenAI ;
380+ } else if (!"" .equals (modelOllamaAI )) {
381+ modelId = modelOllamaAI ;
382+ }
383+ return modelId ;
384+ }
250385}
386+
387+
388+
389+
0 commit comments