-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Add CSATv2 Models #2624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add CSATv2 Models #2624
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Fix JIT compilation failure by replacing axis with dimThe PyTorch JIT compiler does not support the axis argument alias for torch.cat (and other operations), causing CI tests to fail with Keyword argument axis unknown. I have replaced all instances of axis with dim to ensure the model is scriptable and compatible with the test suite. |
[Fix] Enable JIT compilation support for CSATv2Several RuntimeError and jit.script compatibility issues in the CSATv2 model have been fixed. Note to ReviewersApologies for the multiple fixes required in this area. This is my first time working with TorchScript/JIT compatibility, so I missed some of the strict static analysis requirements in the initial implementation. I have verified the fix with a local test script. Detailed ChangesTransformerBlock:Fixed a logic error in init where nested if statements prevented the else block from executing. Explicitly initialized unused attributes (self.proj, self.pool1, etc.) as nn.Identity() when downsample=False. This fixes the "Module has no attribute" error during static analysis. LayerNorm:Changed the elif block to else in forward. This guarantees that the function always returns a Tensor, resolving the Expected Tensor but found Optional[Tensor] error. Block:Replaced dynamic instantiation of nn.UpsamplingBilinear2d inside forward with F.interpolate. JIT does not support creating module instances within script functions. PreNorm:Removed **kwargs from forward to comply with JIT's strict argument typing. Attention:Replaced map and lambda with standard torch operations and explicit loops/reshaping, as JIT does not support Python lambdas. |
|
FWIW you can run the tests locally on just the models with
|
Hello,
As mentioned in the related issue (#2622),
the CSATv2 model worked only in the HuggingFace environment and failed to load in a standard timm setup.
I updated the code, so that CSATv2 loads correctly through the timm registry.
Changed
Regarding model definition / pretrained_cfg / bits
Validation
Result (Imagenet 1K)
If further adjustments are needed, I’m happy to revise the PR.
Thank you!