Skip to content
72 changes: 41 additions & 31 deletions src/gimkit/dsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,24 @@
- `build_json_schema` constructs a JSON schema representing the response structure."""

from gimkit.contexts import Query
from gimkit.schemas import (
RESPONSE_PREFIX,
RESPONSE_SUFFIX,
TAG_END,
TAG_OPEN_LEFT,
TAG_OPEN_RIGHT,
)


def get_grammar_spec(grammar: str) -> str:
from llguidance import grammar_from
CFG_TAG_RULE_NAME_PREFIX = "masked_tag_"
LLGUIDANCE_CFG_DOCS_URL = "https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md"


def validate_cfg(cfg: str) -> tuple[bool, list[str]]:
from llguidance import LLMatcher, grammar_from

# Borrowed from outlines source code at https://github.com/dottxt-ai/outlines/blob/87234d202924acce84ead694f8d06748608fd5f9/outlines/backends/llguidance.py#L296-L299
# This turns the original LLGuidance CFG to a normal CFG
# We try both lark and ebnf
try:
grammar_spec = grammar_from("grammar", grammar)
grammar_spec = grammar_from("grammar", cfg)
except ValueError: # pragma: no cover
grammar_spec = grammar_from("lark", grammar)

return grammar_spec


def validate_grammar_spec(grammar_spec: str) -> tuple[bool, list[str]]:
from llguidance import LLMatcher
grammar_spec = grammar_from("lark", cfg)

# Validate the grammar spec
is_error, msgs = LLMatcher.validate_grammar_with_warnings(grammar_spec)
return is_error, msgs

Expand All @@ -38,27 +31,44 @@ def build_cfg(query: Query) -> str:

LLGuidance syntax reference: https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md
"""
num_tags = len(query.tags)
grammar_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"tag{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"'''

grammar_rest_lines = []
for i, tag in enumerate(query.tags):
# `/(?s:.)*?/` is a non-greedy match for any character including newlines
content_pattern = f"/{tag.regex}/" if tag.regex else "/(?s:.)*?/"
grammar_rest_lines.append(
f'tag{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {content_pattern} "{TAG_END}"'
)
# Avoid circular import
from gimkit.schemas import (
RESPONSE_PREFIX,
RESPONSE_SUFFIX,
TAG_END,
TAG_OPEN_LEFT,
TAG_OPEN_RIGHT,
)

grammar = grammar_first_line + "\n" + "\n".join(grammar_rest_lines)
num_tags = len(query.tags)
cfg_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"{CFG_TAG_RULE_NAME_PREFIX}{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"'''

is_error, msgs = validate_grammar_spec(get_grammar_spec(grammar))
cfg_rest_lines = []
for i, tag in enumerate(query.tags):
rule_template = f'{CFG_TAG_RULE_NAME_PREFIX}{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {{}} "{TAG_END}"'
if tag.regex:
rule = rule_template.format(f"/{tag.regex}/")
elif tag.cfg:
sub_rule_0 = rule_template.format(f"{CFG_TAG_RULE_NAME_PREFIX}{i}_start")
sub_rule_rest = f"{CFG_TAG_RULE_NAME_PREFIX}{i}_{tag.cfg}" # may be multiple lines
rule = sub_rule_0 + "\n" + sub_rule_rest
else:
# `/(?s:.)*?/` is a non-greedy match for any character including newlines
rule = rule_template.format("/(?s:.)*?/")
cfg_rest_lines.append(rule)

cfg = cfg_first_line + "\n" + "\n".join(cfg_rest_lines)

is_error, msgs = validate_cfg(cfg)
if is_error:
raise ValueError(
"Invalid CFG grammar constructed from the query object:\n"
"Invalid CFG constructed from the query object:\n"
+ "\n".join(msgs)
+ "\nWe recommend checking the syntax documentation at https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md"
+ "\nWe recommend checking the syntax documentation at "
+ LLGUIDANCE_CFG_DOCS_URL
)
return grammar
return cfg


def build_json_schema(query: Query) -> dict:
Expand Down
3 changes: 2 additions & 1 deletion src/gimkit/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ def __call__(
name: str | None = None,
desc: str | None = None,
regex: str | None = None,
cfg: str | None = None,
content: str | None = None,
) -> MaskedTag:
return MaskedTag(name=name, desc=desc, regex=regex, content=content)
return MaskedTag(name=name, desc=desc, regex=regex, cfg=cfg, content=content)


class FormMixin:
Expand Down
55 changes: 46 additions & 9 deletions src/gimkit/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@

# ─── Tag Fields Definitions ───────────────────────────────────────────────────

