-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanalysis.py
More file actions
62 lines (45 loc) · 1.93 KB
/
analysis.py
File metadata and controls
62 lines (45 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from scipy.stats import mode
from sklearn.metrics import silhouette_score
from sklearn.manifold import trustworthiness
from scipy.spatial.distance import pdist
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
def get_knn_score(embeddings, labels):
# get KNN score
n_neighbors = 6
knn = KNeighborsClassifier(n_neighbors=n_neighbors + 1)
knn.fit(embeddings, labels)
distances, indices = knn.kneighbors(embeddings)
# exclude the point itself and predict the label
# predicted_labels = np.array([mode(labels[indices[i][1:]])[0][0] for i in range(len(labels))])
predicted_labels = [mode(labels[indices[i][1:]])[0] for i in range(len(labels))]
# calculate accuracy
accuracy = np.mean(predicted_labels == labels)
return accuracy
def get_scores(file_path):
# load all raw data
data = np.load(file_path)
high_dim_data = data["high_dim_data"]
# loop through each epoch
for epoch in range(data["embeddings"].shape[0]):
# get epoch data
embeddings = data["embeddings"][epoch]
colors = data["colors"][epoch]
labels = data["labels"][epoch]
print(f"------------ {epoch} ------------")
print(embeddings.shape)
print(labels.shape)
knn_accuracy = get_knn_score(embeddings, labels)
print("KNN accuracy: ", knn_accuracy)
sil_score = silhouette_score(embeddings, labels)
print("silhouette score: ", sil_score)
trust_score = trustworthiness(high_dim_data, embeddings, n_neighbors=5)
print("trustworthiness score: ", trust_score)
original_distances = pdist(high_dim_data)
embedded_distances = pdist(embeddings)
correlation, _ = pearsonr(original_distances, embedded_distances)
print("shepard diagram correlation: ", correlation)
if __name__ == "__main__":
get_scores("images/raw-data.npz")