diff --git a/lib/model_zoo/ddim.py b/lib/model_zoo/ddim.py index 13995df..c0785d6 100644 --- a/lib/model_zoo/ddim.py +++ b/lib/model_zoo/ddim.py @@ -92,7 +92,7 @@ def ddim_sampling(self, bs = shape[0] timesteps = self.ddim_timesteps if ('xt' in x_info) and (x_info['xt'] is not None): - xt = x_info['xt'].astype(dtype).to(device) + xt = x_info['xt'].type(dtype).to(device) x_info['x'] = xt elif ('x0' in x_info) and (x_info['x0'] is not None): x0 = x_info['x0'].type(dtype).to(device)