From 930448717a238d2fa69bb137e9de5b6d440b1539 Mon Sep 17 00:00:00 2001 From: Moritz Scheuerle <68465911+Plixo2@users.noreply.github.com> Date: Tue, 18 Nov 2025 18:24:24 +0100 Subject: [PATCH] Refactor TransformerOps into a helper to operate on state in Context --- .../main/scala/effekt/context/Context.scala | 6 +- .../main/scala/effekt/core/BindingDB.scala | 15 +++ .../main/scala/effekt/core/Transformer.scala | 92 +++++++++---------- 3 files changed, 65 insertions(+), 48 deletions(-) create mode 100644 effekt/shared/src/main/scala/effekt/core/BindingDB.scala diff --git a/effekt/shared/src/main/scala/effekt/context/Context.scala b/effekt/shared/src/main/scala/effekt/context/Context.scala index 47b8c84c3b..dbc7fac50a 100644 --- a/effekt/shared/src/main/scala/effekt/context/Context.scala +++ b/effekt/shared/src/main/scala/effekt/context/Context.scala @@ -3,7 +3,7 @@ package context import effekt.namer.NamerOps import effekt.typer.{TyperOps, Unification} -import effekt.core.TransformerOps +import effekt.core.{BindingDB, TransformerOps} import effekt.source.Tree import effekt.util.messages.{EffektMessages, ErrorReporter} import effekt.util.Timers @@ -42,12 +42,14 @@ abstract class Context extends NamerOps with TyperOps with ModuleDB - with TransformerOps with Timers { // bring the context itself in scope implicit val context: Context = this + // Storage for bindings + var bindingDB: BindingDB = new BindingDB + // the currently processed module var module: Module = _ diff --git a/effekt/shared/src/main/scala/effekt/core/BindingDB.scala b/effekt/shared/src/main/scala/effekt/core/BindingDB.scala new file mode 100644 index 0000000000..41c47991f7 --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/BindingDB.scala @@ -0,0 +1,15 @@ +package effekt.core + + +import scala.collection.mutable.ListBuffer + +/** + * Storage for bindings + */ +final class BindingDB { + + /** + * A _mutable_ ListBuffer that stores all bindings to be inserted at the current scope + */ + var bindings: ListBuffer[Binding] = ListBuffer() +} diff --git a/effekt/shared/src/main/scala/effekt/core/Transformer.scala b/effekt/shared/src/main/scala/effekt/core/Transformer.scala index f56064b43e..1cd482bcef 100644 --- a/effekt/shared/src/main/scala/effekt/core/Transformer.scala +++ b/effekt/shared/src/main/scala/effekt/core/Transformer.scala @@ -18,7 +18,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { def run(input: Typechecked)(using Context) = val Typechecked(source, tree, mod) = input - Context.initTransformerState() + TransformerOps.initTransformerState() if (Context.messaging.hasErrors) { None @@ -90,7 +90,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { case v @ source.DefDef(id, captures, annot, binding, doc, span) => val sym = v.symbol - val (definition, bindings) = Context.withBindings { + val (definition, bindings) = TransformerOps.withBindings { Toplevel.Def(sym, transformAsBlock(binding)) } @@ -186,13 +186,13 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { case v @ source.RegDef(id, _, reg, binding, doc, span) => val sym = v.symbol insertBindings { - Alloc(sym, Context.bind(transform(binding)), sym.region, transform(rest)) + Alloc(sym, TransformerOps.bind(transform(binding)), sym.region, transform(rest)) } case v @ source.VarDef(id, _, binding, doc, span) => val sym = v.symbol insertBindings { - Var(sym, Context.bind(transform(binding)), sym.capture, transform(rest)) + Var(sym, TransformerOps.bind(transform(binding)), sym.capture, transform(rest)) } case d: source.Def.Extern => Context.panic("Only allowed on the toplevel") @@ -321,7 +321,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { val tpe = TState.extractType(stateType) val stateId = Id("s") // emits `let s = !ref; return s` - Context.bind(Get(stateId, transform(tpe), sym, transform(Context.captureOf(sym)), Return(core.ValueVar(stateId, transform(tpe))))) + TransformerOps.bind(Get(stateId, transform(tpe), sym, transform(Context.captureOf(sym)), Return(core.ValueVar(stateId, transform(tpe))))) case sym: ValueSymbol => ValueVar(sym) case sym: BlockSymbol => transformBox(tree) } @@ -356,7 +356,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { val tpe = transform(substitution.substitute(f.returnType)) core.ValueParam(if f == field then selected else Id("_"), tpe) } - Context.bind(Stmt.Match(transformAsExpr(receiver), + TransformerOps.bind(Stmt.Match(transformAsExpr(receiver), List((constructor, BlockLit(Nil, Nil, params, Nil, Stmt.Return(Expr.ValueVar(selected, tpe))))), None)) case source.Box(capt, block, _) => @@ -373,12 +373,12 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { case source.If(List(MatchGuard.BooleanGuard(cond, _)), thn, els, _) => val c = transformAsExpr(cond) - Context.bind(If(c, transform(thn), transform(els))) + TransformerOps.bind(If(c, transform(thn), transform(els))) case source.If(guards, thn, els, _) => val thnClause = preprocess("thn", Nil, guards, transform(thn)) val elsClause = preprocess("els", Nil, Nil, transform(els)) - Context.bind(PatternMatchingCompiler.compile(List(thnClause, elsClause))) + TransformerOps.bind(PatternMatchingCompiler.compile(List(thnClause, elsClause))) // case i @ source.If(guards, thn, els) => // val compiled = collectClauses(i) @@ -412,22 +412,22 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { } } - Context.bind(loopName, Block.BlockLit(Nil, Nil, Nil, Nil, loopBody)) + TransformerOps.bind(loopName, Block.BlockLit(Nil, Nil, Nil, Nil, loopBody)) - Context.bind(loopCall) + TransformerOps.bind(loopCall) // Empty match (matching on Nothing) case source.Match(List(sc), Nil, None, _) => - val scrutinee: ValueVar = Context.bind(transformAsExpr(sc)) - Context.bind(core.Match(scrutinee, Nil, None)) + val scrutinee: ValueVar = TransformerOps.bind(transformAsExpr(sc)) + TransformerOps.bind(core.Match(scrutinee, Nil, None)) case source.Match(scs, cs, default, _) => // (1) Bind scrutinee and all clauses so we do not have to deal with sharing on demand. - val scrutinees: List[ValueVar] = scs.map{ sc => Context.bind(transformAsExpr(sc)) } + val scrutinees: List[ValueVar] = scs.map{ sc => TransformerOps.bind(transformAsExpr(sc)) } val clauses = cs.zipWithIndex.map((c, i) => preprocess(s"k${i}", scrutinees, c)) val defaultClause = default.map(stmt => preprocess("k_els", Nil, Nil, transform(stmt))).toList val compiledMatch = PatternMatchingCompiler.compile(clauses ++ defaultClause) - Context.bind(compiledMatch) + TransformerOps.bind(compiledMatch) case source.TryHandle(prog, handlers, _) => @@ -451,21 +451,21 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { val body: BlockLit = BlockLit(Nil, List(promptCapt), Nil, List(promptParam), Binding(transformedHandlers, transform(prog))) - Context.bind(Reset(body)) + TransformerOps.bind(Reset(body)) case r @ source.Region(name, body, _) => val region = r.symbol val tpe = Context.blockTypeOf(region) val cap: core.BlockParam = core.BlockParam(region, transform(tpe), Set(region.capture)) - Context.bind(Region(BlockLit(Nil, List(region.capture), Nil, List(cap), transform(body)))) + TransformerOps.bind(Region(BlockLit(Nil, List(region.capture), Nil, List(cap), transform(body)))) case source.Hole(id, stmts, span) => - Context.bind(core.Hole(span)) + TransformerOps.bind(core.Hole(span)) case a @ source.Assign(id, expr, _) => val sym = a.definition // emits `ref := value; return ()` - Context.bind(Put(sym, transform(Context.captureOf(sym)), transformAsExpr(expr), Return(Literal((), core.Type.TUnit)))) + TransformerOps.bind(Put(sym, transform(Context.captureOf(sym)), transformAsExpr(expr), Return(Literal((), core.Type.TUnit)))) Literal((), core.Type.TUnit) // methods are dynamically dispatched, so we have to assume they are `control`, hence no PureApp. @@ -485,7 +485,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { // Do not pass type arguments for the type constructor of the receiver. val remainingTypeArgs = typeArgs.drop(operation.interface.tparams.size) - Context.bind(Invoke(rec, operation, opType, remainingTypeArgs, valueArgs, blockArgs)) + TransformerOps.bind(Invoke(rec, operation, opType, remainingTypeArgs, valueArgs, blockArgs)) case c @ source.Call(source.ExprTarget(source.Unbox(expr, _)), targs, vargs, bargs, _) => @@ -499,7 +499,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { val blockArgs = bargs.map(transformAsBlock) // val captArgs = blockArgs.map(b => b.capt) //transform(Context.inferredCapture(b))) - Context.bind(App(Unbox(e), typeArgs, valueArgs, blockArgs)) + TransformerOps.bind(App(Unbox(e), typeArgs, valueArgs, blockArgs)) case c @ source.Call(fun: source.IdTarget, _, vargs, bargs, _) => // assumption: typer removed all ambiguous references, so there is exactly one @@ -694,7 +694,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { // create joinpoint val tparams = patterns.flatMap { case (sc, p) => boundTypesInPattern(p) } ++ guards.flatMap(boundTypesInGuard) val params = patterns.flatMap { case (sc, p) => boundInPattern(p) } ++ guards.flatMap(boundInGuard) - val joinpoint = Context.bind(TmpBlock(label), BlockLit(tparams, Nil, params, Nil, body)) + val joinpoint = TransformerOps.bind(TmpBlock(label), BlockLit(tparams, Nil, params, Nil, body)) def transformPattern(p: source.MatchPattern): Pattern = p match { case source.AnyPattern(id, _) => @@ -721,12 +721,12 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { } def transformGuard(p: source.MatchGuard): List[Condition] = - val (cond, bindings) = Context.withBindings { + val (cond, bindings) = TransformerOps.withBindings { p match { case MatchGuard.BooleanGuard(condition, _) => Condition.Predicate(transformAsExpr(condition)) case MatchGuard.PatternGuard(scrutinee, pattern, _) => - val x = Context.bind(transformAsExpr(scrutinee)) + val x = TransformerOps.bind(transformAsExpr(scrutinee)) Condition.Patterns(Map(x -> transformPattern(pattern))) } } @@ -807,7 +807,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { case f: Callable if callingConvention(f) == CallingConvention.Pure => PureApp(BlockVar(f), targs, vargsT) case f: Callable if callingConvention(f) == CallingConvention.Direct => - Context.bind(BlockVar(f), targs, vargsT, bargsT) + TransformerOps.bind(BlockVar(f), targs, vargsT, bargsT) case r: Constructor => if (bargs.nonEmpty) Context.abort("Constructors cannot take block arguments.") val universals = targs.take(r.tpe.tparams.length) @@ -818,9 +818,9 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { case f: Field => Context.panic("Should have been translated to a select!") case f: BlockSymbol => - Context.bind(App(BlockVar(f), targs, vargsT, bargsT)) + TransformerOps.bind(App(BlockVar(f), targs, vargsT, bargsT)) case f: ValueSymbol => - Context.bind(App(Unbox(ValueVar(f)), targs, vargsT, bargsT)) + TransformerOps.bind(App(Unbox(ValueVar(f)), targs, vargsT, bargsT)) } } @@ -829,7 +829,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { def transform(p: source.ValueParam)(using Context): core.ValueParam = ValueParam(p.symbol) def insertBindings(stmt: => Stmt)(using Context): Stmt = { - val (body, bindings) = Context.withBindings { stmt } + val (body, bindings) = TransformerOps.withBindings { stmt } Binding(bindings, body) } @@ -901,15 +901,15 @@ object Transformer extends Phase[Typechecked, CoreTransformed] { } -trait TransformerOps extends ContextOps { Context: Context => +/** + * Helper for dealing with bindings in the [[BindingDB]] + * As [[Context]] stores the [[BindingDB]], it is implicitly passed to all the functions + */ +object TransformerOps { - /** - * A _mutable_ ListBuffer that stores all bindings to be inserted at the current scope - */ - private var bindings: ListBuffer[Binding] = ListBuffer() - private[core] def initTransformerState() = { - bindings = ListBuffer() + private[core] def initTransformerState()(using Context) = { + Context.bindingDB.bindings = ListBuffer() } /** @@ -918,50 +918,50 @@ trait TransformerOps extends ContextOps { Context: Context => * @param tpe the type of the bound statement * @param s the statement to be bound */ - private[core] def bind(s: Stmt): ValueVar = { + private[core] def bind(s: Stmt)(using Context): ValueVar = { // create a fresh symbol and assign the type val x = TmpValue("r") val binding = Binding.Val(x, s.tpe, s) - bindings += binding + Context.bindingDB.bindings += binding ValueVar(x, s.tpe) } - private[core] def bind(e: Expr): ValueVar = e match { + private[core] def bind(e: Expr)(using Context): ValueVar = e match { case x: ValueVar => x case e => // create a fresh symbol and assign the type val x = TmpValue("r") val binding = Binding.Let(x, e.tpe, e) - bindings += binding + Context.bindingDB.bindings += binding ValueVar(x, e.tpe) } - private[core] def bind(callee: Block.BlockVar, targs: List[core.ValueType], vargs: List[Expr], bargs: List[Block]): ValueVar = { + private[core] def bind(callee: Block.BlockVar, targs: List[core.ValueType], vargs: List[Expr], bargs: List[Block])(using Context): ValueVar = { // create a fresh symbol and assign the type val x = TmpValue("r") val binding: Binding.ImpureApp = Binding.ImpureApp(x, callee, targs, vargs, bargs) - bindings += binding + Context.bindingDB.bindings += binding ValueVar(x, Type.bindingType(binding)) } - private[core] def bind(name: BlockSymbol, b: Block): BlockVar = { + private[core] def bind(name: BlockSymbol, b: Block)(using Context): BlockVar = { val binding = Binding.Def(name, b) - bindings += binding + Context.bindingDB.bindings += binding BlockVar(name, b.tpe, b.capt) } - private[core] def withBindings[R](block: => R): (R, List[Binding]) = Context in { - val before = bindings + private[core] def withBindings[R](block: => R)(using Context): (R, List[Binding]) = Context in { + val before = Context.bindingDB.bindings val b = ListBuffer.empty[Binding] - bindings = b + Context.bindingDB.bindings = b val result = block - bindings = before + Context.bindingDB.bindings = before (result, b.toList) } }