Skip to content

Commit c9367a4

Browse files
authored
get prompt with label (#22)
Change-Id: Ief6540e9b9bc12aa326a00841dbd79c842a85d04
1 parent baa9ffb commit c9367a4

File tree

13 files changed

+174
-31
lines changed

13 files changed

+174
-31
lines changed

cozeloop/_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,10 @@ def close(self):
259259
self._trace_provider.close_trace()
260260
self._closed = True
261261

262-
def get_prompt(self, prompt_key: str, version: str = '') -> Optional[Prompt]:
262+
def get_prompt(self, prompt_key: str, version: str = '', label: str = '') -> Optional[Prompt]:
263263
if self._closed:
264264
raise ClientClosedError()
265-
return self._prompt_provider.get_prompt(prompt_key, version)
265+
return self._prompt_provider.get_prompt(prompt_key, version, label)
266266

267267
def prompt_format(self, prompt: Prompt, variables: Dict[str, PromptVariable]) -> List[Message]:
268268
if self._closed:
@@ -360,8 +360,8 @@ def close():
360360
return get_default_client().close()
361361

362362

363-
def get_prompt(prompt_key: str, version: str = '') -> Prompt:
364-
return get_default_client().get_prompt(prompt_key, version)
363+
def get_prompt(prompt_key: str, version: str = '', label: str = '') -> Prompt:
364+
return get_default_client().get_prompt(prompt_key, version, label)
365365

366366

367367
def prompt_format(prompt: Prompt, variables: Dict[str, Any]) -> List[Message]:

cozeloop/_noop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def workspace_id(self) -> str:
2727
def close(self):
2828
logger.warning(f"Noop client not supported. {self.new_exception}")
2929

30-
def get_prompt(self, prompt_key: str, version: str = '') -> Optional[Prompt]:
30+
def get_prompt(self, prompt_key: str, version: str = '', label: str = '') -> Optional[Prompt]:
3131
logger.warning(f"Noop client not supported. {self.new_exception}")
3232
raise self.new_exception
3333

cozeloop/internal/prompt/cache.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,29 @@ def __init__(self, workspace_id: str,
4141
if auto_refresh and self.openapi_client is not None:
4242
self._start_refresh_task()
4343

44-
def get(self, prompt_key: str, version: str) -> Optional['Prompt']:
45-
cache_key = self._get_cache_key(prompt_key, version)
44+
def get(self, prompt_key: str, version: str, label: str = '') -> Optional['Prompt']:
45+
cache_key = self._get_cache_key(prompt_key, version, label)
4646
return self.cache.get(cache_key)
4747

48-
def set(self, prompt_key: str, version: str, value: 'Prompt') -> None:
49-
cache_key = self._get_cache_key(prompt_key, version)
48+
def set(self, prompt_key: str, version: str, label: str, value: 'Prompt') -> None:
49+
cache_key = self._get_cache_key(prompt_key, version, label)
5050
self.cache[cache_key] = value
5151

52-
def get_all_prompt_queries(self) -> List[Tuple[str, str]]:
52+
def get_all_prompt_queries(self) -> List[Tuple[str, str, str]]:
5353
result = []
5454
for cache_key in self.cache.keys():
5555
parsed = self._parse_cache_key(cache_key)
5656
if parsed:
5757
result.append(parsed)
5858
return result
5959

60-
def _get_cache_key(self, prompt_key: str, version: str) -> str:
61-
return f"prompt_hub:{prompt_key}:{version}"
60+
def _get_cache_key(self, prompt_key: str, version: str, label: str = '') -> str:
61+
return f"prompt_hub:{prompt_key}:{version}:{label}"
6262

63-
def _parse_cache_key(self, cache_key: str) -> Optional[Tuple[str, str]]:
63+
def _parse_cache_key(self, cache_key: str) -> Optional[Tuple[str, str, str]]:
6464
parts = cache_key.split(':')
65-
if len(parts) == 3:
66-
return parts[1], parts[2]
65+
if len(parts) == 4:
66+
return parts[1], parts[2], parts[3]
6767
return None
6868

6969
def _start_refresh_task(self):
@@ -91,13 +91,13 @@ def _refresh_all_prompts(self):
9191
"""Refresh all cached prompts"""
9292
try:
9393
# Get all cached prompt_keys and versions
94-
key_pairs = self.get_all_prompt_queries()
95-
queries = [PromptQuery(prompt_key=prompt_key, version=version) for prompt_key, version in key_pairs]
94+
key_tuples = self.get_all_prompt_queries()
95+
queries = [PromptQuery(prompt_key=prompt_key, version=version, label=label) for prompt_key, version, label in key_tuples]
9696
try:
9797
results = self.openapi_client.mpull_prompt(self.workspace_id, queries)
9898
for result in results:
99-
prompt_key, version = result.query.prompt_key, result.query.version
100-
self.set(prompt_key, version, _convert_prompt(result.prompt))
99+
prompt_key, version, label = result.query.prompt_key, result.query.version, result.query.label
100+
self.set(prompt_key, version, label, _convert_prompt(result.prompt))
101101
except Exception as e:
102102
logger.error(f"Error refreshing prompts: {e}")
103103

cozeloop/internal/prompt/openapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class Prompt(BaseModel):
103103
class PromptQuery(BaseModel):
104104
prompt_key: str
105105
version: Optional[str] = None
106+
label: Optional[str] = None
106107

107108

108109
class MPullPromptRequest(BaseModel):

cozeloop/internal/prompt/prompt.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import json
55
from typing import Dict, Any, List, Optional
66

7-
from jinja2 import Environment, BaseLoader, Undefined
8-
from jinja2.utils import missing, object_type_repr
7+
from jinja2 import BaseLoader, Undefined
98
from jinja2.sandbox import SandboxedEnvironment
9+
from jinja2.utils import missing, object_type_repr
1010

11-
from cozeloop.spec.tracespec import PROMPT_KEY, INPUT, PROMPT_VERSION, V_SCENE_PROMPT_TEMPLATE, V_SCENE_PROMPT_HUB
1211
from cozeloop.entities.prompt import (Prompt, Message, VariableDef, VariableType, TemplateType, Role,
1312
PromptVariable)
1413
from cozeloop.internal import consts
@@ -18,6 +17,7 @@
1817
from cozeloop.internal.prompt.converter import _convert_prompt, _to_span_prompt_input, _to_span_prompt_output
1918
from cozeloop.internal.prompt.openapi import OpenAPIClient, PromptQuery
2019
from cozeloop.internal.trace.trace import TraceProvider
20+
from cozeloop.spec.tracespec import PROMPT_KEY, INPUT, PROMPT_VERSION, V_SCENE_PROMPT_TEMPLATE, V_SCENE_PROMPT_HUB, PROMPT_LABEL
2121

2222

2323
class PromptProvider:
@@ -39,18 +39,18 @@ def __init__(
3939
auto_refresh=True)
4040
self.prompt_trace = prompt_trace
4141

42-
def get_prompt(self, prompt_key: str, version: str = '') -> Optional[Prompt]:
42+
def get_prompt(self, prompt_key: str, version: str = '', label: str = '') -> Optional[Prompt]:
4343
# Trace reporting
4444
if self.prompt_trace and self.trace_provider is not None:
4545
with self.trace_provider.start_span(consts.TRACE_PROMPT_HUB_SPAN_NAME,
4646
consts.TRACE_PROMPT_HUB_SPAN_TYPE,
4747
scene=V_SCENE_PROMPT_HUB) as prompt_hub_pan:
4848
prompt_hub_pan.set_tags({
4949
PROMPT_KEY: prompt_key,
50-
INPUT: json.dumps({PROMPT_KEY: prompt_key, PROMPT_VERSION: version})
50+
INPUT: json.dumps({PROMPT_KEY: prompt_key, PROMPT_VERSION: version, PROMPT_LABEL: label})
5151
})
5252
try:
53-
prompt = self._get_prompt(prompt_key, version)
53+
prompt = self._get_prompt(prompt_key, version, label)
5454
if prompt is not None:
5555
prompt_hub_pan.set_tags({
5656
PROMPT_VERSION: prompt.version,
@@ -65,20 +65,20 @@ def get_prompt(self, prompt_key: str, version: str = '') -> Optional[Prompt]:
6565
prompt_hub_pan.set_error(str(e))
6666
raise e
6767
else:
68-
return self._get_prompt(prompt_key, version)
68+
return self._get_prompt(prompt_key, version, label)
6969

70-
def _get_prompt(self, prompt_key: str, version: str) -> Optional[Prompt]:
70+
def _get_prompt(self, prompt_key: str, version: str, label: str = '') -> Optional[Prompt]:
7171
"""
7272
Get Prompt, prioritize retrieving from cache, if not found then fetch from server
7373
"""
7474
# Try to get from cache
75-
prompt = self.cache.get(prompt_key, version)
75+
prompt = self.cache.get(prompt_key, version, label)
7676
# If not in cache, fetch from server and cache it
7777
if prompt is None:
78-
result = self.openapi_client.mpull_prompt(self.workspace_id, [PromptQuery(prompt_key=prompt_key, version=version)])
78+
result = self.openapi_client.mpull_prompt(self.workspace_id, [PromptQuery(prompt_key=prompt_key, version=version, label=label)])
7979
if result:
8080
prompt = _convert_prompt(result[0].prompt)
81-
self.cache.set(prompt_key, version, prompt)
81+
self.cache.set(prompt_key, version, label, prompt)
8282
# object cache item should be read only
8383
return prompt.copy(deep=True)
8484

cozeloop/prompt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ class PromptClient(ABC):
1313
"""
1414

1515
@abstractmethod
16-
def get_prompt(self, prompt_key: str, version: str = '') -> Optional[Prompt]:
16+
def get_prompt(self, prompt_key: str, version: str = '', label: str = '') -> Optional[Prompt]:
1717
"""
1818
Get a prompt by prompt key and version.
1919
2020
:param prompt_key: A unique key for retrieving the prompt.
2121
:param version: The version of the prompt. Defaults to empty, which represents fetching the latest version.
22+
:param label: The label of the prompt. Defaults to empty.
2223
:return: An instance of `entity.Prompt` if found, or None.
2324
"""
2425

cozeloop/spec/tracespec/span_key.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
PROMPT_PROVIDER = "prompt_provider" # Prompt providers, such as Loop, Langsmith, etc.
2020
PROMPT_KEY = "prompt_key"
2121
PROMPT_VERSION = "prompt_version"
22+
PROMPT_LABEL = "prompt_label"
2223

2324
# Internal experimental field.
2425
# It is not recommended to use for the time being. Instead, use the corresponding Set method.

examples/__init__.py

Whitespace-only changes.

examples/prompt/__init__.py

Whitespace-only changes.

examples/prompt/prompt_hub/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)