Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-strict

import pandas as pd
from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput


class CodeSimilarityAnalysisInput(BaseAnalysisInput):
"""
Analysis input for code similarity analysis.

Stores a generation DataFrame containing target and model-generated code strings
along with their parsed ASTs.

Required columns:
- target_code_string: the original target code
- model_generated_code_string: the model's generated code
- target_ast: parsed AST (zss Node) for the target code
- generated_ast: parsed AST (zss Node) for the generated code
- target_parse_status: "success" or "partial" (error nodes filtered)
- generated_parse_status: "success" or "partial" (error nodes filtered)

Args:
generation_df: DataFrame containing code strings and parsed ASTs
"""

REQUIRED_COLUMNS: list[str] = [
"target_code_string",
"model_generated_code_string",
"target_ast",
"generated_ast",
"target_parse_status",
"generated_parse_status",
]

def __init__(self, generation_df: pd.DataFrame) -> None:
missing = set(self.REQUIRED_COLUMNS) - set(generation_df.columns)
if missing:
raise ValueError(f"Missing required columns in generation_df: {missing}")

super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame())

@property
def generation_df(self) -> pd.DataFrame:
"""Property accessor for the generation DataFrame."""
return self._df_train_user
135 changes: 135 additions & 0 deletions privacy_guard/analysis/code_similarity/tree_edit_distance_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-strict

import logging
from dataclasses import dataclass, field
from typing import cast

import pandas as pd
from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode
from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput
from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import (
CodeSimilarityAnalysisInput,
)
from zss import Node as ZSSNode, simple_distance


logger: logging.Logger = logging.getLogger(__name__)


def _count_nodes(node: ZSSNode) -> int:
"""Recursively count the number of nodes in a zss tree."""
count = 1
for child in node.children:
count += _count_nodes(child)
return count


@dataclass
class TreeEditDistanceNodeOutput(BaseAnalysisOutput):
"""Output of :class:`TreeEditDistanceNode`.

Attributes:
num_samples: total number of sample rows.
num_both_parsed: number of rows where both target and generated
code produced an AST (always equals *num_samples* since the
attack now returns partial ASTs for malformed code).
per_sample_similarity: DataFrame with a ``similarity`` column.
avg_similarity: average similarity across all pairs.
avg_similarity_by_language: per-language average similarity, or
``None`` when no ``language`` column is present.
"""

num_samples: int
num_both_parsed: int
per_sample_similarity: pd.DataFrame = field(repr=False)
avg_similarity: float
avg_similarity_by_language: dict[str, float] | None


class TreeEditDistanceNode(BaseAnalysisNode):
"""Compute tree-edit-distance similarity between AST pairs.

Uses the Zhang-Shasha algorithm (via ``zss.simple_distance``) to
compute edit distance, then normalises to a 0-1 similarity score::

similarity = max(1 - distance / max(n1, n2), 0)

where *n1* and *n2* are the node counts of the two trees.

Args:
analysis_input: a :class:`CodeSimilarityAnalysisInput` produced
by :class:`PyTreeSitterAttack`.
"""

def __init__(self, analysis_input: CodeSimilarityAnalysisInput) -> None:
super().__init__(analysis_input=analysis_input)

# ------------------------------------------------------------------
# Public static helper
# ------------------------------------------------------------------

@staticmethod
def compute_similarity(tree1: ZSSNode, tree2: ZSSNode) -> float:
"""Compute normalised tree-edit-distance similarity.

Args:
tree1: first zss Node tree.
tree2: second zss Node tree.

Returns:
Similarity in [0, 1] where 1.0 means identical trees.
"""
dist: int = simple_distance(tree1, tree2)
n1 = _count_nodes(tree1)
n2 = _count_nodes(tree2)
max_nodes = max(n1, n2)
if max_nodes == 0:
return 1.0
return max(1.0 - dist / max_nodes, 0.0)

# ------------------------------------------------------------------
# BaseAnalysisNode interface
# ------------------------------------------------------------------

def run_analysis(self) -> TreeEditDistanceNodeOutput:
analysis_input = cast(CodeSimilarityAnalysisInput, self.analysis_input)
df = analysis_input.generation_df

def _row_similarity(row: pd.Series) -> float: # type: ignore[type-arg]
return TreeEditDistanceNode.compute_similarity(
row["target_ast"], row["generated_ast"]
)

similarities = df.apply(_row_similarity, axis=1)
per_sample = pd.DataFrame({"similarity": similarities})

num_both_parsed = len(similarities)
avg_similarity = float(similarities.mean()) if num_both_parsed > 0 else 0.0

avg_by_lang: dict[str, float] | None = None
if "language" in df.columns:
per_sample["language"] = df["language"].values
grouped = per_sample.groupby("language")["similarity"].mean()
avg_by_lang = grouped.to_dict()

