diff --git a/src/gimkit/dsls.py b/src/gimkit/dsls.py index 4540116..6720f40 100644 --- a/src/gimkit/dsls.py +++ b/src/gimkit/dsls.py @@ -47,7 +47,7 @@ def build_cfg(query: Query) -> str: ```python query = '<|GIM_QUERY|>The capital of <|MASKED desc="single word" regex="中国|法国"|><|/MASKED|> is Beijing<|MASKED desc="punctuation mark" regex="\\."|><|/MASKED|><|/GIM_QUERY|>' print(repr(build_cfg(Query(query)))) - >>> '%llguidance {}\nstart: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|MASKED id=\\"m_1\\"|>" m_1 REGEX "<|/GIM_RESPONSE|>"\nREGEX: /\\s*/\nm_0[capture, suffix="<|/MASKED|>"]: M_0\nM_0: /中国|法国/\nm_1[capture, suffix="<|/MASKED|>"]: M_1\nM_1: /\\./\n' + >>> '%llguidance {}\nstart: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|MASKED id=\\"m_1\\"|>" m_1 REGEX "<|/GIM_RESPONSE|>"\nREGEX: /\\s*/\nm_0[capture, suffix="<|/MASKED|>"]: T_0\nm_1[capture, suffix="<|/MASKED|>"]: T_1\nT_0: /中国|法国/\nT_1: /\\./\n' ``` """ num_tags = len(query.tags) @@ -80,24 +80,34 @@ def build_cfg(query: Query) -> str: # 3. Define whitespace rule (named REGEX to match examples, usually can also be called WS) lines.append(r"REGEX: /\s*/") - # 4. Generate specific rules for each tag + # 4. Collect unique patterns and create a mapping for terminal reuse + # This optimization avoids creating duplicate terminal rules for tags with the same regex + unique_pattern_terminals: dict[str, str] = {} + terminal_definitions: list[str] = [] + for i, tag in enumerate(query.tags): # Note: When used with suffix, using greedy match /(?s:.*)/ instead of /(?s:.)*?/ is correct and legal. pattern = f"/{tag.regex}/" if tag.regex else "/(?s:.*)/" + # Get or create a shared terminal for this pattern + if pattern not in unique_pattern_terminals: + # Create a new terminal name for this unique pattern + terminal_name = f"T_{len(unique_pattern_terminals)}" + unique_pattern_terminals[pattern] = terminal_name + terminal_definitions.append(f"{terminal_name}: {pattern}") + + terminal_name = unique_pattern_terminals[pattern] + # Rule m_i (logical layer): # - capture: tells the engine to capture this part. # - suffix: specifies the ending tag, the engine stops and consumes it when encountered. # Note: Here we reference the TAG_END constant (i.e., "<|/MASKED|>") - lines.append(f'm_{i}[capture, suffix="{TAG_END}"]: M_{i}') - - # Rule M_i (regex layer): - # Define the actual matching pattern for this tag. - lines.append(f"M_{i}: {pattern}") + lines.append(f'm_{i}[capture, suffix="{TAG_END}"]: {terminal_name}') - # TODO: There may be many tags with "/(?s:.*)/" pattern, which can be inefficient. + # 5. Add all unique terminal definitions + lines.extend(terminal_definitions) - # 5. Assemble final string + # 6. Assemble final string grammar = "\n".join(lines) + "\n" is_error, msgs = validate_grammar_spec(get_grammar_spec(grammar)) diff --git a/tests/test_dsls.py b/tests/test_dsls.py index 10190e7..ee085bf 100644 --- a/tests/test_dsls.py +++ b/tests/test_dsls.py @@ -14,8 +14,8 @@ def test_build_cfg(): "%llguidance {}\n" 'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|/GIM_RESPONSE|>"\n' "REGEX: /\\s*/\n" - 'm_0[capture, suffix="<|/MASKED|>"]: M_0\n' - "M_0: /(?s:.*)/\n" + 'm_0[capture, suffix="<|/MASKED|>"]: T_0\n' + "T_0: /(?s:.*)/\n" ) assert build_cfg(query) == grm @@ -25,8 +25,8 @@ def test_build_cfg(): "%llguidance {}\n" 'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|/GIM_RESPONSE|>"\n' "REGEX: /\\s*/\n" - 'm_0[capture, suffix="<|/MASKED|>"]: M_0\n' - "M_0: /\\w+\\.com/\n" + 'm_0[capture, suffix="<|/MASKED|>"]: T_0\n' + "T_0: /\\w+\\.com/\n" ) assert build_cfg(query_with_regex) == whole_grammar_regex @@ -37,6 +37,29 @@ def test_build_cfg(): ): build_cfg(Query(MaskedTag(regex="[[]]"))) + # Test with various complex patterns including repeated regexes + query = Query( + "Date: ", + MaskedTag(id=0, regex=r"\d{4}-\d{2}-\d{2}"), + ", AnotherDate: ", + MaskedTag(id=1, regex=r"\d{4}-\d{2}-\d{2}"), # same as id=0 + ", Time: ", + MaskedTag(id=2, regex=r"\d{2}:\d{2}:\d{2}"), + ", AnotherTime: ", + MaskedTag(id=3, regex=r"\d{2}:\d{2}:\d{2}"), # same as id=2 + ) + assert build_cfg(query) == ( + "%llguidance {}\n" + 'start: "<|GIM_RESPONSE|>" REGEX "<|MASKED id=\\"m_0\\"|>" m_0 REGEX "<|MASKED id=\\"m_1\\"|>" m_1 REGEX "<|MASKED id=\\"m_2\\"|>" m_2 REGEX "<|MASKED id=\\"m_3\\"|>" m_3 REGEX "<|/GIM_RESPONSE|>"\n' + "REGEX: /\\s*/\n" + 'm_0[capture, suffix="<|/MASKED|>"]: T_0\n' + 'm_1[capture, suffix="<|/MASKED|>"]: T_0\n' + 'm_2[capture, suffix="<|/MASKED|>"]: T_1\n' + 'm_3[capture, suffix="<|/MASKED|>"]: T_1\n' + "T_0: /\\d{4}-\\d{2}-\\d{2}/\n" + "T_1: /\\d{2}:\\d{2}:\\d{2}/\n" + ) + def test_build_json_schema(): query = Query(