Skip to content

Commit cb0c7dc

Browse files
Added Claude 3 models, new invert_dict util
1 parent 0961289 commit cb0c7dc

File tree

5 files changed

+195
-43
lines changed

5 files changed

+195
-43
lines changed

src/synthesizrr/base/algorithm/bedrock.py

Lines changed: 144 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import os, time, logging, sys, shutil, numpy as np, pandas as pd, gc, warnings, json
44
from contextlib import contextmanager
55
from synthesizrr.base.util import optional_dependency, set_param_from_alias, Parameters, get_default, safe_validate_arguments, \
6-
accumulate, dispatch, str_format_args, format_exception_msg, any_item, retry, Log, remove_values, as_list
6+
accumulate, dispatch, dispatch_executor, any_are_none, format_exception_msg, any_item, retry, Log, remove_values, as_list, \
7+
stop_executor
78
from synthesizrr.base.framework import Dataset
89
from synthesizrr.base.framework.task.text_generation import GenerativeLM, Prompts, GENERATED_TEXTS_COL, TextGenerationParams, \
910
TextGenerationParamsMapper
@@ -17,38 +18,150 @@
1718
import boto3
1819

1920

21+
def call_claude_v1_v2(
22+
bedrock,
23+
model_name: str,
24+
prompt: str,
25+
max_tokens_to_sample: int,
26+
temperature: Optional[float] = None,
27+
top_k: Optional[int] = None,
28+
top_p: Optional[float] = None,
29+
stop_sequences: Optional[List[str]] = None,
30+
**kwargs,
31+
) -> str:
32+
assert any_are_none(top_k, top_p), f'At least one of top_k, top_p must be None'
33+
bedrock_params = {
34+
"prompt": prompt,
35+
"max_tokens_to_sample": max_tokens_to_sample,
36+
}
37+
if top_p is not None and temperature is not None:
38+
raise ValueError(f'Cannot specify both top_p and temperature; at most one must be specified.')
39+
40+
if top_k is not None:
41+
assert isinstance(top_k, int)
42+
bedrock_params["top_k"] = top_k
43+
elif temperature is not None:
44+
assert isinstance(temperature, (float, int)) and 0 <= temperature <= 1
45+
bedrock_params["temperature"] = temperature
46+
elif top_p is not None:
47+
assert isinstance(top_p, (float, int)) and 0 <= top_p <= 1
48+
bedrock_params["top_p"] = top_p
49+
50+
if stop_sequences is not None:
51+
bedrock_params["stop_sequences"] = stop_sequences
52+
53+
response = bedrock.invoke_model(
54+
body=json.dumps(bedrock_params),
55+
modelId=model_name,
56+
accept='application/json',
57+
contentType='application/json',
58+
)
59+
response_body: Dict = json.loads(response.get('body').read())
60+
return response_body.get('completion')
61+
62+
63+
def call_claude_v3(
64+
bedrock,
65+
*,
66+
model_name: str,
67+
prompt: str,
68+
max_tokens_to_sample: int,
69+
temperature: Optional[float] = None,
70+
system: Optional[str] = None,
71+
top_k: Optional[int] = None,
72+
top_p: Optional[float] = None,
73+
stop_sequences: Optional[List[str]] = None,
74+
**kwargs,
75+
) -> str:
76+
assert any_are_none(top_k, top_p), f'At least one of top_k, top_p must be None'
77+
bedrock_params = {
78+
"anthropic_version": "bedrock-2023-05-31",
79+
"max_tokens": max_tokens_to_sample,
80+
"messages": [
81+
{
82+
"role": "user",
83+
"content": prompt,
84+
}
85+
],
86+
}
87+
if system is not None:
88+
assert isinstance(system, str) and len(system) > 0
89+
bedrock_params["system"] = system
90+
91+
if top_p is not None and temperature is not None:
92+
raise ValueError(f'Cannot specify both top_p and temperature; at most one must be specified.')
93+
94+
if top_k is not None:
95+
assert isinstance(top_k, int) and len(system) >= 1
96+
bedrock_params["top_k"] = top_k
97+
elif top_p is not None:
98+
assert isinstance(top_p, (float, int)) and 0 <= top_p <= 1
99+
bedrock_params["top_p"] = top_p
100+
elif temperature is not None:
101+
assert isinstance(temperature, (float, int)) and 0 <= temperature <= 1
102+
bedrock_params["temperature"] = temperature
103+
104+
if stop_sequences is not None:
105+
bedrock_params["stop_sequences"] = stop_sequences
106+
107+
bedrock_params_json: str = json.dumps(bedrock_params)
108+
# print(f'\n\nbedrock_params_json:\n{json.dumps(bedrock_params, indent=4)}')
109+
response = bedrock.invoke_model(
110+
body=bedrock_params_json,
111+
modelId=model_name,
112+
accept='application/json',
113+
contentType='application/json',
114+
)
115+
response_body: Dict = json.loads(response.get('body').read())
116+
return '\n'.join([d['text'] for d in response_body.get("content")])
117+
118+
20119
def call_bedrock(
21120
prompt: str,
22121
*,
23122
model_name: str,
24123
generation_params: Dict,
25124
region_name: List[str],
26-
) -> Dict:
27-
start = time.perf_counter()
125+
) -> str:
28126
## Note: creation of the bedrock client is fast.
29127
bedrock = boto3.client(
30128
service_name='bedrock-runtime',
31129
region_name=any_item(region_name),
32-
# endpoint_url=f'https://bedrock.{region_name}.amazonaws.com',
130+
# endpoint_url='https://bedrock.us-east-1.amazonaws.com',
33131
)
34-
bedrock_invoke_model_params = {
35-
"prompt": prompt,
36-
**generation_params
37-
}
38-
response = bedrock.invoke_model(
39-
body=json.dumps(bedrock_invoke_model_params),
40-
modelId=model_name,
41-
accept='application/json',
42-
contentType='application/json'
43-
)
44-
response_body = json.loads(response.get('body').read())
45-
end = time.perf_counter()
46-
time_taken_sec: float = end - start
47-
return response_body.get('completion')
132+
if 'anthropic.claude-3' in model_name:
133+
generated_text: str = call_claude_v3(
134+
bedrock=bedrock,
135+
prompt=prompt,
136+
model_name=model_name,
137+
**generation_params
138+
)
139+
elif 'claude' in model_name:
140+
generated_text: str = call_claude_v1_v2(
141+
bedrock=bedrock,
142+
prompt=prompt,
143+
model_name=model_name,
144+
**generation_params
145+
)
146+
else:
147+
bedrock_invoke_model_params = {
148+
"prompt": prompt,
149+
**generation_params
150+
}
151+
response = bedrock.invoke_model(
152+
body=json.dumps(bedrock_invoke_model_params),
153+
modelId=model_name,
154+
accept='application/json',
155+
contentType='application/json'
156+
)
157+
response_body = json.loads(response.get('body').read())
158+
generated_text: str = response_body.get('completion')
159+
return generated_text
48160

