|
| 1 | +""" |
| 2 | +perfect sampling vs. state sampling |
| 3 | +the benchmark results show that only use perfect/tensor sampling when the wavefunction doesn't fit in memory |
| 4 | +""" |
| 5 | + |
| 6 | +import time |
| 7 | +import numpy as np |
| 8 | +import tensorcircuit as tc |
| 9 | + |
| 10 | +K = tc.set_backend("jax") |
| 11 | +# tf staging is too slow |
| 12 | + |
| 13 | + |
| 14 | +def construct_circuit(n, nlayers): |
| 15 | + c = tc.Circuit(n) |
| 16 | + for i in range(n): |
| 17 | + c.H(i) |
| 18 | + for _ in range(nlayers): |
| 19 | + for i in range(n): |
| 20 | + c.cnot(i, (i + nlayers - 1) % n) |
| 21 | + return c |
| 22 | + |
| 23 | + |
| 24 | +for n in [8, 10, 12, 14, 16]: |
| 25 | + for nlayers in [2, 6, 10]: |
| 26 | + print("n: ", n, " nlayers: ", nlayers) |
| 27 | + c = construct_circuit(n, nlayers) |
| 28 | + time0 = time.time() |
| 29 | + s = c.state() |
| 30 | + time1 = time.time() |
| 31 | + smp = bin(np.random.choice(range(2 ** n), p=np.abs(K.numpy(s)) ** 2)) |
| 32 | + # print(smp) |
| 33 | + print("state sampling time: ", time1 - time0) |
| 34 | + time0 = time.time() |
| 35 | + smp = c.sample() |
| 36 | + # print(smp) |
| 37 | + time1 = time.time() |
| 38 | + print("nonjit tensor sampling time: ", time1 - time0) |
| 39 | + |
| 40 | + @K.jit |
| 41 | + def f(key): |
| 42 | + K.set_random_state(key) |
| 43 | + return c.sample() |
| 44 | + |
| 45 | + key = K.get_random_state(42) |
| 46 | + key1, key2 = K.random_split(key) |
| 47 | + time0 = time.time() |
| 48 | + smp = f(key1) |
| 49 | + time1 = time.time() |
| 50 | + for _ in range(5): |
| 51 | + key1, key2 = K.random_split(key2) |
| 52 | + smp = f(key1) |
| 53 | + # print(smp) |
| 54 | + time2 = time.time() |
| 55 | + |
| 56 | + print("jittable tensor sampling staging time: ", time1 - time0) |
| 57 | + print("jittable tensor sampling running time: ", (time2 - time1) / 5) |
0 commit comments