Skip to content

Conversation

@gusdlf93
Copy link

@gusdlf93 gusdlf93 commented Dec 6, 2025

Hello,
As mentioned in the related issue (#2622),
the CSATv2 model worked only in the HuggingFace environment and failed to load in a standard timm setup.

I updated the code, so that CSATv2 loads correctly through the timm registry.

Changed

  • Ensured that timm.create_model("csatv2") works without errors

Regarding model definition / pretrained_cfg / bits

  • I reviewed the maintainer’s comment and updated these parts to align with the timm API as best as I understood.

Validation

  • Model loads successfully in timm environment
  • train.py, validate.py works without any issues

Result (Imagenet 1K)

Model Acc@1 Acc@5 FLOPs#G MACs#G Params#M
csatv2 80.02% 94.9 2.77 1.38 11.1 M

If further adjustments are needed, I’m happy to revise the PR.
Thank you!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gusdlf93
Copy link
Author

gusdlf93 commented Dec 7, 2025

Fix JIT compilation failure by replacing axis with dim

The PyTorch JIT compiler does not support the axis argument alias for torch.cat (and other operations), causing CI tests to fail with Keyword argument axis unknown.

I have replaced all instances of axis with dim to ensure the model is scriptable and compatible with the test suite.

@gusdlf93
Copy link
Author

gusdlf93 commented Dec 8, 2025

[Fix] Enable JIT compilation support for CSATv2

Several RuntimeError and jit.script compatibility issues in the CSATv2 model have been fixed.
The model now successfully passes the torch.jit.script(model) test and produces the same output as before.

Note to Reviewers

Apologies for the multiple fixes required in this area. This is my first time working with TorchScript/JIT compatibility, so I missed some of the strict static analysis requirements in the initial implementation. I have verified the fix with a local test script.

Detailed Changes

TransformerBlock:

Fixed a logic error in init where nested if statements prevented the else block from executing.

Explicitly initialized unused attributes (self.proj, self.pool1, etc.) as nn.Identity() when downsample=False. This fixes the "Module has no attribute" error during static analysis.

LayerNorm:

Changed the elif block to else in forward. This guarantees that the function always returns a Tensor, resolving the Expected Tensor but found Optional[Tensor] error.

Block:

Replaced dynamic instantiation of nn.UpsamplingBilinear2d inside forward with F.interpolate. JIT does not support creating module instances within script functions.

PreNorm:

Removed **kwargs from forward to comply with JIT's strict argument typing.

Attention:

Replaced map and lambda with standard torch operations and explicit loops/reshaping, as JIT does not support Python lambdas.

@rwightman
Copy link
Collaborator

FWIW you can run the tests locally on just the models with

pytest -vv tests/test_models.py -k csatv2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants