updating mse_ens and linting#2371
Conversation
|
I think 1.) should be the default since it is a well-defined, self-contained loss. 2.) requires another loss term to control the higher order moments. |
clessig
left a comment
There was a problem hiding this comment.
Thanks for the implementation. I think we can make the code simpler by using the lp_loss or mse_loss functions.
| weights_channels : shape (num_channels,) or None | ||
| weights_points : shape (num_data_points,) or None | ||
| """ | ||
| mask_nan = ~torch.isnan(target) |
There was a problem hiding this comment.
Please re-use the lp_loss or mse_loss functions below
There was a problem hiding this comment.
Sure, I will push an update
There was a problem hiding this comment.
what do you think the correct way to control the spread when using option 2?
clessig
left a comment
There was a problem hiding this comment.
Thanks, just some minor comments.
|
|
||
| losses, losses_chs = zip( | ||
| *[mse(target, member.unsqueeze(0), weights_channels, weights_points) for member in pred], | ||
| strict=False, |
There was a problem hiding this comment.
Why is strict=False necessary here?
| pred: torch.Tensor, | ||
| weights_channels: torch.Tensor | None, | ||
| weights_points: torch.Tensor | None, | ||
| use_ensemble_mean: bool = False, |
There was a problem hiding this comment.
Can we control this from the config?
Description
Implemented mse_ens with the updated signature with 2 options:
1- applying mse on each ensemble member vs target and then take the mean (default)
2- calculating the mean of all ensembles and then calculate the loss (configurable)
The default one to be decided after discussion
Issue Number
Closes #2250
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60