49161

50162
class BedrockPrompter(GenerativeLM):
51163
aliases = ['bedrock']
164+
executor: Optional[Any] = None
52165

53166
class Hyperparameters(GenerativeLM.Hyperparameters):
54167
ALLOWED_TEXT_GENERATION_PARAMS: ClassVar[List[str]] = [
@@ -59,6 +172,7 @@ class Hyperparameters(GenerativeLM.Hyperparameters):
59172
'top_p',
60173
'max_new_tokens',
61174
'stop_sequences',
175+
'system',
62176
]
63177

64178
region_name: List[str] = [
@@ -70,8 +184,9 @@ class Hyperparameters(GenerativeLM.Hyperparameters):
70184
model_name: constr(min_length=1)
71185
retries: conint(ge=0) = 3
72186
retry_wait: confloat(ge=0) = 1.0
73-
retry_jitter: confloat(ge=0) = 0.25
187+
retry_jitter: confloat(ge=0) = 0.5
74188
parallelize: Parallelize = Parallelize.sync
189+
max_workers: int = 1
75190
generation_params: Union[TextGenerationParams, Dict, str]
76191

77192
@root_validator(pre=True)
@@ -105,7 +220,15 @@ def max_num_generated_tokens(self) -> int:
105220

106221
def initialize(self, model_dir: Optional[FileMetadata] = None):
107222
## Ignore the model_dir.
108-
pass
223+
if self.executor is None:
224+
self.executor: Optional[Any] = dispatch_executor(
225+
parallelize=self.hyperparams.parallelize,
226+
max_workers=self.hyperparams.max_workers,
227+
)
228+
229+
def cleanup(self):
230+
super(self.__class__, self).cleanup()
231+
stop_executor(self.executor)
109232

110233
@property
111234
def bedrock_text_generation_params(self) -> Dict[str, Any]:
@@ -146,6 +269,7 @@ def predict_step(self, batch: Prompts, **kwargs) -> Any:
146269
self.prompt_model_with_retries,
147270
prompt,
148271
parallelize=self.hyperparams.parallelize,
272+
executor=self.executor,
149273
)
150274
generated_texts.append(generated_text)
151275
generated_texts: List[str] = accumulate(generated_texts)

src/synthesizrr/base/framework/mixins.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -939,19 +939,16 @@ def evaluate(
939939
self,
940940
metric: Optional[Union[Metric, Dict, str]] = None,
941941
*,
942-
rolling: bool = False,
942+
rolling: Optional[bool] = None,
943+
inplace: Optional[bool] = None,
943944
**kwargs
944945
) -> Metric:
945946
if metric is None:
946-
return Metric.of(**kwargs).evaluate(self)
947+
metric: Metric = Metric.of(**kwargs)
947948
if isinstance(metric, str):
948-
return Metric.of(name=metric, **kwargs).evaluate(self)
949-
if isinstance(metric, Metric):
950-
if rolling:
951-
return metric.evaluate(self, rolling=True)
952-
else:
953-
return metric.evaluate(self, inplace=False)
954-
raise NotImplementedError(f'Unsupported value for input `metric`: {type(metric)} with value:\n{metric}')
949+
metric: Metric = Metric.of(name=metric, **kwargs)
950+
assert isinstance(metric, Metric)
951+
return metric.evaluate(self, rolling=rolling, inplace=inplace)
955952

956953
@safe_validate_arguments
957954
def columns(

src/synthesizrr/base/framework/task/text_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,14 +581,14 @@ class BeamSearchParams(TextGenerationParams):
581581

582582
class TopKSamplingParams(TextGenerationParams):
583583
strategy = 'TopKSampling'
584-
temperature: confloat(gt=0.0, le=1.0)
584+
temperature: confloat(gt=0.0, le=100.0) = 1.0
585585
do_sample: Literal[True] = True ## When not doing greedy decoding, we should sample.
586586

587587

588588
class NucleusSamplingParams(TextGenerationParams):
589589
strategy = 'NucleusSampling'
590590
do_sample: Literal[True] = True ## When not doing greedy decoding, we should sample.
591-
temperature: confloat(gt=0.0, le=1.0)
591+
temperature: confloat(gt=0.0, le=100.0) = 1.0
592592

593593

594594
class LogitsProcessorListParams(TextGenerationParams):

src/synthesizrr/base/util/language.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,8 +1363,10 @@ def filter_keys(
13631363
keys: Set = as_set(keys)
13641364
if how == 'include':
13651365
return keep_keys(d, keys)
1366-
else:
1366+
elif how == 'exclude':
13671367
return remove_keys(d, keys)
1368+
else:
1369+
raise NotImplementedError(f'Invalid value for parameter `how`: "{how}"')
13681370

13691371

13701372
def filter_values(
@@ -1488,6 +1490,15 @@ def eval_dict_values(params: Dict):
14881490
return updated_dict
14891491

14901492

1493+
def invert_dict(d: Dict) -> Dict:
1494+
if not isinstance(d, dict):
1495+
raise ValueError(f'{d} should be of type dict')
1496+
d_inv: Dict = {v: k for k, v in d.items()}
1497+
if len(d_inv) != len(d):
1498+
raise ValueError(f'Dict is not invertible as values are not unique.')
1499+
return d_inv
1500+
1501+
14911502
## ======================== NumPy utils ======================== ##
14921503
def is_numpy_integer_array(data: Any) -> bool:
14931504
if not isinstance(data, np.ndarray):
@@ -1929,6 +1940,10 @@ def iter_batches(
19291940
yield struct[i: min(i + batch_size, struct_len)]
19301941

19311942

1943+
def mean(vals):
1944+
return sum(vals) / len(vals)
1945+
1946+
19321947
def random_sample(
19331948
data: Union[List, Tuple, np.ndarray],
19341949
n: SampleSizeType,
@@ -2827,18 +2842,34 @@ class Timeout1Week(Timeout):
28272842

28282843

28292844
@contextmanager
2830-
def pd_display(
2831-
max_rows: Optional[int] = None,
2832-
max_cols: Optional[int] = None,
2833-
max_colwidth: Optional[int] = None,
2834-
vertical_align: str = 'top',
2835-
text_align: str = 'left',
2836-
ignore_css: bool = False,
2837-
):
2845+
def pd_display(**kwargs):
2846+
"""
2847+
Use pd.describe_option('display') to see all options.
2848+
"""
28382849
try:
28392850
from IPython.display import display
28402851
except ImportError:
28412852
display = print
2853+
set_param_from_alias(params=kwargs, param='max_rows', alias=['num_rows', 'nrows', 'rows'], default=None)
2854+
set_param_from_alias(params=kwargs, param='max_cols', alias=['num_cols', 'ncols', 'cols'], default=None)
2855+
set_param_from_alias(params=kwargs, param='max_colwidth', alias=[
2856+
'max_col_width',
2857+
'max_columnwidth', 'max_column_width',
2858+
'columnwidth', 'column_width',
2859+
'colwidth', 'col_width',
2860+
], default=None)
2861+
set_param_from_alias(params=kwargs, param='vertical_align', alias=['valign'], default='top')
2862+
set_param_from_alias(params=kwargs, param='text_align', alias=['textalign'], default='left')
2863+
set_param_from_alias(params=kwargs, param='ignore_css', alias=['css'], default=False)
2864+
2865+
max_rows: Optional[int] = kwargs.get('max_rows')
2866+
max_cols: Optional[int] = kwargs.get('max_cols')
2867+
max_colwidth: Optional[int] = kwargs.get('max_colwidth')
2868+
vertical_align: str = kwargs['vertical_align']
2869+
text_align: str = kwargs['text_align']
2870+
ignore_css: bool = kwargs['ignore_css']
2871+
2872+
# print(kwargs)
28422873

28432874
def disp(df: pd.DataFrame):
28442875
css = [
@@ -2851,7 +2882,7 @@ def disp(df: pd.DataFrame):
28512882
('padding', '10px'),
28522883
]
28532884
},
2854-
## Align cell to top and left
2885+
## Align cell to top and left/center
28552886
{
28562887
'selector': 'td',
28572888
'props': [

src/synthesizrr/base/util/string.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,10 +817,10 @@ def is_fuzzy_match(cls, string: str, strings_to_match: List[str]) -> bool:
817817
return cls.fuzzy_match(string, strings_to_match) is not None
818818

819819
@classmethod
820-
def make_heading(cls, heading_text: str, width: int = 85, border: str = '=') -> str:
820+
def header(cls, text: str, width: int = 65, border: str = '=') -> str:
821821
out = ''
822822
out += border * width + cls.NEWLINE
823-
out += ('{:^' + str(width) + 's}').format(heading_text) + cls.NEWLINE
823+
out += ('{:^' + str(width) + 's}').format(text) + cls.NEWLINE
824824
out += border * width + cls.NEWLINE
825825
return out
826826

0 commit comments

Comments
 (0)