return TreeEditDistanceNodeOutput(
num_samples=len(df),
num_both_parsed=num_both_parsed,
per_sample_similarity=per_sample,
avg_similarity=avg_similarity,
avg_similarity_by_language=avg_by_lang,
)
147 changes: 147 additions & 0 deletions privacy_guard/analysis/tests/test_tree_edit_distance_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-strict

import unittest

import pandas as pd
from privacy_guard.analysis.code_similarity.tree_edit_distance_node import (
TreeEditDistanceNode,
TreeEditDistanceNodeOutput,
)
from privacy_guard.attacks.code_similarity.py_tree_sitter_attack import (
PyTreeSitterAttack,
)


def _run_e2e(
df: pd.DataFrame,
default_language: str = "python",
) -> TreeEditDistanceNodeOutput:
"""Helper: run attack then analysis end-to-end."""
attack = PyTreeSitterAttack(data=df, default_language=default_language)
analysis_input = attack.run_attack()
node = TreeEditDistanceNode(analysis_input=analysis_input)
return node.run_analysis()


class TestTreeEditDistanceNode(unittest.TestCase):
def test_similarity_values(self) -> None:
"""Identical code should yield ~1.0; different code should be low."""
with self.subTest("identical_python"):
code = "def foo():\n return 1\n"
df = pd.DataFrame(
{
"target_code_string": [code],
"model_generated_code_string": [code],
}
)
output = _run_e2e(df)
self.assertIsInstance(output, TreeEditDistanceNodeOutput)
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
self.assertEqual(output.num_both_parsed, 1)

with self.subTest("different_python"):
df = pd.DataFrame(
{
"target_code_string": ["def foo():\n return 1\n"],
"model_generated_code_string": [
"class Bar:\n def __init__(self):\n"
" self.x = 1\n"
" def method(self, a, b):\n"
" return a + b\n"
],
}
)
output = _run_e2e(df)
self.assertLess(output.avg_similarity, 0.5)

with self.subTest("cpp_similarity"):
df = pd.DataFrame(
{
"target_code_string": ["int add(int a, int b) { return a + b; }"],
"model_generated_code_string": [
"int sum(int x, int y) { return x + y; }"
],
}
)
output = _run_e2e(df, default_language="cpp")
self.assertGreater(output.avg_similarity, 0.7)

with self.subTest("partial_parse_high_similarity"):
# Generated code contains the same function as the target
# but is surrounded by syntax errors. After error-node
# filtering the partial AST should still yield high
# similarity against the clean target.
target = "def foo():\n x = 1\n return x\n"
generated = "))))\ndef foo():\n x = 1\n @@@@\n return x\n(((\n"
df = pd.DataFrame(
{
"target_code_string": [target],
"model_generated_code_string": [generated],
}
)
output = _run_e2e(df)
# Partial parse still produces a similarity score (not NaN)
self.assertEqual(output.num_both_parsed, 1)
self.assertGreater(output.avg_similarity, 0.5)

with self.subTest("ast_equivalence_different_strings"):
# Two code snippets that are syntactically equivalent but
# differ in identifier names and string literals should
# yield similarity ≈ 1.0 because tree-sitter AST nodes are
# labelled by grammar category (e.g. "identifier", "string"),
# not by the actual text content.
target = 'def compute():\n result = "hello"\n return result\n'
generated = 'def process():\n output = "world"\n return output\n'
df = pd.DataFrame(
{
"target_code_string": [target],
"model_generated_code_string": [generated],
}
)
output = _run_e2e(df)
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)

def test_avg_similarity_by_language(self) -> None:
"""Mixed Python+C++ input produces per-language averages."""
df = pd.DataFrame(
{
"target_code_string": [
"def foo():\n return 1\n",
"int main() { return 0; }",
],
"model_generated_code_string": [
"def foo():\n return 1\n",
"int main() { return 0; }",
],
"language": ["python", "cpp"],
}
)
output = _run_e2e(df)
assert output.avg_similarity_by_language is not None
by_lang = output.avg_similarity_by_language
self.assertIn("python", by_lang)
self.assertIn("cpp", by_lang)
self.assertAlmostEqual(by_lang["python"], 1.0, places=5)
self.assertAlmostEqual(by_lang["cpp"], 1.0, places=5)

def test_compute_similarity_static_method(self) -> None:
"""TreeEditDistanceNode.compute_similarity works standalone."""
node1, _ = PyTreeSitterAttack.parse_code("x = 1\n", language="python")
node2, _ = PyTreeSitterAttack.parse_code("x = 1\n", language="python")

sim = TreeEditDistanceNode.compute_similarity(node1, node2)
self.assertAlmostEqual(sim, 1.0, places=5)
Loading
Loading