1919 delete_offload_module ,
2020 delete_offload_parameter ,
2121 disable_hf_hook ,
22+ disable_offloading ,
2223 get_execution_device ,
2324 has_offloaded_params ,
2425 offloaded_dispatch ,
@@ -397,29 +398,37 @@ def test_delete_offload_module(exec_device):
397398
398399@requires_gpu
399400@requires_accelerate ()
400- @pytest .mark .parametrize ("exec_device" , [torch .device ("cpu" ), torch .device ("cuda" )])
401- def test_offloaded_dispatch (exec_device ):
401+ @pytest .mark .parametrize (
402+ "exec_device,offload_device" ,
403+ [
404+ (torch .device ("cpu" ), torch .device ("cpu" )),
405+ (torch .device ("cpu" ), torch .device ("cuda:0" )),
406+ (torch .device ("cuda:0" ), torch .device ("cpu" )),
407+ (torch .device ("cuda:0" ), torch .device ("cuda:0" )),
408+ ],
409+ )
410+ def test_offloaded_dispatch (exec_device , offload_device ):
402411 # single module
403- module = torch .nn .Linear (1 , 2 )
404- module = offloaded_dispatch (module , exec_device )
412+ module = torch .nn .Linear (1 , 2 , device = offload_device )
413+ module = offloaded_dispatch (module , exec_device , offload_device )
405414 assert has_offloaded_params (module )
406415 assert module ._hf_hook .offload
407416 assert module .weight .device == torch .device ("meta" )
408- assert "weight" in module ._hf_hook .weights_map
417+ assert module ._hf_hook .weights_map [ "weight" ]. device == offload_device
409418 assert module ._hf_hook .tied_params_map is not None
410419
411420 # can run
412421 module (torch .empty (1 , device = exec_device ))
413422
414423 # model
415424 model = ExampleModel ()
416- model = offloaded_dispatch (model , exec_device )
425+ model = offloaded_dispatch (model , exec_device , offload_device )
417426 assert not has_offloaded_params (model )
418427
419428 assert has_offloaded_params (model .linear )
420429 assert model .linear ._hf_hook .offload
421430 assert model .linear .weight .device == torch .device ("meta" )
422- assert "weight" in model .linear ._hf_hook .weights_map
431+ assert model .linear ._hf_hook .weights_map [ "weight" ]. device == offload_device
423432 assert model .linear ._hf_hook .tied_params_map is not None
424433
425434 # can run
@@ -429,4 +438,43 @@ def test_offloaded_dispatch(exec_device):
429438 parameter = torch .nn .Parameter (torch .tensor (1.0 ))
430439 register_offload_parameter (module , "new_param" , parameter )
431440 assert module .new_param .device == torch .device ("meta" )
432- assert module ._hf_hook .weights_map ["new_param" ].device == torch .device ("cpu" )
441+ assert module ._hf_hook .weights_map ["new_param" ].device == offload_device
442+
443+
444+ @requires_gpu
445+ @requires_accelerate ()
446+ @pytest .mark .parametrize (
447+ "exec_device,offload_device" ,
448+ [
449+ (torch .device ("cpu" ), torch .device ("cpu" )),
450+ (torch .device ("cpu" ), torch .device ("cuda:0" )),
451+ (torch .device ("cuda:0" ), torch .device ("cpu" )),
452+ (torch .device ("cuda:0" ), torch .device ("cuda:0" )),
453+ ],
454+ )
455+ def test_disable_offloading (exec_device , offload_device ):
456+ module = torch .nn .Linear (1 , 2 , device = exec_device )
457+
458+ # non-offloaded modules are unaffected
459+ with disable_offloading ():
460+ output = module (torch .empty (1 , device = exec_device ))
461+ assert module .weight .device == exec_device
462+ assert output .device == exec_device
463+
464+ # offloaded modules stay on device until context exit
465+ offloaded_dispatch (module , exec_device , offload_device )
466+ assert module .weight .device == torch .device ("meta" )
467+ assert module ._hf_hook .weights_map ["weight" ].device == offload_device
468+
469+ with disable_offloading ():
470+ assert module .weight .device == torch .device ("meta" )
471+ output = module (torch .empty (1 , device = exec_device ))
472+ assert module .weight .device == exec_device
473+ assert output .device == exec_device
474+
475+ output = module (torch .empty (1 , device = exec_device ))
476+ assert module .weight .device == exec_device
477+ assert output .device == exec_device
478+
479+ assert module .weight .device == torch .device ("meta" )
480+ assert module ._hf_hook .weights_map ["weight" ].device == offload_device
0 commit comments