Skip to content

Fix operator precedence bug in lvlb_weights and mutable default arguments#108

Open
Mr-Neutr0n wants to merge 1 commit intoAILab-CVC:mainfrom
Mr-Neutr0n:fix/operator-precedence-and-mutable-defaults
Open

Fix operator precedence bug in lvlb_weights and mutable default arguments#108
Mr-Neutr0n wants to merge 1 commit intoAILab-CVC:mainfrom
Mr-Neutr0n:fix/operator-precedence-and-mutable-defaults

Conversation

@Mr-Neutr0n
Copy link

Summary

  • Operator precedence fix in lvdm/models/ddpm3d.py: The x0-parameterization branch of lvlb_weights computes (2. * 1 - torch.Tensor(alphas_cumprod)), which due to Python's operator precedence evaluates as (2.0 - tensor) instead of the intended (2.0 * (1 - tensor)). This produces incorrect loss weighting. Fixed by adding explicit parentheses.

  • NaN assertion fix in lvdm/models/ddpm3d.py: assert not torch.isnan(self.lvlb_weights).all() only triggers when every element is NaN, silently allowing partial NaN corruption. Changed to .any() so any NaN value is caught.

  • Mutable default argument fix in DDPM.__init__ (ignore_keys=[]), DDPM.init_from_ckpt (ignore_keys=list()), and AutoencoderKL.init_from_ckpt (ignore_keys=list()): Mutable defaults are shared across all calls, which can lead to subtle bugs if the list is ever modified. Replaced with None and a guard clause.

Files changed

  • lvdm/models/ddpm3d.py
  • lvdm/models/autoencoder.py

Test plan

  • Verify x0-parameterization lvlb_weights values are numerically correct after the parenthesization fix
  • Confirm NaN assertion now catches partial NaN in lvlb_weights
  • Verify ignore_keys default behavior is unchanged (empty list when not provided)

…ents

- Fix operator precedence in x0-parameterization lvlb_weights calculation:
  `(2. * 1 - torch.Tensor(alphas_cumprod))` incorrectly evaluates as
  `(2.0 - tensor)` due to multiplication binding tighter than subtraction.
  Changed to `(2. * (1 - torch.Tensor(alphas_cumprod)))` for correct scaling.

- Fix NaN assertion using `.all()` instead of `.any()`: the original check
  `assert not isnan(...).all()` only fails when every element is NaN,
  silently allowing partial NaN corruption. Changed to `.any()` to catch
  any NaN values.

- Replace mutable default arguments (`ignore_keys=[]` and
  `ignore_keys=list()`) with `None` + guard clause in DDPM.__init__,
  DDPM.init_from_ckpt, and AutoencoderKL.init_from_ckpt to prevent
  shared state across calls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant