diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index 925ca7f..4540116 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -36,20 +36,69 @@ def validate_grammar_spec(grammar_spec: str) -> tuple[bool, list[str]]: def build_cfg(query: Query) -> str: """Build an LLGuidance context-free grammar (CFG) string based on the query object. - LLGuidance syntax reference: https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md + Constructs a flattened grammar structure compatible with LLGuidance's suffix/capture logic. + + Ref: + - https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md: Incomplete documentation of llguidance grammar syntax + - https://github.com/guidance-ai/guidance/blob/main/guidance/_ast.py: LarkSerializer implementation + - https://github.com/guidance-ai/llguidance: Source code + + Real-World Example: + ```python + query = '<|GIM_QUERY|>The capital of <|MASKED desc="single word" regex="中国|法国"|><|/MASKED|> is Beijing<|MASKED desc="punctuation mark" regex="\\."|><|/MASKED|><|/GIM_QUERY|>' + print(repr(build_cfg(Query(query)))) + >>> '%llguidance {}\nstart: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|MASKED id=\\"m_1\\"|>" m_1 REGEX "<|/GIM_RESPONSE|>"\nREGEX: /\\s*/\nm_0[capture, suffix="<|/MASKED|>"]: M_0\nM_0: /中国|法国/\nm_1[capture, suffix="<|/MASKED|>"]: M_1\nM_1: /\\./\n' + ``` """ num_tags = len(query.tags) - grammar_first_line = f'''start: "{RESPONSE_PREFIX}" {" ".join(f"tag{i}" for i in range(num_tags))} "{RESPONSE_SUFFIX}"''' - grammar_rest_lines = [] + # 1. Header declaration + lines = ["%llguidance {}"] + + # 2. Build start rule + # Target format: start: "PREFIX" REGEX "OPEN_TAG_0" m_0 REGEX "OPEN_TAG_1" m_1 ... REGEX "SUFFIX" + start_parts = [f'"{RESPONSE_PREFIX}"'] + + for i in range(num_tags): + # Add whitespace rule reference + start_parts.append("REGEX") + + # Add opening tag literal, e.g.: "<|MASKED id=\"m_0\"|>" + # Note escaping: id=\"m_{i}\" + open_tag_str = f'"{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}"' + start_parts.append(open_tag_str) + + # Add content rule reference (lowercase m_i) + start_parts.append(f"m_{i}") + + # Add trailing whitespace and suffix + start_parts.append("REGEX") + start_parts.append(f'"{RESPONSE_SUFFIX}"') + + lines.append(f"start: {' '.join(start_parts)}") + + # 3. Define whitespace rule (named REGEX to match examples, usually can also be called WS) + lines.append(r"REGEX: /\s*/") + + # 4. Generate specific rules for each tag for i, tag in enumerate(query.tags): - # `/(?s:.)*?/` is a non-greedy match for any character including newlines - content_pattern = f"/{tag.regex}/" if tag.regex else "/(?s:.)*?/" - grammar_rest_lines.append( - f'tag{i}: "{TAG_OPEN_LEFT} id=\\"m_{i}\\"{TAG_OPEN_RIGHT}" {content_pattern} "{TAG_END}"' - ) + # Note: When used with suffix, using greedy match /(?s:.*)/ instead of /(?s:.)*?/ is correct and legal. + pattern = f"/{tag.regex}/" if tag.regex else "/(?s:.*)/" + + # Rule m_i (logical layer): + # - capture: tells the engine to capture this part. + # - suffix: specifies the ending tag, the engine stops and consumes it when encountered. + # Note: Here we reference the TAG_END constant (i.e., "<|/MASKED|>") + lines.append(f'm_{i}[capture, suffix="{TAG_END}"]: M_{i}') + + # Rule M_i (regex layer): + # Define the actual matching pattern for this tag. + lines.append(f"M_{i}: {pattern}") + + # TODO: There may be many tags with "/(?s:.*)/" pattern, which can be inefficient. - grammar = grammar_first_line + "\n" + "\n".join(grammar_rest_lines) + # 5. Assemble final string + grammar = "\n".join(lines) + "\n" is_error, msgs = validate_grammar_spec(get_grammar_spec(grammar)) if is_error: diff --git a/tests/test_dsls.py b/tests/test_dsls.py index fe3ffd1..10190e7 100644 --- a/tests/test_dsls.py +++ b/tests/test_dsls.py @@ -11,16 +11,22 @@ def test_build_cfg(): query = Query('Hello, <|MASKED id="m_0"|>world<|/MASKED|>!') grm = ( - 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"\n' - 'tag0: "<|MASKED id=\\"m_0\\"|>" /(?s:.)*?/ "<|/MASKED|>"' + "%llguidance {}\n" + 'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|/GIM_RESPONSE|>"\n' + "REGEX: /\\s*/\n" + 'm_0[capture, suffix="<|/MASKED|>"]: M_0\n' + "M_0: /(?s:.*)/\n" ) assert build_cfg(query) == grm # Test with regex - query_with_regex = Query("Hello, ", MaskedTag(id=0, regex="[A-Za-z]{5}"), "!") + query_with_regex = Query("Hello, ", MaskedTag(id=0, regex=r"\w+\.com"), "!") whole_grammar_regex = ( - 'start: "<|GIM_RESPONSE|>" tag0 "<|/GIM_RESPONSE|>"\n' - 'tag0: "<|MASKED id=\\"m_0\\"|>" /[A-Za-z]{5}/ "<|/MASKED|>"' + "%llguidance {}\n" + 'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|/GIM_RESPONSE|>"\n' + "REGEX: /\\s*/\n" + 'm_0[capture, suffix="<|/MASKED|>"]: M_0\n' + "M_0: /\\w+\\.com/\n" ) assert build_cfg(query_with_regex) == whole_grammar_regex