diff --git a/efficient_sam/build_efficient_sam.py b/efficient_sam/build_efficient_sam.py index c6dd40a..a5cefc0 100644 --- a/efficient_sam/build_efficient_sam.py +++ b/efficient_sam/build_efficient_sam.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .efficient_sam import build_efficient_sam +from efficient_sam import build_efficient_sam def build_efficient_sam_vitt(): return build_efficient_sam( diff --git a/efficient_sam/efficient_sam.py b/efficient_sam/efficient_sam.py index cf2af8a..6a6a2c0 100644 --- a/efficient_sam/efficient_sam.py +++ b/efficient_sam/efficient_sam.py @@ -12,9 +12,11 @@ from torch import nn, Tensor -from .efficient_sam_decoder import MaskDecoder, PromptEncoder -from .efficient_sam_encoder import ImageEncoderViT -from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer +from efficient_sam_decoder import MaskDecoder, PromptEncoder +from efficient_sam_encoder import ImageEncoderViT +from two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer + +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download class EfficientSam(nn.Module): mask_threshold: float = 0.0 @@ -303,3 +305,102 @@ def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=N state_dict = torch.load(f, map_location="cpu") sam.load_state_dict(state_dict["model"]) return sam + + +class EfficientSAM(EfficientSam, PyTorchModelHubMixin): + def __init__(self, config): + + assert config["activation"] in ["relu", "gelu"] + if config["activation"] == "relu": + activation_fn = nn.ReLU + else: + activation_fn = nn.GELU + + image_encoder = ImageEncoderViT( + img_size=config["img_size"], + patch_size=config["encoder_patch_size"], + in_chans=3, + patch_embed_dim=config["encoder_patch_embed_dim"], + normalization_type=config["normalization_type"], + depth=config["encoder_depth"], + num_heads=config["encoder_num_heads"], + mlp_ratio=config["encoder_mlp_ratio"], + neck_dims=config["encoder_neck_dims"], + act_layer=activation_fn, + ) + + image_embedding_size = image_encoder.image_embedding_size + encoder_transformer_output_dim = image_encoder.transformer_output_dim + + prompt_encoder = PromptEncoder( + embed_dim=encoder_transformer_output_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(config["img_size"], config["img_size"]), + ), + decoder_max_num_input_points=config["decoder_max_num_input_points"], + mask_decoder = MaskDecoder( + transformer_dim=encoder_transformer_output_dim, + transformer=TwoWayTransformer( + depth=config["decoder_transformer_depth"], + embedding_dim=encoder_transformer_output_dim, + num_heads=config["decoder_num_heads"], + mlp_dim=config["decoder_transformer_mlp_dim"], + activation=activation_fn, + normalize_before_activation=config["normalize_before_activation"], + ), + num_multimask_outputs=config["num_multimask_outputs"], + activation=activation_fn, + normalization_type=config["normalization_type"], + normalize_before_activation=config["normalize_before_activation"], + iou_head_depth=config["iou_head_depth"] - 1, + iou_head_hidden_dim=config["iou_head_hidden_dim"], + upscaling_layer_dims=config["decoder_upscaling_layer_dims"], + ), + + super().__init__(image_encoder=image_encoder, + prompt_encoder=prompt_encoder, + mask_decoder=mask_decoder, + decoder_max_num_input_points=decoder_max_num_input_points, + pixel_mean=config["pixel_mean"], pixel_std=config["pixel_std"]) + + +config = dict(img_size = 1024, + encoder_patch_embed_dim=192, + encoder_num_heads=3, + encoder_patch_size = 16, + encoder_depth = 12, + encoder_mlp_ratio = 4.0, + encoder_neck_dims = [256, 256], + decoder_max_num_input_points = 6, + decoder_transformer_depth = 2, + decoder_transformer_mlp_dim = 2048, + decoder_num_heads = 8, + decoder_upscaling_layer_dims = [64, 32], + num_multimask_outputs = 3, + iou_head_depth = 3, + iou_head_hidden_dim = 256, + activation = "gelu", + normalization_type = "layer_norm", + normalize_before_activation = False, + pixel_mean=[0.485, 0.456, 0.406], + pixel_std=[0.229, 0.224, 0.225],) + +model = EfficientSAM(config) + +# load weights +filepath = hf_hub_download("merve/EfficientSAM", filename="efficient_sam_vitt.pt", repo_type="model") +state_dict = torch.load(filepath, map_location="cpu") + +for name, param in state_dict["model"].items(): + print(name, param.shape) + +model.load_state_dict(state_dict["model"]) + +# save locally +model.save_pretrained("efficient_sam") + +# push to HF hub +model.push_to_hub("nielsr/efficientsam-tiny", config=config) + +# reload +model = EfficientSAM.from_pretrained("nielsr/efficientsam-tiny") diff --git a/efficient_sam/efficient_sam_decoder.py b/efficient_sam/efficient_sam_decoder.py index 380f41c..3a24abd 100644 --- a/efficient_sam/efficient_sam_decoder.py +++ b/efficient_sam/efficient_sam_decoder.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F -from .mlp import MLPBlock +from mlp import MLPBlock class PromptEncoder(nn.Module): diff --git a/efficient_sam/two_way_transformer.py b/efficient_sam/two_way_transformer.py index b06e528..19385eb 100644 --- a/efficient_sam/two_way_transformer.py +++ b/efficient_sam/two_way_transformer.py @@ -2,7 +2,7 @@ from typing import Tuple, Type import torch from torch import nn, Tensor -from .mlp import MLPBlock +from mlp import MLPBlock