33import os , time , logging , sys , shutil , numpy as np , pandas as pd , gc , warnings , json
44from contextlib import contextmanager
55from synthesizrr .base .util import optional_dependency , set_param_from_alias , Parameters , get_default , safe_validate_arguments , \
6- accumulate , dispatch , str_format_args , format_exception_msg , any_item , retry , Log , remove_values , as_list
6+ accumulate , dispatch , dispatch_executor , any_are_none , format_exception_msg , any_item , retry , Log , remove_values , as_list , \
7+ stop_executor
78from synthesizrr .base .framework import Dataset
89from synthesizrr .base .framework .task .text_generation import GenerativeLM , Prompts , GENERATED_TEXTS_COL , TextGenerationParams , \
910 TextGenerationParamsMapper
1718 import boto3
1819
1920
21+ def call_claude_v1_v2 (
22+ bedrock ,
23+ model_name : str ,
24+ prompt : str ,
25+ max_tokens_to_sample : int ,
26+ temperature : Optional [float ] = None ,
27+ top_k : Optional [int ] = None ,
28+ top_p : Optional [float ] = None ,
29+ stop_sequences : Optional [List [str ]] = None ,
30+ ** kwargs ,
31+ ) -> str :
32+ assert any_are_none (top_k , top_p ), f'At least one of top_k, top_p must be None'
33+ bedrock_params = {
34+ "prompt" : prompt ,
35+ "max_tokens_to_sample" : max_tokens_to_sample ,
36+ }
37+ if top_p is not None and temperature is not None :
38+ raise ValueError (f'Cannot specify both top_p and temperature; at most one must be specified.' )
39+
40+ if top_k is not None :
41+ assert isinstance (top_k , int )
42+ bedrock_params ["top_k" ] = top_k
43+ elif temperature is not None :
44+ assert isinstance (temperature , (float , int )) and 0 <= temperature <= 1
45+ bedrock_params ["temperature" ] = temperature
46+ elif top_p is not None :
47+ assert isinstance (top_p , (float , int )) and 0 <= top_p <= 1
48+ bedrock_params ["top_p" ] = top_p
49+
50+ if stop_sequences is not None :
51+ bedrock_params ["stop_sequences" ] = stop_sequences
52+
53+ response = bedrock .invoke_model (
54+ body = json .dumps (bedrock_params ),
55+ modelId = model_name ,
56+ accept = 'application/json' ,
57+ contentType = 'application/json' ,
58+ )
59+ response_body : Dict = json .loads (response .get ('body' ).read ())
60+ return response_body .get ('completion' )
61+
62+
63+ def call_claude_v3 (
64+ bedrock ,
65+ * ,
66+ model_name : str ,
67+ prompt : str ,
68+ max_tokens_to_sample : int ,
69+ temperature : Optional [float ] = None ,
70+ system : Optional [str ] = None ,
71+ top_k : Optional [int ] = None ,
72+ top_p : Optional [float ] = None ,
73+ stop_sequences : Optional [List [str ]] = None ,
74+ ** kwargs ,
75+ ) -> str :
76+ assert any_are_none (top_k , top_p ), f'At least one of top_k, top_p must be None'
77+ bedrock_params = {
78+ "anthropic_version" : "bedrock-2023-05-31" ,
79+ "max_tokens" : max_tokens_to_sample ,
80+ "messages" : [
81+ {
82+ "role" : "user" ,
83+ "content" : prompt ,
84+ }
85+ ],
86+ }
87+ if system is not None :
88+ assert isinstance (system , str ) and len (system ) > 0
89+ bedrock_params ["system" ] = system
90+
91+ if top_p is not None and temperature is not None :
92+ raise ValueError (f'Cannot specify both top_p and temperature; at most one must be specified.' )
93+
94+ if top_k is not None :
95+ assert isinstance (top_k , int ) and len (system ) >= 1
96+ bedrock_params ["top_k" ] = top_k
97+ elif top_p is not None :
98+ assert isinstance (top_p , (float , int )) and 0 <= top_p <= 1
99+ bedrock_params ["top_p" ] = top_p
100+ elif temperature is not None :
101+ assert isinstance (temperature , (float , int )) and 0 <= temperature <= 1
102+ bedrock_params ["temperature" ] = temperature
103+
104+ if stop_sequences is not None :
105+ bedrock_params ["stop_sequences" ] = stop_sequences
106+
107+ bedrock_params_json : str = json .dumps (bedrock_params )
108+ # print(f'\n\nbedrock_params_json:\n{json.dumps(bedrock_params, indent=4)}')
109+ response = bedrock .invoke_model (
110+ body = bedrock_params_json ,
111+ modelId = model_name ,
112+ accept = 'application/json' ,
113+ contentType = 'application/json' ,
114+ )
115+ response_body : Dict = json .loads (response .get ('body' ).read ())
116+ return '\n ' .join ([d ['text' ] for d in response_body .get ("content" )])
117+
118+
20119 def call_bedrock (
21120 prompt : str ,
22121 * ,
23122 model_name : str ,
24123 generation_params : Dict ,
25124 region_name : List [str ],
26- ) -> Dict :
27- start = time .perf_counter ()
125+ ) -> str :
28126 ## Note: creation of the bedrock client is fast.
29127 bedrock = boto3 .client (
30128 service_name = 'bedrock-runtime' ,
31129 region_name = any_item (region_name ),
32- # endpoint_url=f 'https://bedrock.{region_name} .amazonaws.com',
130+ # endpoint_url='https://bedrock.us-east-1 .amazonaws.com',
33131 )
34- bedrock_invoke_model_params = {
35- "prompt" : prompt ,
36- ** generation_params
37- }
38- response = bedrock .invoke_model (
39- body = json .dumps (bedrock_invoke_model_params ),
40- modelId = model_name ,
41- accept = 'application/json' ,
42- contentType = 'application/json'
43- )
44- response_body = json .loads (response .get ('body' ).read ())
45- end = time .perf_counter ()
46- time_taken_sec : float = end - start
47- return response_body .get ('completion' )
132+ if 'anthropic.claude-3' in model_name :
133+ generated_text : str = call_claude_v3 (
134+ bedrock = bedrock ,
135+ prompt = prompt ,
136+ model_name = model_name ,
137+ ** generation_params
138+ )
139+ elif 'claude' in model_name :
140+ generated_text : str = call_claude_v1_v2 (
141+ bedrock = bedrock ,
142+ prompt = prompt ,
143+ model_name = model_name ,
144+ ** generation_params
145+ )
146+ else :
147+ bedrock_invoke_model_params = {
148+ "prompt" : prompt ,
149+ ** generation_params
150+ }
151+ response = bedrock .invoke_model (
152+ body = json .dumps (bedrock_invoke_model_params ),
153+ modelId = model_name ,
154+ accept = 'application/json' ,
155+ contentType = 'application/json'
156+ )
157+ response_body = json .loads (response .get ('body' ).read ())
158+ generated_text : str = response_body .get ('completion' )
159+ return generated_text
48160
49161
50162 class BedrockPrompter (GenerativeLM ):
51163 aliases = ['bedrock' ]
164+ executor : Optional [Any ] = None
52165
53166 class Hyperparameters (GenerativeLM .Hyperparameters ):
54167 ALLOWED_TEXT_GENERATION_PARAMS : ClassVar [List [str ]] = [
@@ -59,6 +172,7 @@ class Hyperparameters(GenerativeLM.Hyperparameters):
59172 'top_p' ,
60173 'max_new_tokens' ,
61174 'stop_sequences' ,
175+ 'system' ,
62176 ]
63177
64178 region_name : List [str ] = [
@@ -70,8 +184,9 @@ class Hyperparameters(GenerativeLM.Hyperparameters):
70184 model_name : constr (min_length = 1 )
71185 retries : conint (ge = 0 ) = 3
72186 retry_wait : confloat (ge = 0 ) = 1.0
73- retry_jitter : confloat (ge = 0 ) = 0.25
187+ retry_jitter : confloat (ge = 0 ) = 0.5
74188 parallelize : Parallelize = Parallelize .sync
189+ max_workers : int = 1
75190 generation_params : Union [TextGenerationParams , Dict , str ]
76191
77192 @root_validator (pre = True )
@@ -105,7 +220,15 @@ def max_num_generated_tokens(self) -> int:
105220
106221 def initialize (self , model_dir : Optional [FileMetadata ] = None ):
107222 ## Ignore the model_dir.
108- pass
223+ if self .executor is None :
224+ self .executor : Optional [Any ] = dispatch_executor (
225+ parallelize = self .hyperparams .parallelize ,
226+ max_workers = self .hyperparams .max_workers ,
227+ )
228+
229+ def cleanup (self ):
230+ super (self .__class__ , self ).cleanup ()
231+ stop_executor (self .executor )
109232
110233 @property
111234 def bedrock_text_generation_params (self ) -> Dict [str , Any ]:
@@ -146,6 +269,7 @@ def predict_step(self, batch: Prompts, **kwargs) -> Any:
146269 self .prompt_model_with_retries ,
147270 prompt ,
148271 parallelize = self .hyperparams .parallelize ,
272+ executor = self .executor ,
149273 )
150274 generated_texts .append (generated_text )
151275 generated_texts : List [str ] = accumulate (generated_texts )
0 commit comments