@@ -1419,7 +1419,70 @@ def perfect_sampling(self) -> Tuple[str, float]:
14191419 """
14201420 return self .measure_jit (* [i for i in range (self ._nqubits )], with_prob = True )
14211421
1422- sample = perfect_sampling
1422+ # sample = perfect_sampling
1423+
1424+ def sample (
1425+ self ,
1426+ batch : Optional [int ] = None ,
1427+ allow_state : bool = False ,
1428+ status : Optional [Tensor ] = None ,
1429+ ) -> Any :
1430+ """
1431+ batched sampling from state or circuit tensor network directly
1432+
1433+ :param batch: number of samples, defaults to None
1434+ :type batch: Optional[int], optional
1435+ :param allow_state: if true, we sample from the final state
1436+ if memory allsows, True is prefered, defaults to False
1437+ :type allow_state: bool, optional
1438+ :param status: random generator, defaults to None
1439+ :type status: Optional[Tensor], optional
1440+ :return: List (if batch) of tuple (binary configuration tensor and correponding probability)
1441+ :rtype: Any
1442+ """
1443+ # allow_state = False is compatibility issue
1444+ if not allow_state :
1445+ if batch is None :
1446+ return self .perfect_sampling ()
1447+
1448+ @backend .jit # type: ignore
1449+ def perfect_sampling (key : Any ) -> Any :
1450+ backend .set_random_state (key )
1451+ return self .perfect_sampling ()
1452+
1453+ r = []
1454+ if status is None :
1455+ status = backend .get_random_state ()
1456+ subkey = status
1457+ for _ in range (batch ):
1458+ key , subkey = backend .random_split (subkey )
1459+ r .append (perfect_sampling (key ))
1460+
1461+ return r
1462+
1463+ if batch is None :
1464+ nbatch = 1
1465+ else :
1466+ nbatch = batch
1467+ s = self .state ()
1468+ p = backend .abs (s ) ** 2
1469+ if status is None :
1470+ ch = backend .implicit_randc (a = 2 ** self ._nqubits , shape = [nbatch ], p = p )
1471+ else :
1472+ ch = backend .stateful_randc (
1473+ status , a = 2 ** self ._nqubits , shape = [nbatch ], p = p
1474+ )
1475+ prob = backend .gather1d (p , ch )
1476+ confg = backend .mod (
1477+ backend .right_shift (
1478+ ch [..., None ], backend .reverse (backend .arange (self ._nqubits ))
1479+ ),
1480+ 2 ,
1481+ )
1482+ r = list (zip (confg , prob ))
1483+ if batch is None :
1484+ r = r [0 ]
1485+ return r
14231486
14241487 # TODO(@refraction-ray): more _before function like state_before? and better API?
14251488
0 commit comments