Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion server/reportmanager/clustering/ClusterBucketManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
from reportmanager.utils import preprocess_text


@dataclass
class DomainClusterData:
"""Domain-specific clustering data including distance threshold and embeddings."""

domain: str
distance_threshold: float
embeddings: dict[int, np.ndarray] # cluster_id -> embeddings


@dataclass
class ClusterReport:
id: int
Expand Down Expand Up @@ -117,6 +126,7 @@ def batch_delete_in_chunks(
class ClusterBucketManager:
def __init__(self) -> None:
self.clusterer = SBERTClusterer()
self.domain_data: dict[str, DomainClusterData] = {}

def build_cluster_report(self, report_data: dict) -> ClusterReport:
comments = report_data["comments_translated"] or report_data["comments"]
Expand Down Expand Up @@ -269,7 +279,7 @@ def cluster_domain_reports(
return []

# Calculate if this is a high-volume domain
# and if so, only use reports in the last 14 days
# and if so, only use reports in the last N days
is_high_volume = self.is_high_volume_domain(reports)

if is_high_volume:
Expand Down Expand Up @@ -417,3 +427,85 @@ def create_buckets_from_clusters(self, all_clusters: list[ClusterData]) -> int:
buckets_created += 1

return buckets_created

def get_bucket_for_cluster(self, cluster_id: int) -> Bucket | None:
cluster = Cluster.objects.filter(id=cluster_id).first()
if not cluster:
return None

signature = self.build_cluster_bucket_signature(cluster.domain, cluster_id)
bucket = Bucket.objects.filter(signature=signature).first()
return bucket

def build_domain_data(
self,
all_reports: list[ClusterReport],
domains: set[str] | None = None,
) -> None:
"""Build domain data: threshold based on volume and cluster embeddings."""

reports_by_domain = self.group_reports_by_domain(all_reports, domains)

for domain, domain_reports in reports_by_domain.items():
is_high_volume = self.is_high_volume_domain(domain_reports)

distance_threshold = (
ClusteringConfig.HIGH_VOLUME_DISTANCE_THRESHOLD
if is_high_volume
else ClusteringConfig.NORMAL_VOLUME_DISTANCE_THRESHOLD
)

domain_clusters = Cluster.objects.filter(domain=domain).prefetch_related(
"reportentry_set"
)

cluster_embeddings = {}
for cluster in domain_clusters:
reports = list(
cluster.reportentry_set.exclude(comments="").values(
"id", "comments", "comments_translated"
)
)

if not reports:
continue

texts = []
for r in reports:
text = r["comments_translated"] or r["comments"]
preprocessed = preprocess_text(text)
if preprocessed:
texts.append(preprocessed)

if not texts:
continue

embeddings = self.clusterer.build_embeddings(texts)
cluster_embeddings[cluster.id] = embeddings

# Store domain data only if we have clusters
if cluster_embeddings:
self.domain_data[domain] = DomainClusterData(
domain=domain,
distance_threshold=distance_threshold,
embeddings=cluster_embeddings,
)

def get_closest_cluster(self, report: ClusterReport) -> int | None:
"""Find the closest cluster for a report."""

domain_data = self.domain_data.get(report.domain)

if not domain_data or not domain_data.embeddings:
return None

min_similarity = 1.0 - domain_data.distance_threshold

cluster_id = self.clusterer.assign_to_cluster_top_n_avg(
report.text,
domain_data.embeddings,
n=5,
min_similarity=min_similarity,
)

return cluster_id
64 changes: 64 additions & 0 deletions server/reportmanager/clustering/SBERTClusterer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,67 @@ def find_centroid_index(self, embeddings: np.ndarray) -> int:
closest_idx = int(np.argmin(squared_distances))

return closest_idx

def build_embeddings(self, texts: list[str]) -> np.ndarray:
embeddings = self.model.encode(
texts, show_progress_bar=False, normalize_embeddings=True
)
return np.asarray(embeddings, dtype=np.float32)

def assign_to_cluster_top_n_avg(
self,
report_text: str,
cluster_embeddings: dict[int, np.ndarray],
n: int,
min_similarity: float,
) -> int | None:
"""Assign report to cluster based on average similarity to top-N members.

This approach averages the similarity scores of the N most similar members
from each cluster, then assigns to the cluster with the highest average.

Args:
report_text: Text to classify
cluster_embeddings: Dict mapping cluster IDs to their member embeddings
n: Number of top similar members to average (default: 3)
min_similarity: Minimum average similarity threshold (default: 0.5)

Returns:
Cluster ID if match found, None otherwise
"""
if not cluster_embeddings:
return None

x = self.model.encode(
[report_text], show_progress_bar=False, normalize_embeddings=True
)[0]
x = np.asarray(x, dtype=np.float32)

best_cluster = None
best_avg_similarity = min_similarity

for cluster_id, embs in cluster_embeddings.items():
# Calculate similarities to all members
sims = embs @ x

# Take average of top-N most similar members
# For small clusters, use all members
top_n = min(n, len(sims))

if top_n == 0:
continue

# Use partition to get top-N
if len(sims) > top_n:
top_sims = np.partition(sims, -top_n)[-top_n:]
else:
top_sims = sims

avg_similarity = float(np.mean(top_sims))

# Update best cluster if this one has higher average similarity
if avg_similarity > best_avg_similarity:
best_avg_similarity = avg_similarity
best_cluster = cluster_id

return best_cluster
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from contextlib import suppress
from datetime import UTC
from logging import getLogger
from urllib.parse import urlsplit

Expand All @@ -10,10 +11,13 @@
from django.core.management import BaseCommand
from django.db.models import Q
from django.db.utils import IntegrityError
from datetime import UTC
from google.cloud import bigquery
from google.oauth2 import service_account

from reportmanager.clustering.ClusterBucketManager import (
ClusterBucketManager,
ClusteringConfig,
)
from reportmanager.models import Bucket, ReportEntry
from webcompat.models import Report

Expand All @@ -24,13 +28,17 @@
KNOWN_BUCKET_IDS: dict[str, int] = {}


# This is only returning a bucket if there is exactly one matching bucket, or
# if there is absolutely no matching bucket and we have to create one. If there
# are multiple buckets matching, we return none, and effectively leave the
# report in the "untriaged", i.e. not assigned to any bucket, state. The cronjob
# can then pick the report up and run the more comprehensive full-signature
# check.
# This is only returning a bucket for low quality reports. The rest
# of the reports are left unbucketed for triaging.
# Returns a bucket if there is a matching domain bucket, or
# creates a new domain bucket if none exist. Excludes cluster buckets.
def find_bucket_for_report(report_info: Report) -> int | None:
# Only apply domain bucketing for low-quality reports
if ClusterBucketManager.ok_to_cluster(
report_info.comments, report_info.ml_valid_probability
):
return None

hostname = report_info.url.hostname

if hostname is None:
Expand All @@ -39,21 +47,25 @@ def find_bucket_for_report(report_info: Report) -> int | None:
if (known_bucket := KNOWN_BUCKET_IDS.get(hostname)) is not None:
return known_bucket

candidates = Bucket.objects.filter(Q(domain=hostname)).values_list("id", flat=True)

if len(candidates) == 1:
KNOWN_BUCKET_IDS[hostname] = candidates[0]
return candidates[0]

if len(candidates) == 0:
bucket = Bucket.objects.create(
description=f"domain is {report_info.url.hostname}",
signature=report_info.create_signature().raw_signature,
)
KNOWN_BUCKET_IDS[hostname] = bucket.id
return bucket.id

return None
# Find domain buckets, excluding cluster buckets
bucket_id = (
Bucket.objects.filter(Q(domain=hostname))
.exclude(description__contains=ClusteringConfig.CLUSTER_BUCKET_IDENTIFIER)
.values_list("id", flat=True)
.first()
)

if bucket_id:
KNOWN_BUCKET_IDS[hostname] = bucket_id
return bucket_id

# No existing bucket found, create new one
bucket = Bucket.objects.create(
description=f"domain is {report_info.url.hostname}",
signature=report_info.create_signature().raw_signature,
)
KNOWN_BUCKET_IDS[hostname] = bucket.pk
return bucket.pk


class Command(BaseCommand):
Expand Down
Loading