Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/gimkit/dsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def build_cfg(query: Query) -> str:
```python
query = '<|GIM_QUERY|>The capital of <|MASKED desc="single word" regex="中国|法国"|><|/MASKED|> is Beijing<|MASKED desc="punctuation mark" regex="\\."|><|/MASKED|><|/GIM_QUERY|>'
print(repr(build_cfg(Query(query))))
>>> '%llguidance {}\nstart: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|MASKED id=\\"m_1\\"|>" m_1 REGEX "<|/GIM_RESPONSE|>"\nREGEX: /\\s*/\nm_0[capture, suffix="<|/MASKED|>"]: M_0\nM_0: /中国|法国/\nm_1[capture, suffix="<|/MASKED|>"]: M_1\nM_1: /\\./\n'
>>> '%llguidance {}\nstart: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|MASKED id=\\"m_1\\"|>" m_1 REGEX "<|/GIM_RESPONSE|>"\nREGEX: /\\s*/\nm_0[capture, suffix="<|/MASKED|>"]: T_0\nm_1[capture, suffix="<|/MASKED|>"]: T_1\nT_0: /中国|法国/\nT_1: /\\./\n'
```
"""
num_tags = len(query.tags)
Expand Down Expand Up @@ -80,24 +80,34 @@ def build_cfg(query: Query) -> str:
# 3. Define whitespace rule (named REGEX to match examples, usually can also be called WS)
lines.append(r"REGEX: /\s*/")

# 4. Generate specific rules for each tag
# 4. Collect unique patterns and create a mapping for terminal reuse
# This optimization avoids creating duplicate terminal rules for tags with the same regex
unique_pattern_terminals: dict[str, str] = {}
terminal_definitions: list[str] = []

for i, tag in enumerate(query.tags):
# Note: When used with suffix, using greedy match /(?s:.*)/ instead of /(?s:.)*?/ is correct and legal.
pattern = f"/{tag.regex}/" if tag.regex else "/(?s:.*)/"

# Get or create a shared terminal for this pattern
if pattern not in unique_pattern_terminals:
# Create a new terminal name for this unique pattern
terminal_name = f"T_{len(unique_pattern_terminals)}"
unique_pattern_terminals[pattern] = terminal_name
terminal_definitions.append(f"{terminal_name}: {pattern}")

terminal_name = unique_pattern_terminals[pattern]

# Rule m_i (logical layer):
# - capture: tells the engine to capture this part.
# - suffix: specifies the ending tag, the engine stops and consumes it when encountered.
# Note: Here we reference the TAG_END constant (i.e., "<|/MASKED|>")
lines.append(f'm_{i}[capture, suffix="{TAG_END}"]: M_{i}')

# Rule M_i (regex layer):
# Define the actual matching pattern for this tag.
lines.append(f"M_{i}: {pattern}")
lines.append(f'm_{i}[capture, suffix="{TAG_END}"]: {terminal_name}')

# TODO: There may be many tags with "/(?s:.*)/" pattern, which can be inefficient.
# 5. Add all unique terminal definitions
lines.extend(terminal_definitions)

# 5. Assemble final string
# 6. Assemble final string
grammar = "\n".join(lines) + "\n"

is_error, msgs = validate_grammar_spec(get_grammar_spec(grammar))
Expand Down
31 changes: 27 additions & 4 deletions tests/test_dsls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_build_cfg():
"%llguidance {}\n"
'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|/GIM_RESPONSE|>"\n'
"REGEX: /\\s*/\n"
'm_0[capture, suffix="<|/MASKED|>"]: M_0\n'
"M_0: /(?s:.*)/\n"
'm_0[capture, suffix="<|/MASKED|>"]: T_0\n'
"T_0: /(?s:.*)/\n"
)
assert build_cfg(query) == grm

Expand All @@ -25,8 +25,8 @@ def test_build_cfg():
"%llguidance {}\n"
'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|/GIM_RESPONSE|>"\n'
"REGEX: /\\s*/\n"
'm_0[capture, suffix="<|/MASKED|>"]: M_0\n'
"M_0: /\\w+\\.com/\n"
'm_0[capture, suffix="<|/MASKED|>"]: T_0\n'
"T_0: /\\w+\\.com/\n"
)
assert build_cfg(query_with_regex) == whole_grammar_regex

Expand All @@ -37,6 +37,29 @@ def test_build_cfg():
):
build_cfg(Query(MaskedTag(regex="[[]]")))

# Test with various complex patterns including repeated regexes
query = Query(
"Date: ",
MaskedTag(id=0, regex=r"\d{4}-\d{2}-\d{2}"),
", AnotherDate: ",
MaskedTag(id=1, regex=r"\d{4}-\d{2}-\d{2}"), # same as id=0
", Time: ",
MaskedTag(id=2, regex=r"\d{2}:\d{2}:\d{2}"),
", AnotherTime: ",
MaskedTag(id=3, regex=r"\d{2}:\d{2}:\d{2}"), # same as id=2
)
assert build_cfg(query) == (
"%llguidance {}\n"
'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|MASKED id=\\"m_1\\"|>" m_1 REGEX "<|MASKED id=\\"m_2\\"|>" m_2 REGEX "<|MASKED id=\\"m_3\\"|>" m_3 REGEX "<|/GIM_RESPONSE|>"\n'
"REGEX: /\\s*/\n"
'm_0[capture, suffix="<|/MASKED|>"]: T_0\n'
'm_1[capture, suffix="<|/MASKED|>"]: T_0\n'
'm_2[capture, suffix="<|/MASKED|>"]: T_1\n'
'm_3[capture, suffix="<|/MASKED|>"]: T_1\n'
"T_0: /\\d{4}-\\d{2}-\\d{2}/\n"
"T_1: /\\d{2}:\\d{2}:\\d{2}/\n"
)


def test_build_json_schema():
query = Query(
Expand Down
Loading