Welcome to the Prompt-Guided Image Segmentation (PSEG) repository! This project presents a robust, end-to-end PyTorch implementation of a highly adaptable multimodal segmentation framework. By leveraging the synergistic power of state-of-the-art vision and language foundational models, PSEG transcends traditional fixed-vocabulary segmentation. Instead, it empowers users to extract precise, pixel-perfect masks for any object simply by providing a descriptive free-text natural language prompt alongside an image.
⚠️ License Notice: This repository contains original source code only. It does not redistribute any dataset samples, dataset files, or pre-trained model weights. All third-party models and datasets must be obtained independently through their respective official sources. See LICENSE for full details.
At the core of PSEG lies an elegantly designed architecture that strategically pairs massive, pre-trained frozen backbones with an ultra-lightweight, trainable segmentation head. The architecture fuses dense, patch-level semantic representations from DINOv2 with rich, globally aware textual embeddings from CLIP, both fed into a partially-initialized SAM-based (Segment Anything Model) decoder. This efficient paradigm yields high-performance semantic segmentation with significantly reduced overhead and faster convergence.
graph TD
A[Image Input] --> B
C[Text Prompt] --> D
subgraph "Visual Backbone (Frozen)"
B[DINOv2 facebook/dinov2-base<br>Dim: 768, Resolution: 16x16]
B -- "Multi-scale features<br>(Layers: 3, 6, 9, 11)" --> E[FPN Neck]
E -- "Fused Visual Features<br>Dim: 256" --> G
end
subgraph "Text Backbone (Frozen)"
D[CLIP openai/clip-vit-base-patch16<br>Max Length: 77]
D -- "L2 Norm / CLS Token<br>Dim: 512" --> F[Text Projection Linear]
F -- "Text Embedding<br>Dim: 256" --> G
end
subgraph "Mask Decoder (Trainable)"
H[Mask Token] --> G
I[IoU Token] --> G
G[Two-Way Transformer Blocks<br>Init: SAM-Base, Depth: 3, Heads: 8]
G -- "Mask Tokens &<br>Image Features" --> J[Dynamic Projection &<br>Progressive Upsampler]
G -- "IoU Token" --> K[IoU Prediction Head]
end
J --> L[Segmentation Mask]
K --> M[Predicted IoU Score]
| Model Component | Base Model | State | Extraction / Details | Output Dim |
|---|---|---|---|---|
| Visual Encoding | facebook/dinov2-base |
Frozen | Patch Size: 14 (16×16 grid). Multi-scale hidden states from layers [3, 6, 9, 11]. FPN Neck fusion via convolutions, GroupNorm, and GELU. |
256 |
| Text Encoding | openai/clip-vit-base-patch16 |
Frozen | Extraction of CLS token projection from last hidden state, followed by L2 normalization. | 256 (projected) |
| Mask Decoder | facebook/sam-vit-base |
Trainable | 3 Layers of TwoWayBlock Transformers (dim=256, heads=8). Uses FPN features, projected CLIP point embeddings, learnable mask & IoU tokens. |
Mask + IoU |
Parameter Breakdown
| Feature | Parameters Count | Status |
|---|---|---|
| DINOv2 + CLIP | ~150M | Frozen |
| SAM Decoder | ~9.3M | Trainable |
| Total Pipeline | ~159.3M | Evaluated + Trainable |
The model is trained on the RefCOCO-m dataset (moondream/refcoco-m on Hugging Face), which provides paired images with free-text referring expressions for grounded segmentation.
Note: The dataset is not included in this repository. Users must download it independently. The
data/download_dataset.pyscript automates the download via the Hugging Facedatasetslibrary. By using this script, you agree to the dataset's original license and terms of use.The RefCOCO dataset is attributed to:
Yu, L., et al. "Modeling Context in Referring Expressions." ECCV 2016.
To get started with this project locally, clone the repository and configure your environment:
-
Clone the Repository
git clone https://github.com/halim-cv1/PSEG_Lightweight.git cd PSEG_Lightweight -
Create a Virtual Environment
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install Dependencies Install all required libraries via the project's
setup.py:pip install -e .Or install directly from
requirements.txt:pip install -r requirements.txt
Full dependency list:
torch,torchvision,transformers,datasets,tqdm,numpy,matplotlib,opencv-python,Pillow,scipy,albumentations,pycocotools
You can trigger training using the provided command-line scripts. This automatically manages the dataloaders and saves prompt_seg_best.pt and prompt_seg_final.pt.
1. Download the Dataset
python data/download_dataset.py2. Start Training
python train.py --epochs 20 --batch_size 32| Hyperparameter | Value |
|---|---|
| Optimizer | AdamW (betas=(0.9, 0.999), weight_decay=1e-2) |
| FPN LR | lr × 0.5 = 1e-4 |
| Decoder LR | lr = 2e-4 |
| LR Schedule | Cosine annealing with 2-epoch linear warmup |
| Batch Size | 32 |
| Epochs | 20 |
| Steps per Epoch | 473 |
| Grad Clip | Max norm 1.0 |
| Mixed Precision | AMP (torch.cuda.amp) |
Running inference on an arbitrary image with an input prompt:
python inference.py \
--image path/to/your/image.jpg \
--prompt "the prompt you want to predict" \
--checkpoint prompt_seg_best.pt \
--output output_mask.pngThis script outputs the Original Image, the predicted Boolean Threshold Mask with estimated IoU confidence, and a final Overlay.
The segmentation capabilities show strong adaptation of the SAM-based mask decoder to the DINOv2 visual features aligned with CLIP textual embeddings.
Validation Results Pipeline
- Best Validation IoU: 0.4167
- Mean IoU: 0.4166 (Median: 0.3899)
- IoU @ 0.5: 36.71%
- Mean Precision: 0.4603
- Mean Recall: 0.6622
| Epoch | Train Loss | Val Loss | Val IoU | IoU@0.5 |
|---|---|---|---|---|
| 1 | 0.9286 | 0.8479 | 0.2928 | 0.00% |
| 5 | 0.6594 | 0.7225 | 0.3650 | 3.85% |
| 7 | 0.6141 | 0.7019 | 0.4063 | 19.23% |
| 10 | 0.5746 | 0.6981 | 0.4128 | 19.23% |
| 13 | 0.5434 | 0.7041 | 0.4167 | 25.00% |
| 14 | 0.5319 | 0.7624 | 0.4044 | 26.92% |
| 20 | 0.4944 | 0.8595 | 0.3936 | 17.31% |
- Peak IoU (0.4167): Reached at Epoch 13 — the optimal balance point.
- Validation Divergence: Val loss hits its minimum at Epoch 10 then rises to
0.8595by Epoch 20, highlighting the need for early stopping at Epoch 13–14.
IoU Distribution
Prompt Length vs IoU
When a text prompt is unambiguous and the target object is clearly visible and well-sized, the cross-modal alignment between CLIP's global text embedding and DINOv2's patch-level spatial features is highly effective. The SAM decoder's boundary-aware upsampler produces clean, well-delineated masks.
The following example illustrates the model's failure modes. These failure cases arise from structural limitations in the architecture:
- Resolution Bottleneck: DINOv2 compresses images to a 16×16 feature grid, making small or thin objects difficult to segment.
- Global Text Pooling: Extracting only the
[CLS]token from CLIP collapses spatial/relational nuances of the prompt (e.g., "to the left of"). - Trainable Capacity Gap: The ~9.3M decoder must bridge the entire cross-modal mapping gap, leading to coarse "blobby" masks for complex scenes.
- Resolution Constraints (DINOv2
base): Coarse 16×16 patch embeddings limit boundary precision on small objects. - Global Prompt Collapse (CLIP
base): The static[CLS]token loses positional nuance in text prompts. - Cross-Modal Mapping Capacity: The small SAM decoder (~9.3M params) is the sole bridge between frozen visual and textual representations.
- Upscaling the Backbones: Upgrade to
dinov2-large/dinov2-giantand richer text encoders (e.g., LLaMA embeddings). - Multi-Dataset Pretraining: Pretrain across MS COCO, LVIS, and SA-1B for stronger zero-shot generalization.
- Unfreezing with PEFT/LoRA: Apply LoRA on deeper layers of both backbones to reduce the cross-modal mapping burden.
- Token-Level Text Fusion: Inject full token sequences into
TwoWayBlockattention for dense word-to-patch spatial interactions.
This project builds upon the following foundational models and datasets. No model weights or dataset files are included in this repository. Users must download them independently and comply with their respective licenses.
| Asset | Source | License |
|---|---|---|
| DINOv2 | facebookresearch/dinov2 | Apache 2.0 |
| CLIP | openai/CLIP | MIT |
| SAM (Segment Anything) | facebookresearch/segment-anything | Apache 2.0 |
| RefCOCO-m Dataset | moondream/refcoco-m | See dataset card |
If you use this code in academic work, please also cite the foundational models:
@article{oquab2023dinov2,
title={DINOv2: Learning Robust Visual Features without Supervision},
author={Oquab, Maxime and others},
journal={TMLR},
year={2023}
}
@article{radford2021learning,
title={Learning Transferable Visual Models From Natural Language Supervision},
author={Radford, Alec and others},
journal={ICML},
year={2021}
}
@article{kirillov2023segment,
title={Segment Anything},
author={Kirillov, Alexander and others},
journal={ICCV},
year={2023}
}
@inproceedings{yu2016modeling,
title={Modeling Context in Referring Expressions},
author={Yu, Licheng and others},
booktitle={ECCV},
year={2016}
}This project's source code is released under an Academic and Research Use License.
This license applies only to the original code in this repository — it does not cover any third-party models, weights, or datasets. See LICENSE for full details.




