diff --git a/avocodo/lightning_module.py b/avocodo/lightning_module.py index fc7f7c2..86e3f12 100644 --- a/avocodo/lightning_module.py +++ b/avocodo/lightning_module.py @@ -11,7 +11,6 @@ from avocodo.models.losses import feature_loss from avocodo.models.losses import generator_loss from avocodo.models.losses import discriminator_loss -from avocodo.pqmf import PQMF class Avocodo(LightningModule): @@ -22,11 +21,8 @@ def __init__( super().__init__() self.save_hyperparameters(h) - self.pqmf_lv2 = PQMF(*self.hparams.pqmf_config["lv2"]) - self.pqmf_lv1 = PQMF(*self.hparams.pqmf_config["lv1"]) - self.generator = Generator(self.hparams.generator) - self.combd = CoMBD(self.hparams.combd, [self.pqmf_lv2, self.pqmf_lv1]) + self.combd = CoMBD(self.hparams.combd) self.sbd = SBD(self.hparams.sbd) def configure_optimizers(self): @@ -43,22 +39,12 @@ def forward(self, z): def training_step(self, batch, batch_idx, optimizer_idx): x, y, _, y_mel = batch y = y.unsqueeze(1) - ys = [ - self.pqmf_lv2.analysis( - y - )[:, :self.hparams.generator.projection_filters[1]], - self.pqmf_lv1.analysis( - y - )[:, :self.hparams.generator.projection_filters[2]], - y - ] - y_g_hats = self.generator(x) # train generator if optimizer_idx == 0: y_du_hat_r, y_du_hat_g, fmap_u_r, fmap_u_g = self.combd( - ys, y_g_hats) + y, y_g_hats) loss_fm_u, losses_fm_u = feature_loss(fmap_u_r, fmap_u_g) loss_gen_u, losses_gen_u = generator_loss(y_du_hat_g) @@ -91,7 +77,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): detached_y_g_hats = [x.detach() for x in y_g_hats] y_du_hat_r, y_du_hat_g, _, _ = self.combd( - ys, detached_y_g_hats) + y, detached_y_g_hats) loss_disc_u, losses_disc_u_r, losses_disc_u_g = discriminator_loss( y_du_hat_r, y_du_hat_g) diff --git a/avocodo/models/CoMBD.py b/avocodo/models/CoMBD.py index 680b303..15aa245 100644 --- a/avocodo/models/CoMBD.py +++ b/avocodo/models/CoMBD.py @@ -65,16 +65,16 @@ def forward(self, x): class CoMBD(torch.nn.Module): - def __init__(self, h, pqmf_list=None, use_spectral_norm=False): + def __init__(self, h, pqmf_list: List=None, use_spectral_norm=False): super(CoMBD, self).__init__() self.h = h if pqmf_list is not None: - self.pqmf = pqmf_list + self.pqmf = nn.ModuleList(pqmf_list) else: - self.pqmf = [ + self.pqmf = nn.ModuleList([ PQMF(*h.pqmf_config["lv2"]), PQMF(*h.pqmf_config["lv1"]) - ] + ]) self.blocks = nn.ModuleList() for _h_u, _d_k, _d_s, _d_d, _d_g, _d_p, _op_f, _op_k, _op_g in zip( @@ -107,18 +107,18 @@ def _block_forward(self, input, blocks, outs, f_maps): f_maps.append(f_map) return outs, f_maps - def _pqmf_forward(self, ys, ys_hat): + def _pqmf_forward(self, y, ys_hat): # preprocess for multi_scale forward - multi_scale_inputs = [] + ys = [] multi_scale_inputs_hat = [] for pqmf in self.pqmf: - multi_scale_inputs.append( - pqmf.to(ys[-1]).analysis(ys[-1])[:, :1, :] + ys.append( + pqmf.analysis(y)[:, :1, :] ) multi_scale_inputs_hat.append( - pqmf.to(ys[-1]).analysis(ys_hat[-1])[:, :1, :] + pqmf.analysis(ys_hat[-1])[:, :1, :] ) - + ys.append(y) outs_real = [] f_maps_real = [] # real @@ -126,8 +126,8 @@ def _pqmf_forward(self, ys, ys_hat): outs_real, f_maps_real = self._block_forward( ys, self.blocks, outs_real, f_maps_real) # for multi_scale forward - outs_real, f_maps_real = self._block_forward( - multi_scale_inputs, self.blocks[:-1], outs_real, f_maps_real) + outs_real.extend(outs_real[:-1]) + f_maps_real.extend(f_maps_real[:-1]) outs_fake = [] f_maps_fake = [] @@ -141,7 +141,7 @@ def _pqmf_forward(self, ys, ys_hat): return outs_real, outs_fake, f_maps_real, f_maps_fake - def forward(self, ys, ys_hat): + def forward(self, y, ys_hat): outs_real, outs_fake, f_maps_real, f_maps_fake = self._pqmf_forward( - ys, ys_hat) + y, ys_hat) return outs_real, outs_fake, f_maps_real, f_maps_fake