diff --git a/privacy_guard/analysis/code_similarity/code_similarity_analysis_input.py b/privacy_guard/analysis/code_similarity/code_similarity_analysis_input.py new file mode 100644 index 0000000..8f53b86 --- /dev/null +++ b/privacy_guard/analysis/code_similarity/code_similarity_analysis_input.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import pandas as pd +from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput + + +class CodeSimilarityAnalysisInput(BaseAnalysisInput): + """ + Analysis input for code similarity analysis. + + Stores a generation DataFrame containing target and model-generated code strings + along with their parsed ASTs. + + Required columns: + - target_code_string: the original target code + - model_generated_code_string: the model's generated code + - target_ast: parsed AST (zss Node) for the target code + - generated_ast: parsed AST (zss Node) for the generated code + - target_parse_status: "success" or "partial" (error nodes filtered) + - generated_parse_status: "success" or "partial" (error nodes filtered) + + Args: + generation_df: DataFrame containing code strings and parsed ASTs + """ + + REQUIRED_COLUMNS: list[str] = [ + "target_code_string", + "model_generated_code_string", + "target_ast", + "generated_ast", + "target_parse_status", + "generated_parse_status", + ] + + def __init__(self, generation_df: pd.DataFrame) -> None: + missing = set(self.REQUIRED_COLUMNS) - set(generation_df.columns) + if missing: + raise ValueError(f"Missing required columns in generation_df: {missing}") + + super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame()) + + @property + def generation_df(self) -> pd.DataFrame: + """Property accessor for the generation DataFrame.""" + return self._df_train_user diff --git a/privacy_guard/analysis/code_similarity/tree_edit_distance_node.py b/privacy_guard/analysis/code_similarity/tree_edit_distance_node.py new file mode 100644 index 0000000..74bc839 --- /dev/null +++ b/privacy_guard/analysis/code_similarity/tree_edit_distance_node.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import logging +from dataclasses import dataclass, field +from typing import cast + +import pandas as pd +from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode +from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput +from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import ( + CodeSimilarityAnalysisInput, +) +from zss import Node as ZSSNode, simple_distance + + +logger: logging.Logger = logging.getLogger(__name__) + + +def _count_nodes(node: ZSSNode) -> int: + """Recursively count the number of nodes in a zss tree.""" + count = 1 + for child in node.children: + count += _count_nodes(child) + return count + + +@dataclass +class TreeEditDistanceNodeOutput(BaseAnalysisOutput): + """Output of :class:`TreeEditDistanceNode`. + + Attributes: + num_samples: total number of sample rows. + num_both_parsed: number of rows where both target and generated + code produced an AST (always equals *num_samples* since the + attack now returns partial ASTs for malformed code). + per_sample_similarity: DataFrame with a ``similarity`` column. + avg_similarity: average similarity across all pairs. + avg_similarity_by_language: per-language average similarity, or + ``None`` when no ``language`` column is present. + """ + + num_samples: int + num_both_parsed: int + per_sample_similarity: pd.DataFrame = field(repr=False) + avg_similarity: float + avg_similarity_by_language: dict[str, float] | None + + +class TreeEditDistanceNode(BaseAnalysisNode): + """Compute tree-edit-distance similarity between AST pairs. + + Uses the Zhang-Shasha algorithm (via ``zss.simple_distance``) to + compute edit distance, then normalises to a 0-1 similarity score:: + + similarity = max(1 - distance / max(n1, n2), 0) + + where *n1* and *n2* are the node counts of the two trees. + + Args: + analysis_input: a :class:`CodeSimilarityAnalysisInput` produced + by :class:`PyTreeSitterAttack`. + """ + + def __init__(self, analysis_input: CodeSimilarityAnalysisInput) -> None: + super().__init__(analysis_input=analysis_input) + + # ------------------------------------------------------------------ + # Public static helper + # ------------------------------------------------------------------ + + @staticmethod + def compute_similarity(tree1: ZSSNode, tree2: ZSSNode) -> float: + """Compute normalised tree-edit-distance similarity. + + Args: + tree1: first zss Node tree. + tree2: second zss Node tree. + + Returns: + Similarity in [0, 1] where 1.0 means identical trees. + """ + dist: int = simple_distance(tree1, tree2) + n1 = _count_nodes(tree1) + n2 = _count_nodes(tree2) + max_nodes = max(n1, n2) + if max_nodes == 0: + return 1.0 + return max(1.0 - dist / max_nodes, 0.0) + + # ------------------------------------------------------------------ + # BaseAnalysisNode interface + # ------------------------------------------------------------------ + + def run_analysis(self) -> TreeEditDistanceNodeOutput: + analysis_input = cast(CodeSimilarityAnalysisInput, self.analysis_input) + df = analysis_input.generation_df + + def _row_similarity(row: pd.Series) -> float: # type: ignore[type-arg] + return TreeEditDistanceNode.compute_similarity( + row["target_ast"], row["generated_ast"] + ) + + similarities = df.apply(_row_similarity, axis=1) + per_sample = pd.DataFrame({"similarity": similarities}) + + num_both_parsed = len(similarities) + avg_similarity = float(similarities.mean()) if num_both_parsed > 0 else 0.0 + + avg_by_lang: dict[str, float] | None = None + if "language" in df.columns: + per_sample["language"] = df["language"].values + grouped = per_sample.groupby("language")["similarity"].mean() + avg_by_lang = grouped.to_dict() + + return TreeEditDistanceNodeOutput( + num_samples=len(df), + num_both_parsed=num_both_parsed, + per_sample_similarity=per_sample, + avg_similarity=avg_similarity, + avg_similarity_by_language=avg_by_lang, + ) diff --git a/privacy_guard/analysis/tests/test_tree_edit_distance_node.py b/privacy_guard/analysis/tests/test_tree_edit_distance_node.py new file mode 100644 index 0000000..ae4dc1d --- /dev/null +++ b/privacy_guard/analysis/tests/test_tree_edit_distance_node.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import unittest + +import pandas as pd +from privacy_guard.analysis.code_similarity.tree_edit_distance_node import ( + TreeEditDistanceNode, + TreeEditDistanceNodeOutput, +) +from privacy_guard.attacks.code_similarity.py_tree_sitter_attack import ( + PyTreeSitterAttack, +) + + +def _run_e2e( + df: pd.DataFrame, + default_language: str = "python", +) -> TreeEditDistanceNodeOutput: + """Helper: run attack then analysis end-to-end.""" + attack = PyTreeSitterAttack(data=df, default_language=default_language) + analysis_input = attack.run_attack() + node = TreeEditDistanceNode(analysis_input=analysis_input) + return node.run_analysis() + + +class TestTreeEditDistanceNode(unittest.TestCase): + def test_similarity_values(self) -> None: + """Identical code should yield ~1.0; different code should be low.""" + with self.subTest("identical_python"): + code = "def foo():\n return 1\n" + df = pd.DataFrame( + { + "target_code_string": [code], + "model_generated_code_string": [code], + } + ) + output = _run_e2e(df) + self.assertIsInstance(output, TreeEditDistanceNodeOutput) + self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + self.assertEqual(output.num_both_parsed, 1) + + with self.subTest("different_python"): + df = pd.DataFrame( + { + "target_code_string": ["def foo():\n return 1\n"], + "model_generated_code_string": [ + "class Bar:\n def __init__(self):\n" + " self.x = 1\n" + " def method(self, a, b):\n" + " return a + b\n" + ], + } + ) + output = _run_e2e(df) + self.assertLess(output.avg_similarity, 0.5) + + with self.subTest("cpp_similarity"): + df = pd.DataFrame( + { + "target_code_string": ["int add(int a, int b) { return a + b; }"], + "model_generated_code_string": [ + "int sum(int x, int y) { return x + y; }" + ], + } + ) + output = _run_e2e(df, default_language="cpp") + self.assertGreater(output.avg_similarity, 0.7) + + with self.subTest("partial_parse_high_similarity"): + # Generated code contains the same function as the target + # but is surrounded by syntax errors. After error-node + # filtering the partial AST should still yield high + # similarity against the clean target. + target = "def foo():\n x = 1\n return x\n" + generated = "))))\ndef foo():\n x = 1\n @@@@\n return x\n(((\n" + df = pd.DataFrame( + { + "target_code_string": [target], + "model_generated_code_string": [generated], + } + ) + output = _run_e2e(df) + # Partial parse still produces a similarity score (not NaN) + self.assertEqual(output.num_both_parsed, 1) + self.assertGreater(output.avg_similarity, 0.5) + + with self.subTest("ast_equivalence_different_strings"): + # Two code snippets that are syntactically equivalent but + # differ in identifier names and string literals should + # yield similarity ≈ 1.0 because tree-sitter AST nodes are + # labelled by grammar category (e.g. "identifier", "string"), + # not by the actual text content. + target = 'def compute():\n result = "hello"\n return result\n' + generated = 'def process():\n output = "world"\n return output\n' + df = pd.DataFrame( + { + "target_code_string": [target], + "model_generated_code_string": [generated], + } + ) + output = _run_e2e(df) + self.assertAlmostEqual(output.avg_similarity, 1.0, places=5) + + def test_avg_similarity_by_language(self) -> None: + """Mixed Python+C++ input produces per-language averages.""" + df = pd.DataFrame( + { + "target_code_string": [ + "def foo():\n return 1\n", + "int main() { return 0; }", + ], + "model_generated_code_string": [ + "def foo():\n return 1\n", + "int main() { return 0; }", + ], + "language": ["python", "cpp"], + } + ) + output = _run_e2e(df) + assert output.avg_similarity_by_language is not None + by_lang = output.avg_similarity_by_language + self.assertIn("python", by_lang) + self.assertIn("cpp", by_lang) + self.assertAlmostEqual(by_lang["python"], 1.0, places=5) + self.assertAlmostEqual(by_lang["cpp"], 1.0, places=5) + + def test_compute_similarity_static_method(self) -> None: + """TreeEditDistanceNode.compute_similarity works standalone.""" + node1, _ = PyTreeSitterAttack.parse_code("x = 1\n", language="python") + node2, _ = PyTreeSitterAttack.parse_code("x = 1\n", language="python") + + sim = TreeEditDistanceNode.compute_similarity(node1, node2) + self.assertAlmostEqual(sim, 1.0, places=5) diff --git a/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py b/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py new file mode 100644 index 0000000..fe67498 --- /dev/null +++ b/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import logging +from types import ModuleType +from typing import Any + +import pandas as pd +import tree_sitter_cpp # @manual=fbsource//third-party/pypi/tree-sitter-cpp:tree-sitter-cpp +import tree_sitter_python # @manual=fbsource//third-party/pypi/tree-sitter-python:tree-sitter-python +from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import ( + CodeSimilarityAnalysisInput, +) +from privacy_guard.attacks.base_attack import BaseAttack +from tree_sitter import ( # @manual=fbsource//third-party/pypi/tree-sitter:tree-sitter + Language, + Parser, +) +from zss import Node as ZSSNode + + +logger: logging.Logger = logging.getLogger(__name__) + +# Maps user-facing language strings to tree-sitter language modules. +_LANGUAGE_REGISTRY: dict[str, ModuleType] = { + "python": tree_sitter_python, + "py": tree_sitter_python, + "c++": tree_sitter_cpp, + "cpp": tree_sitter_cpp, +} + + +def _get_parser(language: str) -> Parser: + """Create a tree-sitter Parser for the given language. + + Args: + language: a key in _LANGUAGE_REGISTRY (e.g. "python", "cpp") + + Returns: + A configured tree-sitter Parser instance. + + Raises: + ValueError: if the language is not supported. + """ + lang_key = language.lower() + ts_module = _LANGUAGE_REGISTRY.get(lang_key) + if ts_module is None: + raise ValueError( + f"Unsupported language '{language}'. " + f"Supported: {sorted(_LANGUAGE_REGISTRY.keys())}" + ) + + ts_language = Language(ts_module.language()) # type: ignore[attr-defined] + parser = Parser(ts_language) + return parser + + +class PyTreeSitterAttack(BaseAttack): + """Parse target and generated code into ASTs using tree-sitter. + + Expects a DataFrame with ``target_code_string`` and + ``model_generated_code_string`` columns. Produces a + :class:`CodeSimilarityAnalysisInput` with additional AST columns + ready for downstream similarity analysis. + + Args: + data: DataFrame with code string columns. + default_language: default language for parsing (e.g. "python", "cpp"). + Rows may override this via a ``language`` column. + """ + + REQUIRED_COLUMNS: list[str] = [ + "target_code_string", + "model_generated_code_string", + ] + + def __init__( + self, + data: pd.DataFrame, + default_language: str = "python", + ) -> None: + missing = set(self.REQUIRED_COLUMNS) - set(data.columns) + if missing: + raise ValueError(f"Missing required columns: {missing}") + + self._data: pd.DataFrame = data.copy() + self._default_language: str = default_language + + # ------------------------------------------------------------------ + # Public static helpers + # ------------------------------------------------------------------ + + @staticmethod + def _ts_node_to_zss_node(ts_node: Any, filter_errors: bool = False) -> ZSSNode: + """Recursively convert a tree-sitter Node into a zss Node. + + Each zss node is labelled with the tree-sitter node's ``type`` + string (e.g. ``"function_definition"``, ``"identifier"``). + + Args: + ts_node: tree-sitter Node to convert. + filter_errors: when True, skip children that are ERROR or + MISSING nodes (tree-sitter error-recovery artefacts). + """ + zss_node = ZSSNode(ts_node.type) + for child in ts_node.children: + if filter_errors and (child.is_error or child.is_missing): + continue + zss_node.addkid( + PyTreeSitterAttack._ts_node_to_zss_node(child, filter_errors) + ) + return zss_node + + @staticmethod + def parse_code(code: str, language: str = "python") -> tuple[ZSSNode, str]: + """Parse a single code snippet and return a zss Node tree. + + Tree-sitter always produces a parse tree, even for malformed + code. When syntax errors are present the parser inserts ERROR + and MISSING nodes. This method filters those nodes out and + returns the valid portion of the AST so that downstream + similarity analysis can still operate on partially-correct code. + + Args: + code: source code string. + language: language identifier (see ``_LANGUAGE_REGISTRY``). + + Returns: + Tuple of ``(root_node, parse_status)`` where *root_node* is + the root :class:`zss.Node` and *parse_status* is + ``"success"`` when the code parsed without errors or + ``"partial"`` when error/missing nodes were filtered out. + """ + parser = _get_parser(language) + tree = parser.parse(code.encode("utf-8")) + if not tree.root_node.has_error: + return ( + PyTreeSitterAttack._ts_node_to_zss_node(tree.root_node), + "success", + ) + return ( + PyTreeSitterAttack._ts_node_to_zss_node(tree.root_node, filter_errors=True), + "partial", + ) + + # ------------------------------------------------------------------ + # BaseAttack interface + # ------------------------------------------------------------------ + + def run_attack(self) -> CodeSimilarityAnalysisInput: + """Parse every row's code strings into ASTs. + + Adds the following columns to the DataFrame: + - ``target_ast``: zss Node (always present) + - ``generated_ast``: zss Node (always present) + - ``target_parse_status``: ``"success"`` or ``"partial"`` + - ``generated_parse_status``: ``"success"`` or ``"partial"`` + + Returns: + A :class:`CodeSimilarityAnalysisInput` wrapping the + augmented DataFrame. + """ + df = self._data + + has_language_col = "language" in df.columns + + target_asts: list[ZSSNode] = [] + generated_asts: list[ZSSNode] = [] + target_parse_statuses: list[str] = [] + generated_parse_statuses: list[str] = [] + + for _idx, row in df.iterrows(): + lang = str(row["language"]) if has_language_col else self._default_language + + t_ast, t_status = self.parse_code(str(row["target_code_string"]), lang) + target_asts.append(t_ast) + target_parse_statuses.append(t_status) + + g_ast, g_status = self.parse_code( + str(row["model_generated_code_string"]), lang + ) + generated_asts.append(g_ast) + generated_parse_statuses.append(g_status) + + df["target_ast"] = target_asts + df["generated_ast"] = generated_asts + df["target_parse_status"] = target_parse_statuses + df["generated_parse_status"] = generated_parse_statuses + + return CodeSimilarityAnalysisInput(generation_df=df) diff --git a/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py b/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py new file mode 100644 index 0000000..de72fbe --- /dev/null +++ b/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import unittest + +import pandas as pd +from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import ( + CodeSimilarityAnalysisInput, +) +from privacy_guard.attacks.code_similarity.py_tree_sitter_attack import ( + PyTreeSitterAttack, +) +from zss import Node as ZSSNode, simple_distance + + +def _trees_identical(a: ZSSNode, b: ZSSNode) -> bool: + """Return True when two zss trees have zero edit distance.""" + return simple_distance(a, b) == 0 + + +class TestPyTreeSitterAttack(unittest.TestCase): + def test_run_attack_and_languages(self) -> None: + """Test run_attack for Python and C++ code, and malformed code.""" + with self.subTest("python"): + df = pd.DataFrame( + { + "target_code_string": ["def foo():\n return 1\n"], + "model_generated_code_string": ["def bar():\n return 2\n"], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="python") + result = attack.run_attack() + + self.assertIsInstance(result, CodeSimilarityAnalysisInput) + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + self.assertIsNotNone(gen_df["target_ast"].iloc[0]) + self.assertIsNotNone(gen_df["generated_ast"].iloc[0]) + + with self.subTest("cpp"): + df = pd.DataFrame( + { + "target_code_string": ["int main() { return 0; }"], + "model_generated_code_string": [ + "int add(int a, int b) { return a + b; }" + ], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="cpp") + result = attack.run_attack() + + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + + with self.subTest("malformed_code_partial_parse"): + df = pd.DataFrame( + { + "target_code_string": ["def foo(:\n return"], + "model_generated_code_string": ["def bar():\n return 1\n"], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="python") + result = attack.run_attack() + + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "partial") + # A partial AST is still returned (not None) + self.assertIsNotNone(gen_df["target_ast"].iloc[0]) + # The well-formed generated code should parse cleanly + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + + with self.subTest("malformed_similar_errors_around_valid_code"): + # Generated code is identical to target but wrapped in syntax + # errors before, after, and in between statements. After + # filtering error nodes the partial AST should closely + # resemble the target AST. + target = "def foo():\n x = 1\n return x\n" + generated = ( + "))))\n" # errors before + "def foo():\n" + " x = 1\n" + " @@@@\n" # errors in between + " return x\n" + "(((\n" # errors after + ) + df = pd.DataFrame( + { + "target_code_string": [target], + "model_generated_code_string": [generated], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="python") + result = attack.run_attack() + + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "partial") + # Both ASTs should be present and structurally close + t_ast: ZSSNode = gen_df["target_ast"].iloc[0] + g_ast: ZSSNode = gen_df["generated_ast"].iloc[0] + dist = simple_distance(t_ast, g_ast) + # The filtered partial AST should be very close to the + # target (small edit distance relative to tree size). + self.assertLessEqual(dist, 5) + + with self.subTest("ast_equivalence_different_identifiers"): + # Two code snippets that differ only in identifier names and + # string literals should produce identical ASTs because + # tree-sitter node types capture grammar categories, not + # the actual text content. + code_a = 'def compute():\n result = "hello"\n return result\n' + code_b = 'def process():\n output = "world"\n return output\n' + ast_a, status_a = PyTreeSitterAttack.parse_code(code_a) + ast_b, status_b = PyTreeSitterAttack.parse_code(code_b) + self.assertEqual(status_a, "success") + self.assertEqual(status_b, "success") + self.assertTrue( + _trees_identical(ast_a, ast_b), + "ASTs should be identical when code differs only in " + "identifier names and string literals", + ) + + def test_missing_column_raises(self) -> None: + """Missing required columns should raise ValueError.""" + with self.subTest("missing_target"): + df = pd.DataFrame({"model_generated_code_string": ["x"]}) + with self.assertRaises(ValueError): + PyTreeSitterAttack(data=df) + + with self.subTest("missing_generated"): + df = pd.DataFrame({"target_code_string": ["x"]}) + with self.assertRaises(ValueError): + PyTreeSitterAttack(data=df) + + def test_parse_code_static_method(self) -> None: + """parse_code() works as a standalone static method.""" + with self.subTest("python_parse"): + node, status = PyTreeSitterAttack.parse_code("x = 1\n", language="python") + self.assertEqual(status, "success") + self.assertEqual(node.label, "module") + + with self.subTest("cpp_parse"): + node, status = PyTreeSitterAttack.parse_code("int x = 1;", language="cpp") + self.assertEqual(status, "success") + self.assertEqual(node.label, "translation_unit") + + with self.subTest("malformed_returns_partial"): + node, status = PyTreeSitterAttack.parse_code("def foo(:", language="python") + self.assertEqual(status, "partial") + # A partial AST is still returned + self.assertEqual(node.label, "module") diff --git a/pyproject.toml b/pyproject.toml index 4d42e65..9c64fe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,10 @@ dependencies = [ 'later', 'torchvision', 'matplotlib', + 'tree-sitter', + 'tree-sitter-python', + 'tree-sitter-cpp', + 'zss', ] [project.optional-dependencies]