-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss.py
More file actions
43 lines (38 loc) · 1.45 KB
/
loss.py
File metadata and controls
43 lines (38 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch as t
import torch.nn as nn
class RateLoss(nn.Module):
def __init__(self, var_noise=1, k=15):
super(RateLoss, self).__init__()
self.var_noise = var_noise
self.k = k
def forward(self, csi, power):
power = t.unsqueeze(power, -1)
csi = csi.permute(0, 2, 1)
csi = t.pow(csi, 2)
rx_power = t.mul(csi, power)
mask = t.eye(self.k)
valid_rx_power = t.sum(t.mul(rx_power, mask), 1)
interference = t.sum(t.mul(rx_power, 1 - mask), 1) + self.var_noise
rate = t.log2(1 + t.div(valid_rx_power, interference))
sum_rate = t.mean(t.sum(rate, 1))
loss = t.neg(sum_rate)
return loss
class EELoss(nn.Module):
def __init__(self, noise_power, user_num,pc):
super(EELoss, self).__init__()
self.noise_power = noise_power
self.user_num = user_num
self.pc = pc
def forward(self, csi, power):
power = t.unsqueeze(power, -1)
csi = csi.permute(0,2,1)
csi = t.pow(csi, 2)
rx_power = t.mul(csi, power)
mask = t.eye(self.user_num)
valid_rx_power = t.sum(t.mul(rx_power, mask), 1)
interference = t.sum(t.mul(rx_power, 1 - mask), 1) + self.noise_power
rate = t.log2(1 + t.div(valid_rx_power, interference))
ee = t.div(rate,(t.squeeze(power)+t.FloatTensor([self.pc])))
sum_ee = t.mean(t.sum(ee, 1))
loss = t.neg(sum_ee)
return loss