Skip to content

Commit 35fd84b

Browse files
dg845sayakpaulyiyixuxuDN6
authored
Replace hardcoded values in SchedulerCommonTest with properties (#5479)
--------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent f275625 commit 35fd84b

File tree

1 file changed

+68
-9
lines changed

1 file changed

+68
-9
lines changed

tests/schedulers/test_schedulers.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
LMSDiscreteScheduler,
3737
UniPCMultistepScheduler,
3838
VQDiffusionScheduler,
39-
logging,
4039
)
4140
from diffusers.configuration_utils import ConfigMixin, register_to_config
4241
from diffusers.schedulers.scheduling_utils import SchedulerMixin
42+
from diffusers.utils import logging
4343
from diffusers.utils.testing_utils import CaptureLogger, torch_device
4444

4545
from ..others.test_utils import TOKEN, USER, is_staging_test
@@ -48,6 +48,9 @@
4848
torch.backends.cuda.matmul.allow_tf32 = False
4949

5050

51+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52+
53+
5154
class SchedulerObject(SchedulerMixin, ConfigMixin):
5255
config_name = "config.json"
5356

@@ -253,6 +256,60 @@ class SchedulerCommonTest(unittest.TestCase):
253256
scheduler_classes = ()
254257
forward_default_kwargs = ()
255258

259+
@property
260+
def default_num_inference_steps(self):
261+
return 50
262+
263+
@property
264+
def default_timestep(self):
265+
kwargs = dict(self.forward_default_kwargs)
266+
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
267+
268+
try:
269+
scheduler_config = self.get_scheduler_config()
270+
scheduler = self.scheduler_classes[0](**scheduler_config)
271+
272+
scheduler.set_timesteps(num_inference_steps)
273+
timestep = scheduler.timesteps[0]
274+
except NotImplementedError:
275+
logger.warning(
276+
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
277+
f" `default_timestep` will be set to the default value of 1."
278+
)
279+
timestep = 1
280+
281+
return timestep
282+
283+
# NOTE: currently taking the convention that default_timestep > default_timestep_2 (alternatively,
284+
# default_timestep comes earlier in the timestep schedule than default_timestep_2)
285+
@property
286+
def default_timestep_2(self):
287+
kwargs = dict(self.forward_default_kwargs)
288+
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
289+
290+
try:
291+
scheduler_config = self.get_scheduler_config()
292+
scheduler = self.scheduler_classes[0](**scheduler_config)
293+
294+
scheduler.set_timesteps(num_inference_steps)
295+
if len(scheduler.timesteps) >= 2:
296+
timestep_2 = scheduler.timesteps[1]
297+
else:
298+
logger.warning(
299+
f"Using num_inference_steps from the scheduler testing class's default config leads to a timestep"
300+
f" scheduler of length {len(scheduler.timesteps)} < 2. The default `default_timestep_2` value of 0"
301+
f" will be used."
302+
)
303+
timestep_2 = 0
304+
except NotImplementedError:
305+
logger.warning(
306+
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
307+
f" `default_timestep_2` will be set to the default value of 0."
308+
)
309+
timestep_2 = 0
310+
311+
return timestep_2
312+
256313
@property
257314
def dummy_sample(self):
258315
batch_size = 4
@@ -313,6 +370,7 @@ def check_over_configs(self, time_step=0, **config):
313370
kwargs = dict(self.forward_default_kwargs)
314371

315372
num_inference_steps = kwargs.pop("num_inference_steps", None)
373+
time_step = time_step if time_step is not None else self.default_timestep
316374

317375
for scheduler_class in self.scheduler_classes:
318376
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
@@ -371,6 +429,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
371429
kwargs.update(forward_kwargs)
372430

373431
num_inference_steps = kwargs.pop("num_inference_steps", None)
432+
time_step = time_step if time_step is not None else self.default_timestep
374433

375434
for scheduler_class in self.scheduler_classes:
376435
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
@@ -411,10 +470,10 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
411470
def test_from_save_pretrained(self):
412471
kwargs = dict(self.forward_default_kwargs)
413472

414-
num_inference_steps = kwargs.pop("num_inference_steps", None)
473+
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
415474

416475
for scheduler_class in self.scheduler_classes:
417-
timestep = 1
476+
timestep = self.default_timestep
418477
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
419478
timestep = float(timestep)
420479

@@ -497,10 +556,10 @@ def test_from_pretrained(self):
497556
def test_step_shape(self):
498557
kwargs = dict(self.forward_default_kwargs)
499558

500-
num_inference_steps = kwargs.pop("num_inference_steps", None)
559+
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
501560

502-
timestep_0 = 1
503-
timestep_1 = 0
561+
timestep_0 = self.default_timestep
562+
timestep_1 = self.default_timestep_2
504563

505564
for scheduler_class in self.scheduler_classes:
506565
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
@@ -558,9 +617,9 @@ def recursive_check(tuple_object, dict_object):
558617
)
559618

560619
kwargs = dict(self.forward_default_kwargs)
561-
num_inference_steps = kwargs.pop("num_inference_steps", 50)
620+
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
562621

563-
timestep = 0
622+
timestep = self.default_timestep
564623
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
565624
timestep = 1
566625

@@ -644,7 +703,7 @@ def test_add_noise_device(self):
644703
continue
645704
scheduler_config = self.get_scheduler_config()
646705
scheduler = scheduler_class(**scheduler_config)
647-
scheduler.set_timesteps(100)
706+
scheduler.set_timesteps(self.default_num_inference_steps)
648707

649708
sample = self.dummy_sample.to(torch_device)
650709
if scheduler_class == CMStochasticIterativeScheduler:

0 commit comments

Comments
 (0)