diff --git a/privacy_guard/analysis/code_similarity/code_similarity_analysis_input.py b/privacy_guard/analysis/code_similarity/code_similarity_analysis_input.py new file mode 100644 index 0000000..8f53b86 --- /dev/null +++ b/privacy_guard/analysis/code_similarity/code_similarity_analysis_input.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import pandas as pd +from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput + + +class CodeSimilarityAnalysisInput(BaseAnalysisInput): + """ + Analysis input for code similarity analysis. + + Stores a generation DataFrame containing target and model-generated code strings + along with their parsed ASTs. + + Required columns: + - target_code_string: the original target code + - model_generated_code_string: the model's generated code + - target_ast: parsed AST (zss Node) for the target code + - generated_ast: parsed AST (zss Node) for the generated code + - target_parse_status: "success" or "partial" (error nodes filtered) + - generated_parse_status: "success" or "partial" (error nodes filtered) + + Args: + generation_df: DataFrame containing code strings and parsed ASTs + """ + + REQUIRED_COLUMNS: list[str] = [ + "target_code_string", + "model_generated_code_string", + "target_ast", + "generated_ast", + "target_parse_status", + "generated_parse_status", + ] + + def __init__(self, generation_df: pd.DataFrame) -> None: + missing = set(self.REQUIRED_COLUMNS) - set(generation_df.columns) + if missing: + raise ValueError(f"Missing required columns in generation_df: {missing}") + + super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame()) + + @property + def generation_df(self) -> pd.DataFrame: + """Property accessor for the generation DataFrame.""" + return self._df_train_user diff --git a/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py b/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py new file mode 100644 index 0000000..fe67498 --- /dev/null +++ b/privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import logging +from types import ModuleType +from typing import Any + +import pandas as pd +import tree_sitter_cpp # @manual=fbsource//third-party/pypi/tree-sitter-cpp:tree-sitter-cpp +import tree_sitter_python # @manual=fbsource//third-party/pypi/tree-sitter-python:tree-sitter-python +from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import ( + CodeSimilarityAnalysisInput, +) +from privacy_guard.attacks.base_attack import BaseAttack +from tree_sitter import ( # @manual=fbsource//third-party/pypi/tree-sitter:tree-sitter + Language, + Parser, +) +from zss import Node as ZSSNode + + +logger: logging.Logger = logging.getLogger(__name__) + +# Maps user-facing language strings to tree-sitter language modules. +_LANGUAGE_REGISTRY: dict[str, ModuleType] = { + "python": tree_sitter_python, + "py": tree_sitter_python, + "c++": tree_sitter_cpp, + "cpp": tree_sitter_cpp, +} + + +def _get_parser(language: str) -> Parser: + """Create a tree-sitter Parser for the given language. + + Args: + language: a key in _LANGUAGE_REGISTRY (e.g. "python", "cpp") + + Returns: + A configured tree-sitter Parser instance. + + Raises: + ValueError: if the language is not supported. + """ + lang_key = language.lower() + ts_module = _LANGUAGE_REGISTRY.get(lang_key) + if ts_module is None: + raise ValueError( + f"Unsupported language '{language}'. " + f"Supported: {sorted(_LANGUAGE_REGISTRY.keys())}" + ) + + ts_language = Language(ts_module.language()) # type: ignore[attr-defined] + parser = Parser(ts_language) + return parser + + +class PyTreeSitterAttack(BaseAttack): + """Parse target and generated code into ASTs using tree-sitter. + + Expects a DataFrame with ``target_code_string`` and + ``model_generated_code_string`` columns. Produces a + :class:`CodeSimilarityAnalysisInput` with additional AST columns + ready for downstream similarity analysis. + + Args: + data: DataFrame with code string columns. + default_language: default language for parsing (e.g. "python", "cpp"). + Rows may override this via a ``language`` column. + """ + + REQUIRED_COLUMNS: list[str] = [ + "target_code_string", + "model_generated_code_string", + ] + + def __init__( + self, + data: pd.DataFrame, + default_language: str = "python", + ) -> None: + missing = set(self.REQUIRED_COLUMNS) - set(data.columns) + if missing: + raise ValueError(f"Missing required columns: {missing}") + + self._data: pd.DataFrame = data.copy() + self._default_language: str = default_language + + # ------------------------------------------------------------------ + # Public static helpers + # ------------------------------------------------------------------ + + @staticmethod + def _ts_node_to_zss_node(ts_node: Any, filter_errors: bool = False) -> ZSSNode: + """Recursively convert a tree-sitter Node into a zss Node. + + Each zss node is labelled with the tree-sitter node's ``type`` + string (e.g. ``"function_definition"``, ``"identifier"``). + + Args: + ts_node: tree-sitter Node to convert. + filter_errors: when True, skip children that are ERROR or + MISSING nodes (tree-sitter error-recovery artefacts). + """ + zss_node = ZSSNode(ts_node.type) + for child in ts_node.children: + if filter_errors and (child.is_error or child.is_missing): + continue + zss_node.addkid( + PyTreeSitterAttack._ts_node_to_zss_node(child, filter_errors) + ) + return zss_node + + @staticmethod + def parse_code(code: str, language: str = "python") -> tuple[ZSSNode, str]: + """Parse a single code snippet and return a zss Node tree. + + Tree-sitter always produces a parse tree, even for malformed + code. When syntax errors are present the parser inserts ERROR + and MISSING nodes. This method filters those nodes out and + returns the valid portion of the AST so that downstream + similarity analysis can still operate on partially-correct code. + + Args: + code: source code string. + language: language identifier (see ``_LANGUAGE_REGISTRY``). + + Returns: + Tuple of ``(root_node, parse_status)`` where *root_node* is + the root :class:`zss.Node` and *parse_status* is + ``"success"`` when the code parsed without errors or + ``"partial"`` when error/missing nodes were filtered out. + """ + parser = _get_parser(language) + tree = parser.parse(code.encode("utf-8")) + if not tree.root_node.has_error: + return ( + PyTreeSitterAttack._ts_node_to_zss_node(tree.root_node), + "success", + ) + return ( + PyTreeSitterAttack._ts_node_to_zss_node(tree.root_node, filter_errors=True), + "partial", + ) + + # ------------------------------------------------------------------ + # BaseAttack interface + # ------------------------------------------------------------------ + + def run_attack(self) -> CodeSimilarityAnalysisInput: + """Parse every row's code strings into ASTs. + + Adds the following columns to the DataFrame: + - ``target_ast``: zss Node (always present) + - ``generated_ast``: zss Node (always present) + - ``target_parse_status``: ``"success"`` or ``"partial"`` + - ``generated_parse_status``: ``"success"`` or ``"partial"`` + + Returns: + A :class:`CodeSimilarityAnalysisInput` wrapping the + augmented DataFrame. + """ + df = self._data + + has_language_col = "language" in df.columns + + target_asts: list[ZSSNode] = [] + generated_asts: list[ZSSNode] = [] + target_parse_statuses: list[str] = [] + generated_parse_statuses: list[str] = [] + + for _idx, row in df.iterrows(): + lang = str(row["language"]) if has_language_col else self._default_language + + t_ast, t_status = self.parse_code(str(row["target_code_string"]), lang) + target_asts.append(t_ast) + target_parse_statuses.append(t_status) + + g_ast, g_status = self.parse_code( + str(row["model_generated_code_string"]), lang + ) + generated_asts.append(g_ast) + generated_parse_statuses.append(g_status) + + df["target_ast"] = target_asts + df["generated_ast"] = generated_asts + df["target_parse_status"] = target_parse_statuses + df["generated_parse_status"] = generated_parse_statuses + + return CodeSimilarityAnalysisInput(generation_df=df) diff --git a/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py b/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py new file mode 100644 index 0000000..de72fbe --- /dev/null +++ b/privacy_guard/attacks/tests/test_py_tree_sitter_attack.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +import unittest + +import pandas as pd +from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import ( + CodeSimilarityAnalysisInput, +) +from privacy_guard.attacks.code_similarity.py_tree_sitter_attack import ( + PyTreeSitterAttack, +) +from zss import Node as ZSSNode, simple_distance + + +def _trees_identical(a: ZSSNode, b: ZSSNode) -> bool: + """Return True when two zss trees have zero edit distance.""" + return simple_distance(a, b) == 0 + + +class TestPyTreeSitterAttack(unittest.TestCase): + def test_run_attack_and_languages(self) -> None: + """Test run_attack for Python and C++ code, and malformed code.""" + with self.subTest("python"): + df = pd.DataFrame( + { + "target_code_string": ["def foo():\n return 1\n"], + "model_generated_code_string": ["def bar():\n return 2\n"], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="python") + result = attack.run_attack() + + self.assertIsInstance(result, CodeSimilarityAnalysisInput) + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + self.assertIsNotNone(gen_df["target_ast"].iloc[0]) + self.assertIsNotNone(gen_df["generated_ast"].iloc[0]) + + with self.subTest("cpp"): + df = pd.DataFrame( + { + "target_code_string": ["int main() { return 0; }"], + "model_generated_code_string": [ + "int add(int a, int b) { return a + b; }" + ], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="cpp") + result = attack.run_attack() + + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + + with self.subTest("malformed_code_partial_parse"): + df = pd.DataFrame( + { + "target_code_string": ["def foo(:\n return"], + "model_generated_code_string": ["def bar():\n return 1\n"], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="python") + result = attack.run_attack() + + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "partial") + # A partial AST is still returned (not None) + self.assertIsNotNone(gen_df["target_ast"].iloc[0]) + # The well-formed generated code should parse cleanly + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "success") + + with self.subTest("malformed_similar_errors_around_valid_code"): + # Generated code is identical to target but wrapped in syntax + # errors before, after, and in between statements. After + # filtering error nodes the partial AST should closely + # resemble the target AST. + target = "def foo():\n x = 1\n return x\n" + generated = ( + "))))\n" # errors before + "def foo():\n" + " x = 1\n" + " @@@@\n" # errors in between + " return x\n" + "(((\n" # errors after + ) + df = pd.DataFrame( + { + "target_code_string": [target], + "model_generated_code_string": [generated], + } + ) + attack = PyTreeSitterAttack(data=df, default_language="python") + result = attack.run_attack() + + gen_df = result.generation_df + self.assertEqual(gen_df["target_parse_status"].iloc[0], "success") + self.assertEqual(gen_df["generated_parse_status"].iloc[0], "partial") + # Both ASTs should be present and structurally close + t_ast: ZSSNode = gen_df["target_ast"].iloc[0] + g_ast: ZSSNode = gen_df["generated_ast"].iloc[0] + dist = simple_distance(t_ast, g_ast) + # The filtered partial AST should be very close to the + # target (small edit distance relative to tree size). + self.assertLessEqual(dist, 5) + + with self.subTest("ast_equivalence_different_identifiers"): + # Two code snippets that differ only in identifier names and + # string literals should produce identical ASTs because + # tree-sitter node types capture grammar categories, not + # the actual text content. + code_a = 'def compute():\n result = "hello"\n return result\n' + code_b = 'def process():\n output = "world"\n return output\n' + ast_a, status_a = PyTreeSitterAttack.parse_code(code_a) + ast_b, status_b = PyTreeSitterAttack.parse_code(code_b) + self.assertEqual(status_a, "success") + self.assertEqual(status_b, "success") + self.assertTrue( + _trees_identical(ast_a, ast_b), + "ASTs should be identical when code differs only in " + "identifier names and string literals", + ) + + def test_missing_column_raises(self) -> None: + """Missing required columns should raise ValueError.""" + with self.subTest("missing_target"): + df = pd.DataFrame({"model_generated_code_string": ["x"]}) + with self.assertRaises(ValueError): + PyTreeSitterAttack(data=df) + + with self.subTest("missing_generated"): + df = pd.DataFrame({"target_code_string": ["x"]}) + with self.assertRaises(ValueError): + PyTreeSitterAttack(data=df) + + def test_parse_code_static_method(self) -> None: + """parse_code() works as a standalone static method.""" + with self.subTest("python_parse"): + node, status = PyTreeSitterAttack.parse_code("x = 1\n", language="python") + self.assertEqual(status, "success") + self.assertEqual(node.label, "module") + + with self.subTest("cpp_parse"): + node, status = PyTreeSitterAttack.parse_code("int x = 1;", language="cpp") + self.assertEqual(status, "success") + self.assertEqual(node.label, "translation_unit") + + with self.subTest("malformed_returns_partial"): + node, status = PyTreeSitterAttack.parse_code("def foo(:", language="python") + self.assertEqual(status, "partial") + # A partial AST is still returned + self.assertEqual(node.label, "module") diff --git a/pyproject.toml b/pyproject.toml index 4d42e65..9c64fe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,10 @@ dependencies = [ 'later', 'torchvision', 'matplotlib', + 'tree-sitter', + 'tree-sitter-python', + 'tree-sitter-cpp', + 'zss', ] [project.optional-dependencies]