From 4e7dcfbfd272b0cdb4b66c9e937c0377de1bc688 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 10 Nov 2025 19:16:44 +0800 Subject: [PATCH 1/7] feat: add grammar support to MaskedTag and update related functionality --- src/gimkit/dsls.py | 53 +++++++++++++++++++++----------------- src/gimkit/schemas.py | 49 ++++++++++++++++++++++++++++++----- tests/models/test_utils.py | 2 +- tests/test_dsls.py | 4 +-- tests/test_schemas.py | 8 +++--- 5 files changed, 80 insertions(+), 36 deletions(-) diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index 250ac79..d93b77a 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -4,31 +4,24 @@ - `build_json_schema` constructs a JSON schema representing the response structure.""" from gimkit.contexts import Query -from gimkit.schemas import ( - RESPONSE_PREFIX, - RESPONSE_SUFFIX, - TAG_END, - TAG_OPEN_LEFT, - TAG_OPEN_RIGHT, -) -def get_grammar_spec(grammar: str) -> str: - from llguidance import grammar_from +CFG_TAG_RULE_NAME_PREFIX = "MASKED_TAG_" +LLGUIDANCE_CFG_DOCS_URL = "https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md" + + +def validate_grammar(grammar: str) -> tuple[bool, list[str]]: + from llguidance import LLMatcher, grammar_from # Borrowed from outlines source code at https://github.com/dottxt-ai/outlines/blob/87234d202924acce84ead694f8d06748608fd5f9/outlines/backends/llguidance.py#L296-L299 + # This turns the original LLGuidance grammar to a normal grammar spec # We try both lark and ebnf try: grammar_spec = grammar_from("grammar", grammar) except ValueError: # pragma: no cover grammar_spec = grammar_from("lark", grammar) - return grammar_spec - - -def validate_grammar_spec(grammar_spec: str) -> tuple[bool, list[str]]: - from llguidance import LLMatcher - + # Validate the grammar spec is_error, msgs = LLMatcher.validate_grammar_with_warnings(grammar_spec) return is_error, msgs @@ -38,25 +31,39 @@ def build_cfg(query: Query) -> str: LLGuidance syntax reference: https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md """ + + # Avoid circular import + from gimkit.schemas import ( + RESPONSE_PREFIX, + RESPONSE_SUFFIX, + TAG_END, + TAG_OPEN_LEFT, + TAG_OPEN_RIGHT, + ) + num_tags = len(query.tags) - grammar_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"tag{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"''' + grammar_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"{CFG_TAG_RULE_NAME_PREFIX}{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"''' grammar_rest_lines = [] for i, tag in enumerate(query.tags): - # `/(?s:.)*?/` is a non-greedy match for any character including newlines - content_pattern = f"/{tag.regex}/" if tag.regex else "/(?s:.)*?/" - grammar_rest_lines.append( - f'tag{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {content_pattern} "{TAG_END}"' - ) + if tag.regex: + rule = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" /{tag.regex}/ "{TAG_END}"' + elif tag.grammar: + rule = f"{CFG_TAG_RULE_NAME_PREFIX}{i}: {tag.grammar.strip().removeprefix('start: ')}" # may be multiple lines + else: + # `/(?s:.)*?/` is a non-greedy match for any character including newlines + rule = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" /(?s:.)*?/ "{TAG_END}"' + grammar_rest_lines.append(rule) grammar = grammar_first_line + "\n" + "\n".join(grammar_rest_lines) - is_error, msgs = validate_grammar_spec(get_grammar_spec(grammar)) + is_error, msgs = validate_grammar(grammar) if is_error: raise ValueError( "Invalid CFG grammar constructed from the query object:\n" + "\n".join(msgs) - + "\nWe recommend checking the syntax documentation at https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md" + + "\nWe recommend checking the syntax documentation at " + + LLGUIDANCE_CFG_DOCS_URL ) return grammar diff --git a/src/gimkit/schemas.py b/src/gimkit/schemas.py index 3ee4a4b..f4d1218 100644 --- a/src/gimkit/schemas.py +++ b/src/gimkit/schemas.py @@ -35,11 +35,11 @@ # ─── Tag Fields Definitions ─────────────────────────────────────────────────── -COMMON_ATTRS = ("name", "desc", "regex") +COMMON_ATTRS = ("name", "desc", "regex", "grammar") ALL_ATTRS = ("id", *COMMON_ATTRS) ALL_FIELDS = ("id", *COMMON_ATTRS, "content") -TagField: TypeAlias = Literal["id", "name", "desc", "regex", "content"] +TagField: TypeAlias = Literal["id", "name", "desc", "regex", "grammar", "content"] # ─── Regex Patterns For Tag Parsing ─────────────────────────────────────────── @@ -90,6 +90,7 @@ class MaskedTag: name: str | None = None desc: str | None = None regex: str | None = None + grammar: str | None = None content: str | None = None # Read-only class variable for additional attribute escapes. These @@ -117,7 +118,16 @@ def attr_unescape(cls, text: str) -> str: return html.unescape(text) def __post_init__(self): - # 1. Validate id + # Avoid circular imports + from gimkit.dsls import CFG_TAG_RULE_NAME_PREFIX, LLGUIDANCE_CFG_DOCS_URL, validate_grammar + + # ─── Ensure Only One Decoding Constraint Is Specified ───────── + + if sum([self.regex is not None, self.grammar is not None]) > 1: + raise ValueError("Only one of regex or grammar can be specified.") + + # ─── Validate Id ────────────────────────────────────────────── + if not ( self.id is None or isinstance(self.id, int) @@ -127,7 +137,8 @@ def __post_init__(self): if isinstance(self.id, str): self.id = int(self.id) - # 2. Validate common attributes + # ─── Validate Common Attributes ─────────────────────────────── + for attr in COMMON_ATTRS: attr_val = getattr(self, attr) if isinstance(attr_val, str): @@ -135,7 +146,8 @@ def __post_init__(self): elif attr_val is not None: raise ValueError(f"{type(attr_val)=}, {attr_val=}, should be str or None") - # 3. Validate content + # ─── Validate Content ───────────────────────────────────────── + if isinstance(self.content, str): # TAG_OPEN_RIGHT is common in text, so we allow it in content. # But other magic strings are not allowed. @@ -148,7 +160,8 @@ def __post_init__(self): elif self.content is not None: raise ValueError(f"{type(self.content)=}, {self.content=}, should be str or None") - # 4. Validate regex if provided + # ─── Validate Regex ─────────────────────────────────────────── + if isinstance(self.regex, str): if self.regex.startswith("^") or self.regex.endswith("$"): raise ValueError( @@ -167,6 +180,30 @@ def __post_init__(self): except re.error as e: raise ValueError(f"Invalid regex pattern: {self.regex}") from e + # ─── Validate Grammar ───────────────────────────────────────── + + if isinstance(self.grammar, str): + if self.grammar == "": + raise ValueError("grammar should not be an empty string.") + if matches := re.findall(CFG_TAG_RULE_NAME_PREFIX + r"\d+", self.grammar): + raise ValueError( + "grammar should not contain reserved rule names like " + + " or ".join(f"`{x}`" for x in set(matches)) + ) + if not self.grammar.startswith("start:"): + raise ValueError( + "Grammar should start with a `start:` rule." + "\nWe recommend checking the syntax documentation at " + LLGUIDANCE_CFG_DOCS_URL + ) + is_error, msgs = validate_grammar(self.grammar) + if is_error: + raise ValueError( + "Invalid CFG grammar constructed from the query object:\n" + + "\n".join(msgs) + + "\nWe recommend checking the syntax documentation at " + + LLGUIDANCE_CFG_DOCS_URL + ) + def to_string( self, fields: list[TagField] | Literal["all"] = "all", diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index 5f46e05..f4ceb5a 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -30,7 +30,7 @@ def test_transform_to_outlines(): model_input, output_type = transform_to_outlines(query, output_type="cfg", use_gim_prompt=False) assert isinstance(model_input, str) assert isinstance(output_type, CFG) - assert 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"' in output_type.definition + assert 'start: "<|GIM_RESPONSE|>" MASKED_TAG_0 "<|/GIM_RESPONSE|>"' in output_type.definition # Test JSON output type model_input, output_type = transform_to_outlines( diff --git a/tests/test_dsls.py b/tests/test_dsls.py index 13bb958..1a5a29f 100644 --- a/tests/test_dsls.py +++ b/tests/test_dsls.py @@ -11,8 +11,8 @@ def test_build_cfg(): query = Query('Hello, <|MASKED id="m_0"|>world<|/MASKED|>!') grm = ( - 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"\n' - 'tag0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' + 'start: "<|GIM_RESPONSE|>" MASKED_TAG_0 "<|/GIM_RESPONSE|>"\n' + 'MASKED_TAG_0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' ) assert build_cfg(query) == grm diff --git a/tests/test_schemas.py b/tests/test_schemas.py index ebb4d10..2513e42 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -24,12 +24,12 @@ def test_global_variables(): - assert COMMON_ATTRS == ("name", "desc", "regex") - assert ALL_ATTRS == ("id", "name", "desc", "regex") - assert ALL_FIELDS == ("id", "name", "desc", "regex", "content") + assert COMMON_ATTRS == ("name", "desc", "regex", "grammar") + assert ALL_ATTRS == ("id", "name", "desc", "regex", "grammar") + assert ALL_FIELDS == ("id", "name", "desc", "regex", "grammar", "content") assert tuple(f.name for f in fields(MaskedTag)) == ALL_FIELDS assert len(set(ALL_FIELDS)) == len(ALL_FIELDS) - assert TagField.__args__ == ("id", "name", "desc", "regex", "content") + assert TagField.__args__ == ("id", "name", "desc", "regex", "grammar", "content") def test_regex_patterns(): From 0da2d42d66e9d68bc85610ee8abc51edfdff5822 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:23:00 +0800 Subject: [PATCH 2/7] test: add tests --- src/gimkit/dsls.py | 2 +- src/gimkit/schemas.py | 6 +++--- tests/models/test_utils.py | 2 +- tests/test_dsls.py | 20 ++++++++++++++++---- tests/test_schemas.py | 15 +++++++++++++++ 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index d93b77a..f24a83e 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -6,7 +6,7 @@ from gimkit.contexts import Query -CFG_TAG_RULE_NAME_PREFIX = "MASKED_TAG_" +CFG_TAG_RULE_NAME_PREFIX = "masked_tag_" LLGUIDANCE_CFG_DOCS_URL = "https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md" diff --git a/src/gimkit/schemas.py b/src/gimkit/schemas.py index f4d1218..caee10e 100644 --- a/src/gimkit/schemas.py +++ b/src/gimkit/schemas.py @@ -184,15 +184,15 @@ def __post_init__(self): if isinstance(self.grammar, str): if self.grammar == "": - raise ValueError("grammar should not be an empty string.") + raise ValueError("Grammar should not be an empty string.") if matches := re.findall(CFG_TAG_RULE_NAME_PREFIX + r"\d+", self.grammar): raise ValueError( - "grammar should not contain reserved rule names like " + "Grammar should not contain reserved rule names like " + " or ".join(f"`{x}`" for x in set(matches)) ) if not self.grammar.startswith("start:"): raise ValueError( - "Grammar should start with a `start:` rule." + "Grammar should begin with a `start:` rule." "\nWe recommend checking the syntax documentation at " + LLGUIDANCE_CFG_DOCS_URL ) is_error, msgs = validate_grammar(self.grammar) diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index f4ceb5a..9929be0 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -30,7 +30,7 @@ def test_transform_to_outlines(): model_input, output_type = transform_to_outlines(query, output_type="cfg", use_gim_prompt=False) assert isinstance(model_input, str) assert isinstance(output_type, CFG) - assert 'start: "<|GIM_RESPONSE|>" MASKED_TAG_0 "<|/GIM_RESPONSE|>"' in output_type.definition + assert 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"' in output_type.definition # Test JSON output type model_input, output_type = transform_to_outlines( diff --git a/tests/test_dsls.py b/tests/test_dsls.py index 1a5a29f..758a238 100644 --- a/tests/test_dsls.py +++ b/tests/test_dsls.py @@ -10,18 +10,30 @@ def test_build_cfg(): query = Query('Hello, <|MASKED id="m_0"|>world<|/MASKED|>!') - grm = ( - 'start: "<|GIM_RESPONSE|>" MASKED_TAG_0 "<|/GIM_RESPONSE|>"\n' - 'MASKED_TAG_0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' + whole_grammar = ( + 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' + 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' ) - assert build_cfg(query) == grm + assert build_cfg(query) == whole_grammar + # Test with invalid regex with ( pytest.warns(FutureWarning, match="Possible nested set at position 1"), pytest.raises(ValueError, match="Invalid CFG grammar constructed from the query object"), ): build_cfg(Query(MaskedTag(regex="[[]]"))) + # Test with cfg + cfg = 'start: obj1 ", " obj2\nobj1: "Hello" | "Hi"\nobj2: "World" | "Everyone"\n' + query_with_grammar = Query(MaskedTag(id=0, grammar=cfg), "!") + whole_grammar = ( + 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' + 'masked_tag_0: obj1 ", " obj2\n' + 'obj1: "Hello" | "Hi"\n' + 'obj2: "World" | "Everyone"' + ) + assert build_cfg(query_with_grammar) == whole_grammar + def test_build_json_schema(): query = Query( diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 2513e42..9104634 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -72,6 +72,8 @@ def test_masked_tag_init_invalid(): ), ): MaskedTag(content="<|MASKED|>") + with pytest.raises(ValueError, match="Only one of regex or grammar can be specified"): + MaskedTag(regex="[a-z]+", grammar="start: 'test'") def test_masked_tag_init_with_regex(): @@ -85,6 +87,19 @@ def test_masked_tag_init_with_regex(): MaskedTag(regex="[") +def test_masked_tag_init_with_grammar(): + with pytest.raises(ValueError, match="Grammar should not be an empty string"): + MaskedTag(grammar="") + with pytest.raises( + ValueError, match="Grammar should not contain reserved rule names like `masked_tag_0`" + ): + MaskedTag(grammar="start: masked_tag_0") + with pytest.raises(ValueError, match="Grammar should begin with a `start:` rule"): + MaskedTag(grammar="rule: 'test'") + with pytest.raises(ValueError, match="Invalid CFG grammar constructed from the query object"): + MaskedTag(grammar="start: invalid_syntax") + + def test_masked_tag_attr_escape(): original = "& < > \" ' \t \n \r" escaped = MaskedTag.attr_escape(original) From 2df261d219d63cc1f3aaf121c8f82da9d0ff71b3 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:35:56 +0800 Subject: [PATCH 3/7] test: add test for `g.__call__` --- src/gimkit/guides.py | 3 ++- tests/test_guides.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/gimkit/guides.py b/src/gimkit/guides.py index 62371aa..af2a800 100644 --- a/src/gimkit/guides.py +++ b/src/gimkit/guides.py @@ -9,9 +9,10 @@ def __call__( name: str | None = None, desc: str | None = None, regex: str | None = None, + grammar: str | None = None, content: str | None = None, ) -> MaskedTag: - return MaskedTag(name=name, desc=desc, regex=regex, content=content) + return MaskedTag(name=name, desc=desc, regex=regex, grammar=grammar, content=content) class FormMixin: diff --git a/tests/test_guides.py b/tests/test_guides.py index 77b7741..15bf372 100644 --- a/tests/test_guides.py +++ b/tests/test_guides.py @@ -1,8 +1,17 @@ +import inspect import re import pytest from gimkit.guides import guide as g +from gimkit.schemas import ALL_FIELDS + + +class TestBaseMixin: + def test_call_params(self): + sig = inspect.signature(g.__call__) + params = sig.parameters + assert list(params.keys()) == list(ALL_FIELDS[1:]) class TestFormMixin: From 6a9f0b97d992db52da3531c04631f65680c56e62 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:54:58 +0800 Subject: [PATCH 4/7] fix: update CFG rule generation in build_cfg to include tags --- src/gimkit/dsls.py | 9 ++++++--- tests/test_dsls.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index f24a83e..4a33160 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -46,13 +46,16 @@ def build_cfg(query: Query) -> str: grammar_rest_lines = [] for i, tag in enumerate(query.tags): + rule_template = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {{}} "{TAG_END}"' if tag.regex: - rule = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" /{tag.regex}/ "{TAG_END}"' + rule = rule_template.format(f'/{tag.regex}/') elif tag.grammar: - rule = f"{CFG_TAG_RULE_NAME_PREFIX}{i}: {tag.grammar.strip().removeprefix('start: ')}" # may be multiple lines + sub_rule_0 = rule_template.format(f'{CFG_TAG_RULE_NAME_PREFIX}{i}_start') + sub_rule_rest = f'{CFG_TAG_RULE_NAME_PREFIX}{i}_{tag.grammar}' # may be multiple lines + rule = sub_rule_0 + "\n" + sub_rule_rest else: # `/(?s:.)*?/` is a non-greedy match for any character including newlines - rule = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" /(?s:.)*?/ "{TAG_END}"' + rule = rule_template.format('/(?s:.)*?/') grammar_rest_lines.append(rule) grammar = grammar_first_line + "\n" + "\n".join(grammar_rest_lines) diff --git a/tests/test_dsls.py b/tests/test_dsls.py index 758a238..c7f52a0 100644 --- a/tests/test_dsls.py +++ b/tests/test_dsls.py @@ -28,7 +28,8 @@ def test_build_cfg(): query_with_grammar = Query(MaskedTag(id=0, grammar=cfg), "!") whole_grammar = ( 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' - 'masked_tag_0: obj1 ", " obj2\n' + 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" masked_tag_0_start "<|/MASKED|>"\n' + 'masked_tag_0_start: obj1 ", " obj2\n' 'obj1: "Hello" | "Hi"\n' 'obj2: "World" | "Everyone"' ) From 542fe7d78303ce873efbc86e7a2ca9b038a99ba9 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:55:08 +0800 Subject: [PATCH 5/7] fix: trim whitespace from grammar in MaskedTag initialization --- src/gimkit/schemas.py | 1 + tests/test_schemas.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gimkit/schemas.py b/src/gimkit/schemas.py index caee10e..14e783e 100644 --- a/src/gimkit/schemas.py +++ b/src/gimkit/schemas.py @@ -183,6 +183,7 @@ def __post_init__(self): # ─── Validate Grammar ───────────────────────────────────────── if isinstance(self.grammar, str): + self.grammar = self.grammar.strip() if self.grammar == "": raise ValueError("Grammar should not be an empty string.") if matches := re.findall(CFG_TAG_RULE_NAME_PREFIX + r"\d+", self.grammar): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 9104634..5621746 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -95,7 +95,8 @@ def test_masked_tag_init_with_grammar(): ): MaskedTag(grammar="start: masked_tag_0") with pytest.raises(ValueError, match="Grammar should begin with a `start:` rule"): - MaskedTag(grammar="rule: 'test'") + MaskedTag(grammar='rule: "test"') + MaskedTag(grammar=' \nstart: "test"') with pytest.raises(ValueError, match="Invalid CFG grammar constructed from the query object"): MaskedTag(grammar="start: invalid_syntax") From 914e8852359ef945fa7989bca294e6bd88969587 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:56:24 +0800 Subject: [PATCH 6/7] style: standardize string formatting in build_cfg function --- src/gimkit/dsls.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index 4a33160..ddc5a83 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -48,14 +48,14 @@ def build_cfg(query: Query) -> str: for i, tag in enumerate(query.tags): rule_template = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {{}} "{TAG_END}"' if tag.regex: - rule = rule_template.format(f'/{tag.regex}/') + rule = rule_template.format(f"/{tag.regex}/") elif tag.grammar: - sub_rule_0 = rule_template.format(f'{CFG_TAG_RULE_NAME_PREFIX}{i}_start') - sub_rule_rest = f'{CFG_TAG_RULE_NAME_PREFIX}{i}_{tag.grammar}' # may be multiple lines + sub_rule_0 = rule_template.format(f"{CFG_TAG_RULE_NAME_PREFIX}{i}_start") + sub_rule_rest = f"{CFG_TAG_RULE_NAME_PREFIX}{i}_{tag.grammar}" # may be multiple lines rule = sub_rule_0 + "\n" + sub_rule_rest else: # `/(?s:.)*?/` is a non-greedy match for any character including newlines - rule = rule_template.format('/(?s:.)*?/') + rule = rule_template.format("/(?s:.)*?/") grammar_rest_lines.append(rule) grammar = grammar_first_line + "\n" + "\n".join(grammar_rest_lines) From 9dde37398e02d65585a4ae03e36b96dbc65b1b12 Mon Sep 17 00:00:00 2001 From: Shichao Song <60967965+Ki-Seki@users.noreply.github.com> Date: Sun, 16 Nov 2025 14:45:36 +0800 Subject: [PATCH 7/7] refactor: rename grammar to cfg for consistency across DSL and schema components --- src/gimkit/dsls.py | 26 +++++++++++++------------- src/gimkit/guides.py | 4 ++-- src/gimkit/schemas.py | 39 +++++++++++++++++++-------------------- tests/test_dsls.py | 13 +++++-------- tests/test_schemas.py | 32 ++++++++++++++++---------------- 5 files changed, 55 insertions(+), 59 deletions(-) diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index 42539a4..401ddd0 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -10,16 +10,16 @@ LLGUIDANCE_CFG_DOCS_URL = "https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md" -def validate_grammar(grammar: str) -> tuple[bool, list[str]]: +def validate_cfg(cfg: str) -> tuple[bool, list[str]]: from llguidance import LLMatcher, grammar_from # Borrowed from outlines source code at https://github.com/dottxt-ai/outlines/blob/87234d202924acce84ead694f8d06748608fd5f9/outlines/backends/llguidance.py#L296-L299 - # This turns the original LLGuidance grammar to a normal grammar spec + # This turns the original LLGuidance CFG to a normal CFG # We try both lark and ebnf try: - grammar_spec = grammar_from("grammar", grammar) + grammar_spec = grammar_from("grammar", cfg) except ValueError: # pragma: no cover - grammar_spec = grammar_from("lark", grammar) + grammar_spec = grammar_from("lark", cfg) # Validate the grammar spec is_error, msgs = LLMatcher.validate_grammar_with_warnings(grammar_spec) @@ -42,33 +42,33 @@ def build_cfg(query: Query) -> str: ) num_tags = len(query.tags) - grammar_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"{CFG_TAG_RULE_NAME_PREFIX}{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"''' + cfg_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"{CFG_TAG_RULE_NAME_PREFIX}{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"''' - grammar_rest_lines = [] + cfg_rest_lines = [] for i, tag in enumerate(query.tags): rule_template = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {{}} "{TAG_END}"' if tag.regex: rule = rule_template.format(f"/{tag.regex}/") - elif tag.grammar: + elif tag.cfg: sub_rule_0 = rule_template.format(f"{CFG_TAG_RULE_NAME_PREFIX}{i}_start") - sub_rule_rest = f"{CFG_TAG_RULE_NAME_PREFIX}{i}_{tag.grammar}" # may be multiple lines + sub_rule_rest = f"{CFG_TAG_RULE_NAME_PREFIX}{i}_{tag.cfg}" # may be multiple lines rule = sub_rule_0 + "\n" + sub_rule_rest else: # `/(?s:.)*?/` is a non-greedy match for any character including newlines rule = rule_template.format("/(?s:.)*?/") - grammar_rest_lines.append(rule) + cfg_rest_lines.append(rule) - grammar = grammar_first_line + "\n" + "\n".join(grammar_rest_lines) + cfg = cfg_first_line + "\n" + "\n".join(cfg_rest_lines) - is_error, msgs = validate_grammar(grammar) + is_error, msgs = validate_cfg(cfg) if is_error: raise ValueError( - "Invalid CFG grammar constructed from the query object:\n" + "Invalid CFG constructed from the query object:\n" + "\n".join(msgs) + "\nWe recommend checking the syntax documentation at " + LLGUIDANCE_CFG_DOCS_URL ) - return grammar + return cfg def build_json_schema(query: Query) -> dict: diff --git a/src/gimkit/guides.py b/src/gimkit/guides.py index af2a800..91ee9b4 100644 --- a/src/gimkit/guides.py +++ b/src/gimkit/guides.py @@ -9,10 +9,10 @@ def __call__( name: str | None = None, desc: str | None = None, regex: str | None = None, - grammar: str | None = None, + cfg: str | None = None, content: str | None = None, ) -> MaskedTag: - return MaskedTag(name=name, desc=desc, regex=regex, grammar=grammar, content=content) + return MaskedTag(name=name, desc=desc, regex=regex, cfg=cfg, content=content) class FormMixin: diff --git a/src/gimkit/schemas.py b/src/gimkit/schemas.py index 14e783e..35ba707 100644 --- a/src/gimkit/schemas.py +++ b/src/gimkit/schemas.py @@ -35,11 +35,11 @@ # ─── Tag Fields Definitions ─────────────────────────────────────────────────── -COMMON_ATTRS = ("name", "desc", "regex", "grammar") +COMMON_ATTRS = ("name", "desc", "regex", "cfg") ALL_ATTRS = ("id", *COMMON_ATTRS) ALL_FIELDS = ("id", *COMMON_ATTRS, "content") -TagField: TypeAlias = Literal["id", "name", "desc", "regex", "grammar", "content"] +TagField: TypeAlias = Literal["id", "name", "desc", "regex", "cfg", "content"] # ─── Regex Patterns For Tag Parsing ─────────────────────────────────────────── @@ -90,11 +90,11 @@ class MaskedTag: name: str | None = None desc: str | None = None regex: str | None = None - grammar: str | None = None + cfg: str | None = None content: str | None = None # Read-only class variable for additional attribute escapes. These - # characters may appear in tag attributes such as `desc` or `grammar`. + # characters may appear in tag attributes such as `desc` or `cfg`. # Hexadecimal numeric character references are used for consistency and # compatibility with Python's built-in `html.escape` conventions. # Ref: https://www.w3.org/MarkUp/html-spec/html-spec_13.html @@ -119,12 +119,12 @@ def attr_unescape(cls, text: str) -> str: def __post_init__(self): # Avoid circular imports - from gimkit.dsls import CFG_TAG_RULE_NAME_PREFIX, LLGUIDANCE_CFG_DOCS_URL, validate_grammar + from gimkit.dsls import CFG_TAG_RULE_NAME_PREFIX, LLGUIDANCE_CFG_DOCS_URL, validate_cfg # ─── Ensure Only One Decoding Constraint Is Specified ───────── - if sum([self.regex is not None, self.grammar is not None]) > 1: - raise ValueError("Only one of regex or grammar can be specified.") + if sum([self.regex is not None, self.cfg is not None]) > 1: + raise ValueError("Only one of regex or cfg can be specified.") # ─── Validate Id ────────────────────────────────────────────── @@ -170,8 +170,7 @@ def __post_init__(self): ) if self.regex.startswith("/") or self.regex.endswith("/"): raise ValueError( - "regex should not start or end with /, " - "as it will be wrapped with /.../ in CFG grammar." + "regex should not start or end with /, as it will be wrapped with /.../ in CFG." ) if self.regex == "": raise ValueError("regex should not be an empty string.") @@ -180,26 +179,26 @@ def __post_init__(self): except re.error as e: raise ValueError(f"Invalid regex pattern: {self.regex}") from e - # ─── Validate Grammar ───────────────────────────────────────── + # ─── Validate CFG ───────────────────────────────────────────── - if isinstance(self.grammar, str): - self.grammar = self.grammar.strip() - if self.grammar == "": - raise ValueError("Grammar should not be an empty string.") - if matches := re.findall(CFG_TAG_RULE_NAME_PREFIX + r"\d+", self.grammar): + if isinstance(self.cfg, str): + self.cfg = self.cfg.strip() + if self.cfg == "": + raise ValueError("CFG should not be an empty string.") + if matches := re.findall(CFG_TAG_RULE_NAME_PREFIX + r"\d+", self.cfg): raise ValueError( - "Grammar should not contain reserved rule names like " + "CFG should not contain reserved rule names like " + " or ".join(f"`{x}`" for x in set(matches)) ) - if not self.grammar.startswith("start:"): + if not self.cfg.startswith("start:"): raise ValueError( - "Grammar should begin with a `start:` rule." + "CFG should begin with a `start:` rule." "\nWe recommend checking the syntax documentation at " + LLGUIDANCE_CFG_DOCS_URL ) - is_error, msgs = validate_grammar(self.grammar) + is_error, msgs = validate_cfg(self.cfg) if is_error: raise ValueError( - "Invalid CFG grammar constructed from the query object:\n" + "Invalid CFG constructed from the query object:\n" + "\n".join(msgs) + "\nWe recommend checking the syntax documentation at " + LLGUIDANCE_CFG_DOCS_URL diff --git a/tests/test_dsls.py b/tests/test_dsls.py index abdb9b5..ee40024 100644 --- a/tests/test_dsls.py +++ b/tests/test_dsls.py @@ -10,38 +10,35 @@ def test_build_cfg(): query = Query('Hello, <|MASKED id="m_0"|>world<|/MASKED|>!') - whole_grammar = ( + assert build_cfg(query) == ( 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' ) - assert build_cfg(query) == whole_grammar # Test with regex query_with_regex = Query("Hello, ", MaskedTag(id=0, regex="[A-Za-z]{5}"), "!") - whole_grammar_regex = ( + assert build_cfg(query_with_regex) == ( 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" /[A-Za-z]{5}/ "<|/MASKED|>"' ) - assert build_cfg(query_with_regex) == whole_grammar_regex # Test with invalid regex with ( pytest.warns(FutureWarning, match="Possible nested set at position 1"), - pytest.raises(ValueError, match="Invalid CFG grammar constructed from the query object"), + pytest.raises(ValueError, match="Invalid CFG constructed from the query object"), ): build_cfg(Query(MaskedTag(regex="[[]]"))) # Test with cfg cfg = 'start: obj1 ", " obj2\nobj1: "Hello" | "Hi"\nobj2: "World" | "Everyone"\n' - query_with_grammar = Query(MaskedTag(id=0, grammar=cfg), "!") - whole_grammar = ( + query_with_cfg = Query(MaskedTag(id=0, cfg=cfg), "!") + assert build_cfg(query_with_cfg) == ( 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n' 'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" masked_tag_0_start "<|/MASKED|>"\n' 'masked_tag_0_start: obj1 ", " obj2\n' 'obj1: "Hello" | "Hi"\n' 'obj2: "World" | "Everyone"' ) - assert build_cfg(query_with_grammar) == whole_grammar def test_build_json_schema(): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 5621746..68a3100 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -24,12 +24,12 @@ def test_global_variables(): - assert COMMON_ATTRS == ("name", "desc", "regex", "grammar") - assert ALL_ATTRS == ("id", "name", "desc", "regex", "grammar") - assert ALL_FIELDS == ("id", "name", "desc", "regex", "grammar", "content") + assert COMMON_ATTRS == ("name", "desc", "regex", "cfg") + assert ALL_ATTRS == ("id", "name", "desc", "regex", "cfg") + assert ALL_FIELDS == ("id", "name", "desc", "regex", "cfg", "content") assert tuple(f.name for f in fields(MaskedTag)) == ALL_FIELDS assert len(set(ALL_FIELDS)) == len(ALL_FIELDS) - assert TagField.__args__ == ("id", "name", "desc", "regex", "grammar", "content") + assert TagField.__args__ == ("id", "name", "desc", "regex", "cfg", "content") def test_regex_patterns(): @@ -72,8 +72,8 @@ def test_masked_tag_init_invalid(): ), ): MaskedTag(content="<|MASKED|>") - with pytest.raises(ValueError, match="Only one of regex or grammar can be specified"): - MaskedTag(regex="[a-z]+", grammar="start: 'test'") + with pytest.raises(ValueError, match="Only one of regex or cfg can be specified"): + MaskedTag(regex="[a-z]+", cfg="start: 'test'") def test_masked_tag_init_with_regex(): @@ -87,18 +87,18 @@ def test_masked_tag_init_with_regex(): MaskedTag(regex="[") -def test_masked_tag_init_with_grammar(): - with pytest.raises(ValueError, match="Grammar should not be an empty string"): - MaskedTag(grammar="") +def test_masked_tag_init_with_cfg(): + with pytest.raises(ValueError, match="CFG should not be an empty string"): + MaskedTag(cfg="") with pytest.raises( - ValueError, match="Grammar should not contain reserved rule names like `masked_tag_0`" + ValueError, match="CFG should not contain reserved rule names like `masked_tag_0`" ): - MaskedTag(grammar="start: masked_tag_0") - with pytest.raises(ValueError, match="Grammar should begin with a `start:` rule"): - MaskedTag(grammar='rule: "test"') - MaskedTag(grammar=' \nstart: "test"') - with pytest.raises(ValueError, match="Invalid CFG grammar constructed from the query object"): - MaskedTag(grammar="start: invalid_syntax") + MaskedTag(cfg="start: masked_tag_0") + with pytest.raises(ValueError, match="CFG should begin with a `start:` rule"): + MaskedTag(cfg='rule: "test"') + MaskedTag(cfg=' \nstart: "test"') + with pytest.raises(ValueError, match="Invalid CFG constructed from the query object"): + MaskedTag(cfg="start: invalid_syntax") def test_masked_tag_attr_escape():