diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 31493a0548..d3ff039eb4 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -481,8 +481,8 @@ pub(crate) mod test { use crate::metadata::Metadata; use crate::ops::{FuncDecl, FuncDefn, OpParent, OpTag, OpTrait, Value, handle::NodeHandle}; use crate::std_extensions::logic::test::and_op; - use crate::types::type_param::TypeParam; - use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV}; + use crate::types::type_param::{TermTypeError, TypeParam}; + use crate::types::{EdgeKind, FuncValueType, Signature, Term, Type, TypeBound, TypeRowRV}; use crate::utils::test_quantum_extension::h_gate; use crate::{Wire, builder::test::n_identity, type_row}; @@ -926,7 +926,7 @@ pub(crate) mod test { #[test] fn no_outer_row_variables() -> Result<(), BuildError> { let e = crate::hugr::validate::test::extension_with_eval_parallel(); - let tv = TypeRV::new_row_var_use(0, TypeBound::Copyable); + let rv = TypeRowRV::new_var_use(0, TypeBound::Copyable); // Can *declare* a function that takes a function-value of unknown #args FunctionBuilder::new( "bad_eval", @@ -935,23 +935,34 @@ pub(crate) mod test { Signature::new( [Type::new_function(FuncValueType::new( [usize_t()], - [tv.clone()], + rv.clone(), ))], [], ), ), )?; - + let rv: Term = rv.into(); // But cannot eval it... + let ev = e.instantiate_extension_op("eval", [Term::new_list([usize_t()]), rv.clone()]); + assert_eq!( + ev, + Err(SignatureError::TypeArgMismatch( + TermTypeError::InvalidValue(Box::new(rv.clone())) + )) + ); + let ev = e.instantiate_extension_op( "eval", - [vec![usize_t().into()].into(), vec![tv.into()].into()], + [Term::new_list([usize_t()]), Term::new_list([rv.clone()])], ); assert_eq!( ev, - Err(SignatureError::RowVarWhereTypeExpected { - var: RowVariable(0, TypeBound::Copyable) - }) + Err(SignatureError::TypeArgMismatch( + TermTypeError::TypeMismatch { + term: Box::new(rv), + type_: Box::new(TypeBound::Linear.into()) + } + )) ); Ok(()) } diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 5131378071..ffd82997ce 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -306,7 +306,7 @@ mod test { use cool_asserts::assert_matches; use crate::builder::test::dfg_calling_defn_decl; - use crate::builder::{Dataflow, DataflowSubContainer, test::n_identity}; + use crate::builder::{Dataflow, DataflowSubContainer}; use crate::extension::prelude::usize_t; use crate::{hugr::linking::NodeLinkingDirective, ops::OpType, types::Signature}; @@ -331,29 +331,6 @@ mod test { Ok(()) } - #[test] - #[ignore] // https://github.com/Quantinuum/hugr/issues/2828 - fn simple_alias() -> Result<(), BuildError> { - let build_result = { - let mut module_builder = ModuleBuilder::new(); - - let qubit_state_type = - module_builder.add_alias_declare("qubit_state", TypeBound::Linear)?; - - let f_build = module_builder.define_function( - "main", - Signature::new( - vec![qubit_state_type.get_alias_type()], - vec![qubit_state_type.get_alias_type()], - ), - )?; - n_identity(f_build)?; - module_builder.finish_hugr() - }; - assert_matches!(build_result, Ok(_)); - Ok(()) - } - #[test] fn builder_from_existing() -> Result<(), BuildError> { let hugr = Hugr::new(); diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index d76bccbc8b..7266322a74 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -2,7 +2,7 @@ use crate::Visibility; use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; -use crate::types::type_param::Term; +use crate::types::{FuncTypeBase, PolyFuncTypeBase, TypeRowLike}; use crate::{ Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port, extension::{ExtensionId, OpDef, SignatureFunc}, @@ -14,10 +14,8 @@ use crate::{ arithmetic::{float_types::ConstF64, int_types::ConstInt}, collections::array::ArrayValue, }, - types::{ - CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase, - }, + types::type_param::{Term, TermVar}, + types::{CustomType, EdgeKind, SumType, Type, TypeBound, TypeRow}, }; use hugr_model::v0::bumpalo; @@ -342,10 +340,12 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let symbol_name = this.export_func_name(node, &mut meta); + let sig = func.signature(); let symbol = this.export_poly_func_type( symbol_name, Some(func.visibility().clone().into()), - func.signature(), + sig, + Self::export_type_row, ); regions = this.bump.alloc_slice_copy(&[this.export_dfg( node, @@ -358,11 +358,12 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let symbol_name = this.export_func_name(node, &mut meta); - + let sig = func.signature(); let symbol = this.export_poly_func_type( symbol_name, Some(func.visibility().clone().into()), - func.signature(), + sig, + Self::export_type_row, ); table::Operation::DeclareFunc(symbol) }), @@ -507,7 +508,7 @@ impl<'a> Context<'a> { Some(signature) => { let num_inputs = signature.input_types().len(); let num_outputs = signature.output_types().len(); - let signature = self.export_func_type(signature); + let signature = self.export_func_type(signature, Self::export_type_row); (Some(signature), num_inputs, num_outputs) } None => (None, 0, 0), @@ -559,7 +560,9 @@ impl<'a> Context<'a> { let symbol = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); - this.export_poly_func_type(name, None, poly_func_type) + this.export_poly_func_type(name, None, poly_func_type, |this, trv| { + this.export_term(trv, None) + }) }); let meta = { @@ -816,11 +819,12 @@ impl<'a> Context<'a> { } /// Exports a polymorphic function type. - pub fn export_poly_func_type( + pub fn export_poly_func_type( &mut self, name: &'a str, visibility: Option, - t: &PolyFuncTypeBase, + t: &PolyFuncTypeBase, + export_io: impl FnMut(&mut Self, &T) -> table::TermId, ) -> &'a table::Symbol<'a> { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); let scope = self @@ -835,7 +839,7 @@ impl<'a> Context<'a> { } let constraints = self.bump.alloc_slice_copy(&self.local_constraints); - let body = self.export_func_type(t.body()); + let body = self.export_func_type(t.body(), export_io); self.bump.alloc(table::Symbol { visibility, @@ -846,30 +850,18 @@ impl<'a> Context<'a> { }) } - pub fn export_type(&mut self, t: &TypeBase) -> table::TermId { - self.export_type_enum(t.as_type_enum()) - } - - pub fn export_type_enum(&mut self, t: &TypeEnum) -> table::TermId { - match t { - TypeEnum::Extension(ext) => self.export_custom_type(ext), - TypeEnum::Alias(alias) => { - let symbol = self.resolve_symbol(self.bump.alloc_str(alias.name())); - self.make_term(table::Term::Apply(symbol, &[])) - } - TypeEnum::Function(func) => self.export_func_type(func), - TypeEnum::Variable(index, _) => { - let node = self.local_scope.expect("local variable out of scope"); - self.make_term(table::Term::Var(table::VarId(node, *index as _))) - } - TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), - TypeEnum::Sum(sum) => self.export_sum_type(sum), - } + pub fn export_type(&mut self, t: &Type) -> table::TermId { + self.export_term(t, None) } - pub fn export_func_type(&mut self, t: &FuncTypeBase) -> table::TermId { - let inputs = self.export_type_row(t.input()); - let outputs = self.export_type_row(t.output()); + pub fn export_func_type( + &mut self, + t: &FuncTypeBase, + mut export_io: impl FnMut(&mut Self, &T) -> table::TermId, + ) -> table::TermId { + let inputs = export_io(self, t.input()); + let outputs = export_io(self, t.output()); + // To use CORE_FN here, the input/output should each be a core List or ListConcat self.make_term_apply(model::CORE_FN, &[inputs, outputs]) } @@ -888,11 +880,6 @@ impl<'a> Context<'a> { self.make_term(table::Term::Var(table::VarId(node, var.index() as _))) } - pub fn export_row_var(&mut self, t: &RowVariable) -> table::TermId { - let node = self.local_scope.expect("local variable out of scope"); - self.make_term(table::Term::Var(table::VarId(node, t.0 as _))) - } - pub fn export_sum_variants(&mut self, t: &SumType) -> table::TermId { match t { SumType::Unit { size } => { @@ -905,7 +892,7 @@ impl<'a> Context<'a> { SumType::General { rows } => { let parts = self.bump.alloc_slice_fill_iter( rows.iter() - .map(|row| table::SeqPart::Item(self.export_type_row(row))), + .map(|row| table::SeqPart::Item(self.export_term(row, None))), ); self.make_term(table::Term::List(parts)) } @@ -918,27 +905,20 @@ impl<'a> Context<'a> { } #[inline] - pub fn export_type_row(&mut self, row: &TypeRowBase) -> table::TermId { + pub fn export_type_row(&mut self, row: &TypeRow) -> table::TermId { self.export_type_row_with_tail(row, None) } - pub fn export_type_row_with_tail( + pub fn export_type_row_with_tail( &mut self, - row: &TypeRowBase, + row: &TypeRow, tail: Option, ) -> table::TermId { let mut parts = BumpVec::with_capacity_in(row.len() + usize::from(tail.is_some()), self.bump); for t in row.iter() { - match t.as_type_enum() { - TypeEnum::RowVar(var) => { - parts.push(table::SeqPart::Splice(self.export_row_var(var.as_rv()))); - } - _ => { - parts.push(table::SeqPart::Item(self.export_type(t))); - } - } + parts.push(table::SeqPart::Item(self.export_type(t))); } if let Some(tail) = tail { @@ -982,7 +962,12 @@ impl<'a> Context<'a> { let item_types = self.export_term(item_types, None); self.make_term_apply(model::CORE_TUPLE_TYPE, &[item_types]) } - Term::Runtime(ty) => self.export_type(ty), + Term::RuntimeExtension(ext) => self.export_custom_type(ext), + Term::RuntimeFunction(func) => { + self.export_func_type(func, |this, trv| this.export_term(trv, None)) + } + Term::RuntimeSum(sum) => self.export_sum_type(sum), + Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()), Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()), Term::Float(value) => self.make_term(model::Literal::Float(*value).into()), diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 3ba9650b25..504135ce8f 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -125,7 +125,6 @@ use thiserror::Error; use crate::hugr::IdentList; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; -use crate::types::RowVariable; use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; @@ -518,9 +517,6 @@ pub enum SignatureError { /// A type variable that was used has not been declared #[error("Type variable {idx} was not declared ({num_decls} in scope)")] FreeTypeVar { idx: usize, num_decls: usize }, - /// A row variable was found outside of a variable-length row - #[error("Expected a single type, but found row variable {var}")] - RowVarWhereTypeExpected { var: RowVariable }, /// The result of the type application stored in a [Call] /// is not what we get by applying the type-args to the polymorphic function /// diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 0cd7f508fd..540bc2dea2 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -614,7 +614,7 @@ pub(super) mod test { use crate::package::Package; use crate::std_extensions::collections::list; use crate::types::type_param::{TermTypeError, TypeParam}; - use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV}; + use crate::types::{PolyFuncTypeRV, Signature, Term, Type, TypeArg, TypeBound}; use crate::{Extension, const_extension_ids}; const_extension_ids! { @@ -862,7 +862,7 @@ pub(super) mod test { def.validate_args(&args, &decls).unwrap(); assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo([tv]))); // But not with an external row variable - let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); + let arg: TypeArg = Term::new_row_var_use(0, TypeBound::Copyable); assert_eq!( def.compute_signature(std::slice::from_ref(&arg)), Err(SignatureError::TypeArgMismatch( diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index cfd09b30d1..0f793247f9 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -18,7 +18,7 @@ use crate::ops::{NamedOp, Value}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Term, Type, - TypeBound, TypeName, TypeRV, TypeRow, TypeRowRV, + TypeBound, TypeName, TypeRow, TypeRowRV, }; use crate::utils::sorted_consts; use crate::{Extension, type_row}; @@ -107,23 +107,22 @@ pub static PRELUDE: LazyLock> = LazyLock::new(|| { extension_ref, ) .unwrap(); + let panic_exit_sig = PolyFuncTypeRV::new( + [ + TypeParam::new_list_type(TypeBound::Linear), + TypeParam::new_list_type(TypeBound::Linear), + ], + FuncValueType::new( + TypeRowRV::from([Type::new_extension(error_type.clone())]) + .concat(TypeRowRV::new_var_use(0, TypeBound::Linear)), + TypeRowRV::new_var_use(1, TypeBound::Linear), + ), + ); prelude .add_op( PANIC_OP_ID, "Panic with input error".to_string(), - PolyFuncTypeRV::new( - [ - TypeParam::new_list_type(TypeBound::Linear), - TypeParam::new_list_type(TypeBound::Linear), - ], - FuncValueType::new( - vec![ - TypeRV::new_extension(error_type.clone()), - TypeRV::new_row_var_use(0, TypeBound::Linear), - ], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], - ), - ), + panic_exit_sig.clone(), extension_ref, ) .unwrap(); @@ -131,19 +130,7 @@ pub static PRELUDE: LazyLock> = LazyLock::new(|| { .add_op( EXIT_OP_ID, "Exit with input error".to_string(), - PolyFuncTypeRV::new( - [ - TypeParam::new_list_type(TypeBound::Linear), - TypeParam::new_list_type(TypeBound::Linear), - ], - FuncValueType::new( - vec![ - TypeRV::new_extension(error_type), - TypeRV::new_row_var_use(0, TypeBound::Linear), - ], - vec![TypeRV::new_row_var_use(1, TypeBound::Linear)], - ), - ), + panic_exit_sig, extension_ref, ) .unwrap(); @@ -334,7 +321,7 @@ pub fn either_type(ty_left: impl Into, ty_right: impl Into /// A constant optional value with a given value. /// -/// See [`option_type`]. +/// See [`SumType::new_option`]. #[must_use] pub fn const_some(value: Value) -> Value { const_some_tuple([value]) @@ -344,7 +331,7 @@ pub fn const_some(value: Value) -> Value { /// /// For single values, use [`const_some`]. /// -/// See [`option_type`]. +/// See [`SumType::new_option`]. pub fn const_some_tuple(values: impl IntoIterator) -> Value { const_right_tuple(TypeRow::new(), values) } @@ -375,11 +362,7 @@ pub fn const_left_tuple( ty_right: impl Into, ) -> Value { let values = values.into_iter().collect_vec(); - let types: TypeRowRV = values - .iter() - .map(|v| TypeRV::from(v.get_type())) - .collect_vec() - .into(); + let types: TypeRowRV = values.iter().map(Value::get_type).collect(); let typ = either_type(types, ty_right); Value::sum(0, values, typ).unwrap() } @@ -403,11 +386,7 @@ pub fn const_right_tuple( values: impl IntoIterator, ) -> Value { let values = values.into_iter().collect_vec(); - let types: TypeRowRV = values - .iter() - .map(|v| TypeRV::from(v.get_type())) - .collect_vec() - .into(); + let types: TypeRowRV = values.iter().map(Value::get_type).collect(); let typ = either_type(ty_left, types); Value::sum(1, values, typ).unwrap() } @@ -642,16 +621,16 @@ impl MakeOpDef for TupleOpDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - let rv = TypeRV::new_row_var_use(0, TypeBound::Linear); - let tuple_type = TypeRV::new_tuple(vec![rv.clone()]); + let rv = TypeRowRV::new_var_use(0, TypeBound::Linear); + let tuple_type = Type::new_tuple(rv.clone()); let param = TypeParam::new_list_type(TypeBound::Linear); match self { TupleOpDef::MakeTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new([rv], [tuple_type])) + PolyFuncTypeRV::new([param], FuncValueType::new(rv, [tuple_type])) } TupleOpDef::UnpackTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new([tuple_type], [rv])) + PolyFuncTypeRV::new([param], FuncValueType::new([tuple_type], rv)) } } .into() @@ -711,18 +690,11 @@ impl MakeExtensionOp for MakeTuple { let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self(tys?.into())) + Ok(Self(elems.clone().try_into()?)) } fn type_args(&self) -> Vec { - vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] + vec![self.0.clone().into()] } } @@ -766,18 +738,11 @@ impl MakeExtensionOp for UnpackTuple { let [Term::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - Term::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self(tys?.into())) + Ok(Self(elems.clone().try_into()?)) } fn type_args(&self) -> Vec { - vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] + vec![self.0.clone().into()] } } @@ -881,10 +846,10 @@ impl MakeExtensionOp for Noop { Self: Sized, { let _def = NoopDef::from_def(ext_op.def())?; - let [TypeArg::Runtime(ty)] = ext_op.args() else { + let [t] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - Ok(Self(ty.clone())) + Ok(Self(t.clone().try_into()?)) } fn type_args(&self) -> Vec { @@ -929,7 +894,7 @@ impl MakeOpDef for BarrierDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { PolyFuncTypeRV::new( vec![TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo([TypeRV::new_row_var_use(0, TypeBound::Linear)]), + FuncValueType::new_endo(TypeRowRV::new_var_use(0, TypeBound::Linear)), ) .into() } @@ -990,22 +955,12 @@ impl MakeExtensionOp for Barrier { let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; - let tys: Result, _> = elems - .iter() - .map(|a| match a { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs), - }) - .collect(); - Ok(Self { - type_row: tys?.into(), - }) + let type_row = elems.clone().try_into()?; + Ok(Self { type_row }) } fn type_args(&self) -> Vec { - vec![TypeArg::new_list( - self.type_row.iter().map(|t| t.clone().into()), - )] + vec![self.type_row.clone().into()] } } @@ -1150,7 +1105,7 @@ mod test { let err = b.add_load_value(error_val); let op = PRELUDE - .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) + .instantiate_extension_op(&EXIT_OP_ID, [Term::EMPTY_LIST, Term::EMPTY_LIST]) .unwrap(); b.add_dataflow_op(op, [err]).unwrap(); @@ -1165,7 +1120,7 @@ mod test { .instantiate_extension_op(&MAKE_ERROR_OP_ID, []) .unwrap(); let panic_op = PRELUDE - .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) + .instantiate_extension_op(&EXIT_OP_ID, [Term::EMPTY_LIST, Term::EMPTY_LIST]) .unwrap(); let mut b = diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index f73b5ce600..4b78b5734e 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -3,11 +3,14 @@ use std::iter; use crate::{ Wire, builder::{BuildError, BuildHandle, Dataflow, DataflowSubContainer, SubContainer}, - extension::prelude::{ConstError, PANIC_OP_ID}, + extension::{ + SignatureError, + prelude::{ConstError, PANIC_OP_ID}, + }, ops::handle::DataflowOpID, types::{SumType, Type, TypeArg, TypeRow}, }; -use itertools::{Itertools as _, zip_eq}; +use itertools::zip_eq; use super::PRELUDE; @@ -21,16 +24,8 @@ pub trait UnwrapBuilder: Dataflow { inputs: impl IntoIterator, ) -> Result, BuildError> { let (input_wires, input_types): (Vec<_>, Vec<_>) = inputs.into_iter().unzip(); - let input_arg: TypeArg = input_types - .into_iter() - .map(>::from) - .collect_vec() - .into(); - let output_arg: TypeArg = output_row - .into_iter() - .map(>::from) - .collect_vec() - .into(); + let input_arg = TypeArg::new_list(input_types); + let output_arg = TypeArg::new_list(output_row); let op = PRELUDE.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?; let err = self.add_load_value(err); self.add_dataflow_op(op, iter::once(err).chain(input_wires)) @@ -74,7 +69,8 @@ pub trait UnwrapBuilder: Dataflow { let tr_rv = sum_type.get_variant(i).unwrap().to_owned(); TypeRow::try_from(tr_rv) }) - .collect::>()?; + .collect::>() + .map_err(SignatureError::from)?; // TODO don't panic if tag >= num_variants let output_row = variants.get(tag).unwrap(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 9fa5ad0772..3f852fd789 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -25,7 +25,7 @@ mod weak_registry; pub use weak_registry::WeakExtensionRegistry; pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; -pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_type_exts}; +pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_term_exts}; pub(crate) use types_mut::resolve_op_types_extensions; use types_mut::{ resolve_custom_type_exts, resolve_term_exts, resolve_type_exts, resolve_value_exts, @@ -39,11 +39,11 @@ use crate::core::HugrNode; use crate::ops::constant::ValueName; use crate::ops::custom::OpaqueOpError; use crate::ops::{NamedOp, OpName, OpType, Value}; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, TypeArg, TypeBase, TypeName}; +use crate::types::{CustomType, Signature, Type, TypeArg, TypeName}; /// Update all weak Extension pointers inside a type. -pub fn resolve_type_extensions( - typ: &mut TypeBase, +pub fn resolve_type_extensions( + typ: &mut Type, extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { let mut used_extensions = WeakExtensionRegistry::default(); @@ -257,8 +257,8 @@ impl ExtensionCollectionError { } /// Create a new error when signature extensions have been dropped. - pub fn dropped_signature( - signature: &FuncTypeBase, + pub fn dropped_signature( + signature: &Signature, missing_extension: impl IntoIterator, ) -> Self { Self::DroppedSignatureExtensions { @@ -268,8 +268,8 @@ impl ExtensionCollectionError { } /// Create a new error when signature extensions have been dropped. - pub fn dropped_type( - typ: &TypeBase, + pub fn dropped_type( + typ: &Type, missing_extension: impl IntoIterator, ) -> Self { Self::DroppedTypeExtensions { diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 01fdf109ac..edae047f48 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -7,12 +7,12 @@ use std::mem; use std::sync::Arc; +use crate::extension::resolution::types::collect_func_type_exts; use crate::extension::{ Extension, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, TypeDef, }; -use super::types::collect_signature_exts; -use super::types_mut::resolve_signature_exts; +use super::types_mut::resolve_func_type_exts; use super::{ExtensionCollectionError, ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -76,7 +76,7 @@ fn collect_extension_deps( for (_, op_def) in extension.operations() { if let Some(signature) = op_def.signature_func().poly_func_type() { let mut local_missing = ExtensionSet::new(); - collect_signature_exts(signature.body(), &mut used, &mut local_missing); + collect_func_type_exts(signature.body(), &mut used, &mut local_missing); for ext in local_missing { missing.insert(ext); } @@ -207,5 +207,5 @@ pub(super) fn resolve_signature_func_exts( return Ok(()); } }; - resolve_signature_exts(None, signature_body, extensions, used_extensions) + resolve_func_type_exts(None, signature_body, extensions, used_extensions) } diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index ceb8590f6d..040898aa55 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -10,8 +10,7 @@ use super::{ExtensionCollectionError, WeakExtensionRegistry}; use crate::Node; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::{DataflowOpTrait, OpType, Value}; -use crate::types::type_row::TypeRowBase; -use crate::types::{FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{FuncValueType, Signature, SumType, Term, TypeRow}; /// Collects every extension used to define the types in an operation. /// @@ -59,7 +58,7 @@ pub(crate) fn collect_op_types_extensions( } } OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), - OpType::LoadConstant(lc) => collect_type_exts(&lc.datatype, &mut used, &mut missing), + OpType::LoadConstant(lc) => collect_term_exts(&lc.datatype, &mut used, &mut missing), OpType::LoadFunction(lf) => { collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&lf.instantiation, &mut used, &mut missing); @@ -121,7 +120,7 @@ pub(crate) fn collect_op_types_extensions( } } -/// Collect the Extension pointers in the [`CustomType`]s inside a signature. +/// Collect the Extension pointers in the [`CustomType`]s inside a [Signature]. /// /// # Attributes /// @@ -129,8 +128,8 @@ pub(crate) fn collect_op_types_extensions( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(crate) fn collect_signature_exts( - signature: &FuncTypeBase, +pub(crate) fn collect_signature_exts( + signature: &Signature, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { @@ -138,6 +137,23 @@ pub(crate) fn collect_signature_exts( collect_type_row_exts(&signature.output, used_extensions, missing_extensions); } +/// Collect the Extension pointers in the [`CustomType`]s inside a [FuncValueType]. +/// +/// # Attributes +/// +/// - `func_ty`: The function type to collect the extensions from. +/// - `used_extensions`: A The registry where to store the used extensions. +/// - `missing_extensions`: A set of `ExtensionId`s of which the +/// `Weak` pointer has been invalidated. +pub(crate) fn collect_func_type_exts( + func_ty: &FuncValueType, + used_extensions: &mut WeakExtensionRegistry, + missing_extensions: &mut ExtensionSet, +) { + collect_term_exts(&func_ty.input, used_extensions, missing_extensions); + collect_term_exts(&func_ty.output, used_extensions, missing_extensions); +} + /// Collect the Extension pointers in the [`CustomType`]s inside a type row. /// /// # Attributes @@ -146,31 +162,31 @@ pub(crate) fn collect_signature_exts( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -fn collect_type_row_exts( - row: &TypeRowBase, +fn collect_type_row_exts( + row: &TypeRow, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { for ty in row.iter() { - collect_type_exts(ty, used_extensions, missing_extensions); + collect_term_exts(ty, used_extensions, missing_extensions); } } -/// Collect the Extension pointers in the [`CustomType`]s inside a type. +/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// # Attributes /// -/// - `typ`: The type to collect the extensions from. +/// - `term`: The term argument to collect the extensions from. /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(crate) fn collect_type_exts( - typ: &TypeBase, +pub(crate) fn collect_term_exts( + term: &Term, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - match typ.as_type_enum() { - TypeEnum::Extension(custom) => { + match term { + Term::RuntimeExtension(custom) => { for arg in custom.args() { collect_term_exts(arg, used_extensions, missing_extensions); } @@ -185,39 +201,16 @@ pub(crate) fn collect_type_exts( } } } - TypeEnum::Function(f) => { - collect_type_row_exts(&f.input, used_extensions, missing_extensions); - collect_type_row_exts(&f.output, used_extensions, missing_extensions); + Term::RuntimeFunction(f) => { + collect_term_exts(&f.input, used_extensions, missing_extensions); + collect_term_exts(&f.output, used_extensions, missing_extensions); } - TypeEnum::Sum(SumType::General { rows }) => { + Term::RuntimeSum(SumType::General { rows }) => { for row in rows { - collect_type_row_exts(row, used_extensions, missing_extensions); + collect_term_exts(row, used_extensions, missing_extensions); } } - // Other types do not store extensions. - TypeEnum::Alias(_) - | TypeEnum::RowVar(_) - | TypeEnum::Variable(_, _) - | TypeEnum::Sum(SumType::Unit { .. }) => {} - } -} - -/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. -/// -/// # Attributes -/// -/// - `term`: The term argument to collect the extensions from. -/// - `used_extensions`: A The registry where to store the used extensions. -/// - `missing_extensions`: A set of `ExtensionId`s of which the -/// `Weak` pointer has been invalidated. -pub(super) fn collect_term_exts( - term: &Term, - used_extensions: &mut WeakExtensionRegistry, - missing_extensions: &mut ExtensionSet, -) { - match term { - Term::Runtime(ty) => collect_type_exts(ty, used_extensions, missing_extensions), - Term::ConstType(ty) => collect_type_exts(ty, used_extensions, missing_extensions), + Term::ConstType(ty) => collect_term_exts(ty, used_extensions, missing_extensions), Term::List(elems) => { for elem in elems.iter() { collect_term_exts(elem, used_extensions, missing_extensions); @@ -254,7 +247,8 @@ pub(super) fn collect_term_exts( | Term::BoundedNat(_) | Term::String(_) | Term::Bytes(_) - | Term::Float(_) => {} + | Term::Float(_) + | Term::RuntimeSum(SumType::Unit { .. }) => {} } } @@ -274,12 +268,12 @@ fn collect_value_exts( match value { Value::Extension { e } => { let typ = e.get_type(); - collect_type_exts(&typ, used_extensions, missing_extensions); + collect_term_exts(&typ, used_extensions, missing_extensions); } Value::Sum(s) => { if let SumType::General { rows } = &s.sum_type { for row in rows { - collect_type_row_exts(row, used_extensions, missing_extensions); + collect_term_exts(row, used_extensions, missing_extensions); } } s.values diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index 56faeec129..1916e48203 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -5,12 +5,11 @@ use std::sync::Weak; -use super::types::collect_type_exts; +use super::types::collect_term_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; -use crate::types::type_row::TypeRowBase; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; +use crate::types::{CustomType, FuncValueType, Signature, SumType, Term, Type, TypeRow, TypeRowRV}; use crate::{Extension, Node}; /// Replace the dangling extension pointer in the [`CustomType`]s inside an @@ -125,12 +124,12 @@ pub fn resolve_op_types_extensions( Ok(used.into_iter()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a signature. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [Signature]. /// /// Adds the extensions used in the signature to the `used_extensions` registry. -pub(super) fn resolve_signature_exts( +pub(super) fn resolve_signature_exts( node: Option, - signature: &mut FuncTypeBase, + signature: &mut Signature, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { @@ -139,12 +138,26 @@ pub(super) fn resolve_signature_exts( Ok(()) } +/// Update all weak Extension pointers in the [`CustomType`]s inside a [FuncValueType]. +/// +/// Adds the extensions used in the signature to the `used_extensions` registry. +pub(super) fn resolve_func_type_exts( + node: Option, + signature: &mut FuncValueType, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + resolve_typerow_rv_exts(node, &mut signature.input, extensions, used_extensions)?; + resolve_typerow_rv_exts(node, &mut signature.output, extensions, used_extensions)?; + Ok(()) +} + /// Update all weak Extension pointers in the [`CustomType`]s inside a type row. /// /// Adds the extensions used in the row to the `used_extensions` registry. -pub(super) fn resolve_type_row_exts( +pub(super) fn resolve_type_row_exts( node: Option, - row: &mut TypeRowBase, + row: &mut TypeRow, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { @@ -154,34 +167,19 @@ pub(super) fn resolve_type_row_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type. +/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. /// -/// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_type_exts( +/// Adds the extensions used in the row to the `used_extensions` registry. +fn resolve_typerow_rv_exts( node: Option, - typ: &mut TypeBase, + row: &mut TypeRowRV, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - match typ.as_type_enum_mut() { - TypeEnum::Extension(custom) => { - resolve_custom_type_exts(node, custom, extensions, used_extensions)?; - } - TypeEnum::Function(f) => { - resolve_type_row_exts(node, &mut f.input, extensions, used_extensions)?; - resolve_type_row_exts(node, &mut f.output, extensions, used_extensions)?; - } - TypeEnum::Sum(SumType::General { rows }) => { - for row in rows.iter_mut() { - resolve_type_row_exts(node, row, extensions, used_extensions)?; - } - } - // Other types do not store extensions. - TypeEnum::Alias(_) - | TypeEnum::RowVar(_) - | TypeEnum::Variable(_, _) - | TypeEnum::Sum(SumType::Unit { .. }) => {} - } + let mut t = Term::from(std::mem::take(row)); + resolve_term_exts(node, &mut t, extensions, used_extensions)?; + *row = TypeRowRV::try_from(t) + .expect("Resolving extensions cannot change kind from ListType(RuntimeType)"); Ok(()) } @@ -211,9 +209,27 @@ pub(super) fn resolve_custom_type_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Type`]. /// /// Adds the extensions used in the type to the `used_extensions` registry. +pub(super) fn resolve_type_exts( + node: Option, + typ: &mut Type, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + const EMPTY: Type = Type::new_unit_sum(0); // as no Type::default() + let mut tm = std::mem::replace(typ, EMPTY).into(); + let r = resolve_term_exts(node, &mut tm, extensions, used_extensions); + *typ = tm + .try_into() + .expect("Resolving extensions cannot change kind from RuntimeType"); + r +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. +/// +/// Adds the extensions used in the term to the `used_extensions` registry. pub(super) fn resolve_term_exts( node: Option, term: &mut Term, @@ -221,7 +237,17 @@ pub(super) fn resolve_term_exts( used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { match term { - Term::Runtime(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, + Term::RuntimeExtension(custom) => { + resolve_custom_type_exts(node, custom, extensions, used_extensions)?; + } + Term::RuntimeFunction(f) => { + resolve_func_type_exts(node, &mut *f, extensions, used_extensions)?; + } + Term::RuntimeSum(SumType::General { rows }) => { + for row in rows.iter_mut() { + resolve_typerow_rv_exts(node, row, extensions, used_extensions)?; + } + } Term::ConstType(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, Term::List(children) | Term::ListConcat(children) @@ -247,7 +273,8 @@ pub(super) fn resolve_term_exts( | Term::BoundedNat(_) | Term::String(_) | Term::Bytes(_) - | Term::Float(_) => {} + | Term::Float(_) + | Term::RuntimeSum(SumType::Unit { .. }) => {} } Ok(()) } @@ -269,7 +296,7 @@ pub(super) fn resolve_value_exts( // return types with valid extensions after we call `update_extensions`. let typ = e.get_type(); let mut missing = ExtensionSet::new(); - collect_type_exts(&typ, used_extensions, &mut missing); + collect_term_exts(&typ, used_extensions, &mut missing); if !missing.is_empty() { return Err(ExtensionResolutionError::InvalidConstTypes { value: e.name(), @@ -280,7 +307,7 @@ pub(super) fn resolve_value_exts( Value::Sum(s) => { if let SumType::General { rows } = &mut s.sum_type { for row in rows.iter_mut() { - resolve_type_row_exts(node, row, extensions, used_extensions)?; + resolve_typerow_rv_exts(node, row, extensions, used_extensions)?; } } s.values diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index 1aff907721..9a4092db3e 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -5,6 +5,7 @@ use std::sync::{Arc, Weak}; use strum::IntoEnumIterator; use crate::ops::{ExtensionOp, OpName, OpNameRef}; +use crate::types::type_param::TermTypeError; use crate::{Extension, ops::OpType, types::TypeArg}; use super::{ExtensionBuildError, ExtensionId, OpDef, SignatureError, op_def::SignatureFunc}; @@ -26,6 +27,12 @@ pub enum OpLoadError { WrongExtension(ExtensionId, ExtensionId), } +impl From for OpLoadError { + fn from(value: TermTypeError) -> Self { + SignatureError::from(value).into() + } +} + /// Traits implemented by types which can add themselves to [`Extension`]s as /// [`OpDef`]s or load themselves from an [`OpDef`]. /// diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index b848c7528f..f2a9dd6956 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -79,7 +79,7 @@ pub struct TypeDef { impl TypeDef { /// Check provided type arguments are valid against parameters. pub fn check_args(&self, args: &[TypeArg]) -> Result<(), SignatureError> { - check_term_types(args, &self.params).map_err(SignatureError::TypeArgMismatch) + Ok(check_term_types(args, &self.params)?) } /// Check [`CustomType`] is a valid instantiation of this definition. @@ -146,10 +146,9 @@ impl TypeDef { } least_upper_bound(indices.iter().map(|i| { let ta = args.get(*i); - match ta { - Some(TypeArg::Runtime(s)) => s.least_upper_bound(), - _ => panic!("TypeArg index does not refer to a type."), - } + ta.copied() + .and_then(TypeArg::least_upper_bound) + .expect("TypeArg index does not refer to a type.") })) } } diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 1aa4384742..f47a717006 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -60,7 +60,7 @@ impl SimpleReplacement { node: replacement.entrypoint(), op: Box::new(replacement.get_optype(replacement.entrypoint()).to_owned()), })?; - if subgraph_sig != repl_sig { + if &subgraph_sig != repl_sig.as_ref() { return Err(InvalidReplacement::InvalidSignature { expected: Box::new(subgraph_sig), actual: Some(Box::new(repl_sig.into_owned())), diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index a6a1103bc6..59ba257848 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -27,7 +27,7 @@ use crate::test_file; use crate::types::type_param::TypeParam; use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, - TypeRV, + TypeRowRV, }; use crate::{OutgoingPort, Visibility, type_row}; use std::fs::File; @@ -54,7 +54,7 @@ pub(super) struct HugrDeser(#[serde(deserialize_with = "Hugr::serde_deserialize" /// Version 1 of the Testing HUGR serialization format, see `testing_hugr.py`. #[derive(Serialize, Deserialize, PartialEq, Debug, Default)] struct SerTestingLatest { - typ: Option, + typ: Option, sum_type: Option, poly_func_type: Option, value: Option, @@ -144,7 +144,7 @@ macro_rules! impl_sertesting_from { }; } -impl_sertesting_from!(crate::types::TypeRV, typ); +impl_sertesting_from!(crate::types::Type, typ); impl_sertesting_from!(crate::types::SumType, sum_type); impl_sertesting_from!(crate::types::PolyFuncTypeRV, poly_func_type); impl_sertesting_from!(crate::ops::Value, value); @@ -158,13 +158,6 @@ impl From for SerTestingLatest { } } -impl From for SerTestingLatest { - fn from(v: Type) -> Self { - let t: TypeRV = v.into(); - t.into() - } -} - #[test] fn empty_hugr_serialize() { check_hugr_json_roundtrip(&Hugr::default(), true); @@ -541,7 +534,7 @@ fn serialize_types_roundtrip() { check_testing_roundtrip(t); // A Classic sum - let t = TypeRV::new_sum([vec![usize_t()], vec![float64_type()]]); + let t = Type::new_sum([vec![usize_t()], vec![float64_type()]]); check_testing_roundtrip(t); let t = Type::new_unit_sum(4); @@ -552,7 +545,6 @@ fn serialize_types_roundtrip() { #[case(bool_t())] #[case(usize_t())] #[case(INT_TYPES[2].clone())] -#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Linear)))] #[case(Type::new_var_use(2, TypeBound::Copyable))] #[case(Type::new_tuple(vec![bool_t(),qb_t()]))] #[case(Type::new_sum([vec![bool_t(),qb_t()], vec![Type::new_unit_sum(4)]]))] @@ -585,14 +577,15 @@ fn polyfunctype1() -> PolyFuncType { } fn polyfunctype2() -> PolyFuncTypeRV { - let tv0 = TypeRV::new_row_var_use(0, TypeBound::Linear); - let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); + let tv0 = TypeRowRV::new_var_use(0, TypeBound::Linear); + let tv1 = TypeRowRV::new_var_use(1, TypeBound::Copyable); let params = [TypeBound::Linear, TypeBound::Copyable].map(TypeParam::new_list_type); - let inputs = vec![ - TypeRV::new_function(FuncValueType::new([tv0.clone()], [tv1.clone()])), - tv0, - ]; - let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, [tv1])); + let inputs = TypeRowRV::from([Type::new_function(FuncValueType::new( + tv0.clone(), + tv1.clone(), + ))]) + .concat(tv0); + let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, tv1)); // Just check we've got the arguments the right way round // (not that it really matters for the serialization schema we have) res.validate().unwrap(); @@ -608,7 +601,7 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[case(PolyFuncType::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( [TypeParam::new_list_type(TypeBound::Linear)], - Signature::new_endo([Type::new_tuple([TypeRV::new_row_var_use(0, TypeBound::Linear)])])))] + Signature::new_endo([Type::new_tuple(TypeRowRV::new_var_use(0, TypeBound::Linear))])))] fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { check_testing_roundtrip(poly_func_type); } @@ -621,7 +614,7 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[case(PolyFuncTypeRV::new([TypeParam::new_tuple_type([TypeBound::Linear.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())])], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( [TypeParam::new_list_type(TypeBound::Linear)], - FuncValueType::new_endo([TypeRV::new_row_var_use(0, TypeBound::Linear)])))] + FuncValueType::new_endo(TypeRowRV::new_var_use(0, TypeBound::Linear))))] #[case(polyfunctype2())] fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { check_testing_roundtrip(poly_func_type); diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 6648d4fd0c..4df7fe0fb4 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -18,8 +18,7 @@ use crate::ops::validate::{ ChildrenEdgeData, ChildrenValidationError, EdgeValidationError, OpValidityFlags, }; use crate::ops::{NamedOp, OpName, OpTag, OpTrait, OpType, ValidateOp}; -use crate::types::EdgeKind; -use crate::types::type_param::TypeParam; +use crate::types::{EdgeKind, type_param::TypeParam}; use crate::{Direction, Port, Visibility}; use super::internal::PortgraphNodeMap; diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 7cec7d6e6d..c8070b962d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -25,7 +25,7 @@ use crate::std_extensions::logic::test::{and_op, or_op}; use crate::types::type_param::{TermTypeError, TypeArg}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Term, Type, TypeBound, - TypeRV, TypeRow, + TypeRow, TypeRowRV, }; use crate::{Direction, Hugr, IncomingPort, Node, const_extension_ids, test_file, type_row}; @@ -495,28 +495,27 @@ fn no_polymorphic_consts() -> Result<(), Box> { pub(crate) fn extension_with_eval_parallel() -> Arc { let rowp = TypeParam::new_list_type(TypeBound::Linear); Extension::new_test_arc(EXT_ID, |ext, extension_ref| { - let inputs = TypeRV::new_row_var_use(0, TypeBound::Linear); - let outputs = TypeRV::new_row_var_use(1, TypeBound::Linear); - let evaled_fn = - TypeRV::new_function(FuncValueType::new([inputs.clone()], [outputs.clone()])); + let inputs = TypeRowRV::new_var_use(0, TypeBound::Linear); + let outputs = TypeRowRV::new_var_use(1, TypeBound::Linear); + let evaled_fn = Type::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone()], - FuncValueType::new([evaled_fn, inputs], [outputs]), + FuncValueType::new(TypeRowRV::from([evaled_fn]).concat(inputs), outputs), ); ext.add_op("eval".into(), String::new(), pf, extension_ref) .unwrap(); - let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Linear); + let rv = |idx| TypeRowRV::new_var_use(idx, TypeBound::Linear); let pf = PolyFuncTypeRV::new( [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], Signature::new( - vec![ - Type::new_function(FuncValueType::new([rv(0)], [rv(2)])), - Type::new_function(FuncValueType::new([rv(1)], [rv(3)])), + [ + Type::new_function(FuncValueType::new(rv(0), rv(2))), + Type::new_function(FuncValueType::new(rv(1), rv(3))), ], [Type::new_function(FuncValueType::new( - [rv(0), rv(1)], - [rv(2), rv(3)], + rv(0).concat(rv(1)), + rv(2).concat(rv(3)), ))], ), ); @@ -528,7 +527,7 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { #[test] fn instantiate_row_variables() -> Result<(), Box> { fn uint_seq(i: usize) -> Term { - vec![usize_t().into(); i].into() + Term::new_list(vec![usize_t(); i]) } let e = extension_with_eval_parallel(); let mut dfb = DFGBuilder::new(inout_sig( @@ -552,16 +551,13 @@ fn instantiate_row_variables() -> Result<(), Box> { Ok(()) } -fn list1ty(t: TypeRV) -> Term { - Term::new_list([t.into()]) -} - #[test] fn row_variables() -> Result<(), Box> { let e = extension_with_eval_parallel(); - let tv = TypeRV::new_row_var_use(0, TypeBound::Linear); - let inner_ft = Type::new_function(FuncValueType::new_endo([tv.clone()])); - let ft_usz = Type::new_function(FuncValueType::new_endo([tv.clone(), usize_t().into()])); + let tv_row = TypeRowRV::new_var_use(0, TypeBound::Linear); + let tv = Term::from(tv_row.clone()); + let inner_ft = Type::new_function(FuncValueType::new_endo(tv_row.clone())); + let ft_usz = Type::new_function(FuncValueType::new_endo(tv_row.concat([usize_t()]))); let mut fb = FunctionBuilder::new( "id", PolyFuncType::new( @@ -580,7 +576,12 @@ fn row_variables() -> Result<(), Box> { }; let par = e.instantiate_extension_op( "parallel", - [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(list1ty), + [ + tv.clone(), + Term::new_list([usize_t()]), + tv.clone(), + Term::new_list([usize_t()]), + ], )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs(par_func.outputs())?; diff --git a/hugr-core/src/hugr/views/root_checked/dfg.rs b/hugr-core/src/hugr/views/root_checked/dfg.rs index abdbacf749..eb68a8ac9c 100644 --- a/hugr-core/src/hugr/views/root_checked/dfg.rs +++ b/hugr-core/src/hugr/views/root_checked/dfg.rs @@ -12,7 +12,7 @@ use crate::{ OpParent, OpTrait, OpType, handle::{DataflowParentID, DfgID}, }, - types::{NoRV, Signature, Type, TypeBase}, + types::{Signature, Type}, }; use super::RootChecked; @@ -262,7 +262,7 @@ fn update_signature(hugr: &mut H, node: H::Node, new_sig: &Signature fn check_valid_inputs( old_ports: &[Vec], - old_sig: &[TypeBase], + old_sig: &[Type], map_sig: &[usize], ) -> Result<(), InvalidSignature> { if let Some(old_pos) = map_sig @@ -291,10 +291,7 @@ fn check_valid_inputs( Ok(()) } -fn check_valid_outputs( - old_sig: &[TypeBase], - map_sig: &[usize], -) -> Result<(), InvalidSignature> { +fn check_valid_outputs(old_sig: &[Type], map_sig: &[usize]) -> Result<(), InvalidSignature> { if let Some(old_pos) = map_sig .iter() .find_map(|&old_pos| (old_pos >= old_sig.len()).then_some(old_pos)) @@ -684,8 +681,8 @@ mod test { let new_inputs = vec![bool_t(), float64_type()]; dfg_view.extend_inputs(&new_inputs).unwrap(); assert_eq!( - dfg_view.hugr().inner_function_type().unwrap(), - Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()]) + dfg_view.hugr().inner_function_type().unwrap().as_ref(), + &Signature::new(vec![qb_t(), bool_t(), float64_type()], vec![qb_t()]) ); let new_inputs_fail = vec![qb_t()]; diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 1e62bc699e..9f775f0b98 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -5,16 +5,19 @@ //! the core and model to converge incrementally. use std::sync::Arc; -use crate::envelope::description::GeneratorDesc; -use crate::metadata::{self, Metadata}; +use crate::envelope::description::{ExtensionDesc, GeneratorDesc, ModuleDesc}; +use crate::metadata::{self, Metadata, RawMetadataValue}; +use crate::types::type_param::{SeqPart, TermTypeError, TypeParam}; +use crate::types::{ + CustomType, FuncTypeBase, PolyFuncType, Signature, Term, Type, TypeArg, TypeBound, TypeName, + TypeRow, TypeRowLike, TypeRowRV, +}; use crate::{ Direction, Hugr, HugrView, Node, Port, - envelope::description::{ExtensionDesc, ModuleDesc}, extension::{ ExtensionId, ExtensionRegistry, SignatureError, resolution::ExtensionResolutionError, }, hugr::HugrMut, - metadata::RawMetadataValue, ops::{ AliasDecl, AliasDefn, CFG, Call, CallIndirect, Case, Conditional, Const, DFG, DataflowBlock, ExitBlock, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, @@ -22,16 +25,8 @@ use crate::{ constant::{CustomConst, CustomSerialized, OpaqueValue}, }, package::Package, - std_extensions::{ - arithmetic::{float_types::ConstF64, int_types::ConstInt}, - collections::array::ArrayValue, - }, - types::{ - CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, - Term, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, - type_param::{SeqPart, TypeParam}, - type_row::TypeRowBase, - }, + std_extensions::arithmetic::{float_types::ConstF64, int_types::ConstInt}, + std_extensions::collections::array::ArrayValue, }; use hugr_model::v0::table; use hugr_model::v0::{self as model}; @@ -117,6 +112,12 @@ enum ImportErrorInner { ExtensionResolution(#[from] ExtensionResolutionError), } +impl From for ImportErrorInner { + fn from(err: TermTypeError) -> Self { + SignatureError::from(err).into() + } +} + #[derive(Debug, Clone, Error)] enum ExtensionError { /// An extension is missing. @@ -314,7 +315,7 @@ impl<'a> Context<'a> { let signature = node_data .signature .ok_or_else(|| error_uninferred!("node signature"))?; - self.import_func_type(signature) + self.import_func_type(signature, Self::import_type_row) } /// Get the node with the given `NodeId`, or return an error if it does not exist. @@ -691,6 +692,7 @@ impl<'a> Context<'a> { region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, + Self::import_type_row, ) .map_err(|err| error_context!(err, "signature of dfg region with id {}", region))?; @@ -839,7 +841,10 @@ impl<'a> Context<'a> { let sum_rows: Vec<_> = { let [variants] = self.expect_symbol(*first, model::CORE_ADT)?; - self.import_type_rows(variants)? + self.import_closed_list(variants)? + .into_iter() + .map(|term_id| self.import_type_row(term_id)) + .collect::>()? }; let rest = rest @@ -937,6 +942,7 @@ impl<'a> Context<'a> { region_data .signature .ok_or_else(|| error_uninferred!("region signature"))?, + Self::import_type_row, )?; let case_node = self @@ -1378,11 +1384,11 @@ impl<'a> Context<'a> { Ok(node) } - fn import_poly_func_type( + fn import_poly_func_type( &mut self, node: table::NodeId, symbol: table::Symbol<'a>, - in_scope: impl FnOnce(&mut Self, PolyFuncTypeBase) -> Result, + in_scope: impl FnOnce(&mut Self, PolyFuncType) -> Result, ) -> Result { (|| { let mut imported_params = Vec::with_capacity(symbol.params.len()); @@ -1425,8 +1431,8 @@ impl<'a> Context<'a> { ); } - let body = self.import_func_type::(symbol.signature)?; - in_scope(self, PolyFuncTypeBase::new(imported_params, body)) + let body = self.import_func_type(symbol.signature, Self::import_type_row)?; + in_scope(self, PolyFuncType::new(imported_params, body)) })() .map_err(|err| error_context!(err, "symbol `{}` defined by node {}", symbol.name, node)) } @@ -1436,6 +1442,10 @@ impl<'a> Context<'a> { self.import_term_with_bound(term_id, TypeBound::Linear) } + fn import_type(&mut self, term_id: table::TermId) -> Result { + Ok(Type::try_from(self.import_term(term_id)?)?) + } + fn import_term_with_bound( &mut self, term_id: table::TermId, @@ -1495,6 +1505,30 @@ impl<'a> Context<'a> { return Ok(TypeParam::new_tuple_type(item_types)); } + if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { + let func_type = self.import_func_type(term_id, |this, term_id| { + Ok(TypeRowRV::try_from(this.import_term(term_id)?)?) + })?; + return Ok(Type::new_function(func_type).into()); + } + + if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { + let variants = (|| { + self.import_closed_list(variants)? + .iter() + .map(|variant| { + self.import_term(*variant).and_then(|tm| { + TypeRowRV::try_from(tm) + .map_err(|e| ImportErrorInner::Signature(e.into())) + }) + }) + .collect::, _>>() + })() + .map_err(|err| error_context!(err, "adt variants"))?; + + return Ok(Type::new_sum(variants).into()); + } + match self.get_term(term_id)? { table::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -1539,51 +1573,6 @@ impl<'a> Context<'a> { table::Term::Literal(model::Literal::Float(value)) => Ok(Term::Float(*value)), table::Term::Func { .. } => Err(error_unsupported!("function constant")), - table::Term::Apply { .. } => { - let ty: Type = self.import_type(term_id)?; - Ok(ty.into()) - } - } - })() - .map_err(|err| error_context!(err, "term {}", term_id)) - } - - fn import_seq_part( - &mut self, - seq_part: &'a table::SeqPart, - ) -> Result, ImportErrorInner> { - Ok(match seq_part { - table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), - table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), - }) - } - - /// Import a `Type` from a term that represents a runtime type. - fn import_type( - &mut self, - term_id: table::TermId, - ) -> Result, ImportErrorInner> { - (|| { - if let Some([_, _]) = self.match_symbol(term_id, model::CORE_FN)? { - let func_type = self.import_func_type::(term_id)?; - return Ok(TypeBase::new_function(func_type)); - } - - if let Some([variants]) = self.match_symbol(term_id, model::CORE_ADT)? { - let variants = (|| { - self.import_closed_list(variants)? - .iter() - .map(|variant| self.import_type_row::(*variant)) - .collect::, _>>() - })() - .map_err(|err| error_context!(err, "adt variants"))?; - - return Ok(TypeBase::new_sum(variants)); - } - - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), - table::Term::Apply(symbol, args) => { let name = self.get_symbol_name(*symbol)?; @@ -1615,32 +1604,28 @@ impl<'a> Context<'a> { let bound = ext_type.bound(&args); - Ok(TypeBase::new_extension(CustomType::new( + Ok(Type::new_extension(CustomType::new( id, args, extension, bound, &Arc::downgrade(extension_ref), - ))) + )) + .into()) } - - table::Term::Var(var @ table::VarId(_, index)) => { - let local_var = self - .local_vars - .get(var) - .ok_or(error_invalid!("unknown var {}", var))?; - Ok(TypeBase::new_var_use(*index as _, local_var.bound)) - } - - // The following terms are not runtime types, but the core `Type` only contains runtime types. - // We therefore report a type error here. - table::Term::Literal(_) - | table::Term::List { .. } - | table::Term::Tuple { .. } - | table::Term::Func { .. } => Err(error_invalid!("expected a runtime type")), } })() - .map_err(|err| error_context!(err, "term {} as `Type`", term_id)) + .map_err(|err| error_context!(err, "term {}", term_id)) + } + + fn import_seq_part( + &mut self, + seq_part: &'a table::SeqPart, + ) -> Result, ImportErrorInner> { + Ok(match seq_part { + table::SeqPart::Item(term_id) => SeqPart::Item(self.import_term(*term_id)?), + table::SeqPart::Splice(term_id) => SeqPart::Splice(self.import_term(*term_id)?), + }) } fn get_func_type( @@ -1667,18 +1652,18 @@ impl<'a> Context<'a> { /// /// Function types are not special-cased in `hugr-model` but are represented /// via the `core.fn` term constructor. - fn import_func_type( + fn import_func_type( &mut self, term_id: table::TermId, - ) -> Result, ImportErrorInner> { + import_io: impl Fn(&mut Self, table::TermId) -> Result, + ) -> Result, ImportErrorInner> { (|| { let [inputs, outputs] = self.get_func_type(term_id)?; - let inputs = self - .import_type_row(inputs) - .map_err(|err| error_context!(err, "function inputs"))?; - let outputs = self - .import_type_row(outputs) - .map_err(|err| error_context!(err, "function outputs"))?; + let inputs = + import_io(self, inputs).map_err(|err| error_context!(err, "function inputs"))?; + let outputs = + import_io(self, outputs).map_err(|err| error_context!(err, "function outputs"))?; + Ok(FuncTypeBase::new(inputs, outputs)) })() .map_err(|err| error_context!(err, "function type")) @@ -1778,64 +1763,18 @@ impl<'a> Context<'a> { Ok(types) } - /// Imports a list of lists as a vector of type rows. - /// - /// See [`Self::import_type_row`]. - fn import_type_rows( - &mut self, - term_id: table::TermId, - ) -> Result>, ImportErrorInner> { - self.import_closed_list(term_id)? - .into_iter() - .map(|term_id| self.import_type_row::(term_id)) - .collect() - } - - /// Imports a list as a type row. + /// Imports a closed list as a type row. /// /// This method works to produce a [`TypeRow`] or a [`TypeRowRV`], depending /// on the `RV` type argument. For [`TypeRow`] a closed list is expected. /// For [`TypeRowRV`] we import spliced variables as row variables. - fn import_type_row( - &mut self, - term_id: table::TermId, - ) -> Result, ImportErrorInner> { - fn import_into( - ctx: &mut Context, - term_id: table::TermId, - types: &mut Vec>, - ) -> Result<(), ImportErrorInner> { - match ctx.get_term(term_id)? { - table::Term::List(parts) => { - types.reserve(parts.len()); - - for item in *parts { - match item { - table::SeqPart::Item(term_id) => { - types.push(ctx.import_type::(*term_id)?); - } - table::SeqPart::Splice(term_id) => { - import_into(ctx, *term_id, types)?; - } - } - } - } - table::Term::Var(table::VarId(_, index)) => { - let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Linear)) - .map_err(|_| { - error_invalid!("Expected a closed list.\n{}", CLOSED_LIST_HINT) - })?; - types.push(TypeBase::new(TypeEnum::RowVar(var))); - } - _ => return Err(error_invalid!("expected a list")), - } - - Ok(()) - } - - let mut types = Vec::new(); - import_into(self, term_id, &mut types)?; - Ok(types.into()) + fn import_type_row(&mut self, term_id: table::TermId) -> Result { + let elems = self.import_closed_list(term_id)?; + Ok(elems + .into_iter() + .map(|id| self.import_term(id)) + .collect::, _>>()? + .try_into()?) } fn import_custom_name( @@ -1987,13 +1926,8 @@ impl<'a> Context<'a> { .map(|(value, ty)| self.import_value(*value, *ty)) .collect::, _>>()?; - let ty = { - // TODO: Import as a `SumType` directly and avoid the copy. - let ty: Type = self.import_type(type_id)?; - match ty.as_type_enum() { - TypeEnum::Sum(sum) => sum.clone(), - _ => unreachable!(), - } + let Term::RuntimeSum(ty) = self.import_term(type_id)? else { + unreachable!() }; return Ok(Value::sum(*tag as _, items, ty).unwrap()); @@ -2138,7 +2072,7 @@ impl<'a> Context<'a> { struct LocalVar { /// The type of the variable. r#type: table::TermId, - /// The type bound of the variable. + /// The type bound of the variable. Overwritten if a constraint is seen. bound: TypeBound, } diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 1bf360c544..23c115d950 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -70,8 +70,8 @@ //! // Create a conditional operation //! let conditional = Conditional { //! sum_rows: vec![[usize_t()].into(), [bool_t()].into()], -//! other_inputs: vec![usize_t().into()].into(), -//! outputs: vec![bool_t()].into(), +//! other_inputs: [usize_t()].into(), +//! outputs: [bool_t()].into(), //! }; //! let cond_op: OpType = conditional.into(); //! assert!(cond_op.is_container()); // Contains case branches diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 45a06b16f5..5838495d60 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use crate::Direction; -use crate::types::{EdgeKind, Signature, Type, TypeRow}; +use crate::types::{EdgeKind, Signature, Type, TypeRow, TypeRowLike}; use super::OpTag; use super::dataflow::{DataflowOpTrait, DataflowParent}; @@ -351,7 +351,7 @@ mod test { use crate::{ extension::prelude::{qb_t, usize_t}, ops::{Conditional, DataflowOpTrait, DataflowParent}, - types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV}, + types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRowRV}, }; use super::{DataflowBlock, TailLoop}; @@ -368,31 +368,28 @@ mod test { let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()])); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( - dfb2.inner_signature(), - Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) + dfb2.inner_signature().as_ref(), + &Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) ); } #[test] fn test_subst_conditional() { let tv1 = Type::new_var_use(1, TypeBound::Linear); + let tup_ty = Type::new_tuple(TypeRowRV::new_var_use(0, TypeBound::Linear)); let cond = Conditional { sum_rows: vec![[usize_t()].into(), [tv1.clone()].into()], - other_inputs: vec![Type::new_tuple([TypeRV::new_row_var_use( - 0, - TypeBound::Linear, - )])] - .into(), + other_inputs: vec![tup_ty].into(), outputs: vec![usize_t(), tv1].into(), }; let cond2 = cond.substitute(&Substitution::new(&[ - TypeArg::new_list([usize_t().into(), usize_t().into(), usize_t().into()]), + TypeArg::new_list([usize_t(), usize_t(), usize_t()]), qb_t().into(), ])); let st = Type::new_sum([[usize_t()], [qb_t()]]); assert_eq!( - cond2.signature(), - Signature::new( + cond2.signature().as_ref(), + &Signature::new( [st, Type::new_tuple(vec![usize_t(); 3])], [usize_t(), qb_t()] ) @@ -409,8 +406,8 @@ mod test { }; let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()])); assert_eq!( - tail2.signature(), - Signature::new( + tail2.signature().as_ref(), + &Signature::new( vec![qb_t(), usize_t(), usize_t()], vec![usize_t(), qb_t(), usize_t()] ) diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 9e46764728..e370fb2936 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -6,7 +6,9 @@ use super::{OpTag, OpTrait, impl_op_name}; use crate::extension::SignatureError; use crate::ops::StaticTag; -use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow}; +use crate::types::{ + EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow, TypeRowLike, +}; use crate::{IncomingPort, type_row}; #[cfg(test)] diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index 71955bdc1b..1561123276 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,12 +1,12 @@ //! Handles to nodes in HUGR. use crate::Node; use crate::core::HugrNode; -use crate::types::{Type, TypeBound}; +use crate::types::TypeBound; use derive_more::From as DerFrom; use smol_str::SmolStr; -use super::{AliasDecl, OpTag}; +use super::OpTag; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. @@ -89,10 +89,6 @@ impl AliasID { Self { node, name, bound } } - /// Construct new `AliasID` - pub fn get_alias_type(&self) -> Type { - Type::new_alias(AliasDecl::new(self.name.clone(), self.bound)) - } /// Retrieve the underlying core type pub fn get_name(&self) -> &SmolStr { &self.name diff --git a/hugr-core/src/ops/sum.rs b/hugr-core/src/ops/sum.rs index 1c535683fc..3ceb88d9a0 100644 --- a/hugr-core/src/ops/sum.rs +++ b/hugr-core/src/ops/sum.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use super::dataflow::DataflowOpTrait; use super::{OpTag, impl_op_name}; -use crate::types::{EdgeKind, Signature, Type, TypeRow}; +use crate::types::{EdgeKind, Signature, Type, TypeRow, TypeRowLike}; /// An operation that creates a tagged sum value from one of its variants. #[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] diff --git a/hugr-core/src/proptest.rs b/hugr-core/src/proptest.rs index 051e300bf8..c569f923de 100644 --- a/hugr-core/src/proptest.rs +++ b/hugr-core/src/proptest.rs @@ -8,8 +8,8 @@ use std::sync::LazyLock; use crate::Hugr; #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] -/// The types [Type], [`TypeEnum`], [`SumType`], [`FunctionType`], [`TypeArg`], -/// [`TypeParam`], as well as several others, form a mutually recursive hierarchy. +/// The types [Type], [`Term`], [`SumType`], [`FunctionType`], [`CustomType`], +/// as well as several others, form a mutually recursive hierarchy. /// /// The proptest [`proptest::strategy::Strategy::prop_recursive`] is inadequate to /// generate values for these types. Instead, the Arbitrary instances take a diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index 3b26a27404..d9903ed0b8 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -13,7 +13,7 @@ use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc}; use crate::ops::{ExtensionOp, OpName}; use crate::std_extensions::arithmetic::int_ops::int_polytype; use crate::std_extensions::arithmetic::int_types::int_type; -use crate::types::{TypeArg, TypeRV}; +use crate::types::TypeArg; use super::float_types::float64_type; use super::int_types::{get_log_width, int_tv}; @@ -62,11 +62,9 @@ impl MakeOpDef for ConvertOpDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { use ConvertOpDef::*; match self { - trunc_s | trunc_u => int_polytype( - 1, - [float64_type()], - [TypeRV::from(sum_with_error([int_tv(0)]))], - ), + trunc_s | trunc_u => { + int_polytype(1, [float64_type()], [sum_with_error([int_tv(0)]).into()]) + } convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]), itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]), ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]), diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 11c16b14a8..05512383a9 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -10,7 +10,7 @@ use crate::extension::simple_op::{ use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs}; use crate::ops::OpName; use crate::ops::custom::ExtensionOp; -use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; +use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRow, TypeRowRV}; use crate::utils::collect_array; use crate::{ @@ -136,7 +136,7 @@ impl MakeOpDef for IntOpDef { } ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(), idivmod_checked_u | idivmod_checked_s => { - let intpair: TypeRowRV = vec![tv0; 2].into(); + let intpair: TypeRow = vec![tv0; 2].into(); int_polytype( 1, intpair.clone(), diff --git a/hugr-core/src/std_extensions/collections/array/array_clone.rs b/hugr-core/src/std_extensions/collections/array/array_clone.rs index 566ee12c70..742c40f549 100644 --- a/hugr-core/src/std_extensions/collections/array/array_clone.rs +++ b/hugr-core/src/std_extensions/collections/array/array_clone.rs @@ -180,8 +180,9 @@ impl HasConcrete for GenericArrayCloneDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { - Ok(GenericArrayClone::new(ty.clone(), *n).unwrap()) + [TypeArg::BoundedNat(n), ty] if ty.copyable() => { + let ty = Type::try_from(ty.clone()).unwrap(); // succeeds as copyable + Ok(GenericArrayClone::new(ty, *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_conversion.rs b/hugr-core/src/std_extensions/collections/array/array_conversion.rs index 61b013a062..baf8442621 100644 --- a/hugr-core/src/std_extensions/collections/array/array_conversion.rs +++ b/hugr-core/src/std_extensions/collections/array/array_conversion.rs @@ -231,8 +231,8 @@ impl HasConcrete fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { - Ok(GenericArrayConvert::new(ty.clone(), *n)) + [TypeArg::BoundedNat(n), ty] => { + Ok(GenericArrayConvert::new(ty.clone().try_into()?, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_discard.rs b/hugr-core/src/std_extensions/collections/array/array_discard.rs index 17e2be1577..be3021bafd 100644 --- a/hugr-core/src/std_extensions/collections/array/array_discard.rs +++ b/hugr-core/src/std_extensions/collections/array/array_discard.rs @@ -164,8 +164,9 @@ impl HasConcrete for GenericArrayDiscardDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { - Ok(GenericArrayDiscard::new(ty.clone(), *n).unwrap()) + [TypeArg::BoundedNat(n), ty] if ty.copyable() => { + let ty = ty.clone().try_into().unwrap(); // succeeds as copyable + Ok(GenericArrayDiscard::new(ty, *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 26ebb5b5f4..84dbd8ce3c 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -326,12 +326,12 @@ impl HasConcrete for GenericArrayOpDef { fn instantiate(&self, type_args: &[Term]) -> Result { let (ty, size) = match (self, type_args) { - (GenericArrayOpDef::discard_empty, [Term::Runtime(ty)]) => (ty.clone(), 0), - (_, [Term::BoundedNat(n), Term::Runtime(ty)]) => (ty.clone(), *n), + (GenericArrayOpDef::discard_empty, [ty]) => (ty.clone(), 0), + (_, [Term::BoundedNat(n), ty]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; - Ok(self.to_concrete(ty.clone(), size)) + Ok(self.to_concrete(ty.try_into()?, size)) } } diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index 3fb121980f..59ef8a7508 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -170,8 +170,9 @@ impl HasConcrete for GenericArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { - Ok(GenericArrayRepeat::new(ty.clone(), *n)) + [TypeArg::BoundedNat(n), ty] => { + let ty = Type::try_from(ty.clone())?; + Ok(GenericArrayRepeat::new(ty, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 5bd62466c2..003c12b449 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -4,8 +4,6 @@ use std::marker::PhantomData; use std::str::FromStr; use std::sync::{Arc, Weak}; -use itertools::Itertools; - use crate::Extension; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, @@ -13,7 +11,7 @@ use crate::extension::simple_op::{ use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound, TypeRowRV}; use super::array_kind::ArrayKind; @@ -64,27 +62,23 @@ impl GenericArrayScanDef { let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let t1 = Type::new_var_use(1, TypeBound::Linear); let t2 = Type::new_var_use(2, TypeBound::Linear); - let s = TypeRV::new_row_var_use(3, TypeBound::Linear); + let with_rest = |tys: Vec| { + TypeRowRV::from(tys).concat(TypeRowRV::new_var_use(3, TypeBound::Linear)) + }; PolyFuncTypeRV::new( params, - FuncTypeBase::::new( - vec![ + FuncValueType::new( + with_rest(vec![ AK::instantiate_ty(array_def, n.clone(), t1.clone()) - .expect("Array type instantiation failed") - .into(), - Type::new_function(FuncTypeBase::::new( - vec![t1.into(), s.clone()], - vec![t2.clone().into(), s.clone()], - )) - .into(), - s.clone(), - ], - vec![ - AK::instantiate_ty(array_def, n, t2) - .expect("Array type instantiation failed") - .into(), - s, - ], + .expect("Array type instantiation failed"), + Type::new_function(FuncValueType::new( + with_rest(vec![t1]), + with_rest(vec![t2.clone()]), + )), + ]), + with_rest(vec![ + AK::instantiate_ty(array_def, n, t2).expect("Array type instantiation failed"), + ]), ), ) .into() @@ -188,7 +182,7 @@ impl MakeExtensionOp for GenericArrayScan { self.size.into(), self.src_ty.clone().into(), self.tgt_ty.clone().into(), - TypeArg::new_list(self.acc_tys.clone().into_iter().map_into()), + TypeArg::new_list(self.acc_tys.clone()), ] } } @@ -214,23 +208,17 @@ impl HasConcrete for GenericArrayScanDef { match type_args { [ TypeArg::BoundedNat(n), - TypeArg::Runtime(src_ty), - TypeArg::Runtime(tgt_ty), + src_elem_ty, + tgt_elem_ty, TypeArg::List(acc_tys), ] => { - let acc_tys: Result<_, OpLoadError> = acc_tys + let src_elem_ty = Type::try_from(src_elem_ty.clone())?; + let tgt_elem_ty = Type::try_from(tgt_elem_ty.clone())?; + let acc_tys = acc_tys .iter() - .map(|acc_ty| match acc_ty { - TypeArg::Runtime(ty) => Ok(ty.clone()), - _ => Err(SignatureError::InvalidTypeArgs.into()), - }) - .collect(); - Ok(GenericArrayScan::new( - src_ty.clone(), - tgt_ty.clone(), - acc_tys?, - *n, - )) + .map(|tm| Type::try_from(tm.clone())) + .collect::, _>>()?; + Ok(GenericArrayScan::new(src_elem_ty, tgt_elem_ty, acc_tys, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs index 33828d9e0d..a05091154a 100644 --- a/hugr-core/src/std_extensions/collections/array/array_value.rs +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -94,8 +94,8 @@ impl GenericArrayValue { // constant can only hold classic type. let ty = match typ.args() { - [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if *n as usize == self.values.len() => { - ty + [TypeArg::BoundedNat(n), ty] if *n as usize == self.values.len() && ty.copyable() => { + Type::try_from(ty.clone()).unwrap() // succeeds as copyable } _ => { return Err(CustomCheckFailure::Message(format!( @@ -107,7 +107,7 @@ impl GenericArrayValue { // check all values are instances of the element type for v in &self.values { - if v.get_type() != *ty { + if v.get_type() != ty { return Err(CustomCheckFailure::Message(format!( "Array element {v:?} is not of expected type {ty}" ))); diff --git a/hugr-core/src/std_extensions/collections/borrow_array.rs b/hugr-core/src/std_extensions/collections/borrow_array.rs index 534a28655c..6a93067228 100644 --- a/hugr-core/src/std_extensions/collections/borrow_array.rs +++ b/hugr-core/src/std_extensions/collections/borrow_array.rs @@ -280,7 +280,10 @@ impl HasConcrete for BArrayUnsafeOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [Term::BoundedNat(n), Term::Runtime(ty)] => Ok(self.to_concrete(ty.clone(), *n)), + [Term::BoundedNat(n), ty] => { + let ty = Type::try_from(ty.clone())?; + Ok(self.to_concrete(ty, *n)) + } _ => Err(SignatureError::InvalidTypeArgs.into()), } } diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 4dd80742f9..20d6d35f39 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -111,13 +111,14 @@ impl CustomConst for ListValue { .map_err(|_| error())?; // constant can only hold classic type. - let [TypeArg::Runtime(ty)] = typ.args() else { - return Err(error()); + let ty = match typ.args() { + [ty] if ty.least_upper_bound().is_some() => Type::try_from(ty.clone()).unwrap(), // succeeds as has l-u-b + _ => return Err(error()), }; // check all values are instances of the element type for v in &self.0 { - if v.get_type() != *ty { + if v.get_type() != ty { return Err(error()); } } @@ -349,18 +350,16 @@ impl MakeExtensionOp for ListOpInst { fn from_extension_op( ext_op: &ExtensionOp, ) -> Result { - let [Term::Runtime(ty)] = ext_op.args() else { + let [ty] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs.into()); }; + let elem_type = ty.clone().try_into()?; let name = ext_op.unqualified_id(); let Ok(op) = ListOp::from_str(name) else { return Err(OpLoadError::NotMember(name.to_string())); }; - Ok(Self { - elem_type: ty.clone(), - op, - }) + Ok(Self { elem_type, op }) } fn type_args(&self) -> Vec { diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 007be5ecc4..a097f4b568 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -310,15 +310,13 @@ impl HasConcrete for StaticArrayOpDef { use TypeBound::Copyable; match type_args { [arg] => { - let elem_ty = arg - .as_runtime() - .filter(|t| Copyable.contains(t.least_upper_bound())) - .ok_or(SignatureError::TypeArgMismatch( - TermTypeError::TypeMismatch { - type_: Box::new(Copyable.into()), - term: Box::new(arg.clone()), - }, - ))?; + if !arg.copyable() { + Err(SignatureError::from(TermTypeError::TypeMismatch { + type_: Box::new(Copyable.into()), + term: Box::new(arg.clone()), + }))? + } + let elem_ty = Type::try_from(arg.clone()).unwrap(); // succeeds as copyable Ok(StaticArrayOp { def: *self, diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 7816e9b03c..8757af2f9f 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -204,7 +204,7 @@ impl HasConcrete for PtrOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { let ty = match type_args { - [TypeArg::Runtime(ty)] => ty.clone(), + [ty] => Type::try_from(ty.clone())?, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 1f67edf594..19741c748f 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -3,41 +3,36 @@ mod check; pub mod custom; mod poly_func; -mod row_var; pub(crate) mod serialize; mod signature; pub mod type_param; pub mod type_row; -pub(crate) use row_var::MaybeRV; -pub use row_var::{NoRV, RowVariable}; use crate::extension::resolution::{ - ExtensionCollectionError, WeakExtensionRegistry, collect_type_exts, + ExtensionCollectionError, WeakExtensionRegistry, collect_term_exts, }; pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure}; -use crate::types::type_param::check_term_type; +use crate::types::type_param::{TermTypeError, check_term_type}; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; -pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; +pub use poly_func::{PolyFuncType, PolyFuncTypeBase, PolyFuncTypeRV}; pub use signature::{FuncTypeBase, FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::{Term, TypeArg}; +pub(crate) use type_row::TypeRowLike; pub use type_row::{TypeRow, TypeRowRV}; -pub(crate) use poly_func::PolyFuncTypeBase; - use itertools::FoldWhile::{Continue, Done}; use itertools::{Either, Itertools as _}; #[cfg(test)] use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; +use std::ops::Deref; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; -use crate::ops::AliasDecl; use self::type_param::TypeParam; -use self::type_row::TypeRowBase; /// A unique identifier for a type. pub type TypeName = SmolStr; @@ -225,13 +220,12 @@ impl SumType { where V: Into, { - let rows = variants.into_iter().map(Into::into).collect_vec(); - - let len: usize = rows.len(); - if u8::try_from(len).is_ok() && rows.iter().all(TypeRowRV::is_empty) { + let variants = variants.into_iter().map(V::into).collect_vec(); + let len = variants.len(); + if u8::try_from(len).is_ok() && variants.iter().all(TypeRowRV::is_empty) { Self::new_unary(len as u8) } else { - Self::General { rows } + Self::General { rows: variants } } } @@ -255,7 +249,7 @@ impl SumType { #[must_use] pub fn get_variant(&self, tag: usize) -> Option<&TypeRowRV> { match self { - SumType::Unit { size } if tag < (*size as usize) => Some(TypeRV::EMPTY_TYPEROW_REF), + SumType::Unit { size } if tag < (*size as usize) => Some(TypeRowRV::EMPTY_REF), SumType::General { rows } => rows.get(tag), _ => None, } @@ -274,39 +268,48 @@ impl SumType { #[must_use] pub fn as_tuple(&self) -> Option<&TypeRowRV> { match self { - SumType::Unit { size } if *size == 1 => Some(TypeRV::EMPTY_TYPEROW_REF), + SumType::Unit { size } if *size == 1 => Some(TypeRowRV::EMPTY_REF), SumType::General { rows } if rows.len() == 1 => Some(&rows[0]), _ => None, } } - /// If the sum matches the convention of `Option[row]`, return the row. + /// If the sum matches the convention of `Option[row]`, return the row + /// (an instance of [Term::ListType]([Term::RuntimeType]). #[must_use] pub fn as_option(&self) -> Option<&TypeRowRV> { match self { - SumType::Unit { size } if *size == 2 => Some(TypeRV::EMPTY_TYPEROW_REF), + SumType::Unit { size } if *size == 2 => Some(TypeRowRV::EMPTY_REF), SumType::General { rows } if rows.len() == 2 && rows[0].is_empty() => Some(&rows[1]), _ => None, } } - /// If a sum is an option of a single type, return the type. - #[must_use] - pub fn as_unary_option(&self) -> Option<&TypeRV> { - self.as_option() - .and_then(|row| row.iter().exactly_one().ok()) - } - - /// Returns an iterator over the variants. + /// Returns an iterator over the variants pub fn variants(&self) -> impl Iterator { match self { - SumType::Unit { size } => Either::Left(itertools::repeat_n( - TypeRV::EMPTY_TYPEROW_REF, - *size as usize, - )), + SumType::Unit { size } => { + Either::Left(itertools::repeat_n(TypeRowRV::EMPTY_REF, *size as usize)) + } SumType::General { rows } => Either::Right(rows.iter()), } } + + fn bound(&self) -> TypeBound { + match self { + SumType::Unit { .. } => TypeBound::Copyable, + SumType::General { rows } => { + if rows + .iter() + .all(|t| check_term_type(t, &Term::new_list_type(TypeBound::Copyable)).is_ok()) + { + TypeBound::Copyable + } else { + TypeBound::Linear + } + } + } + } } impl Transformable for SumType { @@ -318,72 +321,29 @@ impl Transformable for SumType { } } -impl From for TypeBase { +impl From for Type { fn from(sum: SumType) -> Self { match sum { - SumType::Unit { size } => TypeBase::new_unit_sum(size), - SumType::General { rows } => TypeBase::new_sum(rows), + SumType::Unit { size } => Type::new_unit_sum(size), + SumType::General { rows } => Type::new_sum(rows), } } } -#[derive(Clone, Debug, Eq, Hash, derive_more::Display)] -/// Core types -pub enum TypeEnum { - /// An extension type. - // - // TODO optimise with `Box`? - // or some static version of this? - Extension(CustomType), - /// An alias of a type. - #[display("Alias({})", _0.name())] - Alias(AliasDecl), - /// A function type. - #[display("{_0}")] - Function(Box), - /// A type variable, defined by an index into a list of type parameters. - // - // We cache the TypeBound here (checked in validation) - #[display("#{_0}")] - Variable(usize, TypeBound), - /// `RowVariable`. Of course, this requires that `RV` has instances, [`NoRV`] doesn't. - #[display("RowVar({_0})")] - RowVar(RV), - /// Sum of types. - #[display("{_0}")] - Sum(SumType), -} - -impl TypeEnum { - /// The smallest type bound that covers the whole type. - fn least_upper_bound(&self) -> TypeBound { - match self { - TypeEnum::Extension(c) => c.bound(), - TypeEnum::Alias(a) => a.bound, - TypeEnum::Function(_) => TypeBound::Copyable, - TypeEnum::Variable(_, b) => *b, - TypeEnum::RowVar(b) => b.bound(), - TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Copyable, - TypeEnum::Sum(SumType::General { rows }) => least_upper_bound( - rows.iter() - .flat_map(TypeRowRV::iter) - .map(TypeRV::least_upper_bound), - ), - } - } -} - -#[derive(Clone, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize)] +#[derive( + Clone, Debug, Eq, Hash, PartialEq, derive_more::Display, serde::Serialize, serde::Deserialize, +)] #[display("{_0}")] #[serde( into = "serialize::SerSimpleType", try_from = "serialize::SerSimpleType" )] -/// A HUGR type - the valid types of [`EdgeKind::Value`] and [`EdgeKind::Const`] edges. +/// A HUGR type - a single value, that can be sent down a wire. /// +/// The valid types of [`EdgeKind::Value`] and [`EdgeKind::Const`] edges. /// Such an edge is valid if the ports on either end agree on the [Type]. -/// Types have an optional [`TypeBound`] which places limits on the valid -/// operations on a type. +/// Types have a [`TypeBound`] which specifies the number of inports +/// to which a particular outport (of that type) may be connected. /// /// Examples: /// ``` @@ -400,58 +360,32 @@ impl TypeEnum { /// let func_type: Type = Type::new_function(Signature::new_endo([])); /// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable); /// ``` -pub struct TypeBase(TypeEnum, TypeBound); - -/// The type of a single value, that can be sent down a wire -pub type Type = TypeBase; - -/// One or more types - either a single type, or a row variable -/// standing for multiple types. -pub type TypeRV = TypeBase; - -impl PartialEq> for TypeEnum { - fn eq(&self, other: &TypeEnum) -> bool { - match (self, other) { - (TypeEnum::Extension(e1), TypeEnum::Extension(e2)) => e1 == e2, - (TypeEnum::Alias(a1), TypeEnum::Alias(a2)) => a1 == a2, - (TypeEnum::Function(f1), TypeEnum::Function(f2)) => f1 == f2, - (TypeEnum::Variable(i1, b1), TypeEnum::Variable(i2, b2)) => i1 == i2 && b1 == b2, - (TypeEnum::RowVar(v1), TypeEnum::RowVar(v2)) => v1.as_rv() == v2.as_rv(), - (TypeEnum::Sum(s1), TypeEnum::Sum(s2)) => s1 == s2, - _ => false, - } - } -} - -impl PartialEq> for TypeBase { - fn eq(&self, other: &TypeBase) -> bool { - self.0 == other.0 && self.1 == other.1 - } -} +pub struct Type(Term, TypeBound); -impl TypeBase { +impl Type { /// An empty `TypeRow` or `TypeRowRV`. Provided here for convenience - pub const EMPTY_TYPEROW: TypeRowBase = TypeRowBase::::new(); + pub const EMPTY_TYPEROW: TypeRow = TypeRow::new(); /// Unit type (empty tuple). pub const UNIT: Self = Self( - TypeEnum::Sum(SumType::Unit { size: 1 }), + Term::RuntimeSum(SumType::Unit { size: 1 }), TypeBound::Copyable, ); - const EMPTY_TYPEROW_REF: &'static TypeRowBase = &Self::EMPTY_TYPEROW; - /// Initialize a new function type. pub fn new_function(fun_ty: impl Into) -> Self { - Self::new(TypeEnum::Function(Box::new(fun_ty.into()))) + Self( + Term::RuntimeFunction(Box::new(fun_ty.into())), + TypeBound::Copyable, + ) } /// Initialize a new tuple type by providing the elements. #[inline(always)] pub fn new_tuple(types: impl Into) -> Self { let row = types.into(); - match row.len() { - 0 => Self::UNIT, - _ => Self::new_sum([row]), + match row.is_empty() { + true => Self::UNIT, + false => Self::new_sum([row]), } } @@ -461,7 +395,9 @@ impl TypeBase { where R: Into, { - Self::new(TypeEnum::Sum(SumType::new(variants))) + let st = SumType::new(variants); + let b = st.bound(); + Self(Term::RuntimeSum(st), b) } /// Initialize a new custom type. @@ -469,25 +405,17 @@ impl TypeBase { #[must_use] pub const fn new_extension(opaque: CustomType) -> Self { let bound = opaque.bound(); - TypeBase(TypeEnum::Extension(opaque), bound) - } - - /// Initialize a new alias. - #[must_use] - pub fn new_alias(alias: AliasDecl) -> Self { - Self::new(TypeEnum::Alias(alias)) - } - - pub(crate) fn new(type_e: TypeEnum) -> Self { - let bound = type_e.least_upper_bound(); - Self(type_e, bound) + Self(Term::RuntimeExtension(opaque), bound) } /// New `UnitSum` with empty Tuple variants #[must_use] pub const fn new_unit_sum(size: u8) -> Self { // should be the only way to avoid going through SumType::new - Self(TypeEnum::Sum(SumType::new_unary(size)), TypeBound::Copyable) + Self( + Term::RuntimeSum(SumType::new_unary(size)), + TypeBound::Copyable, + ) } /// New use (occurrence) of the type variable with specified index. @@ -495,8 +423,8 @@ impl TypeBase { /// (i.e. as a [`Term::RuntimeType`]`(bound)`), which may be narrower /// than required for the use. #[must_use] - pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::Variable(idx, bound), bound) + pub fn new_var_use(idx: usize, bound: TypeBound) -> Self { + Self(Term::new_var_use(idx, bound), bound) } /// Report the least upper [`TypeBound`] @@ -505,34 +433,6 @@ impl TypeBase { self.1 } - /// Report the component `TypeEnum`. - #[inline(always)] - pub const fn as_type_enum(&self) -> &TypeEnum { - &self.0 - } - - /// Report a mutable reference to the component `TypeEnum`. - #[inline(always)] - pub fn as_type_enum_mut(&mut self) -> &mut TypeEnum { - &mut self.0 - } - - /// Returns the inner [`SumType`] if the type is a sum. - pub fn as_sum(&self) -> Option<&SumType> { - match &self.0 { - TypeEnum::Sum(s) => Some(s), - _ => None, - } - } - - /// Returns the inner [`CustomType`] if the type is from an extension. - pub fn as_extension(&self) -> Option<&CustomType> { - match &self.0 { - TypeEnum::Extension(ct) => Some(ct), - _ => None, - } - } - /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { @@ -540,54 +440,33 @@ impl TypeBase { } /// Checks all variables used in the type are in the provided list - /// of bound variables, rejecting any [`RowVariable`]s if `allow_row_vars` is False; - /// and that for each [`CustomType`] the corresponding + /// of bound variables, and that for each [`CustomType`] the corresponding /// [`TypeDef`] is in the [`ExtensionRegistry`] and the type arguments /// [validate] and fit into the def's declared parameters. /// - /// [RowVariable]: TypeEnum::RowVariable /// [validate]: crate::types::type_param::TypeArg::validate /// [TypeDef]: crate::extension::TypeDef pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - // There is no need to check the components against the bound, - // that is guaranteed by construction (even for deserialization) - match &self.0 { - TypeEnum::Sum(SumType::General { rows }) => { - rows.iter().try_for_each(|row| row.validate(var_decls)) - } - TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there - TypeEnum::Alias(_) => Ok(()), - TypeEnum::Extension(custy) => custy.validate(var_decls), - // Function values may be passed around without knowing their arity - // (i.e. with row vars) as long as they are not called: - TypeEnum::Function(ft) => ft.validate(var_decls), - TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()), - TypeEnum::RowVar(rv) => rv.validate(var_decls), - } + self.0.validate(var_decls)?; + // ALAN even this should be only a debug-assert really: + // we have no unchecked access from outside crate::types + // so it must be a bug in our caching logic if this is wrong: + check_term_type(&self.0, &self.1.into())?; + debug_assert!( + self.1 == TypeBound::Copyable + || check_term_type(&self.0, &TypeBound::Copyable.into()).is_err() + ); + Ok(()) } /// Applies a substitution to a type. - /// This may result in a row of types, if this [Type] is not really a single type but actually a row variable - /// Invariants may be confirmed by validation: - /// * If [`Type::validate`]`(false)` returns successfully, this method will return a Vec containing exactly one type - /// * If [`Type::validate`]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self) - /// return a Vec containing any number of [Type]s. These may (or not) pass [`Type::validate`] - fn substitute(&self, t: &Substitution) -> Vec { - match &self.0 { - TypeEnum::RowVar(rv) => rv.substitute(t), - TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], - TypeEnum::Variable(idx, bound) => { - let TypeArg::Runtime(ty) = t.apply_var(*idx, &((*bound).into())) else { - panic!("Variable was not a type - try validate() first") - }; - vec![ty.into_()] - } - TypeEnum::Extension(cty) => vec![TypeBase::new_extension(cty.substitute(t))], - TypeEnum::Function(bf) => vec![TypeBase::new_function(bf.substitute(t))], - TypeEnum::Sum(SumType::General { rows }) => { - vec![TypeBase::new_sum(rows.iter().map(|r| r.substitute(t)))] - } - } + /// + /// Always produces exactly one type, but may narrow the bound (from + /// [TypeBound::Linear] to [TypeBound::Copyable]). + fn substitute(&self, s: &Substitution) -> Self { + let t = self.0.substitute(s); + let b = t.least_upper_bound().unwrap(); // Recompute. + Self(t, b) } /// Returns a registry with the concrete extensions used by this type. @@ -598,7 +477,7 @@ impl TypeBase { let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); - collect_type_exts(self, &mut used, &mut missing); + collect_term_exts(self, &mut used, &mut missing); if missing.is_empty() { Ok(used.try_into().expect("all extensions are present")) @@ -608,114 +487,41 @@ impl TypeBase { } } -impl Transformable for TypeBase { +impl Transformable for Type { fn transform(&mut self, tr: &T) -> Result { - match &mut self.0 { - TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => Ok(false), - TypeEnum::Extension(custom_type) => { - if let Some(nt) = tr.apply_custom(custom_type)? { - *self = nt.into_(); - Ok(true) - } else { - let args_changed = custom_type.args_mut().transform(tr)?; - if args_changed { - *self = Self::new_extension( - custom_type - .get_type_def(&custom_type.get_extension()?)? - .instantiate(custom_type.args())?, - ); - } - Ok(args_changed) - } - } - TypeEnum::Function(fty) => fty.transform(tr), - TypeEnum::Sum(sum_type) => { - let ch = sum_type.transform(tr)?; - self.1 = self.0.least_upper_bound(); - Ok(ch) - } + let res = self.0.transform(tr)?; + if res { + self.1 = self.0.least_upper_bound().unwrap() } + Ok(res) } } -impl Type { - fn substitute1(&self, s: &Substitution) -> Self { - let v = self.substitute(s); - let [r] = v.try_into().unwrap(); // No row vars, so every Type produces exactly one - r - } -} - -impl TypeRV { - /// Tells if this Type is a row variable, i.e. could stand for any number >=0 of Types - #[must_use] - pub fn is_row_var(&self) -> bool { - matches!(self.0, TypeEnum::RowVar(_)) - } - - /// New use (occurrence) of the row variable with specified index. - /// `bound` must match that with which the variable was declared - /// (i.e. as a list of runtime types of that bound). - /// For use in [OpDef], not [FuncDefn], type schemes only. - /// - /// [OpDef]: crate::extension::OpDef - /// [FuncDefn]: crate::ops::FuncDefn - #[must_use] - pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::RowVar(RowVariable(idx, bound)), bound) - } -} +impl Deref for Type { + type Target = Term; -// ====== Conversions ====== -impl TypeBase { - /// (Fallibly) converts a `TypeBase` (parameterized, so may or may not be able - /// to contain [`RowVariable`]s) into a [Type] that definitely does not. - pub fn try_into_type(self) -> Result { - Ok(TypeBase( - match self.0 { - TypeEnum::Extension(e) => TypeEnum::Extension(e), - TypeEnum::Alias(a) => TypeEnum::Alias(a), - TypeEnum::Function(f) => TypeEnum::Function(f), - TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound), - TypeEnum::RowVar(rv) => Err(rv.as_rv().clone())?, - TypeEnum::Sum(s) => TypeEnum::Sum(s), - }, - self.1, - )) + fn deref(&self) -> &Self::Target { + &self.0 } } -impl TryFrom for Type { - type Error = RowVariable; - fn try_from(value: TypeRV) -> Result { - value.try_into_type() - } -} +impl TryFrom for Type { + type Error = TermTypeError; -impl TypeBase { - /// A swiss-army-knife for any safe conversion of the type argument `RV1` - /// to/from [`NoRV`]/RowVariable/rust-type-variable. - fn into_(self) -> TypeBase - where - RV1: Into, - { - TypeBase( - match self.0 { - TypeEnum::Extension(e) => TypeEnum::Extension(e), - TypeEnum::Alias(a) => TypeEnum::Alias(a), - TypeEnum::Function(f) => TypeEnum::Function(f), - TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound), - TypeEnum::RowVar(rv) => TypeEnum::RowVar(rv.into()), - TypeEnum::Sum(s) => TypeEnum::Sum(s), - }, - self.1, - ) + fn try_from(t: Term) -> Result { + match t.least_upper_bound() { + Some(b) => Ok(Self(t, b)), + None => Err(TermTypeError::TypeMismatch { + term: Box::new(t), + type_: Box::new(TypeBound::Linear.into()), + }), + } } } -impl From for TypeRV { - fn from(value: Type) -> Self { - value.into_() +impl From for Term { + fn from(t: Type) -> Self { + t.0 } } @@ -745,36 +551,6 @@ impl<'a> Substitution<'a> { debug_assert_eq!(check_term_type(arg, decl), Ok(())); arg.clone() } - - fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec { - let arg = self - .0 - .get(idx) - .expect("Undeclared type variable - call validate() ?"); - debug_assert!(check_term_type(arg, &TypeParam::new_list_type(bound)).is_ok()); - match arg { - TypeArg::List(elems) => elems - .iter() - .map(|ta| { - match ta { - Term::Runtime(ty) => return ty.clone().into(), - Term::Variable(v) => { - if let Some(b) = v.bound_if_row_var() { - return TypeRV::new_row_var_use(v.index(), b); - } - } - _ => (), - } - panic!("Not a list of types - call validate() ?") - }) - .collect(), - Term::Runtime(ty) if matches!(ty.0, TypeEnum::RowVar(_)) => { - // Standalone "Type" can be used iff its actually a Row Variable not an actual (single) Type - vec![ty.clone().into()] - } - _ => panic!("Not a type or list of types - call validate() ?"), - } - } } /// A transformation that can be applied to a [Type] or [`TypeArg`]. @@ -825,31 +601,6 @@ impl Transformable for [E] { } } -pub(crate) fn check_typevar_decl( - decls: &[TypeParam], - idx: usize, - cached_decl: &TypeParam, -) -> Result<(), SignatureError> { - match decls.get(idx) { - None => Err(SignatureError::FreeTypeVar { - idx, - num_decls: decls.len(), - }), - Some(actual) => { - // The cache here just mirrors the declaration. The typevar can be used - // anywhere expecting a kind *containing* the decl - see `check_type_arg`. - if actual == cached_decl { - Ok(()) - } else { - Err(SignatureError::TypeVarDoesNotMatchDeclaration { - cached: Box::new(cached_decl.clone()), - actual: Box::new(actual.clone()), - }) - } - } - } -} - #[cfg(test)] pub(crate) mod test { use std::hash::{Hash, Hasher}; @@ -876,18 +627,14 @@ pub(crate) mod test { // Dummy extension reference. &Weak::default(), )), - Type::new_alias(AliasDecl::new("my_alias", TypeBound::Copyable)), ]); - assert_eq!( - &t.to_string(), - "[usize, [] -> [], my_custom, Alias(my_alias)]" - ); + assert_eq!(&t.to_string(), "[usize, [] -> [], my_custom]"); } #[rstest::rstest] fn sum_construct() { let pred1 = Type::new_sum([type_row![], type_row![]]); - let pred2 = TypeRV::new_unit_sum(2); + let pred2 = Type::new_unit_sum(2); assert_eq!(pred1, pred2); @@ -905,9 +652,21 @@ pub(crate) mod test { fn as_option() { let opt = option_type([usize_t()]); - assert_eq!(opt.as_unary_option().unwrap().clone(), usize_t()); assert_eq!( - Type::new_unit_sum(2).as_sum().unwrap().as_unary_option(), + opt.as_option().unwrap().clone(), + TypeRowRV::from([usize_t()]) + ); + // Two empty variants is like an option of empty. + assert_eq!( + Type::new_unit_sum(2).as_sum().unwrap().as_option(), + Some(TypeRowRV::EMPTY_REF) + ); + + assert_eq!( + Type::new_sum(vec![[usize_t()]; 2]) + .as_sum() + .unwrap() + .as_option(), None ); @@ -932,18 +691,16 @@ pub(crate) mod test { #[test] fn sum_variants() { let variants: Vec = vec![ - [TypeRV::UNIT].into(), - vec![TypeRV::new_row_var_use(0, TypeBound::Linear)].into(), + [Type::UNIT].into(), + TypeRowRV::new_var_use(0, TypeBound::Linear), ]; let t = SumType::new(variants.clone()); assert_eq!(variants, t.variants().cloned().collect_vec()); - let empty_rows = vec![TypeRV::EMPTY_TYPEROW; 3]; + let empty_rows = vec![TypeRowRV::new(); 3]; let sum_unary = SumType::new_unary(3); - let sum_general = SumType::General { - rows: empty_rows.clone(), - }; - assert_eq!(&empty_rows, &sum_unary.variants().cloned().collect_vec()); + assert_eq!(empty_rows, sum_unary.variants().cloned().collect_vec()); + let sum_general = SumType::General { rows: empty_rows }; assert_eq!(sum_general, sum_unary); let mut hasher_general = std::hash::DefaultHasher::new(); @@ -1023,7 +780,7 @@ pub(crate) mod test { let coln = e.get_type(&COLN).unwrap(); let c_of_cpy = coln - .instantiate([Term::new_list([Type::from(cpy.clone()).into()])]) + .instantiate([Term::new_list([Type::from(cpy.clone())])]) .unwrap(); let mut t = Type::new_extension(c_of_cpy.clone()); @@ -1036,7 +793,7 @@ pub(crate) mod test { ); let mut t = Type::new_extension( - coln.instantiate([Term::new_list([mk_opt(Type::from(cpy.clone())).into()])]) + coln.instantiate([Term::new_list([mk_opt(Type::from(cpy.clone()))])]) .unwrap(), ); assert_eq!( @@ -1053,14 +810,14 @@ pub(crate) mod test { (ct == &c_of_cpy).then_some(usize_t()) }); let mut t = Type::new_extension( - coln.instantiate([Term::new_list(vec![Type::from(c_of_cpy.clone()).into(); 2])]) + coln.instantiate([Term::new_list(vec![Type::from(c_of_cpy.clone()); 2])]) .unwrap(), ); assert_eq!(t.transform(&cpy_to_qb2), Ok(true)); assert_eq!( t, Type::new_extension( - coln.instantiate([Term::new_list([usize_t().into(), usize_t().into()])]) + coln.instantiate([Term::new_list([usize_t(), usize_t()])]) .unwrap() ) ); @@ -1070,11 +827,10 @@ pub(crate) mod test { use crate::proptest::RecursionDepth; - use super::{AliasDecl, MaybeRV, TypeBase, TypeBound, TypeEnum}; - use crate::types::{CustomType, FuncValueType, SumType, TypeRowRV}; - use proptest::prelude::*; + use crate::types::{CustomType, FuncValueType, SumType, Term, Type, TypeBound, TypeRowRV}; + use proptest::{prelude::*, strategy::Union}; - impl Arbitrary for super::SumType { + impl Arbitrary for SumType { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { @@ -1089,25 +845,38 @@ pub(crate) mod test { } } - impl Arbitrary for TypeBase { + impl Arbitrary for Type { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - // We descend here, because a TypeEnum may contain a Type + let strat = Union::new([ + (any::(), any::()) + .prop_map(|(i, b)| Type::new_var_use(i, b)) + .boxed(), + any_with::(depth.into()) + .prop_map(Type::new_extension) + .boxed(), + ]); + if depth.leaf() { + return strat.boxed(); + } let depth = depth.descend(); - prop_oneof![ - 1 => any::().prop_map(TypeBase::new_alias), - 1 => any_with::(depth.into()).prop_map(TypeBase::new_extension), - 1 => any_with::(depth).prop_map(TypeBase::new_function), - 1 => any_with::(depth).prop_map(TypeBase::from), - 1 => (any::(), any::()).prop_map(|(i,b)| TypeBase::new_var_use(i,b)), - // proptest_derive::Arbitrary's weight attribute requires a constant, - // rather than this expression, hence the manual impl: - RV::weight() => RV::arb().prop_map(|rv| TypeBase::new(TypeEnum::RowVar(rv))) - ] + strat + .or(any_with::(depth) + .prop_map(Type::new_function) + .boxed()) + .or(any_with::(depth).prop_map(Type::from).boxed()) .boxed() } } + + proptest! { + #[test] + fn type_term_roundtrip(t: Type) { + let tm = Term::from(t.clone()); + assert_eq!(Type::try_from(tm), Ok(t)); + } + } } } @@ -1116,11 +885,10 @@ pub(super) mod proptest_utils { use proptest::collection::vec; use proptest::prelude::{Strategy, any_with}; - use super::serialize::{TermSer, TypeArgSer, TypeParamSer}; - use super::type_param::Term; - use crate::proptest::RecursionDepth; - use crate::types::serialize::ArrayOrTermSer; + + use super::serialize::{ArrayOrTermSer, TermSer, TypeArgSer, TypeParamSer}; + use super::type_param::Term; fn term_is_serde_type_arg(t: &Term) -> bool { let TermSer::TypeArg(arg) = TermSer::from(t.clone()) else { @@ -1138,7 +906,7 @@ pub(super) mod proptest_utils { } else { true } - } // Do we need to inspect inside function types? sum types? + } TypeArgSer::BoundedNat { .. } | TypeArgSer::String { .. } | TypeArgSer::Bytes { .. } diff --git a/hugr-core/src/types/check.rs b/hugr-core/src/types/check.rs index 072da5884e..9ef8dc783b 100644 --- a/hugr-core/src/types/check.rs +++ b/hugr-core/src/types/check.rs @@ -3,7 +3,10 @@ use thiserror::Error; use super::{Type, TypeRow}; -use crate::{extension::SignatureError, ops::Value}; +use crate::{ + ops::Value, + types::{Term, type_param::TermTypeError}, +}; /// Errors that arise from typechecking constants #[derive(Clone, Debug, PartialEq, Error)] @@ -69,10 +72,16 @@ impl super::SumType { num_variants: self.num_variants(), })?; let variant: TypeRow = variant.clone().try_into().map_err(|e| { - let SignatureError::RowVarWhereTypeExpected { var } = e else { - panic!("Unexpected error") + let TermTypeError::TypeMismatch { term, .. } = e else { + panic!("Unexpected error {e}") }; - SumTypeError::VariantNotConcrete { tag, varidx: var.0 } + let Term::Variable(tv) = &*term else { + panic!("Unexpected term {term}"); + }; + SumTypeError::VariantNotConcrete { + tag, + varidx: tv.index(), + } })?; if variant.len() != val.len() { diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 248e0f6253..e4010271fe 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -8,10 +8,9 @@ use crate::Extension; use crate::extension::{ExtensionId, SignatureError, TypeDef}; use super::{ - Substitution, TypeBound, + Substitution, Type, TypeBound, TypeName, type_param::{TypeArg, TypeParam}, }; -use super::{Type, TypeName}; /// An opaque type element. Contains the unique identifier of its definition. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -27,8 +26,6 @@ pub struct CustomType { /// [`TypeDef`]: crate::extension::TypeDef id: TypeName, /// Arguments that fit the [`TypeParam`]s declared by the typedef - /// - /// [`TypeParam`]: super::type_param::TypeParam args: Vec, /// The [`TypeBound`] describing what can be done to instances of this type bound: TypeBound, diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index ea16ab958b..988a88b00a 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -4,18 +4,14 @@ use std::borrow::Cow; use itertools::Itertools; -use crate::extension::SignatureError; -#[cfg(test)] -use { - super::proptest_utils::any_serde_type_param, - crate::proptest::RecursionDepth, - ::proptest::{collection::vec, prelude::*}, - proptest_derive::Arbitrary, +use crate::{ + extension::SignatureError, + types::{TypeRow, TypeRowLike, TypeRowRV}, }; use super::Substitution; +use super::signature::FuncTypeBase; use super::type_param::{TypeArg, TypeParam, check_term_types}; -use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; /// A polymorphic type scheme, i.e. of a [`FuncDecl`], [`FuncDefn`] or [`OpDef`]. /// (Nodes/operations in the Hugr are not polymorphic.) @@ -24,19 +20,24 @@ use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; /// [`FuncDefn`]: crate::ops::module::FuncDefn /// [`OpDef`]: crate::extension::OpDef #[derive( - Clone, PartialEq, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, + PartialEq, + Debug, + Default, + Eq, + Hash, + derive_more::Display, + serde::Serialize, + serde::Deserialize, )] -#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] #[display("{}{body}", self.display_params())] -pub struct PolyFuncTypeBase { +pub struct PolyFuncTypeBase { /// The declared type parameters, i.e., these must be instantiated with /// the same number of [`TypeArg`]s before the function can be called. This /// defines the indices used by variables inside the body. - #[cfg_attr(test, proptest(strategy = "vec(any_serde_type_param(params), 0..3)"))] params: Vec, /// Template for the function. May contain variables up to length of [`Self::params`] - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - body: FuncTypeBase, + body: FuncTypeBase, } /// The polymorphic type of a [`Call`]-able function ([`FuncDecl`] or [`FuncDefn`]). @@ -45,26 +46,16 @@ pub struct PolyFuncTypeBase { /// [`Call`]: crate::ops::Call /// [`FuncDefn`]: crate::ops::FuncDefn /// [`FuncDecl`]: crate::ops::FuncDecl -pub type PolyFuncType = PolyFuncTypeBase; +pub type PolyFuncType = PolyFuncTypeBase; /// The polymorphic type of an [`OpDef`], whose number of input and outputs -/// may vary according to how [`RowVariable`]s therein are instantiated. +/// may vary according to how row variables therein are instantiated. /// /// [`OpDef`]: crate::extension::OpDef -pub type PolyFuncTypeRV = PolyFuncTypeBase; - -// deriving Default leads to an impl that only applies for RV: Default -impl Default for PolyFuncTypeBase { - fn default() -> Self { - Self { - params: Default::default(), - body: Default::default(), - } - } -} +pub type PolyFuncTypeRV = PolyFuncTypeBase; -impl From> for PolyFuncTypeBase { - fn from(body: FuncTypeBase) -> Self { +impl From> for PolyFuncTypeBase { + fn from(body: FuncTypeBase) -> Self { Self { params: vec![], body, @@ -81,11 +72,11 @@ impl From for PolyFuncTypeRV { } } -impl TryFrom> for FuncTypeBase { +impl TryFrom> for FuncTypeBase { /// If the `PolyFuncTypeBase` is not monomorphic, fail with its binders type Error = Vec; - fn try_from(value: PolyFuncTypeBase) -> Result { + fn try_from(value: PolyFuncTypeBase) -> Result { if value.params.is_empty() { Ok(value.body) } else { @@ -94,20 +85,36 @@ impl TryFrom> for FuncTypeBase { } } -impl PolyFuncTypeBase { +impl PolyFuncTypeBase { + /// Helper function for the Display implementation + fn display_params(&self) -> Cow<'static, str> { + if self.params.is_empty() { + return Cow::Borrowed(""); + } + let params_list = self + .params + .iter() + .enumerate() + .map(|(i, param)| format!("(#{i} : {param})")) + .join(" "); + Cow::Owned(format!("∀ {params_list}. ",)) + } +} + +impl PolyFuncTypeBase { /// The type parameters, aka binders, over which this type is polymorphic pub fn params(&self) -> &[TypeParam] { &self.params } /// The body of the type, a function type. - pub fn body(&self) -> &FuncTypeBase { + pub fn body(&self) -> &FuncTypeBase { &self.body } /// Create a new `PolyFuncTypeBase` given the kinds of the variables it declares /// and the underlying [`FuncTypeBase`]. - pub fn new(params: impl Into>, body: impl Into>) -> Self { + pub fn new(params: impl Into>, body: impl Into>) -> Self { Self { params: params.into(), body: body.into(), @@ -120,7 +127,7 @@ impl PolyFuncTypeBase { /// # Errors /// If there is not exactly one [`TypeArg`] for each binder ([`Self::params`]), /// or an arg does not fit into its corresponding [`TypeParam`] - pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { + pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. check_term_types(args, &self.params)?; @@ -133,26 +140,15 @@ impl PolyFuncTypeBase { self.body.validate(&self.params) } - /// Helper function for the Display implementation - fn display_params(&self) -> Cow<'static, str> { - if self.params.is_empty() { - return Cow::Borrowed(""); - } - let params_list = self - .params - .iter() - .enumerate() - .map(|(i, param)| format!("(#{i} : {param})")) - .join(" "); - Cow::Owned(format!("∀ {params_list}. ",)) - } - /// Returns a mutable reference to the body of the function type. - pub fn body_mut(&mut self) -> &mut FuncTypeBase { + pub fn body_mut(&mut self) -> &mut FuncTypeBase { &mut self.body } } +// Do not implement Substitutable: we never need to substitute into a PolyFuncType +// (i.e. under a binder). + #[cfg(test)] pub(crate) mod test { use std::num::NonZeroU64; @@ -168,15 +164,41 @@ pub(crate) mod test { use crate::types::signature::FuncTypeBase; use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, MaybeRV, Signature, Term, Type, TypeBound, TypeName, TypeRV, + CustomType, FuncValueType, PolyFuncTypeBase, Signature, Term, Type, TypeBound, TypeName, + TypeRowLike, TypeRowRV, }; - use super::PolyFuncTypeBase; - - impl PolyFuncTypeBase { + mod proptest { + use proptest::collection::vec; + use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any_with}; + + use super::PolyFuncTypeBase; + use crate::proptest::RecursionDepth; + use crate::types::proptest_utils::any_serde_type_param; + use crate::types::{TypeRowLike, signature::FuncTypeBase}; + + impl + 'static> Arbitrary + for PolyFuncTypeBase + { + type Parameters = RecursionDepth; + type Strategy = BoxedStrategy; + + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + // We want to generate a random number of type parameters, and then generate a body that can refer to those parameters. + // To do this, we first generate the type parameters, and then pass them as parameters to the body strategy. + ( + vec(any_serde_type_param(params), 0..3), + any_with::>(params), + ) + .prop_map(|(params, body)| Self::new(params, body)) + .boxed() + } + } + } + impl PolyFuncTypeBase { fn new_validated( params: impl Into>, - body: FuncTypeBase, + body: FuncTypeBase, ) -> Result { let res = Self::new(params, body); res.validate()?; @@ -187,7 +209,7 @@ pub(crate) mod test { #[test] fn test_opaque() -> Result<(), SignatureError> { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); - let tyvar = TypeArg::new_var_use(0, TypeBound::Linear.into()); + let tyvar = TypeArg::new_var_use(0, TypeBound::Linear); let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); let list_len = PolyFuncTypeBase::new_validated( [TypeBound::Linear.into()], @@ -211,7 +233,7 @@ pub(crate) mod test { #[test] fn test_mismatched_args() -> Result<(), SignatureError> { let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); - let ty_var = TypeArg::new_var_use(1, TypeBound::Linear.into()); + let ty_var = TypeArg::new_var_use(1, TypeBound::Linear); let type_params = [TypeParam::max_nat_type(), TypeBound::Linear.into()]; // Valid schema... @@ -262,7 +284,7 @@ pub(crate) mod test { #[test] fn test_misused_variables() -> Result<(), SignatureError> { // Variables in args have different bounds from variable declaration - let tv = TypeArg::new_var_use(0, TypeBound::Copyable.into()); + let tv = TypeArg::new_var_use(0, TypeBound::Copyable); let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo([Type::new_extension(list_def.instantiate([tv])?)]); for decl in [ @@ -377,10 +399,7 @@ pub(crate) mod test { let decl = Term::new_list_type(TP_ANY); let e = PolyFuncTypeBase::new_validated( [decl.clone()], - FuncValueType::new( - vec![usize_t()], - vec![TypeRV::new_row_var_use(0, TypeBound::Copyable)], - ), + FuncValueType::new([usize_t()], TypeRowRV::new_var_use(0, TypeBound::Copyable)), ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { @@ -401,10 +420,13 @@ pub(crate) mod test { #[test] fn row_variables() { - let rty = TypeRV::new_row_var_use(0, TypeBound::Linear); + let rty = TypeRowRV::new_var_use(0, TypeBound::Linear); let pf = PolyFuncTypeBase::new_validated( [TypeParam::new_list_type(TP_ANY)], - FuncValueType::new([usize_t().into(), rty.clone()], [TypeRV::new_tuple([rty])]), + FuncValueType::new( + TypeRowRV::from([usize_t()]).concat(rty.clone()), + [Type::new_tuple(rty)], + ), ) .unwrap(); @@ -417,20 +439,20 @@ pub(crate) mod test { let t2 = pf.instantiate(&[Term::new_list(seq2())]).unwrap(); assert_eq!( - t2, Signature::new( vec![usize_t(), usize_t(), bool_t()], vec![Type::new_tuple(vec![usize_t(), bool_t()])] - ) + ), + t2 ); } #[test] fn row_variables_inner() { - let inner_fty = Type::new_function(FuncValueType::new_endo([TypeRV::new_row_var_use( + let inner_fty = Type::new_function(FuncValueType::new_endo(TypeRowRV::new_var_use( 0, TypeBound::Copyable, - )])); + ))); let pf = PolyFuncTypeBase::new_validated( [Term::new_list_type(TypeBound::Copyable)], Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), @@ -439,11 +461,7 @@ pub(crate) mod test { let inner3 = Type::new_function(Signature::new_endo([usize_t(), bool_t(), usize_t()])); let t3 = pf - .instantiate(&[Term::new_list([ - usize_t().into(), - bool_t().into(), - usize_t().into(), - ])]) + .instantiate(&[Term::new_list([usize_t(), bool_t(), usize_t()])]) .unwrap(); assert_eq!( t3, diff --git a/hugr-core/src/types/row_var.rs b/hugr-core/src/types/row_var.rs deleted file mode 100644 index 086ab7b076..0000000000 --- a/hugr-core/src/types/row_var.rs +++ /dev/null @@ -1,126 +0,0 @@ -//! Classes for row variables (i.e. Type variables that can stand for multiple types) - -use super::type_param::TypeParam; -use super::{Substitution, TypeBase, TypeBound, check_typevar_decl}; -use crate::extension::SignatureError; - -#[cfg(test)] -use proptest::prelude::{BoxedStrategy, Strategy, any}; -/// Describes a row variable - a type variable bound with a list of runtime types -/// of the specified bound (checked in validation) -// The serde derives here are not used except as markers -// so that other types containing this can also #derive-serde the same way. -#[derive( - Clone, Debug, Eq, Hash, PartialEq, derive_more::Display, serde::Serialize, serde::Deserialize, -)] -#[display("{_0}")] -pub struct RowVariable(pub usize, pub TypeBound); - -// Note that whilst 'pub' this is not re-exported outside private module `row_var` -// so is effectively sealed. -pub trait MaybeRV: - Clone - + std::fmt::Debug - + std::fmt::Display - + From - + Into - + Eq - + PartialEq - + 'static -{ - fn as_rv(&self) -> &RowVariable; - fn try_from_rv(rv: RowVariable) -> Result; - fn bound(&self) -> TypeBound; - fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError>; - #[allow(private_interfaces)] - fn substitute(&self, s: &Substitution) -> Vec>; - #[cfg(test)] - fn weight() -> u32 { - 1 - } - #[cfg(test)] - fn arb() -> BoxedStrategy; -} - -/// Has no instances - used as parameter to [`Type`] to rule out the possibility -/// of there being any [`TypeEnum::RowVar`]s -/// -/// [`TypeEnum::RowVar`]: super::TypeEnum::RowVar -/// [`Type`]: super::Type -// The serde derives here are not used except as markers -// so that other types containing this can also #derive-serde the same way. -#[derive( - Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize, -)] -pub enum NoRV {} - -impl From for RowVariable { - fn from(value: NoRV) -> Self { - match value {} - } -} - -impl MaybeRV for RowVariable { - fn as_rv(&self) -> &RowVariable { - self - } - - fn bound(&self) -> TypeBound { - self.1 - } - - fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - check_typevar_decl(var_decls, self.0, &TypeParam::new_list_type(self.1)) - } - - #[allow(private_interfaces)] - fn substitute(&self, s: &Substitution) -> Vec> { - s.apply_rowvar(self.0, self.1) - } - - fn try_from_rv(rv: RowVariable) -> Result { - Ok(rv) - } - - #[cfg(test)] - fn arb() -> BoxedStrategy { - (any::(), any::()) - .prop_map(|(i, b)| Self(i, b)) - .boxed() - } -} - -impl MaybeRV for NoRV { - fn as_rv(&self) -> &RowVariable { - match *self {} - } - - fn bound(&self) -> TypeBound { - match *self {} - } - - fn validate(&self, _var_decls: &[TypeParam]) -> Result<(), SignatureError> { - match *self {} - } - - #[allow(private_interfaces)] - fn substitute(&self, _s: &Substitution) -> Vec> { - match *self {} - } - - fn try_from_rv(rv: RowVariable) -> Result { - Err(rv) - } - - #[cfg(test)] - fn weight() -> u32 { - 0 - } - - #[cfg(test)] - fn arb() -> BoxedStrategy { - any::() - .prop_map(|_| panic!("Should be ruled out by weight==0")) - .boxed() - } -} diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index eeff6f2e14..4fd1565029 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -1,18 +1,18 @@ use std::sync::Arc; use ordered_float::OrderedFloat; +use serde::Serialize; -use super::{FuncValueType, MaybeRV, RowVariable, SumType, TypeBase, TypeBound, TypeEnum}; +use super::{FuncValueType, SumType, Term, Type, TypeBound}; use super::custom::CustomType; -use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::ops::AliasDecl; -use crate::types::type_param::{TermVar, UpperBound}; -use crate::types::{Term, Type}; +use crate::types::TypeRowRV; +use crate::types::type_param::{SeqPart, TermTypeError, TermVar, UpperBound}; -#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] +#[derive(Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "t")] pub(crate) enum SerSimpleType { Q, @@ -25,45 +25,68 @@ pub(crate) enum SerSimpleType { R { i: usize, b: TypeBound }, } -impl From> for SerSimpleType { - fn from(value: TypeBase) -> Self { +impl From for SerSimpleType { + fn from(value: Type) -> Self { if value == qb_t() { return SerSimpleType::Q; } if value == usize_t() { return SerSimpleType::I; } - match value.0 { - TypeEnum::Extension(o) => SerSimpleType::Opaque(o), - TypeEnum::Alias(a) => SerSimpleType::Alias(a), - TypeEnum::Function(sig) => SerSimpleType::G(sig), - TypeEnum::Variable(i, b) => SerSimpleType::V { i, b }, - TypeEnum::RowVar(rv) => { - let RowVariable(idx, bound) = rv.as_rv(); - SerSimpleType::R { i: *idx, b: *bound } + match value.into() { + Term::RuntimeExtension(o) => SerSimpleType::Opaque(o), + Term::RuntimeFunction(sig) => SerSimpleType::G(sig), + Term::Variable(tv) => { + let i = tv.index(); + let Term::RuntimeType(b) = &*tv.cached_decl else { + panic!("Variable with bound {} is not a valid Type", tv.cached_decl); + }; + SerSimpleType::V { i, b: *b } } - TypeEnum::Sum(st) => SerSimpleType::Sum(st), + Term::RuntimeSum(st) => SerSimpleType::Sum(st), + v => panic!("{} was not a valid Type", v), } } } -impl TryFrom for TypeBase { - type Error = SignatureError; +// Row Variables can also be serialized as "simple types" +impl TryFrom for SerSimpleType { + type Error = TermTypeError; + + fn try_from(value: Term) -> Result { + if let Term::Variable(tv) = &value + && let Term::ListType(t) = &*tv.cached_decl + && let Term::RuntimeType(b) = &**t + { + return Ok(SerSimpleType::R { + i: tv.index(), + b: *b, + }); + } + Type::try_from(value).map(SerSimpleType::from) + } +} + +impl From for Term { + fn from(value: SerSimpleType) -> Self { + match value { + SerSimpleType::Q => qb_t().into(), + SerSimpleType::I => usize_t().into(), + SerSimpleType::G(sig) => Type::new_function(*sig).into(), + SerSimpleType::Sum(st) => Type::from(st).into(), + SerSimpleType::Opaque(o) => Type::new_extension(o).into(), + SerSimpleType::Alias(_) => unimplemented!("Aliases are currently not supported"), + SerSimpleType::V { i, b } => Type::new_var_use(i, b).into(), + SerSimpleType::R { i, b } => Term::new_row_var_use(i, b), + } + } +} + +impl TryFrom for Type { + type Error = TermTypeError; + fn try_from(value: SerSimpleType) -> Result { - Ok(match value { - SerSimpleType::Q => qb_t().into_(), - SerSimpleType::I => usize_t().into_(), - SerSimpleType::G(sig) => TypeBase::new_function(*sig), - SerSimpleType::Sum(st) => st.into(), - SerSimpleType::Opaque(o) => TypeBase::new_extension(o), - SerSimpleType::Alias(a) => TypeBase::new_alias(a), - SerSimpleType::V { i, b } => TypeBase::new_var_use(i, b), - // We can't use new_row_var because that returns TypeRV not TypeBase. - SerSimpleType::R { i, b } => TypeBase::new(TypeEnum::RowVar( - RV::try_from_rv(RowVariable(i, b)) - .map_err(|var| SignatureError::RowVarWhereTypeExpected { var })?, - )), - }) + Term::from(value).try_into() } } @@ -138,7 +161,11 @@ impl From for TermSer { Term::FloatType => TermSer::TypeParam(TypeParamSer::Float), Term::ListType(param) => TermSer::TypeParam(TypeParamSer::List { param }), Term::ConstType(ty) => TermSer::TypeParam(TypeParamSer::ConstType { ty: *ty }), - Term::Runtime(ty) => TermSer::TypeArg(TypeArgSer::Type { ty }), + Term::RuntimeFunction(_) | Term::RuntimeExtension(_) | Term::RuntimeSum(_) => { + TermSer::TypeArg(TypeArgSer::Type { + ty: value.try_into().unwrap(), + }) + } Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { params: (*params).into(), }), @@ -148,7 +175,13 @@ impl From for TermSer { Term::Float(value) => TermSer::TypeArg(TypeArgSer::Float { value }), Term::List(elems) => TermSer::TypeArg(TypeArgSer::List { elems }), Term::Tuple(elems) => TermSer::TypeArg(TypeArgSer::Tuple { elems }), + Term::Variable(ref v) if matches!(&*v.cached_decl, Term::RuntimeType(_)) => { + TermSer::TypeArg(TypeArgSer::Type { + ty: value.try_into().unwrap(), + }) + } Term::Variable(v) => TermSer::TypeArg(TypeArgSer::Variable { v }), + Term::ListConcat(lists) => TermSer::TypeArg(TypeArgSer::ListConcat { lists }), Term::TupleConcat(tuples) => TermSer::TypeArg(TypeArgSer::TupleConcat { tuples }), } @@ -170,7 +203,7 @@ impl From for Term { TypeParamSer::ConstType { ty } => Term::ConstType(Box::new(ty)), }, TermSer::TypeArg(arg) => match arg { - TypeArgSer::Type { ty } => Term::Runtime(ty), + TypeArgSer::Type { ty } => Term::from(ty), TypeArgSer::BoundedNat { n } => Term::BoundedNat(n), TypeArgSer::String { arg } => Term::String(arg), TypeArgSer::Bytes { value } => Term::Bytes(value), @@ -233,3 +266,38 @@ mod base64 { .map_err(serde::de::Error::custom) } } + +impl serde::Serialize for TypeRowRV { + fn serialize(&self, serializer: S) -> Result { + let items: Vec = self + .0 + .clone() + .into_list_parts() + .map(|part| match part { + SeqPart::Item(t) => { + let t = Type::try_from(t).unwrap(); + let s = SerSimpleType::from(t); + assert!(!matches!(s, SerSimpleType::R { .. })); + s + } + SeqPart::Splice(t) => { + let s = SerSimpleType::try_from(t).unwrap(); + assert!(matches!(s, SerSimpleType::R { .. })); + s + } + }) + .collect(); + items.serialize(serializer) + } +} + +impl<'de> serde::Deserialize<'de> for TypeRowRV { + fn deserialize>(deser: D) -> Result { + let items: Vec = serde::Deserialize::deserialize(deser)?; + let list_parts = items.into_iter().map(|s| match s { + SerSimpleType::R { i, b } => SeqPart::Splice(Term::new_row_var_use(i, b)), + s => SeqPart::Item(Term::from(s)), + }); + Ok(TypeRowRV::try_from(Term::new_list_from_parts(list_parts)).unwrap()) + } +} diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 0a7200ed92..47e98d0529 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -2,107 +2,80 @@ use itertools::Either; -use std::borrow::Cow; use std::fmt::{self, Display}; use super::type_param::TypeParam; -use super::type_row::TypeRowBase; -use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeRow, TypeTransformer, -}; +use super::{Substitution, Transformable, Type, TypeRow, TypeTransformer}; use crate::core::PortIndex; use crate::extension::resolution::{ ExtensionCollectionError, WeakExtensionRegistry, collect_signature_exts, }; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; +use crate::types::{TypeRowLike, TypeRowRV}; use crate::{Direction, IncomingPort, OutgoingPort, Port}; -#[cfg(test)] -use {crate::proptest::RecursionDepth, proptest::prelude::*, proptest_derive::Arbitrary}; - -#[derive(Clone, Debug, Eq, Hash, serde::Serialize, serde::Deserialize)] -#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] /// Base type for listing inputs and output types. /// -/// The exact semantics depend on the use case: -/// - If `ROWVARS=`[`NoRV`], describes the edges required to/from a node or inside a [`FuncDefn`]. -/// - If `ROWVARS=`[`RowVariable`], describes the type of the inputs/outputs from an `OpDef`. -/// -/// `ROWVARS` specifies whether the type lists may contain [`RowVariable`]s or not. -/// -/// [`FuncDefn`]: crate::ops::FuncDefn -pub struct FuncTypeBase { +/// Parametrized by the type used to list the inputs and outputs. Exactly two +/// instantiations are used: [Signature] and [FuncValueType]. +pub struct FuncTypeBase { /// Value inputs of the function. - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - pub input: TypeRowBase, + pub input: T, /// Value outputs of the function. - #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] - pub output: TypeRowBase, + pub output: T, } /// The concept of "signature" in the spec - the edges required to/from a node /// or within a [`FuncDefn`], also the target (value) of a call (static). /// +/// Thus, contains a statically-known number of types. +/// /// [`FuncDefn`]: crate::ops::FuncDefn -pub type Signature = FuncTypeBase; +pub type Signature = FuncTypeBase; -/// A function that may contain [`RowVariable`]s and thus has potentially-unknown arity; +/// A function that may contain row variables and thus has potentially-unknown arity; /// used for [`OpDef`]'s and passable as a value round a Hugr (see [`Type::new_function`]) /// but not a valid node type. /// /// [`OpDef`]: crate::extension::OpDef -pub type FuncValueType = FuncTypeBase; +pub type FuncValueType = FuncTypeBase; -impl FuncTypeBase { - pub(crate) fn substitute(&self, tr: &Substitution) -> Self { +impl FuncTypeBase { + pub(crate) fn substitute(&self, subst: &Substitution) -> Self { Self { - input: self.input.substitute(tr), - output: self.output.substitute(tr), + input: self.input.substitute(subst), + output: self.output.substitute(subst), } } /// Create a new signature with specified inputs and outputs. - pub fn new(input: impl Into>, output: impl Into>) -> Self { + pub fn new(input: impl Into, output: impl Into) -> Self { Self { input: input.into(), output: output.into(), } } - /// Create a new signature with the same input and output types (signature of an endomorphic - /// function). - pub fn new_endo(row: impl Into>) -> Self { - let row = row.into(); - Self::new(row.clone(), row) - } - - /// True if both inputs and outputs are necessarily empty. - /// (For [`FuncValueType`], even after any possible substitution of row variables) - #[inline(always)] - #[must_use] - pub fn is_empty(&self) -> bool { - self.input.is_empty() && self.output.is_empty() - } - #[inline] /// Returns a row of the value inputs of the function. #[must_use] - pub fn input(&self) -> &TypeRowBase { + pub fn input(&self) -> &T { &self.input } #[inline] /// Returns a row of the value outputs of the function. #[must_use] - pub fn output(&self) -> &TypeRowBase { + pub fn output(&self) -> &T { &self.output } #[inline] /// Returns a tuple with the input and output rows of the function. #[must_use] - pub fn io(&self) -> (&TypeRowBase, &TypeRowBase) { + pub fn io(&self) -> (&T, &T) { (&self.input, &self.output) } @@ -110,7 +83,20 @@ impl FuncTypeBase { self.input.validate(var_decls)?; self.output.validate(var_decls) } +} + +impl FuncTypeBase { + /// Create a new signature with the same input and output types. + pub fn new_endo(io: impl Into) -> Self { + let io = io.into(); + Self { + input: io.clone(), + output: io, + } + } +} +impl Signature { /// Returns a registry with the concrete extensions used by this signature. pub fn used_extensions(&self) -> Result { let mut used = WeakExtensionRegistry::default(); @@ -124,16 +110,34 @@ impl FuncTypeBase { Err(ExtensionCollectionError::dropped_signature(self, missing)) } } + + /// True if both inputs and outputs are necessarily empty. + /// (For [`FuncValueType`], even after any possible substitution of row variables) + #[inline(always)] + #[must_use] + pub fn is_empty(&self) -> bool { + self.input.is_empty() && self.output.is_empty() + } } -impl Transformable for FuncTypeBase { - fn transform(&mut self, tr: &T) -> Result { +impl FuncValueType { + /// True if both inputs and outputs are necessarily empty + /// (even after any possible substitution of row variables) + #[inline(always)] + #[must_use] + pub fn is_empty(&self) -> bool { + self.input.is_empty() && self.output.is_empty() + } +} + +impl Transformable for FuncTypeBase { + fn transform(&mut self, tr: &U) -> Result { // TODO handle extension sets? Ok(self.input.transform(tr)? | self.output.transform(tr)?) } } -impl FuncValueType { +/*impl FuncValueType { /// If this `FuncValueType` contains any row variables, return one. #[must_use] pub fn find_rowvar(&self) -> Option { @@ -142,17 +146,7 @@ impl FuncValueType { .chain(self.output.iter()) .find_map(|t| Type::try_from(t.clone()).err()) } -} - -// deriving Default leads to an impl that only applies for RV: Default -impl Default for FuncTypeBase { - fn default() -> Self { - Self { - input: Default::default(), - output: Default::default(), - } - } -} +}*/ impl Signature { /// Returns the type of a value [`Port`]. Returns `None` if the port is out @@ -274,7 +268,7 @@ impl Signature { } } -impl Display for FuncTypeBase { +impl Display for FuncTypeBase { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.input.fmt(f)?; f.write_str(" -> ")?; @@ -301,31 +295,40 @@ impl From for FuncValueType { } } -impl PartialEq> for FuncTypeBase { - fn eq(&self, other: &FuncTypeBase) -> bool { +impl PartialEq for Signature { + fn eq(&self, other: &FuncValueType) -> bool { self.input == other.input && self.output == other.output } } -impl PartialEq>> for FuncTypeBase { - fn eq(&self, other: &Cow<'_, FuncTypeBase>) -> bool { - self.eq(other.as_ref()) - } -} - -impl PartialEq> for Cow<'_, FuncTypeBase> { - fn eq(&self, other: &FuncTypeBase) -> bool { - self.as_ref().eq(other) - } -} - #[cfg(test)] mod test { use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::type_row; - use crate::types::{CustomType, TypeEnum, test::FnTransformer}; + use crate::types::{CustomType, Term, test::FnTransformer}; use super::*; + + mod proptest { + use proptest::prelude::{Arbitrary, BoxedStrategy, Strategy, any_with}; + + use super::FuncTypeBase; + use crate::{proptest::RecursionDepth, types::TypeRowLike}; + + impl + 'static> Arbitrary + for FuncTypeBase + { + type Parameters = RecursionDepth; + type Strategy = BoxedStrategy; + + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + (any_with::(params), any_with::(params)) + .prop_map(|(input, output)| Self::new(input, output)) + .boxed() + } + } + } + #[test] fn test_function_type() { let mut f_type = Signature::new(type_row![Type::UNIT], type_row![Type::UNIT]); @@ -354,7 +357,7 @@ mod test { #[test] fn test_transform() { - let TypeEnum::Extension(usz_t) = usize_t().as_type_enum().clone() else { + let Term::RuntimeExtension(usz_t) = usize_t().into() else { panic!() }; let tr = FnTransformer(|ct: &CustomType| (ct == &usz_t).then_some(bool_t())); diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index c9ddcc81b9..4da9546222 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -4,6 +4,7 @@ //! //! [`TypeDef`]: crate::extension::TypeDef +use itertools::Itertools as _; use ordered_float::OrderedFloat; #[cfg(test)] use proptest_derive::Arbitrary; @@ -14,12 +15,9 @@ use std::sync::Arc; use thiserror::Error; use tracing::warn; -use super::row_var::MaybeRV; -use super::{ - NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, TypeTransformer, - check_typevar_decl, -}; +use super::{Substitution, Transformable, Type, TypeBound, TypeRowLike, TypeTransformer}; use crate::extension::SignatureError; +use crate::types::{CustomType, FuncValueType, SumType}; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] // A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) @@ -58,7 +56,20 @@ pub type TypeArg = Term; /// A [`Term`] that is the static type of an operation or constructor parameter. pub type TypeParam = Term; -/// A term in the language of static parameters in HUGR. +/// The main entity in the static language (aka "type system") of Hugr. +/// +/// Terms include types (i.e. which describe sets of runtime values) +/// but also other compile-time entities which can be used to parametrize +/// and instantiate functions, ops, and types. (For example, array lengths +/// are not types but they are static parameters of array types and ops.) +/// +/// Terms are used for both parameter declarations and arguments fitting those +/// parameters, e.g. a [Term::FloatType] parameter would be instantiated (statically) +/// with a [Term::Float] argument. [`check_term_type`] checks that an argument +/// is valid (of the correct kind) for the parameter. +// TODO it might be good to have a separate function that tells, for a given Term, +// whether there is *any* valid argument; we could then rule out using as parameters +// any Term for which there are no valid arguments. #[derive( Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] @@ -95,9 +106,23 @@ pub enum Term { /// The type of static tuples. #[display("TupleType[{_0}]")] TupleType(Box), - /// A runtime type as a term. Instance of [`Term::RuntimeType`]. + /// The type of runtime values defined by an extension type. + /// Instance of [Self::RuntimeType] for some bound. + // + // TODO optimise with `Box`? + // or some static version of this? #[display("{_0}")] - Runtime(Type), + RuntimeExtension(CustomType), + /// The type of runtime values that are function pointers. + /// Instance of [Self::RuntimeType]`(`[TypeBound::Copyable]`)`. + /// Function values may be passed around without knowing their arity + /// (i.e. with row vars) as long as they are not called. + #[display("{_0}")] + RuntimeFunction(Box), + /// The type of runtime values that are sums of products (ADTs) + /// Instance of [Self::RuntimeType]`(bound)` for `bound` calculated from each variant's elements. + #[display("{_0}")] + RuntimeSum(SumType), /// A 64bit unsigned integer literal. Instance of [`Term::BoundedNatType`]. #[display("{_0}")] BoundedNat(u64), @@ -111,10 +136,7 @@ pub enum Term { #[display("{}", _0.into_inner())] Float(OrderedFloat), /// A list of static terms. Instance of [`Term::ListType`]. - #[display("[{}]", { - use itertools::Itertools as _; - _0.iter().map(|t|t.to_string()).join(",") - })] + #[display("[{}]", _0.iter().map(|t|t.to_string()).join(", "))] List(Vec), /// Instance of [`TypeParam::List`] defined by a sequence of concatenated lists of the same type. #[display("[{}]", { @@ -152,6 +174,9 @@ pub enum Term { } impl Term { + /// An empty list of Terms. + pub const EMPTY_LIST: Self = Self::List(vec![]); + /// Creates a [`Term::BoundedNatType`] with the maximum bound (`u64::MAX` + 1). #[must_use] pub const fn max_nat_type() -> Self { @@ -165,8 +190,8 @@ impl Term { } /// Creates a new [`Term::List`] given a sequence of its items. - pub fn new_list(items: impl IntoIterator) -> Self { - Self::List(items.into_iter().collect()) + pub fn new_list>(items: impl IntoIterator) -> Self { + Self::List(items.into_iter().map_into().collect()) } /// Creates a new [`Term::ListType`] given the type of its elements. @@ -197,24 +222,57 @@ impl Term { (Term::StringType, Term::StringType) => true, (Term::StaticType, Term::StaticType) => true, (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2), + // The term inside a TupleType is a list of types, so this is ok as long as + // supertype holds element-wise (Term::TupleType(es1), Term::TupleType(es2)) => es1.is_supertype(es2), (Term::BytesType, Term::BytesType) => true, (Term::FloatType, Term::FloatType) => true, - (Term::Runtime(t1), Term::Runtime(t2)) => t1 == t2, + // Needed for TupleType, does not make a great deal of sense otherwise: + (Term::List(es1), Term::List(es2)) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) + } + // The following are not types (they have no instances), so these are just to + // maintain reflexivity of the relation: + (Term::RuntimeSum(t1), Term::RuntimeSum(t2)) => t1 == t2, + (Term::RuntimeFunction(f1), Term::RuntimeFunction(f2)) => f1 == f2, + (Term::RuntimeExtension(c1), Term::RuntimeExtension(c2)) => c1 == c2, (Term::BoundedNat(n1), Term::BoundedNat(n2)) => n1 == n2, (Term::String(s1), Term::String(s2)) => s1 == s2, (Term::Bytes(v1), Term::Bytes(v2)) => v1 == v2, (Term::Float(f1), Term::Float(f2)) => f1 == f2, (Term::Variable(v1), Term::Variable(v2)) => v1 == v2, - (Term::List(es1), Term::List(es2)) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) - } (Term::Tuple(es1), Term::Tuple(es2)) => { es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) } _ => false, } } + + /// Returns true if this term is an empty list (contains no elements) + pub(super) fn is_empty_list(&self) -> bool { + match self { + Term::List(v) => v.is_empty(), + // We probably don't need to be this thorough in dealing with unnormalized forms but it's easy enough + Term::ListConcat(v) => v.iter().all(Term::is_empty_list), + _ => false, + } + } + + /// Returns the inner [`CustomType`] if this `Term` is a [Self::RuntimeExtension] + pub fn as_extension(&self) -> Option<&CustomType> { + match self { + Term::RuntimeExtension(ct) => Some(ct), + _ => None, + } + } + + /// Returns the inner [`SumType`] if this `Term` is a [Self::RuntimeSum]. + pub fn as_sum(&self) -> Option<&SumType> { + match self { + Term::RuntimeSum(s) => Some(s), + _ => None, + } + } } impl From for Term { @@ -229,15 +287,6 @@ impl From for Term { } } -impl From> for Term { - fn from(value: TypeBase) -> Self { - match value.try_into_type() { - Ok(ty) => Term::Runtime(ty), - Err(RowVariable(idx, bound)) => Term::new_var_use(idx, TypeParam::new_list_type(bound)), - } - } -} - impl From for Term { fn from(n: u64) -> Self { Self::BoundedNat(n) @@ -268,8 +317,7 @@ impl From<[Term; N]> for Term { } } -/// Variable in a [`Term`], that is not a single runtime type (i.e. not a [`Type::new_var_use`] -/// - it might be a [`Type::new_row_var_use`]). +/// Variable in a [`Term`], i.e. contents of a [`Term::Variable`]. #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, )] @@ -280,24 +328,23 @@ pub struct TermVar { } impl Term { - /// [`Type::UNIT`] as a [`Term::Runtime`] - pub const UNIT: Self = Self::Runtime(Type::UNIT); - /// Makes a `TypeArg` representing a use (occurrence) of the type variable /// with the specified index. + /// /// `decl` must be exactly that with which the variable was declared. #[must_use] - pub fn new_var_use(idx: usize, decl: Term) -> Self { - match decl { - // Note a TypeParam::List of TypeParam::Type *cannot* be represented - // as a TypeArg::Type because the latter stores a Type i.e. only a single type, - // not a RowVariable. - Term::RuntimeType(b) => Type::new_var_use(idx, b).into(), - _ => Term::Variable(TermVar { - idx, - cached_decl: Box::new(decl), - }), - } + pub fn new_var_use(idx: usize, decl: impl Into) -> Self { + Term::Variable(TermVar { + idx, + cached_decl: Box::new(decl.into()), + }) + } + + /// Makes a `Term` representing a use (occurrence) of a variable whose + /// kind is a [Term::ListType] of [Term::RuntimeType]. + #[must_use] + pub fn new_row_var_use(idx: usize, b: TypeBound) -> Self { + Self::new_var_use(idx, Term::new_list_type(b)) } /// Creates a new string literal. @@ -306,10 +353,15 @@ impl Term { Self::String(str.to_string()) } - /// Creates a new concatenated list. + /// Creates or returns a term equivalent to concatenating a number of lists. + /// + /// If there is only one list, returns it directly. #[inline] - pub fn new_list_concat(lists: impl IntoIterator) -> Self { - Self::ListConcat(lists.into_iter().collect()) + pub fn concat_lists(lists: impl IntoIterator) -> Self { + match lists.into_iter().exactly_one() { + Ok(list) => list, + Err(e) => Self::ListConcat(e.collect()), + } } /// Creates a new tuple from its items. @@ -333,15 +385,34 @@ impl Term { } } - /// Returns a [`Type`] if the [`Term`] is a runtime type. - #[must_use] - pub fn as_runtime(&self) -> Option> { + pub(crate) fn least_upper_bound(&self) -> Option { match self { - TypeArg::Runtime(ty) => Some(ty.clone()), + Self::RuntimeExtension(ct) => Some(ct.bound()), + Self::RuntimeSum(st) => Some(st.bound()), + Self::RuntimeFunction(_) => Some(TypeBound::Copyable), + Self::Variable(v) => match &*v.cached_decl { + TypeParam::RuntimeType(b) => Some(*b), + _ => None, + }, _ => None, } } + /// Report if this is a copyable runtime type, i.e. an instance + /// of [Self::RuntimeType]`(`[TypeBound::Copyable]`)` + // where the least upper bound of the type is contained by the copyable bound. + pub(crate) fn copyable(&self) -> bool { + self.least_upper_bound() + .is_some_and(|b| TypeBound::Copyable.contains(b)) + } + + /// Report if this is a runtime type, i.e. an instance of [Self::RuntimeType] for some bound. + /// + /// If so, [Type::try_from(Type)] will succeed and can be followed by [Type::least_upper_bound] to get the bound. + pub fn is_runtime_type(&self) -> bool { + self.least_upper_bound().is_some() + } + /// Returns a string if the [`Term`] is a string literal. #[must_use] pub fn as_string(&self) -> Option { @@ -351,31 +422,25 @@ impl Term { } } - /// Much as [`Type::validate`], also checks that the type of any [`TypeArg::Opaque`] - /// is valid and closed. + /// Checks variables are as declared and [CustomType] arguments fit their parameters. + /// Does not check that e.g. list elements all have same type (except inside a + /// [CustomType] where we know the element type from the corresponding list parameter) + /// - this is left to [check_term_type]. pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { match self { - Term::Runtime(ty) => ty.validate(var_decls), - Term::List(elems) => { - // TODO: Full validation would check that the type of the elements agrees - elems.iter().try_for_each(|a| a.validate(var_decls)) + Term::RuntimeSum(SumType::General { rows }) => { + rows.iter().try_for_each(|row| row.validate(var_decls))?; + Ok(()) } + Term::RuntimeSum(SumType::Unit { .. }) => Ok(()), // No leaves there + Term::RuntimeExtension(custy) => custy.validate(var_decls), + Term::RuntimeFunction(ft) => ft.validate(var_decls), + Term::List(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), - TypeArg::ListConcat(lists) => { - // TODO: Full validation would check that each of the lists is indeed a - // list or list variable of the correct types. - lists.iter().try_for_each(|a| a.validate(var_decls)) - } + TypeArg::ListConcat(lists) => lists.iter().try_for_each(|a| a.validate(var_decls)), TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)), - Term::Variable(TermVar { idx, cached_decl }) => { - assert!( - !matches!(&**cached_decl, TypeParam::RuntimeType { .. }), - "Malformed TypeArg::Variable {cached_decl} - should be inconstructible" - ); - - check_typevar_decl(var_decls, *idx, cached_decl) - } + Term::Variable(tv) => tv.check_decl(var_decls), Term::RuntimeType { .. } => Ok(()), Term::BoundedNatType { .. } => Ok(()), Term::StringType => Ok(()), @@ -390,36 +455,19 @@ impl Term { pub(crate) fn substitute(&self, t: &Substitution) -> Self { match self { - Term::Runtime(ty) => { - // RowVariables are represented as Term::Variable - ty.substitute1(t).into() + TypeArg::RuntimeSum(SumType::Unit { .. }) => self.clone(), + TypeArg::RuntimeSum(SumType::General { rows }) => { + // A substitution of a row variable for an empty list, + // could make the general case into a unary SumType. + Term::RuntimeSum(SumType::new(rows.iter().map(|r| r.substitute(t)))) } + TypeArg::RuntimeExtension(cty) => Term::RuntimeExtension(cty.substitute(t)), + TypeArg::RuntimeFunction(bf) => Term::RuntimeFunction(Box::new(bf.substitute(t))), + TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => { self.clone() } // We do not allow variables as bounds on BoundedNat's - TypeArg::List(elems) => { - // NOTE: This implements a hack allowing substitutions to - // replace `TypeArg::Variable`s representing "row variables" - // with a list that is to be spliced into the containing list. - // We won't need this code anymore once we stop conflating types - // with lists of types. - - fn is_type(type_arg: &TypeArg) -> bool { - match type_arg { - TypeArg::Runtime(_) => true, - TypeArg::Variable(v) => v.bound_if_row_var().is_some(), - _ => false, - } - } - - let are_types = elems.first().map(is_type).unwrap_or(false); - - Self::new_list_from_parts(elems.iter().map(|elem| match elem.substitute(t) { - list @ TypeArg::List { .. } if are_types => SeqPart::Splice(list), - list @ TypeArg::ListConcat { .. } if are_types => SeqPart::Splice(list), - elem => SeqPart::Item(elem), - })) - } + TypeArg::List(elems) => Self::List(elems.iter().map(|e| e.substitute(t)).collect()), TypeArg::ListConcat(lists) => { // When a substitution instantiates spliced list variables, we // may be able to merge the concatenated lists. @@ -446,9 +494,9 @@ impl Term { Term::BytesType => self.clone(), Term::FloatType => self.clone(), Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), - Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)), + Term::TupleType(item_types) => Term::new_tuple_type(item_types.substitute(t)), Term::StaticType => self.clone(), - Term::ConstType(ty) => Term::new_const(ty.substitute1(t)), + Term::ConstType(ty) => Term::new_const(ty.substitute(t)), } } @@ -488,7 +536,7 @@ impl Term { Self::new_seq_from_parts( parts.into_iter().flat_map(ListPartIter::new), TypeArg::List, - TypeArg::ListConcat, + TypeArg::concat_lists, ) } @@ -518,7 +566,7 @@ impl Term { /// # let b = Term::new_string("b"); /// # let c = Term::new_string("c"); /// let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); - /// let term = Term::new_list_concat([ + /// let term = Term::concat_lists([ /// Term::new_list([a.clone(), b.clone()]), /// var.clone(), /// Term::new_list([c.clone()]) @@ -537,12 +585,12 @@ impl Term { /// # let a = Term::new_string("a"); /// # let b = Term::new_string("b"); /// # let c = Term::new_string("c"); - /// let term = Term::new_list_concat([ - /// Term::new_list_concat([ + /// let term = Term::concat_lists([ + /// Term::concat_lists([ /// Term::new_list([a.clone()]), /// Term::new_list([b.clone()]) /// ]), - /// Term::new_list([]), + /// Term::EMPTY_LIST, /// Term::new_list([c.clone()]) /// ]); /// @@ -565,7 +613,7 @@ impl Term { /// ); /// ``` #[inline] - pub fn into_list_parts(self) -> ListPartIter { + pub fn into_list_parts(self) -> impl Iterator> { ListPartIter::new(SeqPart::Splice(self)) } @@ -584,7 +632,7 @@ impl Term { /// /// Analogous to [`TypeArg::into_list_parts`]. #[inline] - pub fn into_tuple_parts(self) -> TuplePartIter { + pub fn into_tuple_parts(self) -> impl Iterator> { TuplePartIter::new(SeqPart::Splice(self)) } } @@ -592,7 +640,22 @@ impl Term { impl Transformable for Term { fn transform(&mut self, tr: &T) -> Result { match self { - Term::Runtime(ty) => ty.transform(tr), + Term::RuntimeExtension(custom_type) => { + if let Some(nt) = tr.apply_custom(custom_type)? { + *self = nt.0; + Ok(true) + } else { + let args_changed = custom_type.args_mut().transform(tr)?; + if args_changed { + *custom_type = custom_type + .get_type_def(&custom_type.get_extension()?)? + .instantiate(custom_type.args())?; + } + Ok(args_changed) + } + } + Term::RuntimeFunction(fty) => fty.transform(tr), + Term::RuntimeSum(sum_type) => sum_type.transform(tr), Term::List(elems) => elems.transform(tr), Term::Tuple(elems) => elems.transform(tr), Term::BoundedNat(_) @@ -633,6 +696,31 @@ impl TermVar { } None } + + /// Check that the cached declaration of this variable matches the actual one (provided). + /// + /// The cache just mirrors the declaration; the typevar can be used anywhere expecting + /// a kind containing the decl - see [check_term_type] / [Term::is_supertype]. + fn check_decl(&self, decls: &[TypeParam]) -> Result<(), SignatureError> { + let idx = self.idx; + let cached_decl: &TypeParam = &self.cached_decl; + match decls.get(idx) { + None => Err(SignatureError::FreeTypeVar { + idx, + num_decls: decls.len(), + }), + Some(actual) => { + if actual == cached_decl { + Ok(()) + } else { + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + cached: Box::new(cached_decl.clone()), + actual: Box::new(actual.clone()), + }) + } + } + } + } } /// Checks that a [`Term`] is valid for a given type. @@ -641,24 +729,17 @@ pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> { (Term::Variable(TermVar { cached_decl, .. }), _) if type_.is_supertype(cached_decl) => { Ok(()) } - (Term::Runtime(ty), Term::RuntimeType(bound)) if bound.contains(ty.least_upper_bound()) => { + (Term::RuntimeSum(st), Term::RuntimeType(bound)) if bound.contains(st.bound()) => Ok(()), + (Term::RuntimeFunction(_), Term::RuntimeType(_)) => Ok(()), // Function pointers are always Copyable so fit any bound + (Term::RuntimeExtension(cty), Term::RuntimeType(bound)) if bound.contains(cty.bound()) => { Ok(()) } - (Term::List(elems), Term::ListType(item_type)) => { - elems.iter().try_for_each(|term| { - // Also allow elements that are RowVars if fitting into a List of Types - if let (Term::Variable(v), Term::RuntimeType(param_bound)) = (term, &**item_type) - && v.bound_if_row_var() - .is_some_and(|arg_bound| param_bound.contains(arg_bound)) - { - return Ok(()); - } - check_term_type(term, item_type) - }) - } - (Term::ListConcat(lists), Term::ListType(item_type)) => lists + (Term::List(elems), Term::ListType(item_type)) => elems .iter() - .try_for_each(|list| check_term_type(list, item_type)), + .try_for_each(|elem| check_term_type(elem, item_type)), + (Term::ListConcat(lists), Term::ListType(_)) => lists + .iter() + .try_for_each(|list| check_term_type(list, type_)), (TypeArg::Tuple(_) | TypeArg::TupleConcat(_), TypeParam::TupleType(item_types)) => { let term_parts: Vec<_> = term.clone().into_tuple_parts().collect(); let type_parts: Vec<_> = item_types.clone().into_list_parts().collect(); @@ -762,7 +843,7 @@ pub enum SeqPart { /// Iterator created by [`TypeArg::into_list_parts`]. #[derive(Debug, Clone)] -pub struct ListPartIter { +pub(crate) struct ListPartIter { parts: SmallVec<[SeqPart; 1]>, } @@ -797,7 +878,7 @@ impl FusedIterator for ListPartIter {} /// Iterator created by [`TypeArg::into_tuple_parts`]. #[derive(Debug, Clone)] -pub struct TuplePartIter { +pub(crate) struct TuplePartIter { parts: SmallVec<[SeqPart; 1]>, } @@ -832,13 +913,10 @@ impl FusedIterator for TuplePartIter {} #[cfg(test)] mod test { - use itertools::Itertools; - use super::{Substitution, TypeArg, TypeParam, check_term_type}; use crate::extension::prelude::{bool_t, usize_t}; - use crate::types::Term; use crate::types::type_param::SeqPart; - use crate::types::{TypeBound, TypeRV, type_param::TermTypeError}; + use crate::types::{Term, Type, TypeBound, TypeRow, type_param::TermTypeError}; #[test] fn new_list_from_parts_items() { @@ -868,13 +946,13 @@ mod test { let var = Term::new_var_use(0, Term::new_list_type(Term::StringType)); let parts = [ SeqPart::Splice(Term::new_list([a.clone(), b.clone()])), - SeqPart::Splice(Term::new_list_concat([Term::new_list([c.clone()])])), + SeqPart::Splice(Term::concat_lists([Term::new_list([c.clone()])])), SeqPart::Item(d.clone()), SeqPart::Splice(var.clone()), ]; assert_eq!( Term::new_list_from_parts(parts), - Term::new_list_concat([Term::new_list([a, b, c, d]), var]) + Term::concat_lists([Term::new_list([a, b, c, d]), var]) ); } @@ -899,7 +977,7 @@ mod test { #[test] fn type_arg_fits_param() { - let rowvar = TypeRV::new_row_var_use; + let rowvar = Term::new_row_var_use; fn check(arg: impl Into, param: &TypeParam) -> Result<(), TermTypeError> { check_term_type(&arg.into(), param) } @@ -907,43 +985,50 @@ mod test { args: &[T], param: &TypeParam, ) -> Result<(), TermTypeError> { - let arg = args.iter().cloned().map_into().collect_vec().into(); - check_term_type(&arg, param) + check_term_type(&Term::new_list(args.to_vec()), param) } - // Simple cases: a Term::Type is a Term::RuntimeType but singleton sequences are lists + // Simple cases: Term::RuntimeXXXs are Term::RuntimeType's check(usize_t(), &TypeBound::Copyable.into()).unwrap(); - let seq_param = TypeParam::new_list_type(TypeBound::Copyable); - check(usize_t(), &seq_param).unwrap_err(); + let lst_of_cpy = TypeParam::new_list_type(TypeBound::Copyable); + check(usize_t(), &lst_of_cpy).unwrap_err(); + // ...but singleton sequences thereof are lists check_seq(&[usize_t()], &TypeBound::Linear.into()).unwrap_err(); // Into a list of type, we can fit a single row var - check(rowvar(0, TypeBound::Copyable), &seq_param).unwrap(); - // or a list of (types or row vars) - check(vec![], &seq_param).unwrap(); - check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap(); - check_seq( - &[ + check(rowvar(0, TypeBound::Copyable), &lst_of_cpy).unwrap(); + // or a list of types, or a "concat" of row vars + check(Term::new_list([usize_t()]), &lst_of_cpy).unwrap(); + check( + Term::ListConcat(vec![rowvar(0, TypeBound::Copyable); 2]), + &lst_of_cpy, + ) + .unwrap(); + check( + Term::concat_lists([ rowvar(1, TypeBound::Linear), - usize_t().into(), + Term::new_list([usize_t()]), rowvar(0, TypeBound::Copyable), - ], + ]), &TypeParam::new_list_type(TypeBound::Linear), ) .unwrap(); - // Next one fails because a list of Eq is required - check_seq( - &[ + // but a *list* of the rowvar is a list of list of types, which is wrong + check_seq(&[rowvar(0, TypeBound::Copyable)], &lst_of_cpy).unwrap_err(); + + // Next one fails because a list of Copyable is required + check( + Term::concat_lists([ rowvar(1, TypeBound::Linear), - usize_t().into(), + Term::new_list([usize_t()]), rowvar(0, TypeBound::Copyable), - ], - &seq_param, + ]), + &lst_of_cpy, ) .unwrap_err(); // seq of seq of types is not allowed check( - vec![usize_t().into(), vec![usize_t().into()].into()], - &seq_param, + vec![Term::from(usize_t()), Term::new_list([usize_t()])], + &lst_of_cpy, ) .unwrap_err(); @@ -963,7 +1048,7 @@ mod test { // `Term::TupleType` requires a `Term::Tuple` of the same number of elems let usize_and_ty = - TypeParam::new_tuple_type([TypeParam::max_nat_type(), TypeBound::Copyable.into()]); + TypeParam::new_tuple_type([TypeParam::max_nat_type(), Term::from(TypeBound::Copyable)]); check( TypeArg::Tuple(vec![5.into(), usize_t().into()]), &usize_and_ty, @@ -974,34 +1059,33 @@ mod test { &usize_and_ty, ) .unwrap_err(); // Wrong way around - let two_types = TypeParam::new_tuple_type(Term::new_list([ - TypeBound::Linear.into(), - TypeBound::Linear.into(), - ])); + + let two_types = + Term::new_tuple_type(Term::new_list([TypeBound::Linear, TypeBound::Linear])); check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap(); // not a Row Var which could have any number of elems - check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err(); + check(TypeArg::new_var_use(0, lst_of_cpy), &two_types).unwrap_err(); } #[test] fn type_arg_subst_row() { let row_param = Term::new_list_type(TypeBound::Copyable); - let row_arg: Term = vec![bool_t().into(), Term::UNIT].into(); + let row_arg: Term = Term::new_list([bool_t(), Type::UNIT]); check_term_type(&row_arg, &row_param).unwrap(); // Now say a row variable referring to *that* row was used // to instantiate an outer "row parameter" (list of type). let outer_param = Term::new_list_type(TypeBound::Linear); - let outer_arg = Term::new_list([ - TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), - usize_t().into(), + let outer_arg = Term::concat_lists([ + Term::new_row_var_use(0, TypeBound::Copyable), + Term::new_list([usize_t()]), ]); check_term_type(&outer_arg, &outer_param).unwrap(); let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg])); assert_eq!( outer_arg2, - vec![bool_t().into(), Term::UNIT, usize_t().into()].into() + Term::new_list([bool_t(), Type::UNIT, usize_t()]) ); // Of course this is still valid (as substitution is guaranteed to preserve validity) @@ -1015,9 +1099,9 @@ mod test { let row_var_use = Term::new_var_use(0, row_var_decl.clone()); let good_arg = Term::new_list([ // The row variables here refer to `row_var_decl` above - vec![usize_t().into()].into(), + Term::new_list([usize_t()]), row_var_use.clone(), - vec![row_var_use, usize_t().into()].into(), + Term::concat_lists([row_var_use, Term::new_list([usize_t()])]), ]); check_term_type(&good_arg, &outer_param).unwrap(); @@ -1025,7 +1109,8 @@ mod test { let Term::List(mut elems) = good_arg.clone() else { panic!() }; - elems.push(usize_t().into()); + let t: Term = usize_t().into(); + elems.push(t); assert_eq!( check_term_type(&Term::new_list(elems), &outer_param), Err(TermTypeError::TypeMismatch { @@ -1036,20 +1121,33 @@ mod test { ); // Now substitute a list of two types for that row-variable - let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); + let row_var_arg = Term::new_list([usize_t(), bool_t()]); check_term_type(&row_var_arg, &row_var_decl).unwrap(); let subst_arg = good_arg.substitute(&Substitution(std::slice::from_ref(&row_var_arg))); check_term_type(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, Term::new_list([ - Term::new_list([usize_t().into()]), + Term::new_list([usize_t()]), row_var_arg, - Term::new_list([usize_t().into(), bool_t().into(), usize_t().into()]) + Term::new_list([usize_t(), bool_t(), usize_t()]) ]) ); } + #[test] + fn test_try_into_list_elements() { + // Test successful conversion with List + let types = vec![Type::new_unit_sum(1), bool_t()]; + let term = TypeArg::new_list(types.clone()); + let result = TypeRow::try_from(term); + assert_eq!(result, Ok(TypeRow::from(types))); + + // Test failure with non-list + let result = TypeRow::try_from(Term::from(Type::UNIT)); + assert!(result.is_err()); + } + #[test] fn bytes_json_roundtrip() { let bytes_arg = Term::Bytes(vec![0, 1, 2, 3, 255, 254, 253, 252].into()); @@ -1058,13 +1156,41 @@ mod test { assert_eq!(deserialized, bytes_arg); } - mod proptest { + #[test] + fn list_from_single_part_item() { + // arbitrary, but not worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!( + Term::List(vec![term.clone()]), + Term::new_list_from_parts(std::iter::once(SeqPart::Item(term))) + ); + } + + #[test] + fn list_from_single_part_splice() { + // arbitrary, but not worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!( + term.clone(), + Term::new_list_from_parts(std::iter::once(SeqPart::Splice(term))) + ); + } + #[test] + fn list_concat_single_item() { + // arbitrary, but not worth cost of trying everything in a proptest + let term = Term::new_list([Term::new_string("foo")]); + assert_eq!(term.clone(), Term::concat_lists([term])); + } + + mod proptest { + use prop::{collection::vec, strategy::Union}; use proptest::prelude::*; use super::super::{TermVar, UpperBound}; use crate::proptest::RecursionDepth; - use crate::types::{Term, Type, TypeBound, proptest_utils::any_serde_type_param}; + use crate::types::proptest_utils::any_serde_type_param; + use crate::types::{Term, Type, TypeBound}; impl Arbitrary for TermVar { type Parameters = RecursionDepth; @@ -1083,9 +1209,7 @@ mod test { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - use prop::collection::vec; - use prop::strategy::Union; - let mut strat = Union::new([ + let strat = Union::new([ Just(Self::StringType).boxed(), Just(Self::BytesType).boxed(), Just(Self::FloatType).boxed(), @@ -1100,32 +1224,29 @@ mod test { any::() .prop_map(|value| Self::Float(value.into())) .boxed(), - any_with::(depth).prop_map(Self::from).boxed(), + any_with::(depth).prop_map_into().boxed(), ]); - if !depth.leaf() { - // we descend here because we these constructors contain Terms - strat = strat - .or( - // TODO this is a bit dodgy, TypeArgVariables are supposed - // to be constructed from TypeArg::new_var_use. We are only - // using this instance for serialization now, but if we want - // to generate valid TypeArgs this will need to change. - any_with::(depth.descend()) - .prop_map(Self::Variable) - .boxed(), - ) - .or(any_with::(depth.descend()) - .prop_map(Self::new_list_type) - .boxed()) - .or(any_with::(depth.descend()) - .prop_map(Self::new_tuple_type) - .boxed()) - .or(vec(any_with::(depth.descend()), 0..3) - .prop_map(Self::new_list) - .boxed()); + if depth.leaf() { + return strat.boxed(); } - - strat.boxed() + // we descend here because we these constructors contain Terms + let depth = depth.descend(); + strat + .or( + // TODO this means we have two ways to create variables of type + // `RuntimeType`, so we probably get more of them than we should` + any_with::(depth).prop_map(Self::Variable).boxed(), + ) + .or(any_with::(depth) + .prop_map(Self::new_list_type) + .boxed()) + .or(any_with::(depth) + .prop_map(Self::new_tuple_type) + .boxed()) + .or(vec(any_with::(depth), 0..3) + .prop_map(Self::new_list) + .boxed()) + .boxed() } } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index db9314ff66..44dde2a059 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -7,43 +7,33 @@ use std::{ ops::{Deref, DerefMut}, }; -use super::{ - MaybeRV, NoRV, RowVariable, Substitution, Term, Transformable, Type, TypeArg, TypeBase, TypeRV, - TypeTransformer, type_param::TypeParam, +use super::{Substitution, Term, Transformable, Type, TypeTransformer, type_param::TypeParam}; +use crate::{ + extension::SignatureError, + types::{ + TypeBound, + type_param::{TermTypeError, check_term_type}, + }, + utils::display_list, }; -use crate::{extension::SignatureError, utils::display_list}; use delegate::delegate; +use derive_more::Display; use itertools::Itertools; -/// List of types, used for function signatures. -/// The `ROWVARS` parameter controls whether this may contain [`RowVariable`]s -#[derive(Clone, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] +/// List of types, of known length, used for node signatures. +/// +/// Also allows sharing via `Cow` and static allocation via [type_row!]. +/// +/// [type_row!]: crate::type_row +#[derive(Clone, PartialEq, Eq, Debug, Hash, serde::Serialize, serde::Deserialize)] #[non_exhaustive] #[serde(transparent)] -pub struct TypeRowBase { +pub struct TypeRow { /// The datatypes in the row. - types: Cow<'static, [TypeBase]>, + types: Cow<'static, [Type]>, } -/// Row of single types i.e. of known length, for node inputs/outputs -pub type TypeRow = TypeRowBase; - -/// Row of types and/or row variables, the number of actual types is thus -/// unknown -pub type TypeRowRV = TypeRowBase; - -impl PartialEq> for TypeRowBase { - fn eq(&self, other: &TypeRowBase) -> bool { - self.types.len() == other.types.len() - && self - .types - .iter() - .zip(other.types.iter()) - .all(|(s, o)| s == o) - } -} - -impl Display for TypeRowBase { +impl Display for TypeRow { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_char('[')?; display_list(self.types.as_ref(), f)?; @@ -51,7 +41,7 @@ impl Display for TypeRowBase { } } -impl TypeRowBase { +impl TypeRow { /// Create a new empty row. #[must_use] pub const fn new() -> Self { @@ -61,48 +51,81 @@ impl TypeRowBase { } /// Returns a new `TypeRow` with `xs` concatenated onto `self`. - pub fn extend<'a>(&'a self, rest: impl IntoIterator>) -> Self { - self.iter().chain(rest).cloned().collect_vec().into() + pub fn extend<'a>(&'a self, rest: impl IntoIterator) -> Self { + self.iter().chain(rest).cloned().collect() } /// Returns a reference to the types in the row. #[must_use] - pub fn as_slice(&self) -> &[TypeBase] { + pub fn as_slice(&self) -> &[Type] { &self.types } - /// Applies a substitution to the row. - /// For `TypeRowRV`, note this may change the length of the row. - /// For `TypeRow`, guaranteed not to change the length of the row. - pub(crate) fn substitute(&self, s: &Substitution) -> Self { - self.iter() - .flat_map(|ty| ty.substitute(s)) - .collect::>() - .into() - } - delegate! { to self.types { /// Iterator over the types in the row. - pub fn iter(&self) -> impl Iterator>; + pub fn iter(&self) -> impl Iterator; /// Mutable vector of the types in the row. - pub fn to_mut(&mut self) -> &mut Vec>; + pub fn to_mut(&mut self) -> &mut Vec; /// Allow access (consumption) of the contained elements - #[must_use] pub fn into_owned(self) -> Vec>; + #[must_use] pub fn into_owned(self) -> Vec; /// Returns `true` if the row contains no types. - #[must_use] pub fn is_empty(&self) -> bool ; + #[must_use] pub fn is_empty(&self) -> bool; } } +} - pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { +/// Compared to just `pub(crate) trait`, this avoids a private_bounds +/// warning when the trait is used as a type bound on a public struct. +mod internal { + use super::{SignatureError, Substitution, Transformable, TypeParam}; + + /// Sub-trait of [`Transformable`] implemented by things that represent + /// rows of types (fixed-length [`TypeRow`] or variable-length [`TypeRowRV`]). + /// + /// [`TypeRow`]: super::TypeRow + /// [`TypeRowRV`]: super::TypeRowRV + pub trait TypeRowLike: Transformable { + /// Checks all variables used in `self` are in the provided list of bound + /// variables, and that for each [`CustomType`] the corresponding [`TypeDef`] + /// is in the [`ExtensionRegistry`] and the type arguments validate (recursively) + /// and fit into the declared parameters of the [`TypeDef`]. + /// + /// [`TypeDef`]: crate::extension::TypeDef + fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError>; + + /// Applies a [`Substitution`] to this instance, returning a new value. + /// + /// Infallible (assuming the `subst` covers all variables) and will + /// not invalidate the instance (assuming all values substituted in are + /// valid instances of the variables they replace). + /// + /// # Panics + /// + /// If the substitution does not cover all type variables in `self`. + fn substitute(&self, s: &Substitution) -> Self; + } +} + +pub(crate) use internal::TypeRowLike; + +impl TypeRowLike for TypeRow { + fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { self.iter().try_for_each(|t| t.validate(var_decls)) } + + fn substitute(&self, s: &Substitution) -> Self { + self.iter() + .map(|ty| ty.substitute(s)) + .collect::>() + .into() + } } -impl Transformable for TypeRowBase { +impl Transformable for TypeRow { fn transform(&mut self, tr: &T) -> Result { self.to_mut().transform(tr) } @@ -127,259 +150,322 @@ impl TypeRow { } } -impl TryFrom for TypeRow { - type Error = SignatureError; - - fn try_from(value: TypeRowRV) -> Result { - Ok(Self::from( - value - .into_owned() - .into_iter() - .map(std::convert::TryInto::try_into) - .collect::, _>>() - .map_err(|var| SignatureError::RowVarWhereTypeExpected { var })?, - )) - } -} - -impl Default for TypeRowBase { +impl Default for TypeRow { fn default() -> Self { Self::new() } } -impl From>> for TypeRowBase { - fn from(types: Vec>) -> Self { - Self { - types: types.into(), - } - } -} - -impl From> for TypeRowRV { +impl From> for TypeRow { fn from(types: Vec) -> Self { Self { - types: types.into_iter().map(Type::into_).collect(), + types: types.into(), } } } -impl From for TypeRowRV { - fn from(value: TypeRow) -> Self { - Self { - types: value.into_owned().into_iter().map(Type::into_).collect(), - } - } -} +impl TryFrom> for TypeRow { + type Error = TermTypeError; -impl From<[TypeBase; N]> for TypeRowBase { - fn from(types: [TypeBase; N]) -> Self { - Self::from(Vec::from(types)) + fn try_from(value: Vec) -> Result { + value + .into_iter() + .map(Type::try_from) + .collect::, _>>() + .map(Self::from) } } -impl From<[Type; N]> for TypeRowRV { +impl From<[Type; N]> for TypeRow { fn from(types: [Type; N]) -> Self { Self::from(Vec::from(types)) } } -impl From<&'static [TypeBase]> for TypeRowBase { - fn from(types: &'static [TypeBase]) -> Self { +impl From<&'static [Type]> for TypeRow { + fn from(types: &'static [Type]) -> Self { Self { types: types.into(), } } } -// Fallibly convert a [Term] to a [TypeRV]. -// -// This will fail if `arg` is of non-type kind (e.g. String). -impl TryFrom for TypeRV { - type Error = SignatureError; - - fn try_from(value: Term) -> Result { - match value { - TypeArg::Runtime(ty) => Ok(ty.into()), - TypeArg::Variable(v) => Ok(TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )), - _ => Err(SignatureError::InvalidTypeArgs), +impl PartialEq for TypeRow { + fn eq(&self, other: &Term) -> bool { + let Term::List(items) = other else { + return false; + }; + if self.types.len() != items.len() { + return false; } + self.types.iter().zip_eq(items).all(|(ty, tm)| &**ty == tm) } } -// Fallibly convert a [Term] to a [TypeRow]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type) -// or if the sequence contains row variables. +/// Fallibly convert a [Term] to a [TypeRow]. +/// +/// This will fail if `arg` is not a [Term::List] or any of the elements are not [Type]s impl TryFrom for TypeRow { - type Error = SignatureError; + type Error = TermTypeError; - fn try_from(value: TypeArg) -> Result { + fn try_from(value: Term) -> Result { match value { - TypeArg::List(elems) => elems + Term::List(elems) => Ok(elems .into_iter() - .map(|ta| ta.as_runtime().ok_or(SignatureError::InvalidTypeArgs)) - .collect::, _>>() - .map(TypeRow::from), - _ => Err(SignatureError::InvalidTypeArgs), + .map(Type::try_from) + .collect::, _>>()? + .into()), + v => Err(TermTypeError::InvalidValue(Box::new(v))), } } } -// Fallibly convert a [TypeArg] to a [TypeRowRV]. -// -// This will fail if `arg` is of non-sequence kind (e.g. Type). -impl TryFrom for TypeRowRV { - type Error = SignatureError; +impl From for Term { + fn from(value: TypeRow) -> Self { + Term::new_list(value.into_owned()) + } +} - fn try_from(value: Term) -> Result { - match value { - TypeArg::List(elems) => elems - .into_iter() - .map(TypeRV::try_from) - .collect::, _>>() - .map(|vec| vec.into()), - TypeArg::Variable(v) => Ok(vec![TypeRV::new_row_var_use( - v.index(), - v.bound_if_row_var() - .ok_or(SignatureError::InvalidTypeArgs)?, - )] - .into()), - _ => Err(SignatureError::InvalidTypeArgs), - } +impl Deref for TypeRow { + type Target = [Type]; + + fn deref(&self) -> &Self::Target { + self.as_slice() } } -impl From for Term { +impl DerefMut for TypeRow { + fn deref_mut(&mut self) -> &mut Self::Target { + self.types.to_mut() + } +} + +impl FromIterator for TypeRow { + fn from_iter>(iter: T) -> Self { + Self::from(iter.into_iter().collect_vec()) + } +} + +/// Row of types and/or row variables, the number of actual types is thus +/// unknown. Used for opdef signatures, and types of runtime function pointers. +/// +/// A [Term] that `check_term_type`s against [Term::ListType] of [Term::RuntimeType] +/// (of a [TypeBound]), i.e. one of +/// * A [Term::Variable] of type [Term::ListType] (of [Term::RuntimeType]...) +/// * A [Term::List], each of whose elements is of type some [Term::RuntimeType] +/// * A [Term::ListConcat], each of whose sublists is one of these three +/// +/// [TypeBound]: crate::types::TypeBound +#[derive(Clone, Debug, Display, PartialEq, Eq, Hash)] +#[display("{_0}")] +pub struct TypeRowRV(pub(super) Term); + +impl TypeRowRV { + const EMPTY: TypeRowRV = Self(Term::EMPTY_LIST); + pub(super) const EMPTY_REF: &TypeRowRV = &Self::EMPTY; + + /// Create a new empty row. + pub const fn new() -> Self { + Self::EMPTY + } + + /// Creates a singleton row with just a row variable + /// (a variable ranging over lists of types of any length) + pub fn new_var_use(idx: usize, b: TypeBound) -> Self { + Self(Term::new_row_var_use(idx, b)) + } + + /// Concatenates another TypeRowRV onto the end of this one + pub fn concat(self, other: impl Into) -> Self { + Self(Term::concat_lists([self.0, other.into().0])) + } + + /// Returns `true` if the row contains no types. + pub fn is_empty(&self) -> bool { + self.0.is_empty_list() + } +} + +impl TypeRowLike for TypeRowRV { + /// Checks that this is indeed a list of runtime types; + /// and that all variables are as declared in the supplied list of params. + fn validate(&self, vars: &[TypeParam]) -> Result<(), SignatureError> { + check_term_type(&self.0, &Term::new_list_type(TypeBound::Linear))?; + self.0.validate(vars) + } + + fn substitute(&self, s: &Substitution) -> Self { + // Substitution cannot make this invalid if it was valid previously + Self(self.0.substitute(s)) + } +} + +impl TryFrom for TypeRow { + type Error = TermTypeError; + + fn try_from(value: TypeRowRV) -> Result { + value.0.try_into() + } +} + +impl PartialEq for TypeRow { + fn eq(&self, other: &TypeRowRV) -> bool { + self == &other.0 + } +} + +impl From for TypeRowRV { fn from(value: TypeRow) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) + Self(Term::from(value)) + } +} + +impl Default for TypeRowRV { + /// Makes a new empty list + fn default() -> Self { + Self::EMPTY + } +} + +impl Transformable for TypeRowRV { + fn transform(&mut self, t: &T) -> Result { + self.0.transform(t) + } +} + +impl std::ops::Deref for TypeRowRV { + type Target = Term; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl TryFrom for TypeRowRV { + type Error = TermTypeError; + + fn try_from(t: Term) -> Result { + check_term_type(&t, &Term::new_list_type(TypeBound::Linear))?; + Ok(Self(t)) } } impl From for Term { fn from(value: TypeRowRV) -> Self { - Term::List(value.into_owned().into_iter().map_into().collect()) + value.0 } } -impl Deref for TypeRowBase { - type Target = [TypeBase]; +impl From> for TypeRowRV { + fn from(value: Vec) -> Self { + Self(Term::new_list(value)) + } +} - fn deref(&self) -> &Self::Target { - self.as_slice() +impl From<[Type; N]> for TypeRowRV { + fn from(value: [Type; N]) -> Self { + Self(Term::new_list(value)) } } -impl DerefMut for TypeRowBase { - fn deref_mut(&mut self) -> &mut Self::Target { - self.types.to_mut() +impl FromIterator for TypeRowRV { + fn from_iter>(iter: T) -> Self { + Self(Term::new_list(iter)) } } #[cfg(test)] mod test { use super::*; - use crate::{ - extension::prelude::bool_t, - types::{Type, TypeArg, TypeRV}, - }; + use crate::{extension::prelude::bool_t, types::Type}; mod proptest { use crate::proptest::RecursionDepth; - use crate::types::{MaybeRV, TypeBase, TypeRowBase}; + use crate::types::{Term, Type, TypeBound, TypeRow, TypeRowRV}; use ::proptest::prelude::*; - impl Arbitrary for super::super::TypeRowBase { + impl Arbitrary for TypeRow { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { use proptest::collection::vec; if depth.leaf() { - Just(TypeRowBase::new()).boxed() + Just(TypeRow::new()).boxed() } else { - vec(any_with::>(depth), 0..4) + vec(any_with::(depth.descend()), 0..4) .prop_map(|ts| ts.clone().into()) .boxed() } } } - } - #[test] - fn test_try_from_term_to_typerv() { - // Test successful conversion with Runtime type - let runtime_type = Type::UNIT; - let term = TypeArg::Runtime(runtime_type.clone()); - let result = TypeRV::try_from(term); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), TypeRV::from(runtime_type)); + impl Arbitrary for TypeRowRV { + type Parameters = RecursionDepth; + type Strategy = BoxedStrategy; + fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { + use proptest::collection::vec; + if depth.leaf() { + Just(TypeRowRV::default()).boxed() + } else { + prop_oneof![ + vec(any_with::(depth.descend()), 0..4) + .prop_map(|ts| ts.clone().into()) + .boxed(), + (any::(), any::()) + .prop_map(|(idx, b)| TypeRowRV::new_var_use(idx, b)) + .boxed(), + ] + .boxed() + } + } + } - // Test failure with non-type kind - let term = Term::String("test".to_string()); - let result = TypeRV::try_from(term); - assert!(result.is_err()); + proptest! { + #[test] + fn type_row_rv_term_roundtrip(tr: TypeRowRV) { + let t: Term = tr.clone().into(); + let tr2: TypeRowRV = t.try_into().unwrap(); + assert_eq!(tr, tr2); + } + } } #[test] fn test_try_from_term_to_typerow() { // Test successful conversion with List let types = vec![Type::new_unit_sum(1), bool_t()]; - let type_args = types.iter().map(|t| TypeArg::Runtime(t.clone())).collect(); - let term = TypeArg::List(type_args); - let result = TypeRow::try_from(term); - assert!(result.is_ok()); - assert_eq!(result.unwrap(), TypeRow::from(types)); + let term = Term::new_list(types.clone()); + assert_eq!( + TypeRow::try_from(term.clone()), + Ok(TypeRow::from(types.clone())) + ); + assert_eq!( + TypeRowRV::try_from(term.clone()), + Ok(TypeRowRV::from(types.clone())) + ); + assert_eq!(*TypeRowRV::try_from(term.clone()).unwrap(), term); // Test failure with non-list - let term = TypeArg::Runtime(Type::UNIT); - let result = TypeRow::try_from(term); - assert!(result.is_err()); - } - - #[test] - fn test_try_from_term_to_typerowrv() { - // Test successful conversion with List - let types = [TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_args = types.iter().map(|t| t.clone().into()).collect(); - let term = TypeArg::List(type_args); - let result = TypeRowRV::try_from(term); - assert!(result.is_ok()); + let term = Term::from(Type::UNIT); + assert!(TypeRow::try_from(term.clone()).is_err()); + assert!(TypeRowRV::try_from(term).is_err()); - // Test failure with non-sequence kind - let term = Term::String("test".to_string()); - let result = TypeRowRV::try_from(term); - assert!(result.is_err()); + assert!(TypeRow::try_from(Term::new_row_var_use(0, TypeBound::Linear)).is_err()); } #[test] fn test_from_typerow_to_term() { let types = vec![Type::UNIT, bool_t()]; let type_row = TypeRow::from(types); - let term = Term::from(type_row); + let term = Term::from(type_row.clone()); - match term { + match &term { Term::List(elems) => { assert_eq!(elems.len(), 2); } _ => panic!("Expected Term::List"), } - } - - #[test] - fn test_from_typerowrv_to_term() { - let types = vec![TypeRV::from(Type::UNIT), TypeRV::from(bool_t())]; - let type_row_rv = TypeRowRV::from(types); - let term = Term::from(type_row_rv); - match term { - TypeArg::List(elems) => { - assert_eq!(elems.len(), 2); - } - _ => panic!("Expected Term::List"), - } + assert_eq!(term.try_into(), Ok(type_row)); } } diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 3a53acb97f..82d672ef53 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -6,7 +6,7 @@ use hugr_core::ops::{ }; use hugr_core::{ HugrView, NodeIndex, - types::{SumType, Type, TypeEnum}, + types::{SumType, Type}, }; use inkwell::types::BasicTypeEnum; use inkwell::values::BasicValueEnum; @@ -100,11 +100,11 @@ where } fn get_exactly_one_sum_type(ts: impl IntoIterator) -> Result { - let Some(TypeEnum::Sum(sum_type)) = ts + let Some(sum_type) = ts .into_iter() - .map(|t| t.as_type_enum().clone()) .exactly_one() .ok() + .and_then(|t| t.as_sum().cloned()) else { Err(anyhow!("Not exactly one SumType"))? }; diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 5c93f0a09f..b28cae2fca 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -855,7 +855,7 @@ mod test_fns { #[case::long(&"x".repeat(PANIC_MSG_BUFFER_LEN + 100))] fn test_exec_panic(mut exec_ctx: TestContext, #[case] msg: &str) { let panic_op = PRELUDE - .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) + .instantiate_extension_op(&EXIT_OP_ID, [Term::EMPTY_LIST, Term::EMPTY_LIST]) .unwrap(); let hugr = SimpleHugrConfig::new() diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 7ca58b6e74..3992845038 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -25,7 +25,7 @@ use hugr_core::ops::DataflowOpTrait; use hugr_core::std_extensions::collections::array::{ self, ArrayClone, ArrayDiscard, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::TypeArg; use hugr_core::{HugrView, Node}; use inkwell::IntPredicate; use inkwell::builder::Builder; @@ -498,7 +498,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -559,7 +559,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -621,7 +621,7 @@ pub fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/borrow_array.rs b/hugr-llvm/src/extension/collections/borrow_array.rs index 6a1ec224e6..2fff393429 100644 --- a/hugr-llvm/src/extension/collections/borrow_array.rs +++ b/hugr-llvm/src/extension/collections/borrow_array.rs @@ -32,7 +32,7 @@ use hugr_core::std_extensions::collections::borrow_array::{ BArrayRepeat, BArrayScan, BArrayToArray, BArrayToArrayDef, BArrayUnsafeOp, BArrayUnsafeOpDef, borrow_array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::TypeArg; use hugr_core::{HugrView, Node}; use inkwell::IntPredicate; use inkwell::builder::Builder; @@ -1083,7 +1083,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("BArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -1149,7 +1149,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("BArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -1216,7 +1216,7 @@ pub fn emit_barray_op<'c, H: HugrView>( .ok_or(anyhow!("BArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("BArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index 37e173ecac..e0a3ac196c 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -4,7 +4,7 @@ use hugr_core::{ extension::simple_op::MakeExtensionOp as _, ops::ExtensionOp, std_extensions::collections::list::{self, ListOp, ListValue}, - types::{SumType, Type, TypeArg}, + types::{SumType, Type}, }; use inkwell::values::FunctionValue; use inkwell::{ @@ -202,7 +202,7 @@ fn emit_list_op<'c, H: HugrView>( op: ListOp, ) -> Result<()> { let hugr_elem_ty = match args.node().args() { - [TypeArg::Runtime(ty)] => ty.clone(), + [ty] => ty.clone().try_into()?, _ => { bail!("Collections: invalid type args for list op"); } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap index 38d3385687..2f13d9a6cd 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm21.snap @@ -5,24 +5,24 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.6acc1b76.0, ptr @sa.inner.e637bb5.0, ptr @sa.inner.2b6593f.0, ptr @sa.inner.1b9ad7c.0, ptr @sa.inner.e67fbfa4.0, ptr @sa.inner.15dc27f6.0, ptr @sa.inner.c43a2bb2.0, ptr @sa.inner.7f5d5e16.0, ptr @sa.inner.a0bc9c53.0, ptr @sa.inner.1e8aada3.0] } +@sa.inner.ac73413c.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.3334f213.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.9447a20a.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.dfbce68f.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.5712c1c3.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.fc8747b9.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.aaa0b715.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.7aa729b2.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.66390c92.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.4d7c0c80.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.f1be8bcf.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.ac73413c.0, ptr @sa.inner.3334f213.0, ptr @sa.inner.9447a20a.0, ptr @sa.inner.dfbce68f.0, ptr @sa.inner.5712c1c3.0, ptr @sa.inner.fc8747b9.0, ptr @sa.inner.aaa0b715.0, ptr @sa.inner.7aa729b2.0, ptr @sa.inner.66390c92.0, ptr @sa.inner.4d7c0c80.0] } define internal i64 @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = getelementptr inbounds { i64, [0 x ptr] }, ptr @sa.outer.e55b610a.0, i32 0, i32 0 + %0 = getelementptr inbounds { i64, [0 x ptr] }, ptr @sa.outer.f1be8bcf.0, i32 0, i32 0 %1 = load i64, ptr %0, align 4 ret i64 %1 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap index 950cd1316d..0373d2a471 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap @@ -5,17 +5,17 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.inner.ac73413c.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.3334f213.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.9447a20a.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.dfbce68f.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.5712c1c3.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.fc8747b9.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.aaa0b715.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.7aa729b2.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.66390c92.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.4d7c0c80.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.f1be8bcf.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.ac73413c.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.3334f213.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.9447a20a.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.dfbce68f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.5712c1c3.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.fc8747b9.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.aaa0b715.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7aa729b2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.66390c92.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.4d7c0c80.0 to { i64, [0 x i64] }*)] } define internal i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 + store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.f1be8bcf.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %"5_01" = load { i64, [0 x { i64, [0 x i64] }*] }*, { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* %"5_01", i32 0, i32 0 %1 = load i64, i64* %0, align 4 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap index 65ffc76d07..ecafb18565 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm21.snap @@ -5,17 +5,17 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.inner.6acc1b76.0 = constant { i64, [0 x i64] } zeroinitializer -@sa.inner.e637bb5.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } -@sa.inner.2b6593f.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } -@sa.inner.1b9ad7c.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } -@sa.inner.e67fbfa4.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } -@sa.inner.15dc27f6.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } -@sa.inner.c43a2bb2.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } -@sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } -@sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } -@sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.e55b610a.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.6acc1b76.0, ptr @sa.inner.e637bb5.0, ptr @sa.inner.2b6593f.0, ptr @sa.inner.1b9ad7c.0, ptr @sa.inner.e67fbfa4.0, ptr @sa.inner.15dc27f6.0, ptr @sa.inner.c43a2bb2.0, ptr @sa.inner.7f5d5e16.0, ptr @sa.inner.a0bc9c53.0, ptr @sa.inner.1e8aada3.0] } +@sa.inner.ac73413c.0 = constant { i64, [0 x i64] } zeroinitializer +@sa.inner.3334f213.0 = constant { i64, [1 x i64] } { i64 1, [1 x i64] [i64 1] } +@sa.inner.9447a20a.0 = constant { i64, [2 x i64] } { i64 2, [2 x i64] [i64 2, i64 2] } +@sa.inner.dfbce68f.0 = constant { i64, [3 x i64] } { i64 3, [3 x i64] [i64 3, i64 3, i64 3] } +@sa.inner.5712c1c3.0 = constant { i64, [4 x i64] } { i64 4, [4 x i64] [i64 4, i64 4, i64 4, i64 4] } +@sa.inner.fc8747b9.0 = constant { i64, [5 x i64] } { i64 5, [5 x i64] [i64 5, i64 5, i64 5, i64 5, i64 5] } +@sa.inner.aaa0b715.0 = constant { i64, [6 x i64] } { i64 6, [6 x i64] [i64 6, i64 6, i64 6, i64 6, i64 6, i64 6] } +@sa.inner.7aa729b2.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } +@sa.inner.66390c92.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } +@sa.inner.4d7c0c80.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } +@sa.outer.f1be8bcf.0 = constant { i64, [10 x ptr] } { i64 10, [10 x ptr] [ptr @sa.inner.ac73413c.0, ptr @sa.inner.3334f213.0, ptr @sa.inner.9447a20a.0, ptr @sa.inner.dfbce68f.0, ptr @sa.inner.5712c1c3.0, ptr @sa.inner.fc8747b9.0, ptr @sa.inner.aaa0b715.0, ptr @sa.inner.7aa729b2.0, ptr @sa.inner.66390c92.0, ptr @sa.inner.4d7c0c80.0] } define internal i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.outer.e55b610a.0, ptr %"5_0", align 8 + store ptr @sa.outer.f1be8bcf.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x ptr] }, ptr %"5_01", i32 0, i32 0 %1 = load i64, ptr %0, align 4 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap index 59bbf8007b..c775c0323e 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_0.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.64db7f63.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define internal ptr @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret ptr @sa.a.97cb22bf.0 + ret ptr @sa.a.64db7f63.0 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap index 0401cf0b50..dab0859bcc 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_2.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.d797f156.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define internal ptr @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret ptr @sa.c.d2dddd66.0 + ret ptr @sa.c.d797f156.0 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap index e682c3b6ce..81eea756ce 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm21_3.snap @@ -5,12 +5,12 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } +@sa.d.e9aebfdb.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } define internal ptr @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret ptr @sa.d.eee08a59.0 + ret ptr @sa.d.e9aebfdb.0 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap index 480fd4e9c8..682e879acc 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_0.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.64db7f63.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define internal { i64, [0 x i64] }* @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.97cb22bf.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }** %"5_0", align 8 + store { i64, [0 x i64] }* bitcast ({ i64, [10 x i64] }* @sa.a.64db7f63.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }** %"5_0", align 8 %"5_01" = load { i64, [0 x i64] }*, { i64, [0 x i64] }** %"5_0", align 8 store { i64, [0 x i64] }* %"5_01", { i64, [0 x i64] }** %"0", align 8 %"02" = load { i64, [0 x i64] }*, { i64, [0 x i64] }** %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap index 8f5dd5efb6..8594f22b58 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.d797f156.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define internal { i64, [0 x i1] }* @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d2dddd66.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8 + store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d797f156.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8 %"5_01" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"5_0", align 8 store { i64, [0 x i1] }* %"5_01", { i64, [0 x i1] }** %"0", align 8 %"02" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap index 6f341ffa4d..c246496191 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_0.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.a.97cb22bf.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } +@sa.a.64db7f63.0 = constant { i64, [10 x i64] } { i64 10, [10 x i64] [i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9] } define internal ptr @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.a.97cb22bf.0, ptr %"5_0", align 8 + store ptr @sa.a.64db7f63.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 store ptr %"5_01", ptr %"0", align 8 %"02" = load ptr, ptr %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap index 3fb020a531..82f98439a6 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_2.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } +@sa.c.d797f156.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] } define internal ptr @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.c.d2dddd66.0, ptr %"5_0", align 8 + store ptr @sa.c.d797f156.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 store ptr %"5_01", ptr %"0", align 8 %"02" = load ptr, ptr %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap index 869b5a847f..3bea34aa6c 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm21_3.snap @@ -5,7 +5,7 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } +@sa.d.e9aebfdb.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] } define internal ptr @_hl.main.1() { alloca_block: @@ -14,7 +14,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store ptr @sa.d.eee08a59.0, ptr %"5_0", align 8 + store ptr @sa.d.e9aebfdb.0, ptr %"5_0", align 8 %"5_01" = load ptr, ptr %"5_0", align 8 store ptr %"5_01", ptr %"0", align 8 %"02" = load ptr, ptr %"0", align 8 diff --git a/hugr-llvm/src/extension/collections/stack_array.rs b/hugr-llvm/src/extension/collections/stack_array.rs index 285a1ba3ec..8dc17cd03d 100644 --- a/hugr-llvm/src/extension/collections/stack_array.rs +++ b/hugr-llvm/src/extension/collections/stack_array.rs @@ -14,7 +14,7 @@ use hugr_core::ops::DataflowOpTrait; use hugr_core::std_extensions::collections::array::{ self, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, array_type, }; -use hugr_core::types::{TypeArg, TypeEnum}; +use hugr_core::types::TypeArg; use hugr_core::{HugrView, Node}; use inkwell::IntPredicate; use inkwell::builder::{Builder, BuilderError}; @@ -135,10 +135,11 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), ty] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; - let elem_ty = ts.llvm_type(ty)?; + let ty = ty.clone().try_into()?; + let elem_ty = ts.llvm_type(&ty)?; Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum()) } }) @@ -357,7 +358,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::get has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::get output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -420,7 +421,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::set has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::set output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? @@ -494,7 +495,7 @@ fn emit_array_op<'c, H: HugrView>( .ok_or(anyhow!("ArrayOp::swap has no outputs"))?; let res_sum_ty = { - let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else { + let Some(st) = res_hugr_ty.as_sum() else { Err(anyhow!("ArrayOp::swap output is not a sum type"))? }; ts.llvm_sum_type(st.clone())? diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 882fbad9f1..de7936d37b 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -339,9 +339,10 @@ impl CodegenExtension for StaticArrayCodegenE { move |ts, custom_type| { // check the arg type, even though the return is always ptr - let _ = custom_type.args()[0] - .as_runtime() - .expect("Type argument for static array must be a type"); + assert!( + custom_type.args()[0].is_runtime_type(), + "Type argument for static array must be a type" + ); Ok(ts.llvm_ptr_type().into()) } }, diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index c620d0926c..c2cf0011b8 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -8,7 +8,7 @@ use hugr_core::{ }, ops::{DataflowOpTrait as _, constant::Value, custom::ExtensionOp}, std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, - types::{TypeEnum, TypeRow}, + types::TypeRow, }; use inkwell::{FloatPredicate, IntPredicate, types::IntType, values::BasicValue}; @@ -189,12 +189,12 @@ fn emit_conversion_op<'c, H: HugrView>( .typing_session() .llvm_type(&INT_TYPES[0])? .into_int_type(); - let sum_ty = context - .typing_session() - .llvm_sum_type(match bool_t().as_type_enum() { - TypeEnum::Sum(st) => st.clone(), - _ => panic!("Hugr prelude bool_t() not a Sum"), - })?; + let sum_ty = context.typing_session().llvm_sum_type( + bool_t() + .as_sum() + .expect("Hugr prelude bool_t() not a Sum") + .clone(), + )?; emit_custom_unary_op(context, args, |ctx, arg, _| { let res = if conversion_op == ConvertOpDef::itobool { diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index 92d31a95ad..a688c9516f 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -699,7 +699,7 @@ mod test { .unwrap(); let panic_op = PRELUDE - .instantiate_extension_op(&PANIC_OP_ID, [Term::new_list([]), Term::new_list([])]) + .instantiate_extension_op(&PANIC_OP_ID, [Term::EMPTY_LIST, Term::EMPTY_LIST]) .unwrap(); let hugr = SimpleHugrConfig::new() diff --git a/hugr-llvm/src/utils/type_map.rs b/hugr-llvm/src/utils/type_map.rs index c8129c2578..2cf006f95b 100644 --- a/hugr-llvm/src/utils/type_map.rs +++ b/hugr-llvm/src/utils/type_map.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use hugr_core::{ extension::ExtensionId, - types::{CustomType, TypeEnum, TypeName, TypeRow}, + types::{CustomType, Term, TypeName, TypeRow}, }; use anyhow::{Result, bail}; @@ -115,18 +115,18 @@ impl<'a, TM: TypeMapping + 'a> TypeMap<'a, TM> { /// Map `hugr_type` using the [`TypeMapping`] `TM`, the registered callbacks, /// and the auxiliary data `inv`. pub fn map_type<'c>(&self, hugr_type: &HugrType, inv: TM::InV<'c>) -> Result> { - match hugr_type.as_type_enum() { - TypeEnum::Extension(custom_type) => { + match &**hugr_type { + Term::RuntimeExtension(custom_type) => { let key = (custom_type.extension().clone(), custom_type.name().clone()); let Some(handler) = self.custom_hooks.get(&key) else { return self.type_map.default_out(inv, &custom_type.clone().into()); }; handler.map_type(inv, custom_type) } - TypeEnum::Sum(sum_type) => self + Term::RuntimeSum(sum_type) => self .map_sum_type(sum_type, inv) .map(|x| self.type_map.sum_into_out(x)), - TypeEnum::Function(function_type) => self + Term::RuntimeFunction(function_type) => self .map_function_type(&function_type.as_ref().clone().try_into()?, inv) .map(|x| self.type_map.func_into_out(x)), _ => self.type_map.default_out(inv, hugr_type), diff --git a/hugr/benches/benchmarks/types.rs b/hugr/benches/benchmarks/types.rs index d05896f01b..7abfa875dd 100644 --- a/hugr/benches/benchmarks/types.rs +++ b/hugr/benches/benchmarks/types.rs @@ -1,8 +1,7 @@ // Required for black_box uses #![allow(clippy::unit_arg)] -use hugr::extension::prelude::{qb_t, usize_t}; -use hugr::ops::AliasDecl; -use hugr::types::{Signature, Type, TypeBound}; +use hugr::extension::prelude::{bool_t, qb_t, usize_t}; +use hugr::types::{Signature, Type}; use criterion::{AxisScale, Criterion, PlotConfiguration, criterion_group}; use std::hint::black_box; @@ -13,8 +12,7 @@ fn make_complex_type() -> Type { let int = usize_t(); let q_register = Type::new_tuple(vec![qb; 8]); let b_register = Type::new_tuple(vec![int; 8]); - let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Linear)); - let sum = Type::new_sum([[q_register], [q_alias]]); + let sum = Type::new_sum([[q_register], [bool_t()]]); Type::new_function(Signature::new(vec![sum], vec![b_register])) }