Better convolve1d implementation on jax and torch#52
Conversation
|
@mlz-EM can you review when you get a chance? Would be helpful if you could benchmark layers regularization on jax before and after, make sure it's not significantly slower |
Ealier benchmark was against main and speedup is culumative changes in the develop branch. Updated is against develop branch. And also layer blur only benchmark without full reconstruction, which shows the expected slowdown with larger sigma. The PR speedup is real, but the time for layer blur is neglible comparing to the iteration time
|




convolve1dusesconv_general_dilateddirectly, which should dispatch to a 1D convolution rather than a 3D convolution with kernel size (1, 1, N).xp.padhas a fallback for casestorch.nn.functional.paddoesn't supportconvolve1dnow usesxp.pad. Importantly, kernel size >= object size now worksxp.pad