Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/unroll_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=invalid-name
# pylint: disable=invalid-name, cyclic-import

"""
Script demonstrating how to unroll a QASM 3 program using pyqasm.
Expand Down
161 changes: 161 additions & 0 deletions src/pyqasm/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import functools
from abc import ABC, abstractmethod
import re
from collections import Counter
from copy import deepcopy
from typing import Optional
Expand All @@ -36,6 +37,7 @@
from pyqasm.visitor import QasmVisitor, ScopeManager



def track_user_operation(func):
"""Decorator to track user operations on a QasmModule."""

Expand Down Expand Up @@ -761,3 +763,162 @@ def accept(self, visitor):
Args:
visitor (QasmVisitor): The visitor to accept
"""


def merge(self, other: "QasmModule", device_qubits: Optional[int] = None) -> "QasmModule":
"""Merge this module with another module into a single consolidated module.

Notes:
- Both modules are unrolled with consolidated qubit registers prior to merging.
- The resulting module has a single declaration: ``qubit[<total>] __PYQASM_QUBITS__``.
- All quantum operations from the second module are appended after the first, with
qubit indices offset by the size of the first module.

Args:
other (QasmModule): The module to merge with the current module.
device_qubits (int | None): Optional device qubit budget to use during unrolling.

Returns:
QasmModule: A new Qasm3Module representing the merged program.
"""

if not isinstance(other, QasmModule):
raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}")

# Normalize both modules to QASM3 form (without mutating originals)
from pyqasm.modules.qasm2 import Qasm2Module # pylint: disable=import-outside-toplevel
from pyqasm.modules.qasm3 import Qasm3Module # pylint: disable=import-outside-toplevel
left_mod = self.to_qasm3(as_str=False) if isinstance(self, Qasm2Module) else self.copy()
right_mod = other.to_qasm3(as_str=False) if isinstance(other, Qasm2Module) else other.copy()

# Unroll with qubit consolidation so both sides use __PYQASM_QUBITS__
unroll_kwargs: dict[str, object] = {"consolidate_qubits": True}
if device_qubits is not None:
unroll_kwargs["device_qubits"] = device_qubits

left_mod.unroll(**unroll_kwargs)
right_mod.unroll(**unroll_kwargs)

# Determine sizes after consolidation
left_qubits = left_mod.num_qubits
right_qubits = right_mod.num_qubits
total_qubits = left_qubits + right_qubits

# Build a new Program. We'll add includes (unique) first, then declaration and ops
merged_program = Program(statements=[], version="3.0")

# gets unique include filenames from both modules
# added this because we get duplicate File 'stdgates.inc' errors
include_names: list[str] = []
for module in (left_mod, right_mod):
for stmt in module.unrolled_ast.statements:
if isinstance(stmt, qasm3_ast.Include):
if stmt.filename not in include_names:
include_names.append(stmt.filename)
for inc_name in include_names:
merged_program.statements.append(qasm3_ast.Include(filename=inc_name))

# single consolidated qubit declaration
merged_qubit_decl = qasm3_ast.QubitDeclaration(
size=qasm3_ast.IntegerLiteral(value=total_qubits),
qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"),
)
merged_program.statements.append(merged_qubit_decl)

# Append left (self) statements, skipping its consolidated qubit declaration
for stmt in left_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)):
continue
merged_program.statements.append(deepcopy(stmt))

# Offsets indices inside a statement by a fixed amount to make sure we merge correctly
def _offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int):
if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement):
# Offset measured qubit source
bit = stmt.measure.qubit
if isinstance(bit, qasm3_ast.IndexedIdentifier):
for group in bit.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
# target is classical; leave untouched
return

if isinstance(stmt, qasm3_ast.QuantumGate):
# Offset all qubit operands
for q in stmt.qubits:
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumReset):
q = stmt.qubits
if isinstance(q, qasm3_ast.IndexedIdentifier):
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumBarrier):
# Barrier can be represented with IndexedIdentifier or a string slice on Identifier
qubits = stmt.qubits
if len(qubits) == 0:
return
first = qubits[0]
if isinstance(first, qasm3_ast.IndexedIdentifier):
for group in first.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
elif isinstance(first, qasm3_ast.Identifier):
# Handle forms: __PYQASM_QUBITS__[:E], [S:], [S:E]
name = first.name
if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"):
slice_str = name[len("__PYQASM_QUBITS__"):]
# Parse slice forms [S:E], [:E], or [S:] and capture optional start/end integers
m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str)
if m:
start_s, end_s = m.group(1), m.group(2)
if start_s is None and end_s is not None:
# [:E]
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[:{end_v}]"
elif start_s is not None and end_s is None:
# [S:]
start_v = int(start_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:]"
elif start_s is not None and end_s is not None:
# [S:E]
start_v = int(start_s) + offset
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]"
return

# Append statements with index offset, skipping its qubit declaration and include statements
for stmt in right_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)):
continue
stmt_copy = deepcopy(stmt)
_offset_statement_qubits(stmt_copy, left_qubits)
merged_program.statements.append(stmt_copy)

# Build merged module
merged_module = Qasm3Module(
name=f"{left_mod.name}_merged_{right_mod.name}",
program=merged_program,
)

# inputs already unrolled, we can set the unrolled AST directly
merged_module.unrolled_ast = Program(
statements=list(merged_program.statements),
version="3.0",
)

# Combine metadata/history in a straightforward manner
merged_module._external_gates = list(
{*left_mod._external_gates, *right_mod._external_gates}
)
merged_module._user_operations = list(left_mod.history) + list(right_mod.history)
merged_module._user_operations.append(f"merge(other={right_mod.name})")
merged_module.validate()

return merged_module
118 changes: 118 additions & 0 deletions tests/qasm3/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2025 qBraid
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Unit tests for QasmModule.merge().
"""

from pyqasm.entrypoint import loads
from pyqasm.modules import QasmModule


def _qasm3(qasm: str) -> QasmModule:
return loads(qasm)


def test_merge_basic_gates_and_offsets():
qasm_a = (
"OPENQASM 3.0;\n"
"include \"stdgates.inc\";\n"
"qubit[2] q;\n"
"x q[0];\n"
"cx q[0], q[1];\n"
)
qasm_b = (
"OPENQASM 3.0;\n"
"include \"stdgates.inc\";\n"
"qubit[3] r;\n"
"h r[0];\n"
"cx r[1], r[2];\n"
)

mod_a = _qasm3(qasm_a)
mod_b = _qasm3(qasm_b)

merged = mod_a.merge(mod_b)

# Unrolled representation should have a single consolidated qubit declaration of size 5
text = str(merged)
assert "qubit[5] __PYQASM_QUBITS__;" in text

lines = [l.strip() for l in text.splitlines() if l.strip()]
# Keep only gate lines for comparison; skip version/includes/declarations
gate_lines = [
l
for l in lines
if l[0].isalpha()
and not l.startswith("include")
and not l.startswith("OPENQASM")
and not l.startswith("qubit")
]
assert gate_lines[0].startswith("x __PYQASM_QUBITS__[0]")
assert gate_lines[1].startswith("cx __PYQASM_QUBITS__[0], __PYQASM_QUBITS__[1]")
assert any(l.startswith("h __PYQASM_QUBITS__[2]") for l in gate_lines)
assert any(l.startswith("cx __PYQASM_QUBITS__[3], __PYQASM_QUBITS__[4]") for l in gate_lines)


def test_merge_with_measurements_and_barriers():
# Module A: 1 qubit + classical 1; has barrier and measure
qasm_a = (
"OPENQASM 3.0;\n"
"include \"stdgates.inc\";\n"
"qubit[1] qa; bit[1] ca;\n"
"h qa[0];\n"
"barrier qa;\n"
"ca[0] = measure qa[0];\n"
)
# Module B: 2 qubits + classical 2
qasm_b = (
"OPENQASM 3.0;\n"
"include \"stdgates.inc\";\n"
"qubit[2] qb; bit[2] cb;\n"
"x qb[1];\n"
"cb[1] = measure qb[1];\n"
)

mod_a = _qasm3(qasm_a)
mod_b = _qasm3(qasm_b)

merged = mod_a.merge(mod_b)
merged_text = str(merged)

assert "qubit[3] __PYQASM_QUBITS__;" in merged_text
assert "measure __PYQASM_QUBITS__[2];" in merged_text
assert "barrier __PYQASM_QUBITS__" in merged_text


def test_merge_qasm2_with_qasm3():
qasm2 = (
"OPENQASM 2.0;\n"
"include \"qelib1.inc\";\n"
"qreg q[1];\n"
"h q[0];\n"
)
qasm3 = (
"OPENQASM 3.0;\n"
"include \"stdgates.inc\";\n"
"qubit[2] r;\n"
"x r[0];\n"
)

mod2 = loads(qasm2)
mod3 = loads(qasm3)

merged = mod2.merge(mod3)
text = str(merged)
assert "qubit[3] __PYQASM_QUBITS__;" in text
assert "x __PYQASM_QUBITS__[1];" in text
Loading