diff --git a/models/fusion_model.py b/models/fusion_model.py old mode 100644 new mode 100755 index 0f6ef6b..81d17c1 --- a/models/fusion_model.py +++ b/models/fusion_model.py @@ -29,34 +29,33 @@ def initialize(self, opt): # load/define networks num_in = opt.input_nc + opt.output_nc + 1 - + self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf, 'instance', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, use_tanh=True, classification=False) self.netG.eval() - + self.netGF = networks.define_G(num_in, opt.output_nc, opt.ngf, - 'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, - use_tanh=True, classification=False) + 'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, + use_tanh=True, classification=False) self.netGF.eval() self.netGComp = networks.define_G(num_in, opt.output_nc, opt.ngf, - 'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, - use_tanh=True, classification=opt.classification) + 'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, + use_tanh=True, classification=opt.classification) self.netGComp.eval() - def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.hint_B = input['hint_B'].to(self.device) - + self.mask_B = input['mask_B'].to(self.device) self.mask_B_nc = self.mask_B + self.opt.mask_cent self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt) - + def set_fusion_input(self, input, box_info): AtoB = self.opt.which_direction == 'AtoB' self.full_real_A = input['A' if AtoB else 'B'].to(self.device) @@ -84,30 +83,51 @@ def set_forward_without_box(self, input): def forward(self): (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) - self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, self.box_info_list) - + self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, + self.box_info_list) + def save_current_imgs(self, path): - out_img = torch.clamp(util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0) + FloatTensorType = torch.cuda.FloatTensor if len(self.gpu_ids) != 0 else torch.FloatTensor + out_img = torch.clamp(util.lab2rgb( + torch.cat((self.full_real_A.type(FloatTensorType), self.fake_B_reg.type(FloatTensorType)), + dim=1), self.opt), 0.0, 1.0) out_img = np.transpose(out_img.cpu().data.numpy()[0], (1, 2, 0)) io.imsave(path, img_as_ubyte(out_img)) def setup_to_test(self, fusion_weight_path): GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path) print('load Fusion model from %s' % GF_path) - GF_state_dict = torch.load(GF_path) - + + target_device = torch.device('cpu') if len(self.gpu_ids) == 0 else None + GF_state_dict = torch.load(GF_path, map_location=target_device) + # G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path) - G_state_dict = torch.load(G_path) + G_state_dict = torch.load(G_path, target_device) # GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net # GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path) - GComp_state_dict = torch.load(GComp_path) - - self.netGF.load_state_dict(GF_state_dict, strict=False) - self.netG.module.load_state_dict(G_state_dict, strict=False) - self.netGComp.module.load_state_dict(GComp_state_dict, strict=False) + GComp_state_dict = torch.load(GComp_path, target_device) + + # It's bad to call load_state_dict() with strict=False + if len(self.gpu_ids) == 0: + try: + self.netGF.load_state_dict(GF_state_dict) + except RuntimeError as e1: + import sys + print(f"{e1}\nAre you using cuda when you training this model?", file=sys.stderr) + GF_state_dict_noparallel = OrderedDict() + for k, v in GF_state_dict.items(): + name = k[7:] + GF_state_dict_noparallel[name] = v + self.netGF.load_state_dict(GF_state_dict_noparallel) + self.netG.load_state_dict(G_state_dict) + self.netGComp.load_state_dict(GComp_state_dict) + else: + self.netGF.load_state_dict(GF_state_dict) + self.netG.module.load_state_dict(G_state_dict) + self.netGComp.module.load_state_dict(GComp_state_dict) self.netGF.eval() self.netG.eval() - self.netGComp.eval() \ No newline at end of file + self.netGComp.eval() diff --git a/test_fusion.py b/test_fusion.py old mode 100644 new mode 100755 index ea571e9..b51e393 --- a/test_fusion.py +++ b/test_fusion.py @@ -16,6 +16,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0" import numpy as np import multiprocessing + + multiprocessing.set_start_method('spawn', True) torch.backends.cudnn.benchmark = True @@ -41,9 +43,11 @@ for data_raw in tqdm(dataset_loader, dynamic_ncols=True): # if os.path.isfile(join(save_img_path, data_raw['file_id'][0] + '.png')) is True: # continue - data_raw['full_img'][0] = data_raw['full_img'][0].cuda() + if len(opt.gpu_ids) != 0: + data_raw['full_img'][0] = data_raw['full_img'][0].cuda() if data_raw['empty_box'][0] == 0: - data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda() + if len(opt.gpu_ids) != 0: + data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda() box_info = data_raw['box_info'][0] box_info_2x = data_raw['box_info_2x'][0] box_info_4x = data_raw['box_info_4x'][0]