diff --git a/metrics.py b/metrics.py index 37d6d39..97cb03e 100644 --- a/metrics.py +++ b/metrics.py @@ -18,7 +18,7 @@ def compute_fid_(model, images1, images2): features2 = model.latent_vector.cpu().numpy() means1 = features1.mean(0) # (bs, num_features) --> (num_features) - means2 = features2.mean(0) + means2 = features2.mean(-1) # calculate mean and covariance statistics sigma1 = np.cov(features1, rowvar=False)