Skip to content

Commit b60faae

Browse files
JanLucaRoberto Losada
authored andcommitted
Refactor contraction to code to do not use custom_definitions anymore
1 parent bb3f2be commit b60faae

File tree

7 files changed

+94
-1356
lines changed

7 files changed

+94
-1356
lines changed

varipeps/contractions/apply.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,20 @@
22
Helpers to apply contractions.
33
"""
44

5-
import collections
65
from functools import partial
76

87
import jax
98
import jax.numpy as jnp
10-
import tensornetwork as tn
11-
from tensornetwork.ncon_interface import _jittable_ncon
129

1310
from varipeps.peps import PEPS_Tensor
14-
from varipeps import varipeps_config
15-
from varipeps.config import VariPEPS_Config
16-
from varipeps.utils.func_cache import Checkpointing_Cache
1711

1812
from .definitions import Definitions, Definition
1913

2014
from typing import Sequence, List, Tuple, Dict, Union, Optional
2115

2216

23-
class _Contraction_Cache:
24-
_cache = None
25-
26-
def __class_getitem__(cls, name: str) -> Checkpointing_Cache:
27-
name = f"_{name}"
28-
obj = getattr(cls, name)
29-
if obj is None:
30-
obj = Checkpointing_Cache(varipeps_config.checkpointing_ncon)
31-
setattr(cls, name, obj)
32-
return obj
33-
34-
35-
_ncon_jitted = jax.jit(_jittable_ncon, static_argnums=(1, 2, 3, 4, 5), inline=True)
36-
37-
3817
@partial(
39-
jax.jit, static_argnames=("name", "disable_identity_check", "custom_definition")
18+
jax.jit, static_argnames=("name", "disable_identity_check")
4019
)
4120
def apply_contraction(
4221
name: str,
@@ -45,8 +24,6 @@ def apply_contraction(
4524
additional_tensors: Sequence[jnp.ndarray],
4625
*,
4726
disable_identity_check: bool = True,
48-
custom_definition: Optional[Definition] = None,
49-
config: VariPEPS_Config = varipeps_config,
5027
) -> jnp.ndarray:
5128
"""
5229
Apply a contraction to a list of tensors.
@@ -69,12 +46,6 @@ def apply_contraction(
6946
disable_identity_check (:obj:`bool`):
7047
Disable the check if the tensor is identical to the one of the
7148
corresponding object.
72-
custom_definition (:obj:`~varipeps.contractions.apply.Definition`, optional):
73-
Use a custom definition for the contraction which is not defined in the
74-
:class:`varipeps.contractions.Definitions` class.
75-
config (:obj:`~varipeps.config.VariPEPS_Config`):
76-
Global configuration object of the variPEPS library. Please see its
77-
class definition for details.
7849
Returns:
7950
jax.numpy.ndarray:
8051
The contracted tensor.
@@ -97,10 +68,7 @@ class definition for details.
9768
"Sequence of PEPS tensors mismatch the objects sequence. Please check your code!"
9869
)
9970

100-
if custom_definition is not None:
101-
contraction = custom_definition
102-
else:
103-
contraction = getattr(Definitions, name)
71+
contraction = getattr(Definitions, name)
10472

