Skip to content

Commit 1228519

Browse files
committed
Adding licensing information to cspnet.py
1 parent bf48f45 commit 1228519

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

timm/models/_hub.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from timm.models._pretrained import filter_pretrained_cfg
3131

3232
try:
33-
from huggingface_hub import HfApi, hf_hub_download
34-
from huggingface_hub.utils import EntryNotFoundError
33+
from huggingface_hub import HfApi, hf_hub_download, model_info
34+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
3535
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
3636
_has_hf_hub = True
3737
except ImportError:
@@ -533,3 +533,44 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
533533
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
534534
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
535535
yield filename[:-4] + ".safetensors"
536+
537+
538+
def _get_license_from_hf_hub(model_id: str | None, hf_hub_id: str | None) -> str | None:
539+
"""Retrieve license information for a model from Hugging Face Hub.
540+
541+
Fetches the license field from the model card metadata on Hugging Face Hub
542+
for the specified model. Returns None if the model is not found, if
543+
huggingface_hub is not installed, or if the model is marked as "untrained".
544+
545+
Args:
546+
model_id: The model identifier/name. In the case of None we assume an untrained model.
547+
hf_hub_id: The Hugging Face Hub organization/user ID. If it is None,
548+
we will return None as we cannot infer the license terms.
549+
550+
Returns:
551+
The license string in lowercase if found, None otherwise.
552+
553+
Note:
554+
Requires huggingface_hub package to be installed. Will log a warning
555+
and return None if the package is not available.
556+
"""
557+
if not has_hf_hub(True):
558+
msg = "For updated license information run `pip install huggingface_hub`."
559+
_logger.warning(msg=msg)
560+
return None
561+
562+
if not (model_id and hf_hub_id):
563+
return None
564+
565+
repo_id: str = hf_hub_id + model_id
566+
567+
try:
568+
info = model_info(repo_id=repo_id)
569+
570+
except RepositoryNotFoundError:
571+
# TODO: any wish what happens here? @rwightman
572+
print(repo_id)
573+
return None
574+
575+
license = info.card_data.get("license").lower() if info.card_data else None
576+
return license

timm/models/cspnet.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2323
from timm.layers import ClassifierHead, ConvNormAct, DropPath, calculate_drop_path_rates, get_attn, create_act_layer, make_divisible
2424
from ._builder import build_model_with_cfg
25+
from ._hub import _get_license_from_hf_hub
2526
from ._manipulate import named_apply, MATCH_PREV_GROUP
2627
from ._registry import register_model, generate_default_cfgs
2728

@@ -1011,82 +1012,97 @@ def _cfg(url='', **kwargs):
10111012
'crop_pct': 0.887, 'interpolation': 'bilinear',
10121013
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
10131014
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
1015+
'license': _get_license_from_hf_hub(kwargs.pop('model_id', None), kwargs.get('hf_hub_id')),
10141016
**kwargs
10151017
}
10161018

1017-
10181019
default_cfgs = generate_default_cfgs({
10191020
'cspresnet50.ra_in1k': _cfg(
10201021
hf_hub_id='timm/',
1022+
model_id='cspresnet50.ra_in1k',
10211023
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
10221024
'cspresnet50d.untrained': _cfg(),
10231025
'cspresnet50w.untrained': _cfg(),
10241026
'cspresnext50.ra_in1k': _cfg(
10251027
hf_hub_id='timm/',
1028+
model_id='cspresnext50.ra_in1k',
10261029
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
10271030
),
10281031
'cspdarknet53.ra_in1k': _cfg(
10291032
hf_hub_id='timm/',
1033+
model_id='cspdarknet53.ra_in1k',
10301034
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
10311035

10321036
'darknet17.untrained': _cfg(),
10331037
'darknet21.untrained': _cfg(),
10341038
'sedarknet21.untrained': _cfg(),
10351039
'darknet53.c2ns_in1k': _cfg(
10361040
hf_hub_id='timm/',
1041+
model_id='darknet53.c2ns_in1k',
10371042
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
10381043
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
10391044
'darknetaa53.c2ns_in1k': _cfg(
10401045
hf_hub_id='timm/',
1046+
model_id='darknetaa53.c2ns_in1k',
10411047
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
10421048
test_input_size=(3, 288, 288), test_crop_pct=1.0),
10431049

10441050
'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
10451051
'cs3darknet_m.c2ns_in1k': _cfg(
10461052
hf_hub_id='timm/',
1053+
model_id='cs3darknet_m.c2ns_in1k',
10471054
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
10481055
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
10491056
),
10501057
'cs3darknet_l.c2ns_in1k': _cfg(
10511058
hf_hub_id='timm/',
1059+
model_id='cs3darknet_l.c2ns_in1k',
10521060
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
10531061
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
10541062
'cs3darknet_x.c2ns_in1k': _cfg(
10551063
hf_hub_id='timm/',
1064+
model_id='cs3darknet_x.c2ns_in1k',
10561065
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
10571066
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
10581067

10591068
'cs3darknet_focus_s.ra4_e3600_r256_in1k': _cfg(
10601069
hf_hub_id='timm/',
1070+
model_id='cs3darknet_focus_s.ra4_e3600_r256_in1k',
10611071
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
10621072
interpolation='bicubic', test_input_size=(3, 320, 320), test_crop_pct=1.0),
10631073
'cs3darknet_focus_m.c2ns_in1k': _cfg(
10641074
hf_hub_id='timm/',
1075+
model_id='cs3darknet_focus_m.c2ns_in1k',
10651076
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
10661077
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
10671078
'cs3darknet_focus_l.c2ns_in1k': _cfg(
10681079
hf_hub_id='timm/',
1080+
model_id='cs3darknet_focus_l.c2ns_in1k',
10691081
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
10701082
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
10711083
'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
10721084

10731085
'cs3sedarknet_l.c2ns_in1k': _cfg(
10741086
hf_hub_id='timm/',
1087+
model_id='cs3sedarknet_l.c2ns_in1k',
10751088
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
10761089
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
10771090
'cs3sedarknet_x.c2ns_in1k': _cfg(
10781091
hf_hub_id='timm/',
1092+
model_id='cs3sedarknet_x.c2ns_in1k',
10791093
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
10801094
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
10811095

10821096
'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
10831097

10841098
'cs3edgenet_x.c2_in1k': _cfg(
10851099
hf_hub_id='timm/',
1100+
model_id='cs3edgenet_x.c2_in1k',
10861101
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
10871102
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
10881103
'cs3se_edgenet_x.c2ns_in1k': _cfg(
10891104
hf_hub_id='timm/',
1105+
model_id='cs3se_edgenet_x.c2ns_in1k',
10901106
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
10911107
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
10921108
})
@@ -1204,4 +1220,4 @@ def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
12041220

12051221
@register_model
12061222
def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
1207-
return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
1223+
return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)

0 commit comments

Comments
 (0)