From 7cf1db07276a00bfe2db7553c01a43bad2ce9dc5 Mon Sep 17 00:00:00 2001 From: aravadikesh Date: Fri, 24 Oct 2025 14:36:45 -0400 Subject: [PATCH] feat: Add perceptual hashing functionality with support for multiple algorithms - Implemented PerceptualHasher class for generating perceptual hashes using various algorithms (pHash, AverageHash, DHash, WaveletHash, PDQ). - Added methods for computing hashes for individual files and directories, as well as batch processing. - Created main plugin file with endpoints for creating, finding, exporting, importing, listing, and deleting hash collections. - Integrated CLI parsing for user-friendly command line interactions. - Established a pyproject.toml for dependency management and project metadata. --- pyproject.toml | 4 + src/perceptual-hash/README.md | 631 +++++++++++ src/perceptual-hash/__init__.py | 1 + .../perceptual_hash/__init__.py | 1 + .../perceptual_hash/database.py | 788 ++++++++++++++ src/perceptual-hash/perceptual_hash/hasher.py | 310 ++++++ src/perceptual-hash/perceptual_hash/main.py | 983 ++++++++++++++++++ src/perceptual-hash/pyproject.toml | 20 + 8 files changed, 2738 insertions(+) create mode 100644 src/perceptual-hash/README.md create mode 100644 src/perceptual-hash/__init__.py create mode 100644 src/perceptual-hash/perceptual_hash/__init__.py create mode 100644 src/perceptual-hash/perceptual_hash/database.py create mode 100644 src/perceptual-hash/perceptual_hash/hasher.py create mode 100644 src/perceptual-hash/perceptual_hash/main.py create mode 100644 src/perceptual-hash/pyproject.toml diff --git a/pyproject.toml b/pyproject.toml index 030afb3..7e648fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,9 @@ chromadb = "^1.0.4" pandas = ">=2.2.3,<3.0.0" pillow = ">=11.2.1,<12.0.0" fusepy = ">=3.0.1,<4.0.0" +perception = "^0.5.1" +psycopg2-binary = "^2.9.10" +pgvector = "^0.3.6" rb-lib = { path = "src/rb-lib", develop = true } @@ -48,6 +51,7 @@ hello-world = { path = "src/hello-world", develop = true } image-details = { path = "src/image-details", develop = true } image-summary = { path = "src/image-summary", develop = true } image-caption-blip-onnx = { path = "src/image-caption-blip-onnx", develop = true } +perceptual-hash = { path = "src/perceptual-hash", develop = true } # Don't add new packages here, add them appropriately in the list above diff --git a/src/perceptual-hash/README.md b/src/perceptual-hash/README.md new file mode 100644 index 0000000..cce6e28 --- /dev/null +++ b/src/perceptual-hash/README.md @@ -0,0 +1,631 @@ +# Perceptual Hash Plugin + +A forensic investigation plugin for detecting duplicate and near-duplicate images using perceptual hashing algorithms. This plugin enables large-scale image similarity detection using PostgreSQL with the pgvector extension for efficient vector-based similarity search. + +## Table of Contents + +- [Overview](#overview) +- [Features](#features) +- [How It Works](#how-it-works) +- [Prerequisites](#prerequisites) +- [Setup Instructions](#setup-instructions) + - [1. PostgreSQL Installation](#1-postgresql-installation) + - [2. pgvector Extension](#2-pgvector-extension) + - [3. Database Setup](#3-database-setup) + - [4. Environment Variables](#4-environment-variables) +- [Usage](#usage) + - [CLI Usage](#cli-usage) + - [Web UI Usage](#web-ui-usage) +- [Supported Hash Algorithms](#supported-hash-algorithms) +- [Understanding Results](#understanding-results) +- [Demo and Testing](#demo-and-testing) +- [Architecture](#architecture) +- [Troubleshooting](#troubleshooting) +- [Performance Considerations](#performance-considerations) + +--- + +## Overview + +The Perceptual Hash plugin is designed for forensic investigators who need to: +- Detect duplicate and near-duplicate images in large datasets +- Find modified versions of images (cropped, resized, filtered, etc.) +- Build searchable databases of known image content +- Perform similarity searches across image collections + +Unlike cryptographic hashes (MD5, SHA), perceptual hashes are robust to image transformations, making them ideal for forensic scenarios where images may be modified to evade detection. + +## Features + +- **6 Endpoint Operations**: + 1. Create Hash Database - Index a directory of images + 2. Find Matches - Search for similar images + 3. Export Database - Export hashes to JSON + 4. Import Database - Import hashes from JSON + 5. List Collections - View all collections + 6. Delete Collection - Remove a collection + +- **5 Hash Algorithms**: + - pHash (Perceptual Hash) - Recommended for general use + - Average Hash - Fast, good for exact matches + - dHash (Difference Hash) - Robust to brightness changes + - Wavelet Hash - Good for complex transformations + - PDQ Hash - Facebook's production-grade algorithm + +- **PostgreSQL + pgvector Backend**: + - Scalable vector similarity search + - Hamming distance-based matching + - Efficient indexing for large datasets + - Support for multiple collections + +## How It Works + +### 1. Hash Generation + +The plugin uses the [perception](https://github.com/thorn-oss/perception) library to compute perceptual hashes: + +``` +Image → Perceptual Hash Algorithm → Binary Vector (64-256 dimensions) +``` + +Each algorithm converts an image into a fixed-length binary vector that captures visual features: +- **pHash**: Uses Discrete Cosine Transform (DCT) to capture frequency patterns +- **Average Hash**: Compares pixels to average brightness +- **dHash**: Captures gradients between adjacent pixels +- **Wavelet Hash**: Uses Haar wavelet decomposition +- **PDQ**: Facebook's robust hash optimized for scale + +### 2. Storage in PostgreSQL + +Hashes are stored as vectors in PostgreSQL using the pgvector extension: + +```sql +CREATE TABLE collection_algorithm ( + id SERIAL PRIMARY KEY, + file_path TEXT NOT NULL, + hash_string TEXT NOT NULL, + hash_vector vector(64) NOT NULL, -- Binary vector representation + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + UNIQUE(file_path) +); +``` + +### 3. Similarity Search + +When searching for matches, the plugin uses **Hamming distance** (L1 distance for binary vectors): + +```sql +SELECT file_path, hash_string, + hash_vector <+> query_vector AS hamming_distance +FROM collection_algorithm +WHERE hash_vector <+> query_vector <= threshold +ORDER BY hamming_distance +LIMIT n_results; +``` + +**Hamming distance** counts the number of differing bits: +- Distance 0 = Identical images +- Distance 1-5 = Near-identical (extremely similar) +- Distance 6-15 = Very similar (likely same content) +- Distance 16-30 = Similar (related content) +- Distance > 30 = Potentially different images + +### 4. Match Quality Scoring + +Results include: +- **hamming_distance**: Number of differing bits (lower = more similar) +- **similarity**: Normalized score from 0.0 (different) to 1.0 (identical) +- **quality**: Human-readable label (exact, very_similar, similar, somewhat_similar) + +--- + +## Prerequisites + +Before using the Perceptual Hash plugin, you need: + +1. **PostgreSQL** (version 12 or higher) +2. **pgvector extension** for PostgreSQL +3. **Python 3.11+** with Poetry +4. **RescueBox** development environment + +--- + +## Setup Instructions + +### 1. PostgreSQL Installation + +#### macOS (using Homebrew) + +```bash +# Install PostgreSQL +brew install postgresql@15 + +# Start PostgreSQL service +brew services start postgresql@15 + +# Verify installation +psql --version +``` + +#### Linux (Ubuntu/Debian) + +```bash +# Install PostgreSQL +sudo apt update +sudo apt install postgresql postgresql-contrib + +# Start PostgreSQL service +sudo systemctl start postgresql +sudo systemctl enable postgresql + +# Verify installation +psql --version +``` + +#### Docker (Alternative) + +```bash +# Run PostgreSQL in Docker +docker run -d \ + --name rescuebox-postgres \ + -e POSTGRES_PASSWORD=test \ + -e POSTGRES_USER=test \ + -e POSTGRES_DB=rescuebox \ + -p 5432:5432 \ + postgres:15 +``` + +### 2. pgvector Extension + +pgvector must be compiled and installed from source: + +```bash +# Clone pgvector repository +cd /tmp +git clone https://github.com/pgvector/pgvector.git +cd pgvector + +# Compile and install (requires PostgreSQL development headers) +# macOS: +brew install postgresql@15 # Includes dev headers + +# Linux: +sudo apt install postgresql-server-dev-15 + +# Compile +make +sudo make install + +# Verify installation +cd /Users/aravadikesh/Documents/GitHub/RescueBox +ls -la pgvector/ # Should see vector.so and other files +``` + +The RescueBox repository includes a pre-built `vector.so` in the `pgvector/` directory for convenience. + +### 3. Database Setup + +#### Create Database and User + +```bash +# Connect to PostgreSQL as superuser +psql postgres + +# Create user and database +CREATE USER test WITH PASSWORD 'test'; +CREATE DATABASE rescuebox OWNER test; + +# Grant privileges +GRANT ALL PRIVILEGES ON DATABASE rescuebox TO test; + +# Exit psql +\q +``` + +#### Enable pgvector Extension + +```bash +# Connect to the rescuebox database +psql -U test -d rescuebox + +# Enable pgvector extension +CREATE EXTENSION IF NOT EXISTS vector; + +# Verify installation +SELECT extversion FROM pg_extension WHERE extname = 'vector'; + +# You should see output like: +# extversion +# ------------ +# 0.7.4 +# (1 row) + +# Exit psql +\q +``` + +#### Test Connection + +```bash +# Test connection with environment variables +export POSTGRES_HOST=localhost +export POSTGRES_PORT=5432 +export POSTGRES_USER=test +export POSTGRES_PASSWORD=test +export POSTGRES_DB=rescuebox + +# Test query +psql -h "${POSTGRES_HOST}" -p "${POSTGRES_PORT}" -U "${POSTGRES_USER}" -d "${POSTGRES_DB}" \ + -c "SELECT extversion FROM pg_extension WHERE extname = 'vector';" +``` + +If successful, you should see the pgvector version number. + +### 4. Environment Variables + +The plugin uses environment variables for database configuration. Add these to your shell configuration file (`~/.bashrc`, `~/.zshrc`, or `.env`): + +```bash +# PostgreSQL Configuration +export POSTGRES_HOST=localhost # Database host +export POSTGRES_PORT=5432 # Database port +export POSTGRES_USER=test # Database user +export POSTGRES_PASSWORD=test # Database password +export POSTGRES_DB=rescuebox # Database name + +# Optional: Enable testing mode (uses separate rescuebox_test database) +# export IS_TESTING=true +``` + +**For persistent configuration**, add to your shell profile: + +```bash +# Add to ~/.bashrc or ~/.zshrc +echo 'export POSTGRES_HOST=localhost' >> ~/.bashrc +echo 'export POSTGRES_PORT=5432' >> ~/.bashrc +echo 'export POSTGRES_USER=test' >> ~/.bashrc +echo 'export POSTGRES_PASSWORD=test' >> ~/.bashrc +echo 'export POSTGRES_DB=rescuebox' >> ~/.bashrc + +# Reload shell configuration +source ~/.bashrc +``` + +**For VS Code Dev Container**, add to [.devcontainer/devcontainer.json](../../.devcontainer/devcontainer.json): + +```json +{ + "containerEnv": { + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + "POSTGRES_USER": "test", + "POSTGRES_PASSWORD": "test", + "POSTGRES_DB": "rescuebox" + } +} +``` + +--- + +## Usage + +### CLI Usage + +The plugin provides 6 endpoints accessible via CLI: + +#### 1. Create Hash Database + +Index a directory of images to create a searchable collection: + +```bash +poetry run rescuebox perceptual-hash/create_database \ + /path/to/images \ + --params "my_collection,phash,true" + +# Parameters format: collection_name,hash_algorithm,recursive +# - collection_name: Name for this collection (e.g., "evidence_photos") +# - hash_algorithm: phash|average|dhash|wavelet|pdq +# - recursive: true|false (search subdirectories) +``` + +**Example:** + +```bash +# Index all images in a directory recursively +poetry run rescuebox perceptual-hash/create_database \ + src/perceptual-hash/demo/original_images \ + --params "evidence_set1,phash,true" + +# Output: +# Successfully created collection 'evidence_set1' with 150 hashes using phash algorithm. +# Total hashes in collection: 150 +``` + +#### 2. Find Matches + +Search for similar images in a collection: + +```bash +poetry run rescuebox perceptual-hash/find_matches \ + /path/to/query_images,output.json \ + --params "my_collection,phash,10.0,10" + +# Parameters format: collection_name,hash_algorithm,max_distance,max_results +# - collection_name: Collection to search in +# - hash_algorithm: Must match the collection's algorithm +# - max_distance: Maximum Hamming distance threshold (0-100) +# - max_results: Max results per query (1-100) +``` + +**Example:** + +```bash +# Find similar images with Hamming distance <= 10 +poetry run rescuebox perceptual-hash/find_matches \ + src/perceptual-hash/demo/query_images,matches.json \ + --params "evidence_set1,phash,10.0,10" + +# Output: +# Successfully found 8 matches for 3 query files. Results saved to matches.json +``` + +**Output JSON format:** + +```json +{ + "metadata": { + "collection_name": "evidence_set1", + "hash_algorithm": "phash", + "max_distance": 10.0, + "max_results_per_query": 10, + "total_query_files": 3, + "total_matches_found": 8 + }, + "results": { + "/path/to/query1.jpg": { + "matches": [ + { + "file_path": "/path/to/original1.jpg", + "hamming_distance": 0, + "similarity": 1.0, + "quality": "exact" + }, + { + "file_path": "/path/to/original2.jpg", + "hamming_distance": 5, + "similarity": 0.921875, + "quality": "very_similar" + } + ], + "total_matches": 2 + } + } +} +``` + +#### 3. Export Database + +Export a collection to a portable JSON file: + +```bash +poetry run rescuebox perceptual-hash/export_database \ + export.json \ + --params "my_collection,phash" + +# Parameters format: collection_name,hash_algorithm +``` + +**Example:** + +```bash +poetry run rescuebox perceptual-hash/export_database \ + evidence_set1_backup.json \ + --params "evidence_set1,phash" + +# Output: +# Successfully exported collection 'evidence_set1' (phash) with 150 hashes to evidence_set1_backup.json +``` + +#### 4. Import Database + +Import a collection from JSON: + +```bash +poetry run rescuebox perceptual-hash/import_database \ + export.json \ + --params "new_collection_name" + +# Parameters: new_collection_name (optional - uses original name if empty) +``` + +**Example:** + +```bash +# Import with new name +poetry run rescuebox perceptual-hash/import_database \ + evidence_set1_backup.json \ + --params "evidence_set1_restored" + +# Import with original name +poetry run rescuebox perceptual-hash/import_database \ + evidence_set1_backup.json \ + --params "" +``` + +#### 5. List Collections + +View all available collections: + +```bash +poetry run rescuebox perceptual-hash/list_collections "" + +# Output: +# Available Collections: +# +# evidence_set1: +# - phash: 150 hashes (dimension: 64) +# - pdq: 150 hashes (dimension: 256) +# +# evidence_set2: +# - average: 85 hashes (dimension: 64) +``` + +#### 6. Delete Collection + +Remove a collection: + +```bash +poetry run rescuebox perceptual-hash/delete_collection \ + "" \ + --params "my_collection,phash" + +# Parameters format: collection_name,hash_algorithm +``` + +**Example:** + +```bash +poetry run rescuebox perceptual-hash/delete_collection \ + "" \ + --params "evidence_set1,phash" + +# Output: +# Successfully deleted collection 'evidence_set1' (phash) +``` + +### Web UI Usage + +The plugin is also accessible via the RescueBox AutoUI: + +1. **Start the backend server:** + +```bash +cd /Users/aravadikesh/Documents/GitHub/RescueBox +./run_server + +# Or manually: +poetry run python -m src.rb-api.rb.api.main +``` + +2. **Access the web interface:** + +Open your browser to [http://localhost:8000](http://localhost:8000) + +3. **Select "Perceptual Hash" from the plugin list** + +4. **Use the dynamically generated UI:** + - **Create Database**: Upload or select a directory, choose algorithm and collection name + - **Find Matches**: Select query directory, choose collection and search parameters + - **Export/Import**: Manage collection backups + - **List/Delete**: View and manage collections + +The UI automatically validates inputs and provides real-time feedback. + +--- + +## Supported Hash Algorithms + +### 1. pHash (Perceptual Hash) - **Recommended** + +```bash +--params "collection,phash,true" +``` + +- **Method**: Discrete Cosine Transform (DCT) +- **Dimensions**: 64 bits +- **Best for**: General-purpose similarity detection +- **Robust to**: Scaling, aspect ratio changes, minor compression, brightness/contrast adjustments +- **Use cases**: Finding modified images, duplicate detection, general forensic work + +### 2. Average Hash + +```bash +--params "collection,average,true" +``` + +- **Method**: Compare pixels to average brightness +- **Dimensions**: 64 bits +- **Best for**: Fast exact duplicate detection +- **Robust to**: Minor compression, small color changes +- **Use cases**: Quick deduplication, nearly-identical image detection + +### 3. dHash (Difference Hash) + +```bash +--params "collection,dhash,true" +``` + +- **Method**: Gradient-based (differences between adjacent pixels) +- **Dimensions**: 64 bits +- **Best for**: Images with different brightness/contrast +- **Robust to**: Brightness changes, contrast adjustments, gamma correction +- **Use cases**: Finding images with exposure/lighting differences + +### 4. Wavelet Hash + +```bash +--params "collection,wavelet,true" +``` + +- **Method**: Haar wavelet decomposition +- **Dimensions**: 64 bits +- **Best for**: Complex image transformations +- **Robust to**: Rotation, complex filtering, artistic modifications +- **Use cases**: Finding heavily modified images + +### 5. PDQ Hash + +```bash +--params "collection,pdq,true" +``` + +- **Method**: Facebook's production-grade hash (DCT-based with quality improvements) +- **Dimensions**: 256 bits +- **Best for**: Large-scale production deployments +- **Robust to**: All common transformations, optimized for CSAM detection +- **Use cases**: Large datasets, law enforcement applications, high-accuracy requirements + +--- + +## Understanding Results + +### Hamming Distance Interpretation + +The Hamming distance indicates how many bits differ between two hashes: + +| Distance | Quality | Interpretation | Example Scenarios | +|----------|---------|----------------|-------------------| +| 0 | Exact | Identical images | Exact duplicate, same file | +| 1-5 | Very Similar | Near-identical | Resized, re-compressed, watermarked | +| 6-15 | Similar | Likely same content | Cropped, filtered, color-adjusted | +| 16-30 | Somewhat Similar | Related content | Different angles, partial matches | +| 31+ | Different | Likely unrelated | Different images | + +### Similarity Score + +The similarity score is normalized: `similarity = 1.0 - (hamming_distance / vector_dimension)` + +- **1.0** = Perfect match (0 bits different) +- **0.9-0.99** = Near-identical (1-6 bits different for 64-bit hash) +- **0.75-0.89** = Very similar (7-16 bits different) +- **0.5-0.74** = Somewhat similar (17-32 bits different) +- **< 0.5** = Likely different images + +### Setting Thresholds + +**For high precision (fewer false positives):** +```bash +--params "collection,phash,5.0,10" # Very strict +``` + +**For high recall (catch more matches):** +```bash +--params "collection,phash,20.0,50" # More permissive +``` + +**Recommended thresholds by use case:** + +- **Exact duplicates**: max_distance = 5 +- **Near-duplicates** (resized, compressed): max_distance = 10 +- **Modified images** (cropped, filtered): max_distance = 15-20 +- **Related content** (same scene, different angle): max_distance = 25-30 diff --git a/src/perceptual-hash/__init__.py b/src/perceptual-hash/__init__.py new file mode 100644 index 0000000..779340c --- /dev/null +++ b/src/perceptual-hash/__init__.py @@ -0,0 +1 @@ +"""Perceptual Hash Plugin for RescueBox""" diff --git a/src/perceptual-hash/perceptual_hash/__init__.py b/src/perceptual-hash/perceptual_hash/__init__.py new file mode 100644 index 0000000..7b208a3 --- /dev/null +++ b/src/perceptual-hash/perceptual_hash/__init__.py @@ -0,0 +1 @@ +"""Perceptual Hash Plugin""" diff --git a/src/perceptual-hash/perceptual_hash/database.py b/src/perceptual-hash/perceptual_hash/database.py new file mode 100644 index 0000000..823a736 --- /dev/null +++ b/src/perceptual-hash/perceptual_hash/database.py @@ -0,0 +1,788 @@ +"""Database module for storing and querying perceptual hashes using PostgreSQL with pgvector. + +IMPROVED SCHEMA with better naming and structure. +""" + +import os +import json +import math +from pathlib import Path +from typing import List, Dict, Any, Optional +import logging + +import numpy as np +import psycopg2 +from psycopg2.extras import execute_values +from pgvector.psycopg2 import register_vector + +logger = logging.getLogger(__name__) + + +class HashDatabase: + """PostgreSQL with pgvector-based database for storing and querying perceptual hashes. + + IMPROVED VERSION with: + - Better table naming (no redundant prefixes) + - Timestamp with timezone + - File path index + - Correct L1/Hamming distance indexes + - Optimal index parameters + + Supports context manager for automatic connection cleanup: + with HashDatabase() as db: + db.add_hashes(...) + """ + + def __init__(self): + """Initialize the hash database connection.""" + # Get database connection parameters from environment variables + self.db_host = os.environ.get("POSTGRES_HOST", "localhost") + self.db_port = os.environ.get("POSTGRES_PORT", "5432") + self.db_name = os.environ.get("POSTGRES_DB", "rescuebox") + self.db_user = os.environ.get("POSTGRES_USER", "test") + self.db_password = os.environ.get("POSTGRES_PASSWORD", "test") + + # Check testing flag with flexible values + testing = os.environ.get("IS_TESTING", "false").lower() + if testing in ("true", "1", "yes", "on"): + # Use separate test database + self.db_name = "rescuebox_test" + + self.conn = None + self._max_reconnect_attempts = 3 + self._connect() + self._ensure_extension() + self._ensure_base_table() + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensures connection is closed.""" + self.close() + return False # Don't suppress exceptions + + def _connect(self): + """Establish database connection with retry logic.""" + for attempt in range(self._max_reconnect_attempts): + try: + self.conn = psycopg2.connect( + host=self.db_host, + port=self.db_port, + database=self.db_name, + user=self.db_user, + password=self.db_password + ) + # Register pgvector extension + register_vector(self.conn) + logger.info(f"Connected to PostgreSQL database: {self.db_name}") + return + except psycopg2.OperationalError as e: + if attempt < self._max_reconnect_attempts - 1: + logger.warning(f"Connection attempt {attempt + 1} failed, retrying...") + else: + logger.error(f"Failed to connect to PostgreSQL after {self._max_reconnect_attempts} attempts: {e}") + logger.error( + "Make sure PostgreSQL is running and accessible. " + "Set POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB, POSTGRES_USER, POSTGRES_PASSWORD environment variables if needed." + ) + raise + + def _ensure_connected(self): + """Ensure connection is alive, reconnect if needed. + + Call this at the start of each public method to handle connection drops. + """ + try: + if self.conn is None or self.conn.closed: + logger.warning("Connection was closed, reconnecting...") + self._connect() + # Re-initialize after reconnection + self._ensure_extension() + self._ensure_base_table() + else: + # Test connection with lightweight query + with self.conn.cursor() as cur: + cur.execute("SELECT 1") + except (psycopg2.OperationalError, psycopg2.InterfaceError) as e: + logger.warning(f"Connection test failed: {e}, reconnecting...") + self._connect() + # Re-initialize after reconnection + self._ensure_extension() + self._ensure_base_table() + + def _ensure_extension(self): + """Ensure pgvector extension is installed.""" + try: + with self.conn.cursor() as cur: + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + self.conn.commit() + except Exception as e: + self.conn.rollback() + logger.error(f"Failed to create pgvector extension: {e}") + raise + + def _ensure_base_table(self): + """Ensure the collections tracking table exists.""" + try: + with self.conn.cursor() as cur: + # Create table to track collections + # IMPROVED: Better naming, timezone-aware timestamps + cur.execute(""" + CREATE TABLE IF NOT EXISTS hash_collections ( + id SERIAL PRIMARY KEY, + collection_name VARCHAR(255) NOT NULL, + hash_algorithm VARCHAR(50) NOT NULL, + vector_dimension INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + total_hashes INTEGER DEFAULT 0, + UNIQUE(collection_name, hash_algorithm) + ) + """) + + # Create index on created_at for time-based queries + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_hash_collections_created_at + ON hash_collections (created_at DESC) + """) + + self.conn.commit() + logger.info("Ensured hash_collections table exists") + except Exception as e: + self.conn.rollback() + logger.error(f"Failed to create collections table: {e}") + raise + + def create_collection_name(self, base_name: str, hash_algorithm: str) -> str: + """ + Create a collection table name. + + IMPROVED: Simpler naming without redundant prefixes + Format: {collection}_{algorithm} (e.g., "photos_pdq", "videos_phash") + + Args: + base_name: Base collection name + hash_algorithm: Hash algorithm (e.g., 'pdq', 'phash', 'dhash') + + Returns: + Table name (sanitized for SQL) + + Raises: + ValueError: If names contain invalid characters + """ + import re + + # Sanitize: only allow alphanumeric and underscore + safe_base = re.sub(r'[^a-zA-Z0-9_]', '_', base_name).lower() + safe_algo = re.sub(r'[^a-zA-Z0-9_]', '_', hash_algorithm).lower() + + # Simple format: collection_algorithm + table_name = f"{safe_base}_{safe_algo}" + + # Ensure starts with letter + if table_name[0].isdigit(): + table_name = f"h_{table_name}" + + # Truncate to PostgreSQL identifier limit (63 chars) + if len(table_name) > 63: + table_name = table_name[:63] + logger.warning(f"Table name truncated to 63 characters: {table_name}") + + return table_name + + def get_available_collections(self) -> List[str]: + """ + Get list of available collection base names. + + Returns: + List of unique collection base names (without hash algorithm suffix) + """ + self._ensure_connected() + + try: + with self.conn.cursor() as cur: + cur.execute(""" + SELECT DISTINCT collection_name + FROM hash_collections + ORDER BY collection_name + """) + results = cur.fetchall() + return [row[0] for row in results] + except Exception as e: + logger.error(f"Failed to get available collections: {e}") + return [] + + def get_or_create_collection( + self, + collection_name: str, + hash_algorithm: str, + vector_dimension: Optional[int] = None + ) -> str: + """ + Get or create a collection table for a specific hash algorithm. + + IMPROVED: Better schema with file_path index and timezone timestamps + + Args: + collection_name: Base collection name + hash_algorithm: Hash algorithm name + vector_dimension: Dimension of hash vectors (required for new tables) + + Returns: + Table name + + Raises: + ValueError: If vector_dimension is None for new collection + """ + self._ensure_connected() + + table_name = self.create_collection_name(collection_name, hash_algorithm) + + try: + with self.conn.cursor() as cur: + # Check if table exists + cur.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = %s AND table_schema = 'public' + ) + """, (table_name,)) + table_exists = cur.fetchone()[0] + + if not table_exists: + if vector_dimension is None: + raise ValueError( + f"vector_dimension is required when creating new collection {table_name}" + ) + + # Register collection in tracking table + cur.execute(""" + INSERT INTO hash_collections (collection_name, hash_algorithm, vector_dimension, total_hashes) + VALUES (%s, %s, %s, 0) + ON CONFLICT (collection_name, hash_algorithm) + DO UPDATE SET + vector_dimension = EXCLUDED.vector_dimension, + updated_at = CURRENT_TIMESTAMP + """, (collection_name, hash_algorithm, vector_dimension)) + + # Create the hash table with IMPROVED schema + cur.execute(f""" + CREATE TABLE {table_name} ( + id SERIAL PRIMARY KEY, + file_path TEXT NOT NULL, + hash_string TEXT NOT NULL, + hash_vector vector({vector_dimension}) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + UNIQUE(file_path) + ) + """) + + # Add index on file_path for faster lookups + # (UNIQUE constraint already creates index, but explicit is clearer) + + # Add index on created_at for time-based queries + cur.execute(f""" + CREATE INDEX IF NOT EXISTS idx_{table_name}_created_at + ON {table_name} (created_at DESC) + """) + + # Add table comment for documentation + cur.execute(f""" + COMMENT ON TABLE {table_name} IS + 'Perceptual hashes for collection {collection_name} using {hash_algorithm} algorithm ({vector_dimension}-dimensional vectors)' + """) + + logger.info(f"Created table {table_name} with vector dimension {vector_dimension}") + + self.conn.commit() + + return table_name + except Exception as e: + logger.error(f"Failed to create collection {table_name}: {e}") + self.conn.rollback() + raise + + def _validate_hashes(self, hashes: List[Dict[str, Any]]) -> int: + """ + Validate hash data and return vector dimension. + + Args: + hashes: List of hash dictionaries + + Returns: + Vector dimension + + Raises: + ValueError: If hashes are invalid + """ + if not hashes: + raise ValueError("hashes list cannot be empty") + + # Validate first hash structure + first_hash = hashes[0] + required_keys = ["hash_vector", "file_path", "hash_string"] + for key in required_keys: + if key not in first_hash: + raise ValueError(f"Hash must contain '{key}' key") + + vector_dimension = len(first_hash["hash_vector"]) + + # Validate all hashes + for i, h in enumerate(hashes): + # Check required keys + for key in required_keys: + if key not in h: + raise ValueError(f"Hash {i} missing required key: {key}") + + # Check dimension consistency + if len(h["hash_vector"]) != vector_dimension: + raise ValueError( + f"Hash {i} has dimension {len(h['hash_vector'])}, " + f"expected {vector_dimension}" + ) + + # Check for invalid values + vec = np.array(h["hash_vector"]) + if np.any(np.isnan(vec)): + raise ValueError(f"Hash {i} contains NaN values") + if np.any(np.isinf(vec)): + raise ValueError(f"Hash {i} contains infinity values") + + return vector_dimension + + def add_hashes( + self, + collection_name: str, + hash_algorithm: str, + hashes: List[Dict[str, Any]] + ): + """ + Add perceptual hashes to the database. + + IMPROVED: Creates index with correct L1 distance operator and optimal parameters + + Args: + collection_name: Base collection name + hash_algorithm: Hash algorithm used + hashes: List of hash dictionaries with 'hash_vector', 'file_path', and 'hash_string' keys + + Raises: + ValueError: If hashes are invalid + """ + if not hashes: + logger.info("No hashes to add") + return + + self._ensure_connected() + + # Validate hashes + vector_dimension = self._validate_hashes(hashes) + + table_name = self.get_or_create_collection(collection_name, hash_algorithm, vector_dimension) + + try: + with self.conn.cursor() as cur: + # Prepare data for insertion + data = [ + ( + h["file_path"], + h["hash_string"], + np.array(h["hash_vector"]) + ) + for h in hashes + ] + + # Use ON CONFLICT to handle duplicates + execute_values( + cur, + f""" + INSERT INTO {table_name} (file_path, hash_string, hash_vector) + VALUES %s + ON CONFLICT (file_path) DO UPDATE + SET hash_string = EXCLUDED.hash_string, + hash_vector = EXCLUDED.hash_vector, + created_at = CURRENT_TIMESTAMP + """, + data + ) + + # Update total_hashes count + cur.execute(f"SELECT COUNT(*) FROM {table_name}") + total_count = cur.fetchone()[0] + + cur.execute(""" + UPDATE hash_collections + SET total_hashes = %s, updated_at = CURRENT_TIMESTAMP + WHERE collection_name = %s AND hash_algorithm = %s + """, (total_count, collection_name, hash_algorithm)) + + # Verify the UPDATE actually affected a row + if cur.rowcount == 0: + logger.warning( + f"UPDATE hash_collections affected 0 rows for " + f"collection_name={collection_name}, hash_algorithm={hash_algorithm}. " + f"This might indicate the tracking table row doesn't exist." + ) + + # Create index if we have enough data + # Note: IVFFlat does not support L1 distance (Hamming) operator class + # For Hamming distance on binary vectors, queries will use <+> operator without index acceleration + # We skip index creation to keep queries using true Hamming distance + if total_count >= 100: + logger.info( + f"Collection {table_name} has {total_count} hashes. " + f"Note: IVFFlat indexes do not support Hamming distance (L1). " + f"Queries will use exact Hamming distance calculation without index acceleration." + ) + # Index creation skipped intentionally to preserve Hamming distance semantics + + self.conn.commit() + logger.info(f"Added {len(hashes)} hashes to {table_name} (total: {total_count})") + + except Exception as e: + logger.error(f"Failed to add hashes to {table_name}: {e}") + self.conn.rollback() + raise + + def query_hashes( + self, + collection_name: str, + hash_algorithm: str, + query_hashes: List[Dict[str, Any]], + n_results: int = 10, + threshold: Optional[float] = None, + ) -> List[List[Dict[str, Any]]]: + """ + Query the database for similar hashes using Hamming distance (L1). + + IMPROVED: Uses L1 distance operator for binary perceptual hashes + + Args: + collection_name: Base collection name + hash_algorithm: Hash algorithm to use + query_hashes: List of query hash dictionaries with 'hash_vector' key + n_results: Number of results to return per query (max 10000) + threshold: Maximum Hamming distance threshold (None = no filtering) + + Returns: + List of results for each query hash. Each result contains: + - file_path: Path to the file + - hash_string: String representation of the hash + - hamming_distance: Integer Hamming distance + - distance: Float distance value + - similarity: Normalized similarity (0.0 to 1.0) + + Raises: + ValueError: If parameters are invalid + """ + # Validate parameters + if n_results < 1: + raise ValueError(f"n_results must be positive, got {n_results}") + if n_results > 10000: + logger.warning(f"Large n_results ({n_results}) capped at 10000") + n_results = 10000 + + if threshold is not None and threshold < 0: + raise ValueError(f"threshold must be non-negative, got {threshold}") + + if not query_hashes: + logger.warning("No query hashes provided") + return [] + + self._ensure_connected() + + table_name = self.create_collection_name(collection_name, hash_algorithm) + + # Check if table exists + try: + with self.conn.cursor() as cur: + cur.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = %s AND table_schema = 'public' + ) + """, (table_name,)) + if not cur.fetchone()[0]: + logger.warning(f"Collection {collection_name} with algorithm {hash_algorithm} does not exist") + return [[] for _ in query_hashes] + except Exception as e: + logger.error(f"Failed to check if table exists: {e}") + return [[] for _ in query_hashes] + + all_results: List[List[Dict[str, Any]]] = [] + + # Process each query individually to handle errors gracefully + for idx, query_hash in enumerate(query_hashes): + try: + with self.conn.cursor() as cur: + # Ensure vector is a numpy array + query_vector = np.array(query_hash["hash_vector"], dtype=float) + vec_len = query_vector.size + + # Build query with L1 distance (Hamming for binary hashes) + if threshold is not None: + sql = f""" + SELECT file_path, hash_string, + hash_vector <+> %s::vector AS hamming_distance + FROM {table_name} + WHERE hash_vector <+> %s::vector <= %s + ORDER BY hamming_distance + LIMIT %s + """ + params = (query_vector, query_vector, threshold, n_results) + else: + sql = f""" + SELECT file_path, hash_string, + hash_vector <+> %s::vector AS hamming_distance + FROM {table_name} + ORDER BY hamming_distance + LIMIT %s + """ + params = (query_vector, n_results) + + cur.execute(sql, params) + results = cur.fetchall() + + query_results: List[Dict[str, Any]] = [] + for file_path, hash_string, distance in results: + # For binary hashes, distance is Hamming distance + hamming = int(distance) + # Normalized similarity: 0 bits different = 1.0, all bits different = 0.0 + similarity = 1.0 - (hamming / float(vec_len)) if vec_len > 0 else 0.0 + + query_results.append({ + "file_path": file_path, + "hash_string": hash_string, + "hamming_distance": hamming, + "distance": float(distance), + "similarity": similarity, + }) + + all_results.append(query_results) + + except KeyError as e: + logger.error(f"Query hash {idx} missing required key: {e}") + all_results.append([]) + except Exception as e: + logger.error(f"Failed to query hash {idx} from {table_name}: {e}") + all_results.append([]) + + return all_results + + def export_collection( + self, + collection_name: str, + hash_algorithm: str, + output_path: str + ): + """ + Export a collection to a JSON file. + + Args: + collection_name: Base collection name + hash_algorithm: Hash algorithm + output_path: Path to output JSON file + + Raises: + ValueError: If collection doesn't exist + """ + self._ensure_connected() + + table_name = self.create_collection_name(collection_name, hash_algorithm) + + # Check if table exists + try: + with self.conn.cursor() as cur: + cur.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = %s AND table_schema = 'public' + ) + """, (table_name,)) + + if not cur.fetchone()[0]: + available = self.get_available_collections() + raise ValueError( + f"Collection '{collection_name}' with algorithm '{hash_algorithm}' does not exist. " + f"Available collections: {available}" + ) + except psycopg2.Error as e: + logger.error(f"Failed to check collection existence: {e}") + raise + + try: + with self.conn.cursor() as cur: + cur.execute(f""" + SELECT file_path, hash_string, hash_vector + FROM {table_name} + ORDER BY created_at + """) + + results = cur.fetchall() + + export_data = { + "collection_name": collection_name, + "hash_algorithm": hash_algorithm, + "hashes": [] + } + + for file_path, hash_string, hash_vector in results: + # Convert pgvector to list of Python floats (for JSON serialization) + if hash_vector is not None: + vector_list = [float(x) for x in hash_vector] + else: + vector_list = [] + export_data["hashes"].append({ + "hash_vector": vector_list, + "file_path": file_path, + "hash_string": hash_string, + }) + + # Write to file + with open(output_path, 'w') as f: + json.dump(export_data, f, indent=2) + + logger.info(f"Exported {len(results)} hashes from {table_name} to {output_path}") + + except Exception as e: + logger.error(f"Failed to export collection {table_name}: {e}") + raise + + def import_collection(self, input_path: str, new_collection_name: Optional[str] = None): + """ + Import a collection from a JSON file. + + Args: + input_path: Path to input JSON file + new_collection_name: Optional new name for the collection (uses original if None) + """ + self._ensure_connected() + + with open(input_path, 'r') as f: + import_data = json.load(f) + + collection_name = new_collection_name or import_data["collection_name"] + hash_algorithm = import_data["hash_algorithm"] + + # Add hashes to collection + self.add_hashes(collection_name, hash_algorithm, import_data["hashes"]) + logger.info(f"Imported {len(import_data['hashes'])} hashes to {collection_name}") + + def delete_collection(self, collection_name: str, hash_algorithm: str): + """ + Delete a collection. + + Args: + collection_name: Base collection name + hash_algorithm: Hash algorithm + """ + self._ensure_connected() + + table_name = self.create_collection_name(collection_name, hash_algorithm) + + try: + with self.conn.cursor() as cur: + # Remove from tracking table + cur.execute(""" + DELETE FROM hash_collections + WHERE collection_name = %s AND hash_algorithm = %s + """, (collection_name, hash_algorithm)) + + # Drop the table + cur.execute(f"DROP TABLE IF EXISTS {table_name} CASCADE") + + self.conn.commit() + logger.info(f"Deleted collection {table_name}") + + except Exception as e: + logger.error(f"Failed to delete collection {table_name}: {e}") + self.conn.rollback() + raise + + def get_collection_stats(self, collection_name: str, hash_algorithm: str) -> Dict[str, Any]: + """ + Get statistics about a collection. + + Args: + collection_name: Base collection name + hash_algorithm: Hash algorithm + + Returns: + Dictionary with collection statistics + """ + self._ensure_connected() + + table_name = self.create_collection_name(collection_name, hash_algorithm) + + try: + with self.conn.cursor() as cur: + # Check if table exists first + cur.execute(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = %s AND table_schema = 'public' + ) + """, (table_name,)) + + if not cur.fetchone()[0]: + logger.info(f"Collection {table_name} does not exist") + return { + "collection_name": collection_name, + "hash_algorithm": hash_algorithm, + "total_hashes": 0, + "vector_dimension": 0, + } + + # Get stats from tracking table + cur.execute(""" + SELECT total_hashes, vector_dimension, created_at, updated_at + FROM hash_collections + WHERE collection_name = %s AND hash_algorithm = %s + """, (collection_name, hash_algorithm)) + + result = cur.fetchone() + if result: + total_hashes, vector_dim, created_at, updated_at = result + return { + "collection_name": collection_name, + "hash_algorithm": hash_algorithm, + "total_hashes": total_hashes, + "vector_dimension": vector_dim, + "created_at": created_at.isoformat() if created_at else None, + "updated_at": updated_at.isoformat() if updated_at else None, + } + else: + # Fallback to counting + cur.execute(f"SELECT COUNT(*) FROM {table_name}") + count = cur.fetchone()[0] + return { + "collection_name": collection_name, + "hash_algorithm": hash_algorithm, + "total_hashes": count, + } + except Exception as e: + logger.error(f"Failed to get stats for {table_name}: {e}") + return { + "collection_name": collection_name, + "hash_algorithm": hash_algorithm, + "total_hashes": 0, + } + + def close(self): + """Close database connection. Safe to call multiple times.""" + if self.conn and not self.conn.closed: + try: + self.conn.close() + logger.info("Closed PostgreSQL connection") + except Exception as e: + logger.warning(f"Error closing connection: {e}") + self.conn = None + + def __del__(self): + """Cleanup on deletion. Note: __exit__ is preferred for guaranteed cleanup.""" + try: + self.close() + except Exception: + pass # Ignore errors in destructor \ No newline at end of file diff --git a/src/perceptual-hash/perceptual_hash/hasher.py b/src/perceptual-hash/perceptual_hash/hasher.py new file mode 100644 index 0000000..e23df67 --- /dev/null +++ b/src/perceptual-hash/perceptual_hash/hasher.py @@ -0,0 +1,310 @@ +"""Perceptual hashing module using the perception library.""" + +import os +from pathlib import Path +from typing import List, Dict, Any, Optional +import logging + +import numpy as np +from PIL import Image +from perception import hashers + +logger = logging.getLogger(__name__) + +# Supported image extensions +IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp'} + +# Video extensions (for future support) +VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv'} + + +class PerceptualHasher: + """Perceptual hash generator using the Thorn perception library.""" + + def __init__(self, hash_algorithm: str = "phash"): + """ + Initialize the hasher with a specific algorithm. + + Args: + hash_algorithm: Hash algorithm to use ('phash', 'average', 'dhash', 'wavelet', 'pdq') + """ + self.hash_algorithm = hash_algorithm.lower() + + # Initialize the appropriate hasher based on algorithm + if self.hash_algorithm == "phash" or self.hash_algorithm == "perceptual": + self.hasher = hashers.PHash() + elif self.hash_algorithm == "average": + self.hasher = hashers.AverageHash() + elif self.hash_algorithm == "dhash" or self.hash_algorithm == "difference": + self.hasher = hashers.DHash() + elif self.hash_algorithm == "wavelet": + self.hasher = hashers.WaveletHash() + elif self.hash_algorithm == "pdq": + self.hasher = hashers.PDQ() + else: + raise ValueError( + f"Unsupported hash algorithm: {hash_algorithm}. " + f"Supported: phash, average, dhash, wavelet, pdq" + ) + + def compute_hash(self, file_path: str) -> Optional[Dict[str, Any]]: + """ + Compute perceptual hash for a single file. + + Args: + file_path: Path to the image/video file + + Returns: + Dictionary with 'hash_vector', 'hash_string', and 'file_path' or None if error + """ + try: + file_path = str(file_path) + ext = Path(file_path).suffix.lower() + + if ext in IMAGE_EXTENSIONS: + return self._compute_image_hash(file_path) + elif ext in VIDEO_EXTENSIONS: + # Video support can be added later + logger.warning(f"Video hashing not yet implemented for {file_path}") + return None + else: + logger.warning(f"Unsupported file type: {file_path}") + return None + + except Exception as e: + logger.error(f"Error computing hash for {file_path}: {e}") + return None + + def _compute_image_hash(self, file_path: str) -> Dict[str, Any]: + """ + Compute hash for an image file. + + Args: + file_path: Path to the image file + + Returns: + Dictionary with hash information + """ + # Compute the hash using the perception library + hash_value = self.hasher.compute(file_path) + + # Convert hash to string representation + hash_string = str(hash_value) + + # Convert to vector for ChromaDB + # The perception library returns different types for different hashers + # We need to convert to a consistent vector format + hash_vector = self._hash_to_vector(hash_value) + + return { + "hash_vector": hash_vector, + "hash_string": hash_string, + "file_path": file_path, + } + + def _hash_to_vector(self, hash_value) -> List[float]: + """ + Convert a hash value to a vector suitable for database storage. + + Args: + hash_value: Hash value from perception library + + Returns: + List of floats representing the hash as a vector + """ + import base64 + + # PDQ returns a numpy array directly + if isinstance(hash_value, np.ndarray): + # Flatten in case it's multidimensional + return hash_value.flatten().astype(float).tolist() + + # The perception library hash objects have a hash attribute that's an integer + # We need to convert this integer to a binary vector + if hasattr(hash_value, 'hash'): + hash_int = hash_value.hash + # Convert to 64-bit binary representation (standard for perceptual hashes) + binary_str = format(hash_int, '064b') + return [float(bit) for bit in binary_str] + + # Try to get the integer value directly if it's an integer-like object + try: + hash_int = int(hash_value) + # Convert to 64-bit binary representation + binary_str = format(hash_int, '064b') + return [float(bit) for bit in binary_str] + except (ValueError, TypeError): + pass + + # The perception library returns base64-encoded strings for hash values + # Try to decode as base64 first + hash_str = str(hash_value) + try: + # Decode base64 to get the raw hash bytes + hash_bytes = base64.b64decode(hash_str) + # Convert each byte to 8 bits (creates 64-bit vector for 8-byte hash) + binary_str = ''.join(format(byte, '08b') for byte in hash_bytes) + return [float(bit) for bit in binary_str] + except Exception: + # If base64 decoding fails, continue to other methods + pass + + # For string-based hashes (like average, difference, perceptual) + # Convert the binary string to a list of floats (0.0 or 1.0) + # Remove any non-binary characters + binary_str = ''.join(c for c in hash_str if c in '01') + + if binary_str: + return [float(bit) for bit in binary_str] + + # Fallback: if we can't parse it, try to convert to bytes and then to floats + try: + hash_bytes = hash_value if isinstance(hash_value, bytes) else str(hash_value).encode() + # Ensure we have at least 64 dimensions + byte_list = [float(b) / 255.0 for b in hash_bytes] + # Pad or truncate to 64 dimensions + if len(byte_list) < 64: + byte_list.extend([0.0] * (64 - len(byte_list))) + return byte_list[:64] + except: + # Last resort: create a 64-dimensional vector from the string representation + hash_str = str(hash_value) + vector = [float(ord(c)) / 255.0 for c in hash_str[:64]] + # Pad to 64 dimensions if needed + if len(vector) < 64: + vector.extend([0.0] * (64 - len(vector))) + return vector[:64] + + def compute_directory_hashes( + self, + directory_path: str, + recursive: bool = True, + progress_callback: Optional[callable] = None + ) -> List[Dict[str, Any]]: + """ + Compute hashes for all supported files in a directory. + + Args: + directory_path: Path to directory + recursive: Whether to search recursively + progress_callback: Optional callback function(current, total) for progress updates + + Returns: + List of hash dictionaries + """ + directory_path = Path(directory_path) + + # Find all supported files + supported_files = [] + if recursive: + for ext in IMAGE_EXTENSIONS | VIDEO_EXTENSIONS: + supported_files.extend(directory_path.rglob(f"*{ext}")) + else: + for ext in IMAGE_EXTENSIONS | VIDEO_EXTENSIONS: + supported_files.extend(directory_path.glob(f"*{ext}")) + + # Compute hashes + hashes = [] + total = len(supported_files) + + for idx, file_path in enumerate(supported_files): + hash_data = self.compute_hash(str(file_path)) + if hash_data: + hashes.append(hash_data) + + # Progress callback + if progress_callback: + progress_callback(idx + 1, total) + + logger.info( + f"Computed {len(hashes)} hashes from {total} files " + f"in {directory_path} using {self.hash_algorithm}" + ) + + return hashes + + def compute_batch_hashes( + self, + file_paths: List[str], + progress_callback: Optional[callable] = None + ) -> List[Dict[str, Any]]: + """ + Compute hashes for a batch of files. + + Args: + file_paths: List of file paths + progress_callback: Optional callback function(current, total) for progress updates + + Returns: + List of hash dictionaries + """ + hashes = [] + total = len(file_paths) + + for idx, file_path in enumerate(file_paths): + hash_data = self.compute_hash(file_path) + if hash_data: + hashes.append(hash_data) + + # Progress callback + if progress_callback: + progress_callback(idx + 1, total) + + logger.info(f"Computed {len(hashes)} hashes from {total} files using {self.hash_algorithm}") + + return hashes + + @staticmethod + def get_supported_algorithms() -> List[str]: + """Get list of supported hash algorithms.""" + return ["phash", "average", "dhash", "wavelet", "pdq"] + + @staticmethod + def get_algorithm_info(algorithm: str) -> Dict[str, str]: + """ + Get information about a hash algorithm. + + Args: + algorithm: Hash algorithm name + + Returns: + Dictionary with algorithm information + """ + info = { + "phash": { + "name": "Perceptual Hash (pHash)", + "description": "DCT-based hash, robust to various transformations", + "best_for": "Finding perceptually similar images with modifications", + }, + "perceptual": { + "name": "Perceptual Hash (pHash)", + "description": "DCT-based hash, robust to various transformations", + "best_for": "Finding perceptually similar images with modifications", + }, + "average": { + "name": "Average Hash", + "description": "Simple and fast hash based on average pixel values", + "best_for": "Quick duplicate detection with minimal changes", + }, + "dhash": { + "name": "Difference Hash (dHash)", + "description": "Hash based on gradient/difference between adjacent pixels", + "best_for": "Detecting similar images with different brightness/contrast", + }, + "difference": { + "name": "Difference Hash (dHash)", + "description": "Hash based on gradient/difference between adjacent pixels", + "best_for": "Detecting similar images with different brightness/contrast", + }, + "wavelet": { + "name": "Wavelet Hash", + "description": "Haar wavelet-based hash for robust image comparison", + "best_for": "Finding similar images with complex transformations", + }, + "pdq": { + "name": "PDQ Hash", + "description": "Facebook's PDQ perceptual hash designed for large-scale image matching", + "best_for": "Production-grade duplicate and near-duplicate detection at scale", + }, + } + return info.get(algorithm.lower(), {"name": algorithm, "description": "Unknown algorithm"}) diff --git a/src/perceptual-hash/perceptual_hash/main.py b/src/perceptual-hash/perceptual_hash/main.py new file mode 100644 index 0000000..44f95d9 --- /dev/null +++ b/src/perceptual-hash/perceptual_hash/main.py @@ -0,0 +1,983 @@ +"""Main plugin file for Perceptual Hash plugin.""" + +import json +import logging +import os +from pathlib import Path +from typing import List, TypedDict, Optional +import typer + +from rb.lib.ml_service import MLService +from rb.api.models import ( + BatchFileInput, + BatchFileResponse, + DirectoryInput, + EnumParameterDescriptor, + EnumVal, + FileInput, + InputSchema, + InputType, + ParameterSchema, + ResponseBody, + TaskSchema, + TextParameterDescriptor, + TextResponse, + TextInput, + RangedFloatParameterDescriptor, + FloatRangeDescriptor, +) + +from perceptual_hash.database import HashDatabase +from perceptual_hash.hasher import PerceptualHasher + +logger = logging.getLogger(__name__) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) + +APP_NAME = "perceptual-hash" +ml_service = MLService(APP_NAME) + +# Load app info +script_dir = os.path.dirname(os.path.abspath(__file__)) +info_file_path = os.path.join(script_dir, "app-info.md") + +try: + with open(info_file_path, "r") as f: + info = f.read() +except FileNotFoundError: + info = "Perceptual hashing plugin for detecting duplicate and near-duplicate media." + +ml_service.add_app_metadata( + plugin_name=APP_NAME, + name="Perceptual Hash", + author="RescueBox Team", + version="0.1.0", + info=info, +) + +# Supported algorithms +supported_algorithms = PerceptualHasher.get_supported_algorithms() + + +def get_available_collections(): + """Get available collections using context manager.""" + try: + with HashDatabase() as db: + return db.get_available_collections() + except Exception as e: + logger.error(f"Error getting available collections: {e}") + return [] + + +""" +****************************************************************************************************** +Endpoint 1: Create Database (Hash Directory) +****************************************************************************************************** +""" + + +class CreateDatabaseInputs(TypedDict): + """Inputs for creating a hash database from a directory.""" + media_directory: DirectoryInput + + +class CreateDatabaseParameters(TypedDict): + """Parameters for creating a hash database.""" + collection_name: str + hash_algorithm: str + recursive: str + + +def create_database_task_schema() -> TaskSchema: + """Task schema for creating a hash database.""" + return TaskSchema( + inputs=[ + InputSchema( + key="media_directory", + label="Media Directory", + input_type=InputType.DIRECTORY, + ) + ], + parameters=[ + ParameterSchema( + key="collection_name", + label="Collection Name", + value=TextParameterDescriptor( + default="my_collection", + placeholder="Enter collection name", + ), + ), + ParameterSchema( + key="hash_algorithm", + label="Hash Algorithm", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key=algo, label=PerceptualHasher.get_algorithm_info(algo)["name"]) + for algo in supported_algorithms + ], + default="phash", + ), + ), + ParameterSchema( + key="recursive", + label="Recursive", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key="true", label="Yes"), + EnumVal(key="false", label="No"), + ], + default="true", + ), + ), + ], + ) + + +def create_database( + inputs: CreateDatabaseInputs, + parameters: CreateDatabaseParameters +) -> ResponseBody: + """ + Create a database of perceptual hashes from a directory of media files. + + Args: + inputs: Directory containing media files + parameters: Collection name, hash algorithm, and recursive flag + + Returns: + Response with success message and statistics + """ + try: + media_directory = inputs["media_directory"].path + collection_name = parameters["collection_name"] + hash_algorithm = parameters["hash_algorithm"] + recursive = parameters["recursive"] == "true" + + logger.info( + f"Creating hash database '{collection_name}' using {hash_algorithm} " + f"from directory: {media_directory}" + ) + + # Initialize hasher + hasher = PerceptualHasher(hash_algorithm) + + # Compute hashes for all files in directory + hashes = hasher.compute_directory_hashes( + str(media_directory), + recursive=recursive + ) + + if not hashes: + return ResponseBody( + root=TextResponse( + value=f"No supported media files found in {media_directory}" + ) + ) + + # Use context manager for database operations + with HashDatabase() as db: + # Add hashes to database + db.add_hashes(collection_name, hash_algorithm, hashes) + + # Get statistics + stats = db.get_collection_stats(collection_name, hash_algorithm) + + result = TextResponse( + value=f"Successfully created collection '{collection_name}' with " + f"{len(hashes)} hashes using {hash_algorithm} algorithm. " + f"Total hashes in collection: {stats['total_hashes']}" + ) + return ResponseBody(root=result) + + except Exception as e: + logger.error(f"Error creating database: {e}", exc_info=True) + return ResponseBody(root=TextResponse(value=f"Error: {str(e)}")) + + +def create_database_cli_parser(value: str): + """Parse CLI input for create_database endpoint.""" + try: + parts = value.split(",") + media_directory = parts[0].strip() + return CreateDatabaseInputs( + media_directory=DirectoryInput(path=media_directory) + ) + except Exception as e: + logger.error(f"Error parsing CLI input: {e}") + raise typer.Abort() + + +def create_database_param_parser(value: str): + """Parse CLI parameters for create_database endpoint.""" + try: + parts = value.split(",") + collection_name = parts[0].strip() if len(parts) > 0 else "my_collection" + hash_algorithm = parts[1].strip() if len(parts) > 1 else "pdq" + recursive = parts[2].strip() if len(parts) > 2 else "true" + return CreateDatabaseParameters( + collection_name=collection_name, + hash_algorithm=hash_algorithm, + recursive=recursive, + ) + except Exception as e: + logger.error(f"Error parsing CLI parameters: {e}") + raise typer.Abort() + + +ml_service.add_ml_service( + rule="/create_database", + ml_function=create_database, + inputs_cli_parser=typer.Argument( + parser=create_database_cli_parser, + help="Media directory path", + ), + parameters_cli_parser=typer.Option( + None, + "--params", + parser=create_database_param_parser, + help="collection_name,hash_algorithm,recursive", + ), + task_schema_func=create_database_task_schema, + short_title="Create Hash Database", + order=0, +) + + +""" +****************************************************************************************************** +Endpoint 2: Find Matches (Query Database) +****************************************************************************************************** +""" + + +class FindMatchesInputs(TypedDict): + """Inputs for finding matches.""" + query_directory: DirectoryInput + output_file: FileInput + + +class FindMatchesParameters(TypedDict): + """Parameters for finding matches.""" + collection_name: str + hash_algorithm: str + max_distance: float + max_results: float + + +def find_matches_task_schema() -> TaskSchema: + """Task schema for finding matches.""" + # Get available collections + available_collections = ["Select a collection"] + get_available_collections() + + return TaskSchema( + inputs=[ + InputSchema( + key="query_directory", + label="Query Directory", + input_type=InputType.DIRECTORY, + ), + InputSchema( + key="output_file", + label="Output JSON File", + input_type=InputType.FILE, + ) + ], + parameters=[ + ParameterSchema( + key="collection_name", + label="Collection Name", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key=name, label=name) + for name in available_collections[1:] + ], + message_when_empty="No collections found. Create one first.", + default=available_collections[0], + ), + ), + ParameterSchema( + key="hash_algorithm", + label="Hash Algorithm", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key=algo, label=PerceptualHasher.get_algorithm_info(algo)["name"]) + for algo in supported_algorithms + ], + default="phash", + ), + ), + ParameterSchema( + key="max_distance", + label="Maximum Distance Threshold (Hamming)", + value=RangedFloatParameterDescriptor( + range=FloatRangeDescriptor(min=0.0, max=100.0), + default=10.0, + ), + ), + ParameterSchema( + key="max_results", + label="Maximum Results per Query", + value=RangedFloatParameterDescriptor( + range=FloatRangeDescriptor(min=1.0, max=100.0), + default=10.0, + ), + ), + ], + ) + + +def find_matches( + inputs: FindMatchesInputs, + parameters: FindMatchesParameters +) -> ResponseBody: + """ + Find matching media files in the database. + + Args: + inputs: Query directory containing files to search for + parameters: Collection name, hash algorithm, and search parameters + + Returns: + Response with success message and output file path + """ + try: + query_directory = inputs["query_directory"].path + output_file = inputs["output_file"].path + collection_name = parameters["collection_name"] + hash_algorithm = parameters["hash_algorithm"] + max_distance = float(parameters["max_distance"]) + max_results = int(float(parameters["max_results"])) + + logger.info( + f"Finding matches for files in {query_directory} in collection " + f"'{collection_name}' using {hash_algorithm}" + ) + + # Initialize hasher + hasher = PerceptualHasher(hash_algorithm) + + # Compute hashes for all files in query directory + query_hashes = hasher.compute_directory_hashes( + str(query_directory), + recursive=False # Only search immediate directory + ) + + if not query_hashes: + return ResponseBody( + root=TextResponse(value=f"No supported media files found in {query_directory}") + ) + + # Query database using context manager + with HashDatabase() as db: + results = db.query_hashes( + collection_name, + hash_algorithm, + query_hashes, + n_results=max_results, + threshold=max_distance, + ) + + # Build JSON output structure + json_output = { + "metadata": { + "collection_name": collection_name, + "hash_algorithm": hash_algorithm, + "max_distance": max_distance, + "max_results_per_query": max_results, + "total_query_files": len(query_hashes), + "total_matches_found": 0 + }, + "results": {} + } + + total_matches = 0 + for query_hash, matches in zip(query_hashes, results): + query_path = query_hash["file_path"] + + # Build matches list for this query + match_list = [] + for match in matches: + hamming_distance = match.get("hamming_distance", match["distance"]) + similarity = match["similarity"] + + # Determine match quality based on Hamming distance + if hamming_distance < 5: + quality = "exact" + elif hamming_distance < 10: + quality = "very_similar" + elif hamming_distance < 20: + quality = "similar" + else: + quality = "somewhat_similar" + + match_list.append({ + "file_path": match["file_path"], + "hamming_distance": int(hamming_distance), + "similarity": similarity, + "quality": quality + }) + total_matches += 1 + + json_output["results"][query_path] = { + "matches": match_list, + "total_matches": len(match_list) + } + + json_output["metadata"]["total_matches_found"] = total_matches + + # Write output to file + with open(output_file, 'w') as f: + json.dump(json_output, f, indent=2) + + result = TextResponse( + value=f"Successfully found {total_matches} matches for {len(query_hashes)} query files. " + f"Results saved to {output_file}" + ) + return ResponseBody(root=result) + + except Exception as e: + logger.error(f"Error finding matches: {e}", exc_info=True) + return ResponseBody(root=TextResponse(value=f"Error: {str(e)}")) + + +def find_matches_cli_parser(value: str): + """Parse CLI input for find_matches endpoint.""" + try: + parts = value.split(",") + query_directory = parts[0].strip() + output_file = parts[1].strip() if len(parts) > 1 else "matches.json" + return FindMatchesInputs( + query_directory=DirectoryInput(path=query_directory), + output_file=FileInput(path=output_file) + ) + except Exception as e: + logger.error(f"Error parsing CLI input: {e}") + raise typer.Abort() + + +def find_matches_param_parser(value: str): + """Parse CLI parameters for find_matches endpoint.""" + try: + parts = value.split(",") + collection_name = parts[0].strip() if len(parts) > 0 else "my_collection" + hash_algorithm = parts[1].strip() if len(parts) > 1 else "phash" + max_distance = float(parts[2].strip()) if len(parts) > 2 else 10.0 + max_results = float(parts[3].strip()) if len(parts) > 3 else 10.0 + return FindMatchesParameters( + collection_name=collection_name, + hash_algorithm=hash_algorithm, + max_distance=max_distance, + max_results=max_results, + ) + except Exception as e: + logger.error(f"Error parsing CLI parameters: {e}") + raise typer.Abort() + + +ml_service.add_ml_service( + rule="/find_matches", + ml_function=find_matches, + inputs_cli_parser=typer.Argument( + parser=find_matches_cli_parser, + help="query_directory,output_file", + ), + parameters_cli_parser=typer.Option( + None, + "--params", + parser=find_matches_param_parser, + help="collection_name,hash_algorithm,max_distance,max_results", + ), + task_schema_func=find_matches_task_schema, + short_title="Find Matches", + order=1, +) + + +""" +****************************************************************************************************** +Endpoint 3: Export Database +****************************************************************************************************** +""" + + +class ExportDatabaseInputs(TypedDict): + """Inputs for exporting a database.""" + output_file: FileInput + + +class ExportDatabaseParameters(TypedDict): + """Parameters for exporting a database.""" + collection_name: str + hash_algorithm: str + + +def export_database_task_schema() -> TaskSchema: + """Task schema for exporting a database.""" + available_collections = ["Select a collection"] + get_available_collections() + + return TaskSchema( + inputs=[ + InputSchema( + key="output_file", + label="Output JSON File", + input_type=InputType.FILE, + ) + ], + parameters=[ + ParameterSchema( + key="collection_name", + label="Collection Name", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key=name, label=name) + for name in available_collections[1:] + ], + message_when_empty="No collections found", + default=available_collections[0], + ), + ), + ParameterSchema( + key="hash_algorithm", + label="Hash Algorithm", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key=algo, label=PerceptualHasher.get_algorithm_info(algo)["name"]) + for algo in supported_algorithms + ], + default="phash", + ), + ), + ], + ) + + +def export_database( + inputs: ExportDatabaseInputs, + parameters: ExportDatabaseParameters +) -> ResponseBody: + """ + Export a hash collection to a JSON file. + + Args: + inputs: Output file path + parameters: Collection name and hash algorithm + + Returns: + Response with success message + """ + try: + output_file = inputs["output_file"].path + collection_name = parameters["collection_name"] + hash_algorithm = parameters["hash_algorithm"] + + logger.info(f"Exporting collection '{collection_name}' ({hash_algorithm}) to {output_file}") + + # Export using context manager + with HashDatabase() as db: + db.export_collection(collection_name, hash_algorithm, str(output_file)) + stats = db.get_collection_stats(collection_name, hash_algorithm) + + result = TextResponse( + value=f"Successfully exported collection '{collection_name}' ({hash_algorithm}) " + f"with {stats['total_hashes']} hashes to {output_file}" + ) + return ResponseBody(root=result) + + except Exception as e: + logger.error(f"Error exporting database: {e}", exc_info=True) + return ResponseBody(root=TextResponse(value=f"Error: {str(e)}")) + + +def export_database_cli_parser(value: str): + """Parse CLI input for export_database endpoint.""" + try: + return ExportDatabaseInputs( + output_file=FileInput(path=value.strip()) + ) + except Exception as e: + logger.error(f"Error parsing CLI input: {e}") + raise typer.Abort() + + +def export_database_param_parser(value: str): + """Parse CLI parameters for export_database endpoint.""" + try: + parts = value.split(",") + collection_name = parts[0].strip() if len(parts) > 0 else "my_collection" + hash_algorithm = parts[1].strip() if len(parts) > 1 else "phash" + return ExportDatabaseParameters( + collection_name=collection_name, + hash_algorithm=hash_algorithm, + ) + except Exception as e: + logger.error(f"Error parsing CLI parameters: {e}") + raise typer.Abort() + + +ml_service.add_ml_service( + rule="/export_database", + ml_function=export_database, + inputs_cli_parser=typer.Argument( + parser=export_database_cli_parser, + help="Output JSON file path", + ), + parameters_cli_parser=typer.Option( + None, + "--params", + parser=export_database_param_parser, + help="collection_name,hash_algorithm", + ), + task_schema_func=export_database_task_schema, + short_title="Export Hash Database", + order=2, +) + + +""" +****************************************************************************************************** +Endpoint 4: Import Database +****************************************************************************************************** +""" + + +class ImportDatabaseInputs(TypedDict): + """Inputs for importing a database.""" + input_file: FileInput + + +class ImportDatabaseParameters(TypedDict): + """Parameters for importing a database.""" + new_collection_name: str + + +def import_database_task_schema() -> TaskSchema: + """Task schema for importing a database.""" + return TaskSchema( + inputs=[ + InputSchema( + key="input_file", + label="Input JSON File", + input_type=InputType.FILE, + ) + ], + parameters=[ + ParameterSchema( + key="new_collection_name", + label="New Collection Name (optional)", + value=TextParameterDescriptor( + default="", + placeholder="Leave empty to use original name", + ), + ), + ], + ) + + +def import_database( + inputs: ImportDatabaseInputs, + parameters: ImportDatabaseParameters +) -> ResponseBody: + """ + Import a hash collection from a JSON file. + + Args: + inputs: Input file path + parameters: Optional new collection name + + Returns: + Response with success message + """ + try: + input_file = inputs["input_file"].path + new_collection_name = parameters["new_collection_name"] if parameters["new_collection_name"] else None + + logger.info(f"Importing collection from {input_file}") + + # Import using context manager + with HashDatabase() as db: + db.import_collection(str(input_file), new_collection_name) + + collection_name = new_collection_name if new_collection_name else "original" + result = TextResponse( + value=f"Successfully imported collection '{collection_name}' from {input_file}" + ) + return ResponseBody(root=result) + + except Exception as e: + logger.error(f"Error importing database: {e}", exc_info=True) + return ResponseBody(root=TextResponse(value=f"Error: {str(e)}")) + + +def import_database_cli_parser(value: str): + """Parse CLI input for import_database endpoint.""" + try: + return ImportDatabaseInputs( + input_file=FileInput(path=value.strip()) + ) + except Exception as e: + logger.error(f"Error parsing CLI input: {e}") + raise typer.Abort() + + +def import_database_param_parser(value: str): + """Parse CLI parameters for import_database endpoint.""" + try: + new_collection_name = value.strip() if value else "" + return ImportDatabaseParameters( + new_collection_name=new_collection_name, + ) + except Exception as e: + logger.error(f"Error parsing CLI parameters: {e}") + raise typer.Abort() + + +ml_service.add_ml_service( + rule="/import_database", + ml_function=import_database, + inputs_cli_parser=typer.Argument( + parser=import_database_cli_parser, + help="Input JSON file path", + ), + parameters_cli_parser=typer.Option( + None, + "--params", + parser=import_database_param_parser, + help="new_collection_name (optional)", + ), + task_schema_func=import_database_task_schema, + short_title="Import Hash Database", + order=3, +) + + +""" +****************************************************************************************************** +Endpoint 5: List Collections +****************************************************************************************************** +""" + + +class ListCollectionsInputs(TypedDict): + """Inputs for listing collections.""" + dummy: TextInput # Workaround for AutoUI - required to send request body + + +def list_collections_task_schema() -> TaskSchema: + """Task schema for listing collections.""" + return TaskSchema( + inputs=[ + InputSchema( + key="dummy", + label="List Collections", + input_type=InputType.TEXT, + ) + ], + parameters=[], + ) + + +def list_collections( + inputs: ListCollectionsInputs +) -> ResponseBody: + """ + List all available hash collections. + + Args: + inputs: No inputs required + parameters: No parameters required + + Returns: + Response with list of collections and their statistics + """ + try: + logger.info("Listing all collections") + + with HashDatabase() as db: + available_collections = db.get_available_collections() + + if not available_collections: + return ResponseBody(root=TextResponse(value="No collections found.")) + + output_lines = ["Available Collections:"] + + for collection_name in available_collections: + output_lines.append(f"\n{collection_name}:") + for algo in supported_algorithms: + try: + stats = db.get_collection_stats(collection_name, algo) + if stats["total_hashes"] > 0: + vector_dim = stats.get("vector_dimension", "unknown") + output_lines.append( + f" - {algo}: {stats['total_hashes']} hashes (dimension: {vector_dim})" + ) + except Exception: + # Collection doesn't exist for this algorithm + pass + + result_text = "\n".join(output_lines) + return ResponseBody(root=TextResponse(value=result_text)) + + except Exception as e: + logger.error(f"Error listing collections: {e}", exc_info=True) + return ResponseBody(root=TextResponse(value=f"Error: {str(e)}")) + + +def list_collections_cli_parser(value: str): + """Parse CLI input for list_collections endpoint.""" + return ListCollectionsInputs(dummy=TextInput(text="")) + + +ml_service.add_ml_service( + rule="/list_collections", + ml_function=list_collections, + inputs_cli_parser=typer.Argument( + default="", + parser=list_collections_cli_parser, + help="(no inputs required)", + ), + task_schema_func=list_collections_task_schema, + short_title="List Collections", + order=4, +) + + +""" +****************************************************************************************************** +Endpoint 6: Delete Collection +****************************************************************************************************** +""" + + +class DeleteCollectionInputs(TypedDict): + """Inputs for deleting a collection.""" + dummy: TextInput # Workaround for AutoUI - required to send request body + + +class DeleteCollectionParameters(TypedDict): + """Parameters for deleting a collection.""" + collection_name: str + hash_algorithm: str + + +def delete_collection_task_schema() -> TaskSchema: + """Task schema for deleting a collection.""" + available_collections = ["Select a collection"] + get_available_collections() + + return TaskSchema( + inputs=[ + InputSchema( + key="dummy", + label="Delete Collection", + input_type=InputType.TEXT, + ) + ], + parameters=[ + ParameterSchema( + key="collection_name", + label="Collection Name", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key=name, label=name) + for name in available_collections[1:] + ], + message_when_empty="No collections found", + default=available_collections[0], + ), + ), + ParameterSchema( + key="hash_algorithm", + label="Hash Algorithm", + value=EnumParameterDescriptor( + enum_vals=[ + EnumVal(key=algo, label=PerceptualHasher.get_algorithm_info(algo)["name"]) + for algo in supported_algorithms + ], + default="phash", + ), + ), + ], + ) + + +def delete_collection( + inputs: DeleteCollectionInputs, + parameters: DeleteCollectionParameters +) -> ResponseBody: + """ + Delete a hash collection. + + Args: + inputs: No inputs required + parameters: Collection name and hash algorithm + + Returns: + Response with success message + """ + try: + collection_name = parameters["collection_name"] + hash_algorithm = parameters["hash_algorithm"] + + logger.info(f"Deleting collection '{collection_name}' ({hash_algorithm})") + + # Delete using context manager + with HashDatabase() as db: + db.delete_collection(collection_name, hash_algorithm) + + result = TextResponse( + value=f"Successfully deleted collection '{collection_name}' ({hash_algorithm})" + ) + return ResponseBody(root=result) + + except Exception as e: + logger.error(f"Error deleting collection: {e}", exc_info=True) + return ResponseBody(root=TextResponse(value=f"Error: {str(e)}")) + + +def delete_collection_cli_parser(value: str): + """Parse CLI input for delete_collection endpoint.""" + return DeleteCollectionInputs(dummy=TextInput(text="")) + + +def delete_collection_param_parser(value: str): + """Parse CLI parameters for delete_collection endpoint.""" + try: + parts = value.split(",") + collection_name = parts[0].strip() if len(parts) > 0 else "my_collection" + hash_algorithm = parts[1].strip() if len(parts) > 1 else "phash" + return DeleteCollectionParameters( + collection_name=collection_name, + hash_algorithm=hash_algorithm, + ) + except Exception as e: + logger.error(f"Error parsing CLI parameters: {e}") + raise typer.Abort() + + +ml_service.add_ml_service( + rule="/delete_collection", + ml_function=delete_collection, + inputs_cli_parser=typer.Argument( + default="", + parser=delete_collection_cli_parser, + help="(no inputs required)", + ), + parameters_cli_parser=typer.Option( + None, + "--params", + parser=delete_collection_param_parser, + help="collection_name,hash_algorithm", + ), + task_schema_func=delete_collection_task_schema, + short_title="Delete Hash Collection", + order=5, +) + + +# Export the app +app = ml_service.app + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/src/perceptual-hash/pyproject.toml b/src/perceptual-hash/pyproject.toml new file mode 100644 index 0000000..b52fa84 --- /dev/null +++ b/src/perceptual-hash/pyproject.toml @@ -0,0 +1,20 @@ +[tool.poetry] +name = "perceptual-hash" +version = "0.1.0" +description = "Perceptual hashing plugin for detecting duplicate and near-duplicate media in forensic investigations" +authors = ["RescueBox Team"] + +[tool.poetry.dependencies] +python = ">=3.11,<3.13" +typer = "*" +pydantic = "*" +perception = "^0.5.1" +psycopg2-binary = "*" +pgvector = "*" +pillow = "*" +numpy = "*" +pandas = "*" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api"