Skip to content

Add AdaMSS tuner with Adaptive Subspace Allocation (ASA)#2987

Open
LonglongaaaGo wants to merge 8 commits intohuggingface:mainfrom
LonglongaaaGo:adamss
Open

Add AdaMSS tuner with Adaptive Subspace Allocation (ASA)#2987
LonglongaaaGo wants to merge 8 commits intohuggingface:mainfrom
LonglongaaaGo:adamss

Conversation

@LonglongaaaGo
Copy link

Paper title: AdaMSS: Adaptive Multi-Subspace Approach for Parameter-Efficient Fine-Tuning
Paper: https://neurips.cc/virtual/2025/loc/san-diego/poster/119606
Github page: https://github.com/jzheng20/AdaMSS/tree/main

AdaMSS Fine-tuning

Introduction

AdaMSS (Adaptive Matrix Decomposition with Subspace Selection) is a parameter-efficient fine-tuning method that decomposes weight matrices using SVD into low-rank subspaces. It uses only ~0.07% of original trainable parameters (e.g., 59K for ViT-Base vs 86M full fine-tuning) while maintaining competitive performance.

The method optionally supports ASA (Adaptive Subspace Allocation) for dynamic subspace selection during training, further improving efficiency and performance.

See the paper for more details.

Installation & Quick Test

Install from local source:

cd peft-main && pip install -e .
pip install transformers datasets torch torchvision evaluate accelerate

Verify installation:

python -c "from peft import AdaMSSConfig, ASACallback; print('AdaMSS ready')"

Detailed Code Explanation

Core AdaMSS Configuration:

from peft import AdaMSSConfig, get_peft_model, ASACallback

# Configure AdaMSS with ASA
config = AdaMSSConfig(
    r=100,                          # SVD rank (full decomposition rank)
    num_subspaces=10,               # Number of subspaces (K) - initial capacity
    subspace_rank=3,                # Rank per subspace (ri) - use 1 for NLU, 3 for Vision
    target_modules=["query", "value"],  # Target attention layers
    use_asa=True,                   # Enable Adaptive Subspace Allocation
    target_kk=5,                    # Target active subspaces (ASA reduces K→5)
    modules_to_save=["classifier"], # Modules to train without decomposition
)
peft_model = get_peft_model(model, config)

ASA Callback Setup:

asa_callback = ASACallback(
    target_kk=5,            # Gradually mask to 5 active subspaces
    init_warmup=50,         # Start ASA after 50 steps (Vision) or 5 epochs (NLU)
    final_warmup=1000,      # Complete masking by step 1000 (Vision) or epoch 95 (NLU)
    mask_interval=100,      # Update mask every 100 steps (Vision) or 10 epochs (NLU)
    verbose=True,           # Print ASA progress
)

# Integrate with Trainer
trainer = Trainer(
    model=peft_model,
    callbacks=[asa_callback],  # Add ASA callback
    # ... other arguments
)

Key Points:

  • Parameterization: Total params = r × (d_in + d_out), split into K subspaces of rank ri each
  • ASA Mechanism: Dynamically selects target_kk most important subspaces from initial num_subspaces
  • Warmup Schedule: ASA gradually increases masking strength from init_warmup to final_warmup
  • Vision vs NLU: Use subspace_rank=3 for vision, subspace_rank=1 for NLU tasks

Use the training example scripts

Vision Tasks (Image Classification)

Run the provided script with your configuration:

python examples/adamss_finetuning/image_classification_adamss_asa.py \
    --model_name_or_path google/vit-base-patch16-224-in21k \
    --dataset_name cifar10 \
    --adamss_r 100 \
    --adamss_k 10 \
    --adamss_ri 3 \
    --use_asa \
    --target_kk 5 \
    --output_dir ./output

NLU Tasks (GLUE Benchmark)

Run GLUE tasks (e.g., CoLA) with ASA:

python examples/adamss_finetuning/glue_adamss_asa_example.py \
    --dataset_name cola \
    --adamss_r 100 \
    --adamss_k 10 \
    --adamss_ri 1 \
    --use_asa \
    --target_kk 5 \
    --num_epochs 100 \
    --batch_size 32 \
    --output_dir ./output_cola_asa

Without ASA (fixed K=10):

python examples/adamss_finetuning/glue_adamss_asa_example.py \
    --dataset_name cola \
    --adamss_r 100 \
    --adamss_k 10 \
    --adamss_ri 1 \
    --num_epochs 100 \
    --batch_size 32 \
    --output_dir ./output_cola_no_asa

AdaMSSConfig Parameters

Parameter Type Default Description
r int 100 SVD decomposition rank
num_subspaces int 10 Number of subspaces (K)
subspace_rank int 3 Rank per subspace (ri)
target_modules list - Modules to apply AdaMSS (e.g., ["query", "value"])
use_asa bool False Enable Adaptive Subspace Allocation
target_kk int None Target active subspaces when ASA enabled
modules_to_save list None Modules to train without decomposition

ASACallback Parameters

Parameter Type Default Description
target_kk int - Target number of active subspaces
init_warmup int 50 Steps before starting masking
final_warmup int 1000 Steps to reach target active subspaces
mask_interval int 100 Steps between subspace selection updates
beta1 float 0.85 EMA decay for importance tracking
beta2 float 0.85 EMA decay for uncertainty tracking

Experimental Results

NLU Tasks (GLUE Benchmark)

Results with AdaMSS + ASA (100 epochs, seed=0):

Task Model AdaMSS Params Metric Score
CoLA RoBERTa-base 27.0K (ASA K→5) Matthews 0.6466
CoLA RoBERTa-large 64.8K (ASA K→5) Matthews 0.7093
MRPC RoBERTa-base 27.2K (ASA K→5) Accuracy 0.8824
MRPC RoBERTa-large 66.7K (ASA K→5) Accuracy 0.9044

Notes:

  • Configuration: r=100, K=10→5 (ASA), ri=1
  • AdaMSS active params with ASA (5 out of 10 subspaces selected)
  • Full AdaMSS capacity: 97K (large) / 42K (base)
  • Training: 100 epochs, batch_size=32, warmup_ratio=0.06

Vision Tasks (Image Classification)

Results with AdaMSS on Stanford Cars (10 epochs, seed=0):

Model Method AdaMSS Params Test Accuracy
ViT-Base AdaMSS (no ASA) 121K (K=10) 82.15%
ViT-Base AdaMSS + ASA 75.0K (K→5) 80.45%

Notes:

  • Configuration: r=100, K=10, ri=3, 10 epochs, batch_size=32
  • ASA dynamically selects 5 out of 10 subspaces (75K active from 121K total)

Citation

If you use AdaMSS in your research, please cite:

@inproceedings{zheng2025adamss,
  title={AdaMSS: Adaptive Multi-Subspace Approach for Parameter-Efficient Fine-Tuning},
  author={Zheng, Jingjing and Lu, Wanglong and Dong, Yiming and Ji, Chaojie and Cao, Yankai and Lin, Zhouchen},
  booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
  year={2025},
}

Reference

@LonglongaaaGo
Copy link
Author

LonglongaaaGo commented Jan 10, 2026

