@@ -22,6 +22,51 @@ def __init__(
2222 use_interface : bool = True ,
2323 use_jit : bool = True ,
2424 ):
25+ """
26+ PyTorch nn Module wrapper on quantum function ``f``.
27+
28+ :Example:
29+
30+ .. code-block:: python
31+
32+ K = tc.set_backend("tensorflow")
33+
34+ n = 6
35+ nlayers = 2
36+ batch = 2
37+
38+ def qpred(x, weights):
39+ c = tc.Circuit(n)
40+ for i in range(n):
41+ c.rx(i, theta=x[i])
42+ for j in range(nlayers):
43+ for i in range(n - 1):
44+ c.cnot(i, i + 1)
45+ for i in range(n):
46+ c.rx(i, theta=weights[2 * j, i])
47+ c.ry(i, theta=weights[2 * j + 1, i])
48+ ypred = K.stack([c.expectation_ps(x=[i]) for i in range(n)])
49+ ypred = K.real(ypred)
50+ return ypred
51+
52+ ql = tc.torchnn.QuantumNet(qpred, weights_shape=[2*nlayers, n])
53+
54+ ql(torch.ones([batch, n]))
55+
56+
57+ :param f: Quantum function with tensor in (input and weights) and tensor out.
58+ :type f: Callable[..., Any]
59+ :param weights_shape: list of shape tuple for different weights as the non-first parameters for ``f``
60+ :type weights_shape: Sequence[Tuple[int, ...]]
61+ :param initializer: function that gives the shape tuple returns torch tensor, defaults to None
62+ :type initializer: Union[Any, Sequence[Any]], optional
63+ :param use_vmap: whether apply vmap (batch input) on ``f``, defaults to True
64+ :type use_vmap: bool, optional
65+ :param use_interface: whether transform ``f`` with torch interface, defaults to True
66+ :type use_interface: bool, optional
67+ :param use_jit: whether jit ``f``, defaults to True
68+ :type use_jit: bool, optional
69+ """
2570 super ().__init__ ()
2671 if use_vmap :
2772 f = backend .vmap (f , vectorized_argnums = 0 )
0 commit comments