|
36 | 36 | LMSDiscreteScheduler, |
37 | 37 | UniPCMultistepScheduler, |
38 | 38 | VQDiffusionScheduler, |
39 | | - logging, |
40 | 39 | ) |
41 | 40 | from diffusers.configuration_utils import ConfigMixin, register_to_config |
42 | 41 | from diffusers.schedulers.scheduling_utils import SchedulerMixin |
| 42 | +from diffusers.utils import logging |
43 | 43 | from diffusers.utils.testing_utils import CaptureLogger, torch_device |
44 | 44 |
|
45 | 45 | from ..others.test_utils import TOKEN, USER, is_staging_test |
|
48 | 48 | torch.backends.cuda.matmul.allow_tf32 = False |
49 | 49 |
|
50 | 50 |
|
| 51 | +logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 52 | + |
| 53 | + |
51 | 54 | class SchedulerObject(SchedulerMixin, ConfigMixin): |
52 | 55 | config_name = "config.json" |
53 | 56 |
|
@@ -253,6 +256,60 @@ class SchedulerCommonTest(unittest.TestCase): |
253 | 256 | scheduler_classes = () |
254 | 257 | forward_default_kwargs = () |
255 | 258 |
|
| 259 | + @property |
| 260 | + def default_num_inference_steps(self): |
| 261 | + return 50 |
| 262 | + |
| 263 | + @property |
| 264 | + def default_timestep(self): |
| 265 | + kwargs = dict(self.forward_default_kwargs) |
| 266 | + num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps) |
| 267 | + |
| 268 | + try: |
| 269 | + scheduler_config = self.get_scheduler_config() |
| 270 | + scheduler = self.scheduler_classes[0](**scheduler_config) |
| 271 | + |
| 272 | + scheduler.set_timesteps(num_inference_steps) |
| 273 | + timestep = scheduler.timesteps[0] |
| 274 | + except NotImplementedError: |
| 275 | + logger.warning( |
| 276 | + f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method." |
| 277 | + f" `default_timestep` will be set to the default value of 1." |
| 278 | + ) |
| 279 | + timestep = 1 |
| 280 | + |
| 281 | + return timestep |
| 282 | + |
| 283 | + # NOTE: currently taking the convention that default_timestep > default_timestep_2 (alternatively, |
| 284 | + # default_timestep comes earlier in the timestep schedule than default_timestep_2) |
| 285 | + @property |
| 286 | + def default_timestep_2(self): |
| 287 | + kwargs = dict(self.forward_default_kwargs) |
| 288 | + num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps) |
| 289 | + |
| 290 | + try: |
| 291 | + scheduler_config = self.get_scheduler_config() |
| 292 | + scheduler = self.scheduler_classes[0](**scheduler_config) |
| 293 | + |
| 294 | + scheduler.set_timesteps(num_inference_steps) |
| 295 | + if len(scheduler.timesteps) >= 2: |
| 296 | + timestep_2 = scheduler.timesteps[1] |
| 297 | + else: |
| 298 | + logger.warning( |
| 299 | + f"Using num_inference_steps from the scheduler testing class's default config leads to a timestep" |
| 300 | + f" scheduler of length {len(scheduler.timesteps)} < 2. The default `default_timestep_2` value of 0" |
| 301 | + f" will be used." |
| 302 | + ) |
| 303 | + timestep_2 = 0 |
| 304 | + except NotImplementedError: |
| 305 | + logger.warning( |
| 306 | + f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method." |
| 307 | + f" `default_timestep_2` will be set to the default value of 0." |
| 308 | + ) |
| 309 | + timestep_2 = 0 |
| 310 | + |
| 311 | + return timestep_2 |
| 312 | + |
256 | 313 | @property |
257 | 314 | def dummy_sample(self): |
258 | 315 | batch_size = 4 |
@@ -313,6 +370,7 @@ def check_over_configs(self, time_step=0, **config): |
313 | 370 | kwargs = dict(self.forward_default_kwargs) |
314 | 371 |
|
315 | 372 | num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 373 | + time_step = time_step if time_step is not None else self.default_timestep |
316 | 374 |
|
317 | 375 | for scheduler_class in self.scheduler_classes: |
318 | 376 | # TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default |
@@ -371,6 +429,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): |
371 | 429 | kwargs.update(forward_kwargs) |
372 | 430 |
|
373 | 431 | num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 432 | + time_step = time_step if time_step is not None else self.default_timestep |
374 | 433 |
|
375 | 434 | for scheduler_class in self.scheduler_classes: |
376 | 435 | if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): |
@@ -411,10 +470,10 @@ def check_over_forward(self, time_step=0, **forward_kwargs): |
411 | 470 | def test_from_save_pretrained(self): |
412 | 471 | kwargs = dict(self.forward_default_kwargs) |
413 | 472 |
|
414 | | - num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 473 | + num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps) |
415 | 474 |
|
416 | 475 | for scheduler_class in self.scheduler_classes: |
417 | | - timestep = 1 |
| 476 | + timestep = self.default_timestep |
418 | 477 | if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): |
419 | 478 | timestep = float(timestep) |
420 | 479 |
|
@@ -497,10 +556,10 @@ def test_from_pretrained(self): |
497 | 556 | def test_step_shape(self): |
498 | 557 | kwargs = dict(self.forward_default_kwargs) |
499 | 558 |
|
500 | | - num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 559 | + num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps) |
501 | 560 |
|
502 | | - timestep_0 = 1 |
503 | | - timestep_1 = 0 |
| 561 | + timestep_0 = self.default_timestep |
| 562 | + timestep_1 = self.default_timestep_2 |
504 | 563 |
|
505 | 564 | for scheduler_class in self.scheduler_classes: |
506 | 565 | if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): |
@@ -558,9 +617,9 @@ def recursive_check(tuple_object, dict_object): |
558 | 617 | ) |
559 | 618 |
|
560 | 619 | kwargs = dict(self.forward_default_kwargs) |
561 | | - num_inference_steps = kwargs.pop("num_inference_steps", 50) |
| 620 | + num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps) |
562 | 621 |
|
563 | | - timestep = 0 |
| 622 | + timestep = self.default_timestep |
564 | 623 | if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler: |
565 | 624 | timestep = 1 |
566 | 625 |
|
@@ -644,7 +703,7 @@ def test_add_noise_device(self): |
644 | 703 | continue |
645 | 704 | scheduler_config = self.get_scheduler_config() |
646 | 705 | scheduler = scheduler_class(**scheduler_config) |
647 | | - scheduler.set_timesteps(100) |
| 706 | + scheduler.set_timesteps(self.default_num_inference_steps) |
648 | 707 |
|
649 | 708 | sample = self.dummy_sample.to(torch_device) |
650 | 709 | if scheduler_class == CMStochasticIterativeScheduler: |
|
0 commit comments