Skip to content

Commit b08ad6c

Browse files
add torchnn doc
1 parent 6e68540 commit b08ad6c

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

docs/source/api/torchnn.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
tensorcircuit.torchnn
2+
==================================================
3+
.. automodule:: tensorcircuit.torchnn
4+
:members:
5+
:undoc-members:
6+
:show-inheritance:

docs/source/modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ tensorcircuit
1717
./api/quantum.rst
1818
./api/simplify.rst
1919
./api/templates.rst
20+
./api/torchnn.rst
2021
./api/translation.rst
2122
./api/utils.rst
2223
./api/vis.rst

tensorcircuit/torchnn.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,51 @@ def __init__(
2222
use_interface: bool = True,
2323
use_jit: bool = True,
2424
):
25+
"""
26+
PyTorch nn Module wrapper on quantum function ``f``.
27+
28+
:Example:
29+
30+
.. code-block:: python
31+
32+
K = tc.set_backend("tensorflow")
33+
34+
n = 6
35+
nlayers = 2
36+
batch = 2
37+
38+
def qpred(x, weights):
39+
c = tc.Circuit(n)
40+
for i in range(n):
41+
c.rx(i, theta=x[i])
42+
for j in range(nlayers):
43+
for i in range(n - 1):
44+
c.cnot(i, i + 1)
45+
for i in range(n):
46+
c.rx(i, theta=weights[2 * j, i])
47+
c.ry(i, theta=weights[2 * j + 1, i])
48+
ypred = K.stack([c.expectation_ps(x=[i]) for i in range(n)])
49+
ypred = K.real(ypred)
50+
return ypred
51+
52+
ql = tc.torchnn.QuantumNet(qpred, weights_shape=[2*nlayers, n])
53+
54+
ql(torch.ones([batch, n]))
55+
56+
57+
:param f: Quantum function with tensor in (input and weights) and tensor out.
58+
:type f: Callable[..., Any]
59+
:param weights_shape: list of shape tuple for different weights as the non-first parameters for ``f``
60+
:type weights_shape: Sequence[Tuple[int, ...]]
61+
:param initializer: function that gives the shape tuple returns torch tensor, defaults to None
62+
:type initializer: Union[Any, Sequence[Any]], optional
63+
:param use_vmap: whether apply vmap (batch input) on ``f``, defaults to True
64+
:type use_vmap: bool, optional
65+
:param use_interface: whether transform ``f`` with torch interface, defaults to True
66+
:type use_interface: bool, optional
67+
:param use_jit: whether jit ``f``, defaults to True
68+
:type use_jit: bool, optional
69+
"""
2570
super().__init__()
2671
if use_vmap:
2772
f = backend.vmap(f, vectorized_argnums=0)

0 commit comments

Comments
 (0)