Skip to content

Commit a0c82ec

Browse files
auto register channels on circuit class
1 parent 3f41c65 commit a0c82ec

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
- Add `tc.utils.arg_alias` which is a decorator that adds alias argument for function with the doc fixed accordingly
3232

33+
- Add quantum channel auto resgisteration method in `Circuit` class
34+
3335
### Changed
3436

3537
- `rxx`, `ryy`, `rzz` gates now has 1/2 factor before theta consitent with `rx`, `ry`, `rz` gates. (breaking change)
@@ -54,6 +56,8 @@
5456

5557
- Fix `arg_alias` when the docstring for each argument is in multiple lines
5658

59+
- Noise channel apply methods in `DMCircuit` can also absorb `status` keyword (directly omitting it) for a consistent API with `Circuit`
60+
5761
## 0.3.1
5862

5963
### Added

tensorcircuit/circuit.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tensornetwork as tn
1212

1313
from . import gates
14+
from . import channels
1415
from .cons import backend, contractor, dtypestr, npdtype
1516
from .quantum import QuOperator, identity
1617
from .simplify import _full_light_cone_cancel
@@ -255,6 +256,7 @@ def depolarizing2(
255256
# roughly benchmark shows that performance of two depolarizing in terms of
256257
# building time and running time are similar
257258

259+
# overwritten now, deprecated
258260
def depolarizing(
259261
self,
260262
index: int,
@@ -414,6 +416,7 @@ def step_function(x: Tensor) -> Tensor:
414416

415417
if status is None:
416418
status = backend.implicit_randu()[0]
419+
status = backend.convert_to_tensor(status)
417420
status = backend.real(status)
418421
prob_cumsum = backend.cast(prob_cumsum, dtype=status.dtype) # type: ignore
419422
r = step_function(status)
@@ -576,6 +579,46 @@ def general_kraus(
576579

577580
apply_general_kraus = general_kraus
578581

582+
@staticmethod
583+
def apply_general_kraus_delayed(
584+
krausf: Callable[..., Sequence[Gate]]
585+
) -> Callable[..., None]:
586+
def apply(
587+
self: "Circuit",
588+
*index: int,
589+
status: Optional[float] = None,
590+
name: Optional[str] = None,
591+
**vars: float,
592+
) -> None:
593+
kraus = krausf(**vars)
594+
self.apply_general_kraus(kraus, *index, status=status, name=name)
595+
596+
return apply
597+
598+
@classmethod
599+
def _meta_apply_channels(cls) -> None:
600+
for k in channels.channels:
601+
setattr(
602+
cls,
603+
k,
604+
cls.apply_general_kraus_delayed(getattr(channels, k + "channel")),
605+
)
606+
doc = """
607+
Apply %s quantum channel on the circuit.
608+
See :py:meth:`tensorcircuit.channels.%schannel`
609+
610+
:param index: Qubit number that the gate applies on.
611+
:type index: int.
612+
:param status: uniform external random number between 0 and 1
613+
:type status: Tensor
614+
:param vars: Parameters for the channel.
615+
:type vars: float.
616+
""" % (
617+
k,
618+
k,
619+
)
620+
getattr(cls, k).__doc__ = doc
621+
579622
def is_valid(self) -> bool:
580623
"""
581624
[WIP], check whether the circuit is legal.
@@ -757,6 +800,7 @@ def expectation( # type: ignore
757800

758801

759802
Circuit._meta_apply()
803+
Circuit._meta_apply_channels()
760804

761805

762806
def expectation(

tensorcircuit/densitymatrix.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def apply_general_kraus_delayed(
203203
krausf: Callable[..., Sequence[Gate]]
204204
) -> Callable[..., None]:
205205
def apply(self: "DMCircuit", *index: int, **vars: float) -> None:
206+
for key in ["status", "name"]:
207+
if key in vars:
208+
del vars[key]
209+
# compatibility with circuit API
206210
kraus = krausf(**vars)
207211
self.apply_general_kraus(kraus, [index])
208212

@@ -339,6 +343,9 @@ def apply_general_kraus_delayed(
339343
krausf: Callable[..., Sequence[Gate]]
340344
) -> Callable[..., None]:
341345
def apply(self: "DMCircuit2", *index: int, **vars: float) -> None:
346+
for key in ["status", "name"]:
347+
if key in vars:
348+
del vars[key]
342349
kraus = krausf(**vars)
343350
self.apply_general_kraus(kraus, *index)
344351

tests/test_circuit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,3 +1152,12 @@ def test_sample_format(backend):
11521152
random_generator=key,
11531153
),
11541154
)
1155+
1156+
1157+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
1158+
def test_channel_auto_register(backend, highp):
1159+
c = tc.Circuit(2)
1160+
c.H(0)
1161+
c.reset(0, status=0.8)
1162+
s = c.state()
1163+
np.testing.assert_allclose(s[0], 1.0, atol=1e-9)

0 commit comments

Comments
 (0)