Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion tket/src/serialize/pytket/decoder/wires.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use hugr::extension::prelude::{bool_t, qb_t};
use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::Value;
use hugr::std_extensions::arithmetic::float_types::{ConstF64, float64_type};
use hugr::types::Type;
use hugr::types::{Type, TypeEnum};
use hugr::{Hugr, IncomingPort, Node, Wire};
use indexmap::{IndexMap, IndexSet};
use itertools::Itertools;
Expand Down Expand Up @@ -715,6 +715,9 @@ impl WireTracker {
_ if ty == &bool_t() || ty == &bool_type() => {
self.initialize_bit_wire(builder, bit_args[0].clone())?
}
_ if matches!(ty.as_type_enum(), TypeEnum::Sum(sum) if sum.as_tuple().is_some()) => {
return self.find_tuple_wire(config, builder, ty, qubit_args, bit_args, params);
}
_ => {
return Err(PytketDecodeErrorInner::NoMatchingWire {
ty: ty.to_string(),
Expand Down Expand Up @@ -770,6 +773,95 @@ impl WireTracker {
}
}

/// Build a tuple wire from its tracked element wires when no matching
/// aggregate wire exists.
///
/// Pytket passes may preserve an opaque barrier while presenting its qubit
/// arguments in an order that does not match any aggregate tuple wire
/// already tracked from the original HUGR. In that case, we can rebuild the
/// tuple explicitly from the individual decoded wires.
fn find_tuple_wire(
&mut self,
config: &PytketDecoderConfig,
builder: &mut DFGBuilder<&mut Hugr>,
ty: &Type,
qubit_args: &mut &[TrackedQubit],
bit_args: &mut &[TrackedBit],
params: &mut &[LoadedParameter],
) -> Result<FoundWire, PytketDecodeError> {
let TypeEnum::Sum(sum) = ty.as_type_enum() else {
unreachable!("find_tuple_wire called with non-sum type");
};
let Some(tuple) = sum.as_tuple() else {
unreachable!("find_tuple_wire called with non-tuple sum type");
};

let mut tuple_qubits = *qubit_args;
let mut tuple_bits = *bit_args;
let mut tuple_params = *params;
let mut element_wires = Vec::with_capacity(tuple.len());
for elem_ty in tuple.iter() {
let elem_ty: Type = elem_ty.clone().try_into().map_err(|_| {
PytketDecodeErrorInner::NoMatchingWire {
ty: ty.to_string(),
qubit_args: qubit_args
.iter()
.map(|q| q.pytket_register().to_string())
.collect(),
bit_args: bit_args
.iter()
.map(|bit| bit.pytket_register().to_string())
.collect(),
}
.wrap()
})?;
let FoundWire::Register(wire) = self.find_typed_wire(
config,
builder,
&elem_ty,
&mut tuple_qubits,
&mut tuple_bits,
&mut tuple_params,
None,
)?
else {
return Err(PytketDecodeErrorInner::NoMatchingWire {
ty: ty.to_string(),
qubit_args: qubit_args
.iter()
.map(|q| q.pytket_register().to_string())
.collect(),
bit_args: bit_args
.iter()
.map(|bit| bit.pytket_register().to_string())
.collect(),
}
.wrap());
};
element_wires.push(wire.wire());
}

let reg_count = config
.type_to_pytket(ty)
.expect("tuple fallback requires a pytket-representable type");
let wire_qubits = qubit_args
.iter()
.take(reg_count.qubits)
.cloned()
.collect_vec();
let wire_bits = bit_args.iter().take(reg_count.bits).cloned().collect_vec();
let tuple_wire = builder
.make_tuple(element_wires)
.map_err(PytketDecodeError::custom)?;
self.track_wire(tuple_wire, Arc::new(ty.clone()), wire_qubits, wire_bits)?;

*qubit_args = tuple_qubits;
*bit_args = tuple_bits;
*params = tuple_params;

Ok(FoundWire::Register(self.wires[&tuple_wire].clone()))
}

/// Returns a new [TrackedWires] set for a list of [`TrackedQubit`]s,
/// [`TrackedBit`]s, and [`LoadedParameter`]s following the required types.
///
Expand Down
29 changes: 27 additions & 2 deletions tket/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod value_tracker;
use hugr::core::HugrNode;
use hugr_core::hugr::internal::PortgraphNodeMap;
use tket_json_rs::clexpr::InputClRegister;
use tket_json_rs::clexpr::operator::{ClArgument, ClOperator, ClTerminal, ClVariable};
use tket_json_rs::opbox::BoxID;
pub use value_tracker::{
TrackedBit, TrackedParam, TrackedQubit, TrackedValue, TrackedValues, ValueTracker,
Expand Down Expand Up @@ -1465,10 +1466,15 @@ pub fn make_tk1_classical_expression(
bit_count: usize,
output_bits: &[u32],
registers: &[InputClRegister],
expression: tket_json_rs::clexpr::operator::ClOperator,
expression: ClOperator,
) -> tket_json_rs::circuit_json::Operation {
let mut bit_vars = Vec::new();
collect_clexpr_bit_vars(&expression, &mut bit_vars);
bit_vars.sort_unstable();
bit_vars.dedup();

let mut clexpr = tket_json_rs::clexpr::ClExpr::default();
clexpr.bit_posn = (0..bit_count as u32).map(|i| (i, i)).collect();
clexpr.bit_posn = bit_vars.into_iter().map(|i| (i, i)).collect();
clexpr.reg_posn = registers.to_vec();
clexpr.output_posn = tket_json_rs::clexpr::ClRegisterBits(output_bits.to_vec());
clexpr.expr = expression;
Expand All @@ -1482,3 +1488,22 @@ pub fn make_tk1_classical_expression(
op.classical_expr = Some(clexpr);
op
}

/// Collect the local bit variables referenced by a classical expression.
///
/// `ClExpr::bit_posn` describes variables used by the expression tree. Fresh
/// output-only bits are listed in `output_posn`, but must not be declared as
/// expression variables.
fn collect_clexpr_bit_vars(expression: &ClOperator, bit_vars: &mut Vec<u32>) {
for arg in &expression.args {
match arg {
ClArgument::Terminal(ClTerminal::Variable(ClVariable::Bit { index })) => {
bit_vars.push(*index);
}
ClArgument::Expression(expression) => {
collect_clexpr_bit_vars(expression, bit_vars);
}
_ => {}
}
}
}
11 changes: 11 additions & 0 deletions tket/src/serialize/pytket/extension/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ mod tests {
// Args layout for a BoolOp: [input_bits..., output_bit]
// The output bit must use a fresh register, not one of the input bits.
assert_eq!(clexpr_cmd.args.len(), num_inputs + 1);
let clexpr = clexpr_cmd
.op
.classical_expr
.as_ref()
.expect("ClExpr command must include expression data");
assert_eq!(
clexpr.bit_posn,
(0..num_inputs as u32).map(|i| (i, i)).collect_vec()
);
assert_eq!(clexpr.output_posn.0, vec![num_inputs as u32]);

let input_args = &clexpr_cmd.args[..num_inputs];
let output_args = &clexpr_cmd.args[num_inputs..];
for output_arg in output_args {
Expand Down
11 changes: 11 additions & 0 deletions tket/src/serialize/pytket/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ impl<H: HugrView> PytketEmitter<H> for PreludeEmitter {
return self.tuple_op_to_pytket(node, op, &tuple_op, hugr, encoder);
};
if let Ok(_barrier) = BarrierDef::from_extension_op(op) {
// Check if the barrier has encodable types in its signature.
// If not, fallback to marking it as unsupported.
if hugr.signature(node).is_none_or(|sig| {
sig.input()
.iter()
.chain(sig.output().iter())
.any(|ty| encoder.config().type_to_pytket(ty).is_none())
}) {
return Ok(EncodeStatus::Unsupported);
}

encoder.emit_node(
PytketOptype::Barrier,
node,
Expand Down
68 changes: 68 additions & 0 deletions tket/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,36 @@ fn circ_complex_param_type() -> Hugr {
h.finish_hugr_with_outputs([float_tuple]).unwrap()
}

/// A prelude barrier carrying one unsupported value next to a qubit.
///
/// The barrier must be encoded as an opaque subgraph; trying to emit it as a
/// native pytket barrier would require pytket register values for the
/// unsupported tuple wire.
#[fixture]
fn circ_barrier_with_unsupported_value() -> Hugr {
let tuple_float_t = Type::from(SumType::new_tuple(vec![float64_type()]));
let input_t = vec![qb_t()];
let output_t = vec![qb_t()];
let mut h = FunctionBuilder::new(
"barrier_with_unsupported_value",
Signature::new(input_t, output_t),
)
.unwrap();
let [q] = h.input_wires_arr();

let float = h.add_load_value(ConstF64::new(1.0));
let tuple = h.make_tuple([float]).unwrap();
let [q, _tuple] = h
.add_dataflow_op(
hugr::extension::prelude::Barrier::new([qb_t(), tuple_float_t]),
[q, tuple],
)
.unwrap()
.outputs_arr();

h.finish_hugr_with_outputs([q]).unwrap()
}

/// A circuit with an unsupported subgraph whose first output is not exposed as
/// a pytket parameter, followed by a float output used by a supported gate.
///
Expand Down Expand Up @@ -961,6 +991,29 @@ fn json_file_roundtrip(#[case] circ: impl AsRef<std::path::Path>) {
compare_serial_circs(&ser, &reser);
}

#[test]
fn decode_tuple_output_from_permuted_barrier_args() {
let ser: circuit_json::SerialCircuit = serde_json::from_str(
r#"{
"phase": "0",
"bits": [],
"qubits": [["q", [0]], ["q", [1]]],
"commands": [
{"args": [["q", [1]], ["q", [0]]], "op": {"type": "Barrier"}}
],
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]]
}"#,
)
.unwrap();

let tuple_qubits = Type::from(SumType::new_tuple(vec![qb_t(), qb_t()]));
let hugr = ser
.decode(DecodeOptions::new().with_signature(Signature::new(vec![], vec![tuple_qubits])))
.unwrap();

hugr.validate().unwrap();
}

/// Test parameter to select which decoders/encoders to enable.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CircuitRoundtripTestConfig {
Expand Down Expand Up @@ -1111,6 +1164,21 @@ fn reject_standalone_complex_subgraphs(#[case] hugr: Hugr) {
);
}

#[rstest]
fn unsupported_prelude_barrier_is_encoded_as_opaque_subgraph(
circ_barrier_with_unsupported_value: Hugr,
) {
let ser =
SerialCircuit::encode(&circ_barrier_with_unsupported_value, EncodeOptions::new()).unwrap();

validate_serial_circ(&ser);
assert!(
ser.commands
.iter()
.any(|cmd| { cmd.op.op_type == optype::OpType::Barrier && cmd.op.data.is_some() })
);
}

/// Test that modifying the hugr before reassembling an EncodedCircuit fails.
#[rstest]
fn fail_on_modified_hugr(circ_tk1_ops: Hugr) {
Expand Down
Loading