-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdraft2.py
More file actions
33 lines (24 loc) · 771 Bytes
/
draft2.py
File metadata and controls
33 lines (24 loc) · 771 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
# a = torch.arange(.5,1000.5,1.).reshape(100, 10).type(torch.float64)
# a = torch.arange(.5,1000.5,1.).reshape(100, 10)
# a = torch.arange(0,1000).reshape(100, 10).type(torch.float32)
a = torch.randn(100, 10)
def softmax(x): return x.exp() / (x.exp().sum(-1)).unsqueeze(-1)
def nl(input, target): return -input[range(target.shape[0]), target].log().mean()
# # target: (100), input: (100,10)
b = torch.arange(100)
b /= 10
print("a dtype: ", a.dtype)
print("b dtype: ", b.dtype)
print("b shape: ", b.shape)
smf = softmax(a)
loss = nl(smf, b)
print("smf shape: ", smf.shape)
# print("smf : ", smf )
print(smf[range(100), b])
print(smf)
import numpy as np
mx = np.ones((100, 200), np.float32)
a = 0
for i in range(100):
for j in range(200):