Skip to content

Support for "kl_divergence" #44

@ControllableGeneration

Description

@ControllableGeneration

I have implemented support for Kullback-Leibler divergence as follows. Shall I make a pull request of it?

def pairwise_kl_divergence(data1, data2, device=torch.device('cpu')):
    # transfer to device
    data1, data2 = data1.to(device), data2.to(device)

    # N*1*M
    A = data1.unsqueeze(dim=1)

    # 1*N*M
    B = data2.unsqueeze(dim=0)

    # normalize the points 
    A_normalized = torch.nn.functional.log_softmax(A, dim=-1)
    B_normalized = torch.nn.functional.log_softmax(B, dim=-1)

    kl_div = torch.nn.functional.kl_div(A_normalized, B_normalized, reduction='none', log_target=True)

    # return N*N matrix for pairwise distance
    kl_div = kl_div.mean(-1)
    return kl_div

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions