Skip to content

Commit 6e68540

Browse files
add torchnn module
1 parent 2f6a28f commit 6e68540

File tree

6 files changed

+107
-1
lines changed

6 files changed

+107
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Unreleased
44

5+
### Added
6+
7+
- Add PyTorch nn Module wrapper in `torchnn`
8+
59
## 0.1.3
610

711
### Added
@@ -32,7 +36,7 @@
3236

3337
- Fix installation issue with tensorflow requirements on MACOS with M1 chip
3438

35-
- Improve M1 macOS compayibility with unjit tensorflow ops
39+
- Improve M1 macOS compatibility with unjit tensorflow ops
3640

3741
- Fixed SVD backprop bug on jax backend of wide matrix
3842

docs/source/infras.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Overview of Modules
3434

3535
- :py:mod:`tensorcircuit.keras`: Provide TensorFlow Keras layers, as well as wrappers of jitted function, save/load from tf side.
3636

37+
- :py:mod:`tensorcircuit.torchnn`: Provide PyTorch nn Modules.
38+
3739
**MPS and MPO Utiliy Modules:**
3840

3941
- :py:mod:`tensorcircuit.quantum`: Provide definition and classes for Matrix Product States as well as Matrix Product Operators, we also include various quantum physics and quantum information quantities in this module.

tensorcircuit/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,10 @@
3131
except ModuleNotFoundError:
3232
pass # in case tf is not installed
3333

34+
try:
35+
from . import torchnn
36+
except ModuleNotFoundError:
37+
pass # in case torch is not installed
38+
3439
# just for fun
3540
from .asciiart import set_ascii

tensorcircuit/cons.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"tensorcircuit.densitymatrix2",
3232
"tensorcircuit.channels",
3333
"tensorcircuit.keras",
34+
"tensorcircuit.torchnn",
3435
"tensorcircuit.quantum",
3536
"tensorcircuit.simplify",
3637
"tensorcircuit.interfaces",

tensorcircuit/torchnn.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
PyTorch nn Module wrapper for quantum function
3+
"""
4+
5+
from typing import Any, Callable, Sequence, Tuple, Union
6+
7+
import torch
8+
9+
from .cons import backend
10+
from .interfaces import torch_interface, is_sequence
11+
12+
Tensor = Any
13+
14+
15+
class QuantumNet(torch.nn.Module): # type: ignore
16+
def __init__(
17+
self,
18+
f: Callable[..., Any],
19+
weights_shape: Sequence[Tuple[int, ...]],
20+
initializer: Union[Any, Sequence[Any]] = None,
21+
use_vmap: bool = True,
22+
use_interface: bool = True,
23+
use_jit: bool = True,
24+
):
25+
super().__init__()
26+
if use_vmap:
27+
f = backend.vmap(f, vectorized_argnums=0)
28+
if use_interface:
29+
f = torch_interface(f, jit=use_jit)
30+
self.f = f
31+
self.q_weights = []
32+
if isinstance(weights_shape[0], int):
33+
weights_shape = [weights_shape]
34+
if not is_sequence(initializer):
35+
initializer = [initializer]
36+
for ws, initf in zip(weights_shape, initializer):
37+
if initf is None:
38+
initf = torch.randn
39+
self.q_weights.append(torch.nn.Parameter(initf(ws))) # type: ignore
40+
41+
def forward(self, inputs: Tensor) -> Tensor:
42+
ypred = self.f(inputs, *self.q_weights)
43+
return ypred

tests/test_torchnn.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import sys
3+
import numpy as np
4+
import pytest
5+
from pytest_lazyfixture import lazy_fixture as lf
6+
7+
thisfile = os.path.abspath(__file__)
8+
modulepath = os.path.dirname(os.path.dirname(thisfile))
9+
10+
sys.path.insert(0, modulepath)
11+
12+
import tensorcircuit as tc
13+
14+
try:
15+
import torch
16+
except ImportError:
17+
pytest.skip("torch is not installed")
18+
19+
20+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
21+
def test_quantumnet(backend):
22+
23+
n = 6
24+
nlayers = 2
25+
26+
def qpred(x, weights):
27+
c = tc.Circuit(n)
28+
for i in range(n):
29+
c.rx(i, theta=x[i])
30+
for j in range(nlayers):
31+
for i in range(n - 1):
32+
c.cnot(i, i + 1)
33+
for i in range(n):
34+
c.rx(i, theta=weights[2 * j, i])
35+
c.ry(i, theta=weights[2 * j + 1, i])
36+
ypred = tc.backend.stack([c.expectation_ps(x=[i]) for i in range(n)])
37+
return tc.backend.real(ypred)
38+
39+
if tc.backend.name == "pytorch":
40+
use_interface = False
41+
else:
42+
use_interface = True
43+
44+
ql = tc.torchnn.QuantumNet(
45+
qpred, weights_shape=[2 * nlayers, n], use_interface=use_interface
46+
)
47+
48+
yp = ql(torch.ones([3, n]))
49+
print(yp)
50+
51+
np.testing.assert_allclose(yp.shape, np.array([3, n]))

0 commit comments

Comments
 (0)