diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index 925ca7f..401ddd0 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -4,31 +4,24 @@ - `build_json_schema` constructs a JSON schema representing the response structure.""" from gimkit.contexts import Query -from gimkit.schemas import ( - RESPONSE_PREFIX, - RESPONSE_SUFFIX, - TAG_END, - TAG_OPEN_LEFT, - TAG_OPEN_RIGHT, -) -def get_grammar_spec(grammar: str) -> str: - from llguidance import grammar_from +CFG_TAG_RULE_NAME_PREFIX = "masked_tag_" +LLGUIDANCE_CFG_DOCS_URL = "https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md" + + +def validate_cfg(cfg: str) -> tuple[bool, list[str]]: + from llguidance import LLMatcher, grammar_from # Borrowed from outlines source code at https://github.com/dottxt-ai/outlines/blob/87234d202924acce84ead694f8d06748608fd5f9/outlines/backends/llguidance.py#L296-L299 + # This turns the original LLGuidance CFG to a normal CFG # We try both lark and ebnf try: - grammar_spec = grammar_from("grammar", grammar) + grammar_spec = grammar_from("grammar", cfg) except ValueError: # pragma: no cover - grammar_spec = grammar_from("lark", grammar) - - return grammar_spec - - -def validate_grammar_spec(grammar_spec: str) -> tuple[bool, list[str]]: - from llguidance import LLMatcher + grammar_spec = grammar_from("lark", cfg) + # Validate the grammar spec is_error, msgs = LLMatcher.validate_grammar_with_warnings(grammar_spec) return is_error, msgs @@ -38,27 +31,44 @@ def build_cfg(query: Query) -> str: LLGuidance syntax reference: https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md """ - num_tags = len(query.tags) - grammar_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"tag{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"''' - grammar_rest_lines = [] - for i, tag in enumerate(query.tags): - # `/(?s:.)*?/` is a non-greedy match for any character including newlines - content_pattern = f"/{tag.regex}/" if tag.regex else "/(?s:.)*?/" - grammar_rest_lines.append( - f'tag{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {content_pattern} "{TAG_END}"' - ) + # Avoid circular import + from gimkit.schemas import ( + RESPONSE_PREFIX, + RESPONSE_SUFFIX, + TAG_END, + TAG_OPEN_LEFT, + TAG_OPEN_RIGHT, + ) - grammar = grammar_first_line + "\n" + "\n".join(grammar_rest_lines) + num_tags = len(query.tags) + cfg_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"{CFG_TAG_RULE_NAME_PREFIX}{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"''' - is_error, msgs = validate_grammar_spec(get_grammar_spec(grammar)) + cfg_rest_lines = [] + for i, tag in enumerate(query.tags): + rule_template = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {{}} "{TAG_END}"' + if tag.regex: + rule = rule_template.format(f"/{tag.regex}/") + elif tag.cfg: + sub_rule_0 = rule_template.format(f"{CFG_TAG_RULE_NAME_PREFIX}{i}_start") + sub_rule_rest = f"{CFG_TAG_RULE_NAME_PREFIX}{i}_{tag.cfg}" # may be multiple lines + rule = sub_rule_0 + "\n" + sub_rule_rest + else: + # `/(?s:.)*?/` is a non-greedy match for any character including newlines + rule = rule_template.format("/(?s:.)*?/") + cfg_rest_lines.append(rule) + + cfg = cfg_first_line + "\n" + "\n".join(cfg_rest_lines) + + is_error, msgs = validate_cfg(cfg) if is_error: raise ValueError( - "Invalid CFG grammar constructed from the query object:\n" + "Invalid CFG constructed from the query object:\n" + "\n".join(msgs) - + "\nWe recommend checking the syntax documentation at https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md" + + "\nWe recommend checking the syntax documentation at " + + LLGUIDANCE_CFG_DOCS_URL ) - return grammar + return cfg def build_json_schema(query: Query) -> dict: diff --git a/src/gimkit/guides.py b/src/gimkit/guides.py index 62371aa..91ee9b4 100644 --- a/src/gimkit/guides.py +++ b/src/gimkit/guides.py @@ -9,9 +9,10 @@ def __call__( name: str | None = None, desc: str | None = None, regex: str | None = None, + cfg: str | None = None, content: str | None = None, ) -> MaskedTag: - return MaskedTag(name=name, desc=desc, regex=regex, content=content) + return MaskedTag(name=name, desc=desc, regex=regex, cfg=cfg, content=content) class FormMixin: diff --git a/src/gimkit/schemas.py b/src/gimkit/schemas.py index 3ee4a4b..35ba707 100644 --- a/src/gimkit/schemas.py +++ b/src/gimkit/schemas.py @@ -35,11 +35,11 @@ # ─── Tag Fields Definitions ─────────────────────────────────────────────────── -COMMON_ATTRS = ("name", "desc", "regex") +COMMON_ATTRS = ("name", "desc", "regex", "cfg") ALL_ATTRS = ("id", *COMMON_ATTRS) ALL_FIELDS = ("id", *COMMON_ATTRS, "content") -TagField: TypeAlias = Literal["id", "name", "desc", "regex", "content"] +TagField: TypeAlias = Literal["id", "name", "desc", "regex", "cfg", "content"] # ─── Regex Patterns For Tag Parsing ─────────────────────────────────────────── @@ -90,10 +90,11 @@ class MaskedTag: name: str | None = None desc: str | None = None regex: str | None = None + cfg: str | None = None content: str | None = None # Read-only class variable for additional attribute escapes. These - # characters may appear in tag attributes such as `desc` or `grammar`. + # characters may appear in tag attributes such as `desc` or `cfg`. # Hexadecimal numeric character references are used for consistency and # compatibility with Python's built-in `html.escape` conventions. # Ref: https://www.w3.org/MarkUp/html-spec/html-spec_13.html @@ -117,7 +118,16 @@ def attr_unescape(cls, text: str) -> str: return html.unescape(text) def __post_init__(self): - # 1. Validate id + # Avoid circular imports + from gimkit.dsls import CFG_TAG_RULE_NAME_PREFIX, LLGUIDANCE_CFG_DOCS_URL, validate_cfg + + # ─── Ensure Only One Decoding Constraint Is Specified ───────── + + if sum([self.regex is not None, self.cfg is not None]) > 1: + raise ValueError("Only one of regex or cfg can be specified.") + + # ─── Validate Id ────────────────────────────────────────────── + if not ( self.id is None or isinstance(self.id, int) @@ -127,7 +137,8 @@ def __post_init__(self): if isinstance(self.id, str): self.id = int(self.id) - # 2. Validate common attributes + # ─── Validate Common Attributes ─────────────────────────────── + for attr in COMMON_ATTRS: attr_val = getattr(self, attr) if isinstance(attr_val, str): @@ -135,7 +146,8 @@ def __post_init__(self): elif attr_val is not None: raise ValueError(f"{type(attr_val)=}, {attr_val=}, should be str or None") - # 3. Validate content + # ─── Validate Content ───────────────────────────────────────── + if isinstance(self.content, str): # TAG_OPEN_RIGHT is common in text, so we allow it in content. # But other magic strings are not allowed. @@ -148,7 +160,8 @@ def __post_init__(self): elif self.content is not None: raise ValueError(f"{type(self.content)=}, {self.content=}, should be str or None") - # 4. Validate regex if provided + # ─── Validate Regex ─────────────────────────────────────────── + if isinstance(self.regex, str): if self.regex.startswith("^") or self.regex.endswith("$"): raise ValueError( @@ -157,8 +170,7 @@ def __post_init__(self): ) if self.regex.startswith("/") or self.regex.endswith("/"): raise ValueError( - "regex should not start or end with /, " - "as it will be wrapped with /.../ in CFG grammar." + "regex should not start or end with /, as it will be wrapped with /.../ in CFG." ) if self.regex == "": raise ValueError("regex should not be an empty string.") @@ -167,6 +179,31 @@ def __post_init__(self): except re.error as e: raise ValueError(f"Invalid regex pattern: {self.regex}") from e + # ─── Validate CFG ───────────────────────────────────────────── + + if isinstance(self.cfg, str): + self.cfg = self.cfg.strip() + if self.cfg == "": + raise ValueError("CFG should not be an empty string.") + if matches := re.findall(CFG_TAG_RULE_NAME_PREFIX + r"\d+", self.cfg): + raise ValueError( + "CFG should not contain reserved rule names like " + + " or ".join(f"`{x}`" for x in set(matches)) + ) + if not self.cfg.startswith("start:"): + raise ValueError( + "CFG should begin with a `start:` rule." + "\nWe recommend checking the syntax documentation at " + LLGUIDANCE_CFG_DOCS_URL + ) + is_error, msgs = validate_cfg(self.cfg) + if is_error: + raise ValueError( + "Invalid CFG constructed from the query object:\n" + + "\n".join(msgs) + + "\nWe recommend checking the syntax documentation at " + + LLGUIDANCE_CFG_DOCS_URL + ) + def to_string( self, fields: list[TagField] | Literal["all"] = "all", diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index 5f46e05..9929be0 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -30,7 +30,7 @@ def test_transform_to_outlines(): model_input, output_type = transform_to_outlines(query, output_type="cfg", use_gim_prompt=False) assert isinstance(model_input, str) assert isinstance(output_type, CFG) - assert 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"' in output_type.definition + assert 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"' in output_type.definition # Test JSON output type model_input, output_type = transform_to_outlines( diff --git a/tests/test_dsls.py b/tests/test_dsls.py index fe3ffd1..ee40024 100644 --- a/tests/test_dsls.py +++ b/tests/test_dsls.py @@ -10,27 +10,36 @@ def test_build_cfg(): query = Query('Hello, <|MASKED id="m_0"|>world<|/MASKED|>!') - grm = ( - 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"\n' - 'tag0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' + assert build_cfg(query) == ( + 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' + 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' ) - assert build_cfg(query) == grm # Test with regex query_with_regex = Query("Hello, ", MaskedTag(id=0, regex="[A-Za-z]{5}"), "!") - whole_grammar_regex = ( - 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"\n' - 'tag0: "<|MASKED id=\\"m_0\\"|>" /[A-Za-z]{5}/ "<|/MASKED|>"' + assert build_cfg(query_with_regex) == ( + 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' + 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" /[A-Za-z]{5}/ "<|/MASKED|>"' ) - assert build_cfg(query_with_regex) == whole_grammar_regex # Test with invalid regex with ( pytest.warns(FutureWarning, match="Possible nested set at position 1"), - pytest.raises(ValueError, match="Invalid CFG grammar constructed from the query object"), + pytest.raises(ValueError, match="Invalid CFG constructed from the query object"), ): build_cfg(Query(MaskedTag(regex="[[]]"))) + # Test with cfg + cfg = 'start: obj1 ", " obj2\nobj1: "Hello" | "Hi"\nobj2: "World" | "Everyone"\n' + query_with_cfg = Query(MaskedTag(id=0, cfg=cfg), "!") + assert build_cfg(query_with_cfg) == ( + 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' + 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" masked_tag_0_start "<|/MASKED|>"\n' + 'masked_tag_0_start: obj1 ", " obj2\n' + 'obj1: "Hello" | "Hi"\n' + 'obj2: "World" | "Everyone"' + ) + def test_build_json_schema(): query = Query( diff --git a/tests/test_guides.py b/tests/test_guides.py index 77b7741..15bf372 100644 --- a/tests/test_guides.py +++ b/tests/test_guides.py @@ -1,8 +1,17 @@ +import inspect import re import pytest from gimkit.guides import guide as g +from gimkit.schemas import ALL_FIELDS + + +class TestBaseMixin: + def test_call_params(self): + sig = inspect.signature(g.__call__) + params = sig.parameters + assert list(params.keys()) == list(ALL_FIELDS[1:]) class TestFormMixin: diff --git a/tests/test_schemas.py b/tests/test_schemas.py index ebb4d10..68a3100 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -24,12 +24,12 @@ def test_global_variables(): - assert COMMON_ATTRS == ("name", "desc", "regex") - assert ALL_ATTRS == ("id", "name", "desc", "regex") - assert ALL_FIELDS == ("id", "name", "desc", "regex", "content") + assert COMMON_ATTRS == ("name", "desc", "regex", "cfg") + assert ALL_ATTRS == ("id", "name", "desc", "regex", "cfg") + assert ALL_FIELDS == ("id", "name", "desc", "regex", "cfg", "content") assert tuple(f.name for f in fields(MaskedTag)) == ALL_FIELDS assert len(set(ALL_FIELDS)) == len(ALL_FIELDS) - assert TagField.__args__ == ("id", "name", "desc", "regex", "content") + assert TagField.__args__ == ("id", "name", "desc", "regex", "cfg", "content") def test_regex_patterns(): @@ -72,6 +72,8 @@ def test_masked_tag_init_invalid(): ), ): MaskedTag(content="<|MASKED|>") + with pytest.raises(ValueError, match="Only one of regex or cfg can be specified"): + MaskedTag(regex="[a-z]+", cfg="start: 'test'") def test_masked_tag_init_with_regex(): @@ -85,6 +87,20 @@ def test_masked_tag_init_with_regex(): MaskedTag(regex="[") +def test_masked_tag_init_with_cfg(): + with pytest.raises(ValueError, match="CFG should not be an empty string"): + MaskedTag(cfg="") + with pytest.raises( + ValueError, match="CFG should not contain reserved rule names like `masked_tag_0`" + ): + MaskedTag(cfg="start: masked_tag_0") + with pytest.raises(ValueError, match="CFG should begin with a `start:` rule"): + MaskedTag(cfg='rule: "test"') + MaskedTag(cfg=' \nstart: "test"') + with pytest.raises(ValueError, match="Invalid CFG constructed from the query object"): + MaskedTag(cfg="start: invalid_syntax") + + def test_masked_tag_attr_escape(): original = "& < > \" ' \t \n \r" escaped = MaskedTag.attr_escape(original)