You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
In file https://github.com/google/TensorNetwork/blob/master/tensornetwork/matrixproductstates/base_mps.py,
line 319: res.append(self.backend.item(result.tensor))
and line 479 return [self.backend.item(o) for o in c],
the using of self.backend.item is incompatible with autograd in jax (and maybe also other backends).
I haven't checked with other files so those files might have similar issues.
Here's a simple example:
import tensornetwork as tn
import numpy as np
import jax
tn.set_default_backend('jax')
Z = jax.numpy.asarray(np.array([[1.0, 0.0], [0.0, -1.0]], dtype=np.complex64))
def func(x):
mps = tn.FiniteMPS.random([2, 2, 2, 2], [4, 4, 4], dtype=np.complex64)
gate = jax.scipy.linalg.expm(Z * x)
e = mps.measure_local_operator([gate], [0])
return e[0]
print(func(1.0)) # output: (1.2248424291610718-2.9802322387695312e-08j)
vg = jax.value_and_grad(func)
print(vg(1.0)) # error: AttributeError: 'ConcreteArray' object has no attribute 'item'