diff --git a/.gitignore b/.gitignore index c5e24b8..d050897 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ data/ # Misc .DS_Store -checkpoints/ \ No newline at end of file +checkpoints/ +results/ \ No newline at end of file diff --git a/src/common/binary_classifier_inference.py b/src/common/binary_classifier_inference.py new file mode 100644 index 0000000..5071cce --- /dev/null +++ b/src/common/binary_classifier_inference.py @@ -0,0 +1,281 @@ +import math +import os +import xml.etree.ElementTree as ET +from glob import glob +from typing import Generator + +import torch +from binary_classifier import BinaryClassifier +from PIL import Image, ImageDraw +from torchvision import transforms + +FOLDER = "data/VOC2012/VOC2012_test/JPEGImages" +MODEL_PATH = "checkpoints/best_model.pth" +THRESHOLD = 0.9 +OUTPUT_PATH = "results/collage.jpg" +WINDOW_SIZES = [256, 512] +STRIDE_RATIO = 0.5 +NMS_IOU_THRESHOLD = 0.2 +ANNOTATION_DIR = "data/VOC2012/VOC2012_test/Annotations" +TARGET_CLASSES = {"person", "cat"} + + +def get_transform() -> transforms.Compose: + return transforms.Compose( + [ + transforms.Resize((64, 64)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + +def sliding_window( + image_width: int, image_height: int, window_size: int, stride: int +) -> Generator[tuple[int, int, int, int], None, None]: + """Generate sliding window coordinates.""" + for y in range(0, image_height - window_size + 1, stride): + for x in range(0, image_width - window_size + 1, stride): + yield (x, y, x + window_size, y + window_size) + + # Handle right edge + if (image_width - window_size) % stride != 0: + for y in range(0, image_height - window_size + 1, stride): + x = image_width - window_size + yield (x, y, x + window_size, y + window_size) + + # Handle bottom edge + if (image_height - window_size) % stride != 0: + for x in range(0, image_width - window_size + 1, stride): + y = image_height - window_size + yield (x, y, x + window_size, y + window_size) + + # Handle bottom-right corner + if (image_width - window_size) % stride != 0 and ( + image_height - window_size + ) % stride != 0: + yield ( + image_width - window_size, + image_height - window_size, + image_width, + image_height, + ) + + +def compute_iou( + box1: tuple[int, int, int, int], box2: tuple[int, int, int, int] +) -> float: + """Compute IoU between two boxes (x1, y1, x2, y2).""" + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + intersection = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0 + + +def non_max_suppression( + detections: list[tuple[tuple[int, int, int, int], float]], iou_threshold: float +) -> list[tuple[tuple[int, int, int, int], float]]: + """Apply NMS to detections. Each detection is (box, confidence).""" + if not detections: + return [] + + # Sort by confidence (descending) + detections = sorted(detections, key=lambda x: x[1], reverse=True) + + keep = [] + while detections: + best = detections.pop(0) + keep.append(best) + + detections = [ + d for d in detections if compute_iou(best[0], d[0]) < iou_threshold + ] + + return keep + + +def detect_in_image( + image: Image.Image, + model: BinaryClassifier, + transform: transforms.Compose, + device: torch.device, +) -> list[tuple[tuple[int, int, int, int], float]]: + """Run sliding window detection on a single image.""" + width, height = image.size + all_detections = [] + + for window_size in WINDOW_SIZES: + if width < window_size or height < window_size: + continue + + stride = int(window_size * STRIDE_RATIO) + + # Collect all windows for this scale + windows = list(sliding_window(width, height, window_size, stride)) + crops = [] + coords = [] + + for x1, y1, x2, y2 in windows: + crop = image.crop((x1, y1, x2, y2)) + crops.append(transform(crop)) + coords.append((x1, y1, x2, y2)) + + # Batch inference + BATCH_SIZE = 64 + for batch_start in range(0, len(crops), BATCH_SIZE): + batch_end = min(batch_start + BATCH_SIZE, len(crops)) + batch_tensor = torch.stack(crops[batch_start:batch_end]).to(device) + + with torch.no_grad(): + outputs = model(batch_tensor) + confidences = torch.sigmoid(outputs).squeeze(-1) + + for j, conf in enumerate(confidences): + if conf.item() >= THRESHOLD: + all_detections.append((coords[batch_start + j], conf.item())) + + detections = non_max_suppression(all_detections, NMS_IOU_THRESHOLD) + return detections + + +def get_valid_image_ids() -> set[str]: + """Get image IDs that contain target classes.""" + valid_ids = set() + + for xml_file in os.listdir(ANNOTATION_DIR): + if not xml_file.endswith(".xml"): + continue + + tree = ET.parse(os.path.join(ANNOTATION_DIR, xml_file)) + root = tree.getroot() + + classes_in_image = set() + for obj in root.findall("object"): + name = obj.find("name") + if name is not None and name.text: + classes_in_image.add(name.text) + + if classes_in_image & TARGET_CLASSES: + image_id = xml_file.replace(".xml", "") + valid_ids.add(image_id) + + return valid_ids + + +def run_inference() -> None: + os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True) + + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"Using device: {device}") + + model = BinaryClassifier(device=str(device)) + model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) + model.eval() + + transform = get_transform() + + extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"] + image_paths = [] + for ext in extensions: + image_paths.extend(glob(os.path.join(FOLDER, ext))) + image_paths.extend(glob(os.path.join(FOLDER, ext.upper()))) + + print(f"Found {len(image_paths)} images") + + # Filter to only images with target classes + valid_ids = get_valid_image_ids() + image_paths = [ + p for p in image_paths if os.path.splitext(os.path.basename(p))[0] in valid_ids + ] + print(f"Filtered to {len(image_paths)} images containing {TARGET_CLASSES}") + + results = [] # (path, image_with_boxes, num_detections, max_confidence) + + for i, path in enumerate(image_paths): + img = Image.open(path).convert("RGB") + detections = detect_in_image(img, model, transform, device) + + if detections: + # Draw bounding boxes on image + img_with_boxes = img.copy() + draw = ImageDraw.Draw(img_with_boxes) + + max_conf = 0.0 + for box, conf in detections: + draw.rectangle(box, outline="red", width=3) + draw.text((box[0], box[1] - 15), f"{conf:.2f}", fill="red") + max_conf = max(max_conf, conf) + + results.append((path, img_with_boxes, len(detections), max_conf)) + print( + f"✓ {os.path.basename(path)}: {len(detections)} detections (max conf: {max_conf:.3f})" + ) + + if (i + 1) % 50 == 0: + print(f"Processed {i + 1}/{len(image_paths)} images...") + + print(f"\n{len(results)} images with detections") + + if not results: + print("No detections found.") + return + + # Sort by max confidence + results.sort(key=lambda x: x[3], reverse=True) + + # Create collage with bounding boxes + GRID_SIZE = 5 + THUMB_SIZE = 200 + PAGE_SIZE = GRID_SIZE * THUMB_SIZE + + num_pages = math.ceil(len(results) / 25) + + for page_idx in range(num_pages): + collage = Image.new("RGB", (PAGE_SIZE, PAGE_SIZE), (255, 255, 255)) + start = page_idx * 25 + end = min(start + 25, len(results)) + + for i, (path, img_with_boxes, num_det, max_conf) in enumerate( + results[start:end] + ): + thumb = img_with_boxes.copy() + thumb.thumbnail((THUMB_SIZE, THUMB_SIZE - 20)) + + cell = Image.new("RGB", (THUMB_SIZE, THUMB_SIZE), (255, 255, 255)) + x_offset = (THUMB_SIZE - thumb.width) // 2 + cell.paste(thumb, (x_offset, 0)) + + # Add text showing detections and confidence + draw = ImageDraw.Draw(cell) + text = f"{num_det} det | {max_conf:.2f}" + bbox = draw.textbbox((0, 0), text) + text_x = (THUMB_SIZE - (bbox[2] - bbox[0])) // 2 + draw.text((text_x, THUMB_SIZE - 16), text, fill=(255, 0, 0)) + + row, col = divmod(i, GRID_SIZE) + collage.paste(cell, (col * THUMB_SIZE, row * THUMB_SIZE)) + + if num_pages == 1: + save_path = OUTPUT_PATH + else: + base, ext = os.path.splitext(OUTPUT_PATH) + save_path = f"{base}_page{page_idx + 1}{ext}" + + collage.save(save_path) + print(f"Saved collage: {save_path}") + + +if __name__ == "__main__": + run_inference() diff --git a/src/common/voc_dataset.py b/src/common/voc_dataset.py index 639e2f6..25104b2 100644 --- a/src/common/voc_dataset.py +++ b/src/common/voc_dataset.py @@ -1,5 +1,6 @@ import os import xml.etree.ElementTree as ET +from collections import Counter from typing import Any, Dict, List, Tuple import torch @@ -76,81 +77,78 @@ def __init__( self.samples: List[Dict[str, Any]] = [] self._parse_annotations() - print( - f"Loaded {len(self.samples)} object instances from {len(self.image_ids)} images" - ) + label_counts = Counter(s["class_name"] for s in self.samples) + print(f"Loaded {len(self.samples)} samples: {dict(label_counts)}") def _parse_annotations(self) -> None: - """Parse all XML annotations and create sample list.""" + """Parse annotations - only include target class and one negative class.""" + TARGET_CLASS_IDX = 14 # person + NEGATIVE_CLASS_IDX = 7 # cat + for image_id in self.image_ids: annotation_path = os.path.join(self.annotation_dir, f"{image_id}.xml") tree = ET.parse(annotation_path) root = tree.getroot() - # Get image dimensions size = root.find("size") if size is None: - raise ValueError(f"Missing size element in annotation for {image_id}") + continue width_elem = size.find("width") - if width_elem is None or width_elem.text is None: - raise ValueError(f"Missing width in annotation for {image_id}") - img_width = int(width_elem.text) - height_elem = size.find("height") - if height_elem is None or height_elem.text is None: - raise ValueError(f"Missing height in annotation for {image_id}") + if width_elem is None or height_elem is None: + continue + if width_elem.text is None or height_elem.text is None: + continue + + img_width = int(width_elem.text) img_height = int(height_elem.text) - # Parse all objects in the image for obj in root.findall("object"): - difficult_elem = obj.find("difficult") - if difficult_elem is None or difficult_elem.text is None: - raise ValueError( - f"Missing difficult field in annotation for {image_id}" - ) - difficult = int(difficult_elem.text) - - # Skip difficult objects if specified - if difficult and not self.include_difficult: - continue + if not self.include_difficult: + difficult_elem = obj.find("difficult") + if difficult_elem is not None and difficult_elem.text == "1": + continue - # Get class label name_elem = obj.find("name") if name_elem is None or name_elem.text is None: - raise ValueError(f"Missing name field in annotation for {image_id}") + continue class_name = name_elem.text if class_name not in self.class_to_idx: continue label = self.class_to_idx[class_name] - # Get bounding box + # Only keep target class and negative class + if label != TARGET_CLASS_IDX and label != NEGATIVE_CLASS_IDX: + continue + bbox = obj.find("bndbox") if bbox is None: - raise ValueError(f"Missing bndbox in annotation for {image_id}") + continue xmin_elem = bbox.find("xmin") - if xmin_elem is None or xmin_elem.text is None: - raise ValueError(f"Missing xmin in bounding box for {image_id}") - xmin = int(xmin_elem.text) - ymin_elem = bbox.find("ymin") - if ymin_elem is None or ymin_elem.text is None: - raise ValueError(f"Missing ymin in bounding box for {image_id}") - ymin = int(ymin_elem.text) - xmax_elem = bbox.find("xmax") - if xmax_elem is None or xmax_elem.text is None: - raise ValueError(f"Missing xmax in bounding box for {image_id}") - xmax = int(xmax_elem.text) - ymax_elem = bbox.find("ymax") - if ymax_elem is None or ymax_elem.text is None: - raise ValueError(f"Missing ymax in bounding box for {image_id}") - ymax = int(ymax_elem.text) - # Store sample + if ( + xmin_elem is None + or xmin_elem.text is None + or ymin_elem is None + or ymin_elem.text is None + or xmax_elem is None + or xmax_elem.text is None + or ymax_elem is None + or ymax_elem.text is None + ): + continue + + xmin = int(float(xmin_elem.text)) + ymin = int(float(ymin_elem.text)) + xmax = int(float(xmax_elem.text)) + ymax = int(float(ymax_elem.text)) + self.samples.append( { "image_id": image_id,