diff --git a/acestep/apg_guidance.py b/acestep/apg_guidance.py index fbe9a4dc..e49ae547 100644 --- a/acestep/apg_guidance.py +++ b/acestep/apg_guidance.py @@ -25,8 +25,8 @@ def project( v1 = torch.nn.functional.normalize(v1, dim=dims) v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel - return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to( - device_type + return v0_parallel.to(dtype).to(v0.device), v0_orthogonal.to(dtype).to( + v0.device )