Skip to content

Commit 325906e

Browse files
add sample benchmark
1 parent b08ad6c commit 325906e

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

examples/sample_benchmark.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
perfect sampling vs. state sampling
3+
the benchmark results show that only use perfect/tensor sampling when the wavefunction doesn't fit in memory
4+
"""
5+
6+
import time
7+
import numpy as np
8+
import tensorcircuit as tc
9+
10+
K = tc.set_backend("jax")
11+
# tf staging is too slow
12+
13+
14+
def construct_circuit(n, nlayers):
15+
c = tc.Circuit(n)
16+
for i in range(n):
17+
c.H(i)
18+
for _ in range(nlayers):
19+
for i in range(n):
20+
c.cnot(i, (i + nlayers - 1) % n)
21+
return c
22+
23+
24+
for n in [8, 10, 12, 14, 16]:
25+
for nlayers in [2, 6, 10]:
26+
print("n: ", n, " nlayers: ", nlayers)
27+
c = construct_circuit(n, nlayers)
28+
time0 = time.time()
29+
s = c.state()
30+
time1 = time.time()
31+
smp = bin(np.random.choice(range(2 ** n), p=np.abs(K.numpy(s)) ** 2))
32+
# print(smp)
33+
print("state sampling time: ", time1 - time0)
34+
time0 = time.time()
35+
smp = c.sample()
36+
# print(smp)
37+
time1 = time.time()
38+
print("nonjit tensor sampling time: ", time1 - time0)
39+
40+
@K.jit
41+
def f(key):
42+
K.set_random_state(key)
43+
return c.sample()
44+
45+
key = K.get_random_state(42)
46+
key1, key2 = K.random_split(key)
47+
time0 = time.time()
48+
smp = f(key1)
49+
time1 = time.time()
50+
for _ in range(5):
51+
key1, key2 = K.random_split(key2)
52+
smp = f(key1)
53+
# print(smp)
54+
time2 = time.time()
55+
56+
print("jittable tensor sampling staging time: ", time1 - time0)
57+
print("jittable tensor sampling running time: ", (time2 - time1) / 5)

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,9 @@ def _rq_tf(
198198
phases = tf.math.sign(tf.linalg.diag_part(r))
199199
q = q * phases
200200
r = phases[:, None] * r
201-
r, q = tf.math.conj(tf.transpose(r)), tf.math.conj(
202-
tf.transpose(q)
201+
r, q = (
202+
tf.math.conj(tf.transpose(r)),
203+
tf.math.conj(tf.transpose(q)),
203204
) # M=r*q at this point
204205
center_dim = tf.shape(r)[1]
205206
r = tf.reshape(r, tf.concat([left_dims, [center_dim]], axis=-1))

0 commit comments

Comments
 (0)