COMMON_ATTRS = ("name", "desc", "regex")
COMMON_ATTRS = ("name", "desc", "regex", "cfg")
ALL_ATTRS = ("id", *COMMON_ATTRS)
ALL_FIELDS = ("id", *COMMON_ATTRS, "content")

TagField: TypeAlias = Literal["id", "name", "desc", "regex", "content"]
TagField: TypeAlias = Literal["id", "name", "desc", "regex", "cfg", "content"]


# ─── Regex Patterns For Tag Parsing ───────────────────────────────────────────
Expand Down Expand Up @@ -90,10 +90,11 @@ class MaskedTag:
name: str | None = None
desc: str | None = None
regex: str | None = None
cfg: str | None = None
content: str | None = None

# Read-only class variable for additional attribute escapes. These
# characters may appear in tag attributes such as `desc` or `grammar`.
# characters may appear in tag attributes such as `desc` or `cfg`.
# Hexadecimal numeric character references are used for consistency and
# compatibility with Python's built-in `html.escape` conventions.
# Ref: https://www.w3.org/MarkUp/html-spec/html-spec_13.html
Expand All @@ -117,7 +118,16 @@ def attr_unescape(cls, text: str) -> str:
return html.unescape(text)

def __post_init__(self):
# 1. Validate id
# Avoid circular imports
from gimkit.dsls import CFG_TAG_RULE_NAME_PREFIX, LLGUIDANCE_CFG_DOCS_URL, validate_cfg

# ─── Ensure Only One Decoding Constraint Is Specified ─────────

if sum([self.regex is not None, self.cfg is not None]) > 1:
raise ValueError("Only one of regex or cfg can be specified.")

# ─── Validate Id ──────────────────────────────────────────────

