11# Copyright (C) 2024 Intel Corporation
22# SPDX-License-Identifier: Apache-2.0
33
4+ import asyncio
45import os
5- from typing import List , Union
66
7- import torch
8- import torch .nn as nn
9- from einops import rearrange
10- from transformers import AutoProcessor , AutoTokenizer , CLIPModel
7+ import requests
118
129from comps import CustomLogger , OpeaComponent , OpeaComponentRegistry , ServiceType
1310from comps .cores .proto .api_protocol import EmbeddingRequest , EmbeddingResponse , EmbeddingResponseData
1613logflag = os .getenv ("LOGFLAG" , False )
1714
1815
19- model_name = "openai/clip-vit-base-patch32"
20-
21- clip = CLIPModel .from_pretrained (model_name )
22- processor = AutoProcessor .from_pretrained (model_name )
23- tokenizer = AutoTokenizer .from_pretrained (model_name )
24-
25-
26- class vCLIP (nn .Module ):
27- def __init__ (self , cfg ):
28- super ().__init__ ()
29-
30- self .num_frm = cfg ["num_frm" ]
31- self .model_name = cfg ["model_name" ]
32-
33- def embed_query (self , texts ):
34- """Input is list of texts."""
35- text_inputs = tokenizer (texts , padding = True , return_tensors = "pt" )
36- text_features = clip .get_text_features (** text_inputs )
37- return text_features
38-
39- def get_embedding_length (self ):
40- text_features = self .embed_query ("sample_text" )
41- return text_features .shape [1 ]
42-
43- def get_image_embeddings (self , images ):
44- """Input is list of images."""
45- image_inputs = processor (images = images , return_tensors = "pt" )
46- image_features = clip .get_image_features (** image_inputs )
47- return image_features
48-
49- def get_video_embeddings (self , frames_batch ):
50- """Input is list of list of frames in video."""
51- self .batch_size = len (frames_batch )
52- vid_embs = []
53- for frames in frames_batch :
54- frame_embeddings = self .get_image_embeddings (frames )
55- frame_embeddings = rearrange (frame_embeddings , "(b n) d -> b n d" , b = len (frames_batch ))
56- # Normalize, mean aggregate and return normalized video_embeddings
57- frame_embeddings = frame_embeddings / frame_embeddings .norm (dim = - 1 , keepdim = True )
58- video_embeddings = frame_embeddings .mean (dim = 1 )
59- video_embeddings = video_embeddings / video_embeddings .norm (dim = - 1 , keepdim = True )
60- vid_embs .append (video_embeddings )
61- return torch .cat (vid_embs , dim = 0 )
62-
63-
6416@OpeaComponentRegistry .register ("OPEA_CLIP_EMBEDDING" )
6517class OpeaClipEmbedding (OpeaComponent ):
6618 """A specialized embedding component derived from OpeaComponent for CLIP embedding services.
@@ -74,7 +26,7 @@ class OpeaClipEmbedding(OpeaComponent):
7426
7527 def __init__ (self , name : str , description : str , config : dict = None ):
7628 super ().__init__ (name , ServiceType .EMBEDDING .name .lower (), description , config )
77- self .embeddings = vCLIP ({ "model_name" : "openai/clip-vit-base-patch32" , "num_frm" : 4 } )
29+ self .base_url = os . getenv ( "CLIP_EMBEDDING_ENDPOINT" , "http://localhost:6990" )
7830
7931 health_status = self .check_health ()
8032 if not health_status :
@@ -89,46 +41,38 @@ async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
8941 Returns:
9042 EmbeddingResponse: The response in OpenAI embedding format, including embeddings, model, and usage information.
9143 """
92- # Parse input according to the EmbeddingRequest format
93- if isinstance (input .input , str ):
94- texts = [input .input .replace ("\n " , " " )]
95- elif isinstance (input .input , list ):
96- if all (isinstance (item , str ) for item in input .input ):
97- texts = [text .replace ("\n " , " " ) for text in input .input ]
98- else :
99- raise ValueError ("Invalid input format: Only string or list of strings are supported." )
100- else :
101- raise TypeError ("Unsupported input type: input must be a string or list of strings." )
102- embed_vector = self .get_embeddings (texts )
103- if input .dimensions is not None :
104- embed_vector = [embed_vector [i ][: input .dimensions ] for i in range (len (embed_vector ))]
105-
106- # for standard openai embedding format
107- res = EmbeddingResponse (
108- data = [EmbeddingResponseData (index = i , embedding = embed_vector [i ]) for i in range (len (embed_vector ))]
109- )
110- return res
44+ json_payload = input .model_dump ()
45+ try :
46+ response = await asyncio .to_thread (
47+ requests .post ,
48+ f"{ self .base_url } /v1/embeddings" ,
49+ headers = {"Content-Type" : "application/json" },
50+ json = json_payload ,
51+ )
52+ response .raise_for_status ()
53+ response_json = response .json ()
54+
55+ return EmbeddingResponse (
56+ data = [EmbeddingResponseData (** item ) for item in response_json .get ("data" , [])],
57+ model = response_json .get ("model" , input .model ),
58+ usage = response_json .get ("usage" , {}),
59+ )
60+ except requests .RequestException as e :
61+ raise RuntimeError (f"Failed to invoke embedding service: { str (e )} " )
11162
11263 def check_health (self ) -> bool :
11364 """Checks if the embedding model is healthy.
11465
11566 Returns:
11667 bool: True if the embedding model is initialized, False otherwise.
11768 """
118- if self .embeddings :
69+ try :
70+ _ = requests .post (
71+ f"{ self .base_url } /v1/embeddings" ,
72+ headers = {"Content-Type" : "application/json" },
73+ json = {"input" : "health check" },
74+ )
75+
11976 return True
120- else :
77+ except requests . RequestException as e :
12178 return False
122-
123- def get_embeddings (self , text : Union [str , List [str ]]) -> List [List [float ]]:
124- """Generates embeddings for input text.
125-
126- Args:
127- text (Union[str, List[str]]): Input text or list of texts.
128-
129- Returns:
130- List[List[float]]: List of embedding vectors.
131- """
132- texts = [text ] if isinstance (text , str ) else text
133- embed_vector = self .embeddings .embed_query (texts ).tolist ()
134- return embed_vector
0 commit comments