@JunrQ put `lat` in `loss = [..., stop_grad(lat)]` and print it with `print(outputs[-1].asnumpy())`
@JunrQ
put
latinloss = [..., stop_grad(lat)]and print it with
print(outputs[-1].asnumpy())