diff --git a/qis-compiler/rust/lib.rs b/qis-compiler/rust/lib.rs index a94eb19bf..c7c61f6cf 100644 --- a/qis-compiler/rust/lib.rs +++ b/qis-compiler/rust/lib.rs @@ -38,9 +38,12 @@ use tket::hugr::{self, llvm::inkwell}; use tket::hugr::{Hugr, HugrView, Node}; use tket::llvm::rotation::RotationCodegenExtension; use tket_qsystem::QSystemPass; -use tket_qsystem::extension::{futures as qsystem_futures, qsystem, result as qsystem_result}; +use tket_qsystem::extension::{ + futures as qsystem_futures, globals as qsystem_globals, qsystem, result as qsystem_result, +}; use tket_qsystem::llvm::array_utils::ArrayLowering; pub use tket_qsystem::llvm::futures::FuturesCodegenExtension; +use tket_qsystem::llvm::globals::GlobalsCodegenExtension; use tket_qsystem::llvm::{ debug::DebugCodegenExtension, prelude::QISPreludeCodegen, qsystem::QSystemCodegenExtension, random::RandomCodegenExtension, result::ResultsCodegenExtension, utils::UtilsCodegenExtension, @@ -70,6 +73,7 @@ static REGISTRY: std::sync::LazyLock = std::sync::LazyLock::n collections::static_array::EXTENSION.to_owned(), collections::borrow_array::EXTENSION.to_owned(), qsystem_futures::EXTENSION.to_owned(), + qsystem_globals::EXTENSION.to_owned(), qsystem_result::EXTENSION.to_owned(), qsystem::EXTENSION.to_owned(), ROTATION_EXTENSION.to_owned(), @@ -160,6 +164,7 @@ fn codegen_extensions() -> CodegenExtsMap<'static, Hugr> { .add_default_static_array_extensions() .add_borrow_array_extensions(array::SeleneHeapBorrowArrayCodegen(pcg.clone())) .add_extension(FuturesCodegenExtension) + .add_extension(GlobalsCodegenExtension) .add_extension(QSystemCodegenExtension::from(pcg.clone())) .add_extension(RandomCodegenExtension) // Results use standard arrays. diff --git a/tket-exts/src/tket_exts/__init__.py b/tket-exts/src/tket_exts/__init__.py index 846d8bb80..c46b3a3d7 100644 --- a/tket-exts/src/tket_exts/__init__.py +++ b/tket-exts/src/tket_exts/__init__.py @@ -3,6 +3,7 @@ from tket_exts.tket.bool import BoolExtension from tket_exts.tket.debug import DebugExtension from tket_exts.tket.global_phase import GlobalPhaseExtension +from tket_exts.tket.globals import GlobalsExtension from tket_exts.tket.gpu import GpuExtension from tket_exts.tket.guppy import GuppyExtension from tket_exts.tket.modifier import ModifierExtension @@ -40,6 +41,7 @@ "wasm", "modifier", "global_phase", + "globals", ] bool: BoolExtension = tket.bool.BoolExtension() @@ -56,6 +58,7 @@ wasm: WasmExtension = tket.wasm.WasmExtension() modifier: ModifierExtension = tket.modifier.ModifierExtension() global_phase: GlobalPhaseExtension = tket.global_phase.GlobalPhaseExtension() +globals: GlobalsExtension = tket.globals.GlobalsExtension() @deprecated("Use tket_exts.bool() instead") diff --git a/tket-exts/src/tket_exts/data/tket/globals.json b/tket-exts/src/tket_exts/data/tket/globals.json new file mode 100644 index 000000000..53caebdcf --- /dev/null +++ b/tket-exts/src/tket_exts/data/tket/globals.json @@ -0,0 +1,58 @@ +{ + "version": "0.1.0", + "name": "tket.globals", + "types": {}, + "operations": { + "swap": { + "extension": "tket.globals", + "name": "swap", + "description": "Swap the contents of the named global variable with the argument.", + "signature": { + "params": [ + { + "tp": "String" + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [], + [ + { + "t": "V", + "i": 1, + "b": "A" + } + ] + ] + } + ], + "output": [ + { + "t": "Sum", + "s": "General", + "rows": [ + [], + [ + { + "t": "V", + "i": 1, + "b": "A" + } + ] + ] + } + ] + } + }, + "binary": false + } + } +} diff --git a/tket-exts/src/tket_exts/tket/__init__.py b/tket-exts/src/tket_exts/tket/__init__.py index 76d8fc891..40ebce0f8 100644 --- a/tket-exts/src/tket_exts/tket/__init__.py +++ b/tket-exts/src/tket_exts/tket/__init__.py @@ -12,6 +12,7 @@ wasm, modifier, global_phase, + globals, ) __all__ = [ @@ -27,4 +28,5 @@ "wasm", "modifier", "global_phase", + "globals", ] diff --git a/tket-exts/src/tket_exts/tket/globals.py b/tket-exts/src/tket_exts/tket/globals.py new file mode 100644 index 000000000..e601a258d --- /dev/null +++ b/tket-exts/src/tket_exts/tket/globals.py @@ -0,0 +1,36 @@ +import functools +from typing import List + +from hugr.ops import ExtOp +from hugr.tys import StringArg, TypeTypeArg, Type + +from ._util import TketExtension, load_extension +from hugr.ext import Extension, OpDef, TypeDef + + +class GlobalsExtension(TketExtension): + """Global state operations.""" + + @functools.cache + def __call__(self) -> Extension: + """Returns the globals extension""" + return load_extension("tket.globals") + + def TYPES(self) -> List[TypeDef]: + """Return the types defined by this extension""" + return [] + + def OPS(self) -> List[OpDef]: + """Return the operations defined by this extension""" + return [ + self.swap_def, + ] + + @functools.cached_property + def swap_def(self) -> OpDef: + """Swap the contents of the named global variable with the argument.""" + return self().get_op("swap") + + def swap(self, name: str, ty: Type) -> ExtOp: + """Swap the contents of the named global variable with the argument.""" + return self().get_op("swap").instantiate([StringArg(name), TypeTypeArg(ty)]) diff --git a/tket-exts/tests/test_validate_exts.py b/tket-exts/tests/test_validate_exts.py index 2527bc4e8..a26351f5c 100644 --- a/tket-exts/tests/test_validate_exts.py +++ b/tket-exts/tests/test_validate_exts.py @@ -69,6 +69,16 @@ def ext_futures() -> Tuple[TketExtension, List[ExtType], List[ExtOp]]: ) +def ext_globals() -> Tuple[TketExtension, List[ExtType], List[ExtOp]]: + ext = tket_exts.globals + bool_t = tket_exts.bool.bool_t + return ( + ext, + [], + [ext.swap("test-name", bool_t)], + ) + + def ext_qsystem() -> Tuple[TketExtension, List[ExtType], List[ExtOp]]: ext = tket_exts.qsystem return ( @@ -202,6 +212,7 @@ def ext_wasm() -> Tuple[TketExtension, List[ExtType], List[ExtOp]]: ext_gpu, ext_guppy, ext_futures, + ext_globals, ext_qsystem, ext_qsystem_random, ext_qsystem_utils, diff --git a/tket-qsystem/src/bin/tket-qsystem.rs b/tket-qsystem/src/bin/tket-qsystem.rs index 20925ef65..0564897cb 100644 --- a/tket-qsystem/src/bin/tket-qsystem.rs +++ b/tket-qsystem/src/bin/tket-qsystem.rs @@ -19,6 +19,7 @@ fn main() -> Result<()> { tket_qsystem::extension::gpu::EXTENSION.to_owned(), tket_qsystem::extension::qsystem::EXTENSION.to_owned(), tket_qsystem::extension::futures::EXTENSION.to_owned(), + tket_qsystem::extension::globals::EXTENSION.to_owned(), tket_qsystem::extension::random::EXTENSION.to_owned(), tket_qsystem::extension::result::EXTENSION.to_owned(), tket_qsystem::extension::utils::EXTENSION.to_owned(), diff --git a/tket-qsystem/src/extension.rs b/tket-qsystem/src/extension.rs index 8e8166496..940659341 100644 --- a/tket-qsystem/src/extension.rs +++ b/tket-qsystem/src/extension.rs @@ -3,6 +3,7 @@ pub mod classical_compute; pub use classical_compute::gpu; pub use classical_compute::wasm; pub mod futures; +pub mod globals; pub mod qsystem; pub mod random; pub mod result; diff --git a/tket-qsystem/src/extension/globals.rs b/tket-qsystem/src/extension/globals.rs new file mode 100644 index 000000000..b7ce15e94 --- /dev/null +++ b/tket-qsystem/src/extension/globals.rs @@ -0,0 +1,161 @@ +#![allow(missing_docs)] + +use std::sync::{Arc, Weak}; + +use hugr::{ + Extension, + extension::{ + ExtensionId, SignatureError, SignatureFunc, Version, + prelude::option_type, + simple_op::{ + HasConcrete, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name, + }, + }, + ops::{ExtensionOp, OpName}, + types::{ + PolyFuncType, Signature, Type, TypeArg, TypeBound, + type_param::{TermTypeError, TypeParam}, + }, +}; + +/// The ID of the `tket.globals` extension. +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket.globals"); +/// The "tket.globals" extension version +pub const EXTENSION_VERSION: Version = Version::new(0, 1, 0); + +lazy_static::lazy_static! { + /// The "tket.globals" extension. + pub static ref EXTENSION: Arc = { + Extension::new_arc(EXTENSION_ID, EXTENSION_VERSION, |ext, ext_ref| { + GlobalsOpDef::load_all_ops(ext, ext_ref).unwrap(); + }) + }; + + pub static ref NAME_PARAM: TypeParam = TypeParam::StringType; + pub static ref TYPE_PARAM: TypeParam = TypeParam::RuntimeType(TypeBound::Linear); +} + +#[derive( + Clone, + Copy, + Debug, + serde::Serialize, + serde::Deserialize, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + strum::EnumIter, + strum::IntoStaticStr, + strum::EnumString, +)] +#[expect(non_camel_case_types)] +#[non_exhaustive] +pub enum GlobalsOpDef { + /// Swap the contents of the named global variable with the argument. + swap, +} + +impl MakeOpDef for GlobalsOpDef { + fn opdef_id(&self) -> OpName { + <&'static str>::from(self).into() + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + match self { + Self::swap => PolyFuncType::new( + [NAME_PARAM.to_owned(), TYPE_PARAM.to_owned()], + Signature::new_endo([Type::from(option_type([Type::new_var_use( + 1, + TypeBound::Linear, + )]))]), + ) + .into(), + } + } + + fn from_def(op_def: &hugr::extension::OpDef) -> Result { + try_from_name(op_def.name(), op_def.extension_id()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID + } + + fn description(&self) -> String { + match self { + Self::swap => { + "Swap the contents of the named global variable with the argument.".to_string() + } + } + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } +} + +pub enum GlobalsOp { + Swap { name: String, ty: Type }, +} + +impl MakeExtensionOp for GlobalsOp { + fn op_id(&self) -> OpName { + GlobalsOpDef::swap.opdef_id() + } + + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + GlobalsOpDef::from_def(ext_op.def())?.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + match self { + Self::Swap { name, ty } => { + vec![TypeArg::String(name.clone()), TypeArg::Runtime(ty.clone())] + } + } + } +} + +impl HasConcrete for GlobalsOpDef { + type Concrete = GlobalsOp; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + let [name_arg, ty_arg] = type_args else { + Err(SignatureError::from(TermTypeError::WrongNumberArgs( + type_args.len(), + 2, + )))? + }; + + let Some(name) = name_arg.as_string() else { + Err(SignatureError::from(TermTypeError::TypeMismatch { + term: name_arg.clone().into(), + type_: NAME_PARAM.to_owned().into(), + }))? + }; + + let Some(ty) = ty_arg.as_runtime() else { + Err(SignatureError::from(TermTypeError::TypeMismatch { + term: ty_arg.clone().into(), + type_: TYPE_PARAM.to_owned().into(), + }))? + }; + + Ok(GlobalsOp::Swap { name, ty }) + } +} + +impl MakeRegisteredOp for GlobalsOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID + } + + fn extension_ref(&self) -> Arc { + EXTENSION.clone() + } +} diff --git a/tket-qsystem/src/llvm.rs b/tket-qsystem/src/llvm.rs index 9ececcccf..18f76479d 100644 --- a/tket-qsystem/src/llvm.rs +++ b/tket-qsystem/src/llvm.rs @@ -2,6 +2,7 @@ pub mod array_utils; pub mod debug; pub mod futures; +pub mod globals; pub mod prelude; pub mod qsystem; pub mod random; diff --git a/tket-qsystem/src/llvm/globals.rs b/tket-qsystem/src/llvm/globals.rs new file mode 100644 index 000000000..cb660a289 --- /dev/null +++ b/tket-qsystem/src/llvm/globals.rs @@ -0,0 +1,95 @@ +#![allow(missing_docs)] + +use crate::extension::globals::{GlobalsOp, GlobalsOpDef}; +use anyhow::{Result, bail, ensure}; +use hugr::llvm::{ + CodegenExtension, CodegenExtsBuilder, + emit::{EmitFuncContext, EmitOpArgs}, + inkwell::{AddressSpace, types::BasicType as _}, +}; +use hugr::{ + HugrView, Node, + extension::{prelude::option_type, simple_op::HasConcrete as _}, + ops::ExtensionOp, +}; + +pub struct GlobalsCodegenExtension; + +impl CodegenExtension for GlobalsCodegenExtension { + fn add_extension<'a, H: HugrView + 'a>( + self, + builder: CodegenExtsBuilder<'a, H>, + ) -> CodegenExtsBuilder<'a, H> + where + Self: 'a, + { + builder.simple_extension_op(emit_globals_op) + } +} + +fn emit_globals_op<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, '_, H>, + args: EmitOpArgs<'c, '_, ExtensionOp, H>, + op: GlobalsOpDef, +) -> Result<()> { + let op = op.instantiate(args.node().args())?; + const PREFIX: &str = "__globals__"; + + match op { + GlobalsOp::Swap { name, ty } => { + let sym = format!("{PREFIX}.{name}"); + let sym_ty = context.llvm_sum_type(option_type([ty.clone()]))?; + + let [new_value] = &args.inputs[..] else { + bail!("Expected one input for GlobalsOp::Swap") + }; + let new_value_ty = new_value.get_type(); + ensure!( + new_value_ty == sym_ty.as_basic_type_enum(), + "Input type does not match global variable type. Found {new_value_ty}, Expected {sym_ty}" + ); + + let module = context.get_current_module(); + let builder = context.builder(); + let none_value = sym_ty.build_tag(builder, 0, vec![])?; + + let global = module.get_global(&sym).unwrap_or_else(|| { + let global = module.add_global(sym_ty.clone(), Some(AddressSpace::default()), &sym); + global.set_initializer(&none_value); + global + }); + + let result = builder.build_load(sym_ty, global.as_pointer_value(), "current_value")?; + let _ = builder.build_store(global.as_pointer_value(), *new_value)?; + args.outputs.finish(builder, [result])? + } + } + + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + use hugr::extension::prelude::usize_t; + use hugr::llvm::{ + check_emission, + test::{TestContext, llvm_ctx, single_op_hugr}, + }; + + #[rstest::rstest] + fn emit_global_codegen(mut llvm_ctx: TestContext) { + llvm_ctx.add_extensions(move |ceb| { + ceb.add_default_prelude_extensions() + .add_extension(GlobalsCodegenExtension) + }); + let hugr = single_op_hugr( + GlobalsOp::Swap { + name: "my_global".to_string(), + ty: usize_t(), + } + .into(), + ); + check_emission!(hugr, llvm_ctx); + } +} diff --git a/tket-qsystem/src/llvm/snapshots/tket_qsystem__llvm__globals__test__emit_global_codegen@llvm21.snap b/tket-qsystem/src/llvm/snapshots/tket_qsystem__llvm__globals__test__emit_global_codegen@llvm21.snap new file mode 100644 index 000000000..bdf44ed78 --- /dev/null +++ b/tket-qsystem/src/llvm/snapshots/tket_qsystem__llvm__globals__test__emit_global_codegen@llvm21.snap @@ -0,0 +1,19 @@ +--- +source: tket-qsystem/src/llvm/globals.rs +assertion_line: 93 +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +@__globals__.my_global = global { i1, i64 } { i1 false, i64 poison } + +define internal { i1, i64 } @_hl.main.1({ i1, i64 } %0) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %current_value = load { i1, i64 }, ptr @__globals__.my_global, align 4 + store { i1, i64 } %0, ptr @__globals__.my_global, align 4 + ret { i1, i64 } %current_value +} diff --git a/tket-qsystem/src/llvm/snapshots/tket_qsystem__llvm__globals__test__emit_global_codegen@pre-mem2reg@llvm21.snap b/tket-qsystem/src/llvm/snapshots/tket_qsystem__llvm__globals__test__emit_global_codegen@pre-mem2reg@llvm21.snap new file mode 100644 index 000000000..c39991d60 --- /dev/null +++ b/tket-qsystem/src/llvm/snapshots/tket_qsystem__llvm__globals__test__emit_global_codegen@pre-mem2reg@llvm21.snap @@ -0,0 +1,28 @@ +--- +source: tket-qsystem/src/llvm/globals.rs +assertion_line: 93 +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +@__globals__.my_global = global { i1, i64 } { i1 false, i64 poison } + +define internal { i1, i64 } @_hl.main.1({ i1, i64 } %0) { +alloca_block: + %"0" = alloca { i1, i64 }, align 8 + %"2_0" = alloca { i1, i64 }, align 8 + %"4_0" = alloca { i1, i64 }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store { i1, i64 } %0, ptr %"2_0", align 4 + %"2_01" = load { i1, i64 }, ptr %"2_0", align 4 + %current_value = load { i1, i64 }, ptr @__globals__.my_global, align 4 + store { i1, i64 } %"2_01", ptr @__globals__.my_global, align 4 + store { i1, i64 } %current_value, ptr %"4_0", align 4 + %"4_02" = load { i1, i64 }, ptr %"4_0", align 4 + store { i1, i64 } %"4_02", ptr %"0", align 4 + %"03" = load { i1, i64 }, ptr %"0", align 4 + ret { i1, i64 } %"03" +}