if not (
self.id is None
or isinstance(self.id, int)
Expand All @@ -127,15 +137,17 @@ def __post_init__(self):
if isinstance(self.id, str):
self.id = int(self.id)

# 2. Validate common attributes
# ─── Validate Common Attributes ───────────────────────────────

for attr in COMMON_ATTRS:
attr_val = getattr(self, attr)
if isinstance(attr_val, str):
setattr(self, attr, MaskedTag.attr_unescape(attr_val))
elif attr_val is not None:
raise ValueError(f"{type(attr_val)=}, {attr_val=}, should be str or None")

# 3. Validate content
# ─── Validate Content ─────────────────────────────────────────

if isinstance(self.content, str):
# TAG_OPEN_RIGHT is common in text, so we allow it in content.
# But other magic strings are not allowed.
Expand All @@ -148,7 +160,8 @@ def __post_init__(self):
elif self.content is not None:
raise ValueError(f"{type(self.content)=}, {self.content=}, should be str or None")

# 4. Validate regex if provided
# ─── Validate Regex ───────────────────────────────────────────

if isinstance(self.regex, str):
if self.regex.startswith("^") or self.regex.endswith("$"):
raise ValueError(
Expand All @@ -157,8 +170,7 @@ def __post_init__(self):
)
if self.regex.startswith("/") or self.regex.endswith("/"):
raise ValueError(
"regex should not start or end with /, "
"as it will be wrapped with /.../ in CFG grammar."
"regex should not start or end with /, as it will be wrapped with /.../ in CFG."
)
if self.regex == "":
raise ValueError("regex should not be an empty string.")
Expand All @@ -167,6 +179,31 @@ def __post_init__(self):
except re.error as e:
raise ValueError(f"Invalid regex pattern: {self.regex}") from e

# ─── Validate CFG ─────────────────────────────────────────────

if isinstance(self.cfg, str):
self.cfg = self.cfg.strip()
if self.cfg == "":
raise ValueError("CFG should not be an empty string.")
if matches := re.findall(CFG_TAG_RULE_NAME_PREFIX + r"\d+", self.cfg):
raise ValueError(
"CFG should not contain reserved rule names like "
+ " or ".join(f"`{x}`" for x in set(matches))
)
if not self.cfg.startswith("start:"):
raise ValueError(
"CFG should begin with a `start:` rule."
"\nWe recommend checking the syntax documentation at " + LLGUIDANCE_CFG_DOCS_URL
)
is_error, msgs = validate_cfg(self.cfg)
if is_error:
raise ValueError(
"Invalid CFG constructed from the query object:\n"
+ "\n".join(msgs)
+ "\nWe recommend checking the syntax documentation at "
+ LLGUIDANCE_CFG_DOCS_URL
)

def to_string(
self,
fields: list[TagField] | Literal["all"] = "all",
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_transform_to_outlines():
model_input, output_type = transform_to_outlines(query, output_type="cfg", use_gim_prompt=False)
assert isinstance(model_input, str)
assert isinstance(output_type, CFG)
assert 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"' in output_type.definition
assert 'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"' in output_type.definition

# Test JSON output type
model_input, output_type = transform_to_outlines(
Expand Down
27 changes: 18 additions & 9 deletions tests/test_dsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,36 @@

def test_build_cfg():
query = Query('Hello, <|MASKED id="m_0"|>world<|/MASKED|>!')
grm = (
'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"\n'
'tag0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"'
assert build_cfg(query) == (
'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n'
'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"'
)
assert build_cfg(query) == grm

# Test with regex
query_with_regex = Query("Hello, ", MaskedTag(id=0, regex="[A-Za-z]{5}"), "!")
whole_grammar_regex = (
'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"\n'
'tag0: "<|MASKED id=\\"m_0\\"|>" /[A-Za-z]{5}/ "<|/MASKED|>"'
assert build_cfg(query_with_regex) == (
'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n'
'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" /[A-Za-z]{5}/ "<|/MASKED|>"'
)
assert build_cfg(query_with_regex) == whole_grammar_regex

# Test with invalid regex
with (
pytest.warns(FutureWarning, match="Possible nested set at position 1"),
pytest.raises(ValueError, match="Invalid CFG grammar constructed from the query object"),
pytest.raises(ValueError, match="Invalid CFG constructed from the query object"),
):
build_cfg(Query(MaskedTag(regex="[[]]")))

# Test with cfg
cfg = 'start: obj1 ", " obj2\nobj1: "Hello" | "Hi"\nobj2: "World" | "Everyone"\n'
query_with_cfg = Query(MaskedTag(id=0, cfg=cfg), "!")
assert build_cfg(query_with_cfg) == (
'start: "<|GIM_RESPONSE|>" masked_tag_0 "<|/GIM_RESPONSE|>"\n'
'masked_tag_0: "<|MASKED id=\\"m_0\\"|>" masked_tag_0_start "<|/MASKED|>"\n'
'masked_tag_0_start: obj1 ", " obj2\n'
'obj1: "Hello" | "Hi"\n'
'obj2: "World" | "Everyone"'
)


def test_build_json_schema():
query = Query(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_guides.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import inspect
import re

import pytest

from gimkit.guides import guide as g
from gimkit.schemas import ALL_FIELDS


class TestBaseMixin:
def test_call_params(self):
sig = inspect.signature(g.__call__)
params = sig.parameters
assert list(params.keys()) == list(ALL_FIELDS[1:])


class TestFormMixin:
Expand Down
24 changes: 20 additions & 4 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@


def test_global_variables():
assert COMMON_ATTRS == ("name", "desc", "regex")
assert ALL_ATTRS == ("id", "name", "desc", "regex")
assert ALL_FIELDS == ("id", "name", "desc", "regex", "content")
assert COMMON_ATTRS == ("name", "desc", "regex", "cfg")
assert ALL_ATTRS == ("id", "name", "desc", "regex", "cfg")
assert ALL_FIELDS == ("id", "name", "desc", "regex", "cfg", "content")
assert tuple(f.name for f in fields(MaskedTag)) == ALL_FIELDS
assert len(set(ALL_FIELDS)) == len(ALL_FIELDS)
assert TagField.__args__ == ("id", "name", "desc", "regex", "content")
assert TagField.__args__ == ("id", "name", "desc", "regex", "cfg", "content")


def test_regex_patterns():
Expand Down Expand Up @@ -72,6 +72,8 @@ def test_masked_tag_init_invalid():
),
):
MaskedTag(content="<|MASKED|>")
with pytest.raises(ValueError, match="Only one of regex or cfg can be specified"):
MaskedTag(regex="[a-z]+", cfg="start: 'test'")


def test_masked_tag_init_with_regex():
Expand All @@ -85,6 +87,20 @@ def test_masked_tag_init_with_regex():
MaskedTag(regex="[")


def test_masked_tag_init_with_cfg():
with pytest.raises(ValueError, match="CFG should not be an empty string"):
MaskedTag(cfg="")
with pytest.raises(
ValueError, match="CFG should not contain reserved rule names like `masked_tag_0`"
):
MaskedTag(cfg="start: masked_tag_0")
with pytest.raises(ValueError, match="CFG should begin with a `start:` rule"):
MaskedTag(cfg='rule: "test"')
MaskedTag(cfg=' \nstart: "test"')
with pytest.raises(ValueError, match="Invalid CFG constructed from the query object"):
MaskedTag(cfg="start: invalid_syntax")


def test_masked_tag_attr_escape():
original = "& < > \" ' \t \n \r"
escaped = MaskedTag.attr_escape(original)
Expand Down