10573
if len(contraction["filter_peps_tensors"]) != len(peps_tensors):
10674
raise ValueError(
@@ -136,17 +104,11 @@ class definition for details.
136104

137105
tensor_shapes = tuple(tuple(e.shape) for e in tensors)
138106

139-
path = contraction["einsum_cache"].get(tensor_shapes)
140-
141-
if path is None:
142-
path, _ = jnp.einsum_path(
143-
contraction["einsum_network"],
144-
*tensors,
145-
optimize="optimal" if len(tensors) < 10 else "dp",
146-
)
147-
contraction["einsum_cache"][tensor_shapes] = path
148-
149-
return jnp.einsum(contraction["einsum_network"], *tensors, optimize=path)
107+
return jnp.einsum(
108+
contraction["einsum_network"],
109+
*tensors,
110+
optimize="optimal" if len(tensors) < 10 else "dp",
111+
)
150112

151113

152114
apply_contraction_jitted = apply_contraction

varipeps/contractions/definitions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,22 @@ def _process_def(cls, e, name):
176176
e["network_additional_tensors"] = network_additional_tensors
177177
e["ncon_network"] = ncon_network
178178
e["einsum_network"] = einsum_network
179-
e["einsum_cache"] = dict()
180179

181180
@classmethod
182181
def _prepare_defs(cls):
183182
for name in dir(cls):
184-
if name.startswith("_"):
183+
if name == "add_def" or name.startswith("_"):
185184
continue
186185

187186
e = getattr(cls, name)
188187

189188
cls._process_def(e, name)
190189

190+
@classmethod
191+
def add_def(cls, name, definition):
192+
cls._process_def(definition, name)
193+
setattr(cls, name, definition)
194+
191195
density_matrix_one_site: Definition = {
192196
"tensors": [
193197
["tensor", "tensor_conj", "C1", "T1", "C2", "T2", "C3", "T3", "C4", "T4"]

varipeps/expectation/helpers.py

Lines changed: 8 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,7 @@ def partially_traced_four_site_density_matrices(
8888
]
8989
],
9090
}
91-
Definitions._process_def(
92-
contraction_top_left,
93-
(
94-
f"partially_traced_four_site_density_matrices_top_left_"
95-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_left_i}"
96-
),
97-
)
98-
setattr(
99-
Definitions,
91+
Definitions.add_def(
10092
(
10193
f"partially_traced_four_site_density_matrices_top_left_"
10294
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_left_i}"
@@ -147,15 +139,7 @@ def partially_traced_four_site_density_matrices(
147139
# ]
148140
],
149141
}
150-
Definitions._process_def(
151-
contraction_top_right,
152-
(
153-
f"partially_traced_four_site_density_matrices_top_right_"
154-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_right_i}"
155-
),
156-
)
157-
setattr(
158-
Definitions,
142+
Definitions.add_def(
159143
(
160144
f"partially_traced_four_site_density_matrices_top_right_"
161145
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_right_i}"
@@ -218,15 +202,7 @@ def partially_traced_four_site_density_matrices(
218202
# ]
219203
],
220204
}
221-
Definitions._process_def(
222-
contraction_bottom_left,
223-
(
224-
f"partially_traced_four_site_density_matrices_bottom_left_"
225-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_left_i}"
226-
),
227-
)
228-
setattr(
229-
Definitions,
205+
Definitions.add_def(
230206
(
231207
f"partially_traced_four_site_density_matrices_bottom_left_"
232208
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_left_i}"
@@ -279,15 +255,7 @@ def partially_traced_four_site_density_matrices(
279255
# ]
280256
],
281257
}
282-
Definitions._process_def(
283-
contraction_bottom_right,
284-
(
285-
f"partially_traced_four_site_density_matrices_bottom_right_"
286-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_right_i}"
287-
),
288-
)
289-
setattr(
290-
Definitions,
258+
Definitions.add_def(
291259
(
292260
f"partially_traced_four_site_density_matrices_bottom_right_"
293261
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_right_i}"
@@ -470,15 +438,7 @@ def partially_traced_horizontal_two_site_density_matrices(
470438
]
471439
],
472440
}
473-
Definitions._process_def(
474-
contraction_left,
475-
(
476-
f"partially_traced_horizontal_two_site_density_matrices_left_"
477-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{left_i}"
478-
),
479-
)
480-
setattr(
481-
Definitions,
441+
Definitions.add_def(
482442
(
483443
f"partially_traced_horizontal_two_site_density_matrices_left_"
484444
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{left_i}"
@@ -512,15 +472,7 @@ def partially_traced_horizontal_two_site_density_matrices(
512472
]
513473
],
514474
}
515-
Definitions._process_def(
516-
contraction_right,
517-
(
518-
f"partially_traced_horizontal_two_site_density_matrices_right_"
519-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{right_i}"
520-
),
521-
)
522-
setattr(
523-
Definitions,
475+
Definitions.add_def(
524476
(
525477
f"partially_traced_horizontal_two_site_density_matrices_right_"
526478
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{right_i}"
@@ -654,15 +606,7 @@ def partially_traced_vertical_two_site_density_matrices(
654606
]
655607
],
656608
}
657-
Definitions._process_def(
658-
contraction_top,
659-
(
660-
f"partially_traced_vertical_two_site_density_matrices_top_"
661-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_i}"
662-
),
663-
)
664-
setattr(
665-
Definitions,
609+
Definitions.add_def(
666610
(
667611
f"partially_traced_vertical_two_site_density_matrices_top_"
668612
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_i}"
@@ -696,15 +640,7 @@ def partially_traced_vertical_two_site_density_matrices(
696640
]
697641
],
698642
}
699-
Definitions._process_def(
700-
contraction_bottom,
701-
(
702-
f"partially_traced_vertical_two_site_density_matrices_bottom_"
703-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_i}"
704-
),
705-
)
706-
setattr(
707-
Definitions,
643+
Definitions.add_def(
708644
(
709645
f"partially_traced_vertical_two_site_density_matrices_bottom_"
710646
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_i}"

varipeps/expectation/triangular_helpers.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,7 @@ def partially_traced_vertical_two_site_density_matrices_triangular(
8686
],
8787
],
8888
}
89-
Definitions._process_def(
90-
contraction_top,
91-
(
92-
f"partially_traced_vertical_two_site_density_matrices_triangular_top_"
93-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_i}"
94-
),
95-
)
96-
setattr(
97-
Definitions,
89+
Definitions.add_def(
9890
(
9991
f"partially_traced_vertical_two_site_density_matrices_triangular_top_"
10092
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_i}"
@@ -129,15 +121,7 @@ def partially_traced_vertical_two_site_density_matrices_triangular(
129121
],
130122
],
131123
}
132-
Definitions._process_def(
133-
contraction_bottom,
134-
(
135-
f"partially_traced_vertical_two_site_density_matrices_triangular_bottom_"
136-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_i}"
137-
),
138-
)
139-
setattr(
140-
Definitions,
124+
Definitions.add_def(
141125
(
142126
f"partially_traced_vertical_two_site_density_matrices_triangular_bottom_"
143127
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_i}"
@@ -271,15 +255,7 @@ def partially_traced_horizontal_two_site_density_matrices_triangular(
271255
],
272256
],
273257
}
274-
Definitions._process_def(
275-
contraction_left,
276-
(
277-
f"partially_traced_horizontal_two_site_density_matrices_triangular_left_"
278-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{left_i}"
279-
),
280-
)
281-
setattr(
282-
Definitions,
258+
Definitions.add_def(
283259
(
284260
f"partially_traced_horizontal_two_site_density_matrices_triangular_left_"
285261
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{left_i}"
@@ -314,15 +290,7 @@ def partially_traced_horizontal_two_site_density_matrices_triangular(
314290
],
315291
],
316292
}
317-
Definitions._process_def(
318-
contraction_right,
319-
(
320-
f"partially_traced_horizontal_two_site_density_matrices_triangular_right_"
321-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{right_i}"
322-
),
323-
)
324-
setattr(
325-
Definitions,
293+
Definitions.add_def(
326294
(
327295
f"partially_traced_horizontal_two_site_density_matrices_triangular_right_"
328296
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{right_i}"
@@ -456,15 +424,7 @@ def partially_traced_diagonal_two_site_density_matrices_triangular(
456424
],
457425
],
458426
}
459-
Definitions._process_def(
460-
contraction_top,
461-
(
462-
f"partially_traced_diagonal_two_site_density_matrices_triangular_top_"
463-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_i}"
464-
),
465-
)
466-
setattr(
467-
Definitions,
427+
Definitions.add_def(
468428
(
469429
f"partially_traced_diagonal_two_site_density_matrices_triangular_top_"
470430
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{top_i}"
@@ -499,15 +459,7 @@ def partially_traced_diagonal_two_site_density_matrices_triangular(
499459
],
500460
],
501461
}
502-
Definitions._process_def(
503-
contraction_bottom,
504-
(
505-
f"partially_traced_diagonal_two_site_density_matrices_triangular_bottom_"
506-
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_i}"
507-
),
508-
)
509-
setattr(
510-
Definitions,
462+
Definitions.add_def(
511463
(
512464
f"partially_traced_diagonal_two_site_density_matrices_triangular_bottom_"
513465
f"{real_physical_dimension}_{num_coarse_grained_physical_indices}_{bottom_i}"

0 commit comments

Comments
 (0)