Cleaned version of previous PR: #2967
Hey @BenjaminBossan, I was working on code cleaning for a while, and the previous one was a little messy, so could you help review this cleaned PR?
The code is ready for review!
Thank you so much!!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reworking the PR to add AdaMSS to PEFT (for other readers, the paper can be found here: https://openreview.net/forum?id=1cjLvtFOmL).

My first comment is the same as in the previous PR: I would strongly suggest to rename all the classes from AdaMSS to Adamss, which is much easier to type and more consistent with the rest of PEFT (e.g. LoraLayer).

In this review, I focused on the PEFT integration mostly. There, I have found quite a few things we need to improve, especially around how we handle multiple adapters. Please check my comments.

Then, as a next step, we should set up the first couple of unit tests to ensure that the adapter works as expected. We can start with test_custom_models.py and add more tests later. For this, please check how other PEFT methods do this:

###########
# BD-LoRA #
###########
(
"BD-LoRA A only",
"MLP",
LoraConfig,
{
"target_modules": ["lin0", "lin1"],
"use_bdlora": BdLoraConfig(target_modules_bd_a=["lin0"], nblocks=2, match_strict=False),
},
),
(
"BD-LoRA B only",
"MLP",
LoraConfig,
{
"target_modules": ["lin0", "lin1"],
"use_bdlora": BdLoraConfig(target_modules_bd_b=["lin1"], nblocks=2, match_strict=False),
},
),
(
"BD-LoRA both A and B",
"MLP",
LoraConfig,
{
"target_modules": ["lin0", "lin1"],
"use_bdlora": BdLoraConfig(target_modules_bd_a=["lin0"], target_modules_bd_b=["lin1"], nblocks=2),
},
),

This adds the PEFT method to a test matrix and should cover the majority of PEFT functionality. Then you can run pytest tests/test_custom_models.py -k 'adamss' to run all AdaMSS tests and ensure that they pass.

@@ -0,0 +1,38 @@
# Copyright 2024-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2024-present the HuggingFace Inc. team.
# Copyright 2026-present the HuggingFace Inc. team.

Here and in every newly added file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

import torch


class ASACallback(TrainerCallback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I'd prefer AdamssCallback so that it's immediately obvious they belong together, but AsaCallback is also oky.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I will change to AdamssASACallback

self.verbose = verbose

# Sanity checks
assert 0 < beta1 < 1, f"beta1 must be in (0, 1), got {beta1}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise proper ValueErrors, asserts are restricted to tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! done!

adapter_name = list(module.KK.keys())[0]
self.total_kk = module.KK[adapter_name]
self._collected_total_kk = True
print(f"ASA: Detected total_kk = {self.total_kk} subspaces")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid printing info like this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

if actual_rank == 0:
actual_rank = 1

print(f" [INFO] Subspace {i}: dynamic rank = {actual_rank} (threshold {svd_threshold} from {len(S_row)} row singular values)")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

if actual_rank == 0:
actual_rank = 1

# print(f" [INFO] Subspace {i}: fixed-rank = {actual_rank} (seg_indices={len(seg_indices)})")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

# to match adamss_pkg behavior (using accumulated gradients).

# Compute newindex for forward pass
self.newindex[adapter_name] = np.concatenate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use torch tensors

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

axis=0
)

def set_adapter(self, adapter_names: str | list[str], inference_mode: bool = False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should not be necessary to have this method. Either we use the parent class implementation (if more than 1 adapter is allowed) or else we already need to check when add_adapter or load_adapter is called that there is only a single adapter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! I directly apply the parent set_adapter function.

# Register A and B parameters
# A maps from r (full SVD rank) dimensions to actual_rank dimensions
# Shape: (actual_rank, r) - matches adamss_pkg structure
self.adamss_A[f"{adapter_name}_A_{i}"] = nn.Parameter(A_init.to(dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, the "A" in the parameter name is redundant, we already know that this is self.adamss_A. Second, the i should not be part of the key, the key should just be the adapter_name. I think this should be refactored like so:

self.adamss_A[adapter_name] = nn.ParameterList()
for i in ...:
    ...
    self.adamss_A[adapter_name][i] = A_init.to(dtype)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@LonglongaaaGo
Copy link
Author

Hey @BenjaminBossan, thanks for the guidance! I will finish the revision and pass the tests as soon as possible.

@LonglongaaaGo
Copy link
Author

Hey @BenjaminBossan, I have implemented your suggestions, except for the verbose prints. The test cases are passing now (feel free to run them locally if needed). Regarding the verbose prints, I will remove them once the logic is finalized. Thanks again for the help!!

@LonglongaaaGo
Copy link
Author

Hey @BenjaminBossan, any advice here? Thank you again for the help!

@LonglongaaaGo
Copy link
Author

Hey @BenjaminBossan, I have removed the print code as well. Could you pls take a review and if it meets the requirements, could you merge the code? Thank you so much for the help!!!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the updates, the shape of the PR looks much better now. Still, I found a couple of places that I think can be improved. Overall, I really want to reduce the extra complexity and rely as much as possible of what's already present in PEFT. Please check.

Also, before pushing your changes, ensure to run make style.


# Validate warmup schedule
if self.total_steps and self.final_warmup > self.total_steps:
import warnings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the import global

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

self.beta2 = beta2
self.total_steps = total_steps
self.tt = tt
self.verbose = verbose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove self.verbose = verbose.

set_seed,
)

from peft import AdaMSSConfig, get_peft_model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from peft import AdaMSSConfig, get_peft_model
from peft import AdamssConfig, get_peft_model

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!


# Critical: Rebuild optimizer to sync requires_grad changes
if self.trainer is not None:
self.trainer.create_optimizer_and_scheduler(self.trainer.num_training_steps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if recreating the optimizer and scheduler on each optimizer step is not asking for trouble. I'm not too familiar with Trainer, so I'm not totally sure, but I think it's not working as it should. First of all, this call doesn't actually rebuild the optimizer if the optimizer already exists:

https://github.com/huggingface/transformers/blob/f73a4db3a0bcf6523e9bfdaaf4afe81dffba4da8/src/transformers/trainer.py#L1023

Second, even if it did, would it be correct? Say we use Adam, the optimizer stores the update moments. If the optimizer is recreated, those are lost, meaning Adam no longer works as expected.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. You're absolutely right on both points:

  1. create_optimizer_and_scheduler is a no-op when the optimizer already exists (source), so this call was doing nothing.
    Even if it did recreate the optimizer, it would destroy the Adam momentum states, breaking the optimizer's behavior.
  2. I've removed the create_optimizer_and_scheduler call entirely and also removed the self.trainer reference to avoid circular references. The masking in _mask_model_to_target
    only sets requires_grad=False on pruned subspace parameters — the existing optimizer simply skips zero-grad parameters naturally, so no optimizer rebuild is needed.

from .layer import AdamssLayer


class AdamssASACallback(TrainerCallback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class AdamssASACallback(TrainerCallback):
class AdamssAsaCallback(TrainerCallback):

To be consistent and for easier typing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! done!


# Copy state for modules to save
if hasattr(new_module, "base_layer"):
new_module.base_layer.load_state_dict(child.state_dict(), strict=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what I wrote above, I'm not sure if this is needed. When I comment out this method, the unit tests still pass. This means that either it's not needed, or the unit tests are missing something. If the latter is true, let's add a unit test to show when it's needed.

# Create estimated seg_result for metadata
estimated_seg_size = max(1, out_features // num_subspaces)
self.seg_result[adapter_name] = {
i: np.arange(i * estimated_seg_size, min((i + 1) * estimated_seg_size, out_features))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use torch and not numpy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Done. The _replace_module override has been removed entirely. The parent class BaseTuner._replace_module handles this correctly. All unit tests pass without it.
  2. Done. Replaced np.arange with torch.arange and removed the numpy import from layer.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that the old version was overly broad — I've removed the AdamssLayer override entirely and simplified the Linear one to a single concern:

Why it's needed: AdaMSS B parameter shapes depend on KMeans clustering of the weight matrix. When loading with low_cpu_mem_usage=True, update_layer runs inside init_empty_weights() on meta tensors, producing different clustering → different B shapes. The override detects these shape mismatches and replaces placeholders before the default load_state_dict runs.

Which test covers it: test_load_model_low_cpu_mem_usage — it fails without this override:
RuntimeError: size mismatch for base_model.model.lin0.adamss_B.other.0: copying a param with shape torch.Size([4, 1]) from checkpoint, the shape in current model is torch.Size([7, 1]).

kmeans = KMeans(n_clusters=effective_num_subspaces, init='random', n_init=1, max_iter=iternum, random_state=123456789)
idx = kmeans.fit_predict(vt)

return [idx], [effective_num_subspaces]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's return torch tensors here, also really no need for lists with a single item, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. clustering_Z now returns (torch.LongTensor, int) directly instead of single-item lists. Also simplified seg_locations and renamed get_trainable_subspaces_all → get_trainable_subspaces to remove unnecessary list wrapping throughout. All callers updated accordingly.

K = int(index[ii].max().item()) + 1
location = []
for i in range(K):
arr = np.where(index[ii] == i)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use torch.where.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. seg_locations now uses torch.where instead of np.where, and the numpy import has been removed from utils.py entirely.

# Special case for AdaMSS: A and B parameters may not get significant gradients with B=0 init
# The gradient dL/dA = (dL/dy) * B^T = 0 when B=0, so A stays unchanged initially
# Similarly, B gradients may be very small depending on layer configuration
# since the adapter output is 0 when B=0, affecting gradient magnitudes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of skipping, can the test be updated e.g. by increasing the learning rate for Adamss?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. The test now uses a higher learning rate (lr=1.0) for AdaMSS. However, due to the B=0 initialization, individual A/B parameters may still remain near-zero even with high LR after just 2 training steps (B updates in step 1, A only starts getting gradients in step 2). Rather than risking flaky assertions, we skip the strict allclose check for A/B and rely on other tests (merge/unmerge correctness, forward output changes, and the new ASA-specific tests in test_adamss_asa.py) to verify parameter updates.

@LonglongaaaGo
Copy link
Author

Hey @BenjaminBossan, I have revised the code based on your advice, and the make style command has been executed before sunmision, could you help take a look see if it meets the requirements? Thank you!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for these improvements to the PR, the code is now simplified, more readable, and better tested.

I still found a couple of issues, please check my comments. As a more general comment, if you use a coding agent, please ensure to clean up after it (e.g. I saw comments and actual changes not corresponding, divergence from the existing coding practices of the project, unnecessary checks being added etc., which coding agents are prone to do).

param_after = params_after[name]
if (model.prefix in name) or ("modules_to_save" in name) or ("token_adapter.trainable_tokens" in name):
# target_modules, modules_to_save and modules of `NewTokensWrapper` _are_ updated
# Special case for AdaMSS: use a higher LR to overcome B=0 init issue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of skipping this check, we should ensure that adamss is trained sufficiently to pass the check. E.g. if the number of epochs is too low, we could increase it (but only for adamss so that other unit tests are not slowed down).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The skip was caused by two issues with AdaMSS's initialization:

B=0 initialization: AdaMSS initializes B to zeros ("orthogonal" mode), so ∂L/∂A = ∂L/∂output @ B = 0 — A never receives gradients. Fixed by calling

set_init_weights_false()
(same pattern as other tests) to give B small random values.

ReLU dead zones: The default test input torch.arange(90) is deterministic, and certain subspace scatter indices map to output dimensions that are always negative after the base linear layer, causing ReLU to zero the gradient for those subspaces. Fixed by using torch.randn inputs for AdaMSS.

Both fixes are AdaMSS-specific — other adapters are unaffected.

# Then update exp_avg_ipt
exp_avg_ipt[key].mul_(importance_beta).add_(ipt, alpha=1 - importance_beta)

def mask_to_target(self, adapter_name: str, asa_target_subspaces: int, verbose: bool = False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def mask_to_target(self, adapter_name: str, asa_target_subspaces: int, verbose: bool = False) -> None:
def mask_to_target(self, adapter_name: str, asa_target_subspaces: int) -> None:

Also, is this function called at all? If not, please remove it completely.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted, thank you!

Comment on lines 58 to 59
self.exp_avg_ipt = {} # Exponential moving average of importance
self.exp_avg_unc = {} # Exponential moving average of uncertainty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly prefer doing this change in this PR. I think the resulting code will be simpler, allowing this PR to be more and not less focused. It also doesn't make much sense to move a refactor of yet unmerged code to a separate PR.

x7 = x7.scatter(-1, index, x6)

# Add this adapter's delta
result = result + x7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During training, I agree that it won't make much of a difference as intermediate tensors must be stored for the backwards pass. However, during inference, this is not the case. Keeping variables around prevents Python from decrementing the reference count. So this code:

def forward(self, x):
    x1 = foo(x)
    x2 = bar(x1)
    return x2

and this code:

def forward(self, x):
    x = foo(x)
    x = bar(x)
    return x

are not equivalent. So my suggestion is to reassign the variable (I didn't mean to explicitly del the intermediate variables). If you think this isn't good for readability, I would indeed prefer more explicit names as you suggested instead of just numbers.

"""Called after optimizer.step() – delegates to model.update_and_allocate()."""
model = kwargs.get("model")
if model is None:
return control
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we call super?

Suggested change
return control
return super().on_optimizer_step(args=args, state=state, control=control)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I strongly prefer doing this change in this PR. I think the resulting code will be simpler, allowing this PR to be more and not less focused. It also doesn't make much sense to move a refactor of yet unmerged code to a separate PR. this one is done!
  2. During training, I agree that it won't make much of a difference as intermediate tensors must be stored for the backwards pass. However, during inference, this is not the case. Keeping variables around prevents Python from decrementing the reference count. This one is done!
  3. Shouldn't we call super? Done!

Comment on lines 497 to 509
first_active_adapter = None
for adapter in self.active_adapters:
if adapter in self.adamss_A:
first_active_adapter = adapter
break

if first_active_adapter is None:
# No active adapters, return base layer output
return self.base_layer(x, *args, **kwargs)

# Compute base output from residual weight (frozen original weight)
resW = self.adamss_resW[first_active_adapter].to(self.dtype)
result = F.linear(newx, resW)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me why we need to treat the first active adapter differently. Also, below we might apply the same adapter again. Is that correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first active adapter is NOT treated differently for the trainable part. resW is the frozen original weight — it's identical for all adapters (stored per-adapter in BufferDict only for device management). The first_active_adapter lookup just finds any valid key to retrieve this shared weight.

The loop below applies every active adapter's trainable A/B delta (including the first), but does NOT re-apply resW. The computation is:

output = resW @ x + Σ_adapter scatter(B_i @ A_i @ newB @ x)
I've updated the comment to clarify:

resW is the frozen original weight — identical for all adapters,
just need any valid adapter key to retrieve it from the BufferDict.

Comment on lines 193 to 194
if adapter_name not in self.peft_config:
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can this happen?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! done!


def _seed_b_params(model):
"""
Give B parameters small non-zero values so that gradients flow to A.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be covered by AdamssConfig(..., init_weights=False)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Replaced _seed_b_params() with init_weights=None in the test config. This is cleaner — the config handles non-zero B initialization natively instead of manually seeding after model creation.

for layer in layers:
assert len(layer.exp_avg_ipt["default"]) == 0, "No importance accumulation should happen outside warmup"

def test_all_params_trainable_initially(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should already be covered by existing tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done!

# -----------------------------------------------------------------------
# Test: update_importance populates EMA scores
# -----------------------------------------------------------------------
class TestUpdateImportance:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move all tests to a single test class, having multiple here is overkill IMO. Ensure to have "Adamss" in the name of the class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged TestUpdateImportance, TestResetImportance, and TestUpdateAndAllocate into a single TestAdamssAsa class. Also removed the redundant test_all_params_trainable_initially (covered by existing tests).

@LonglongaaaGo
Copy link
Author

Hey @BenjaminBossan, I have revised the code based on your suggestions. Let me know if you have more questions.
Thank you so much for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants