diff --git a/effekt/shared/src/main/scala/effekt/core/Renamer.scala b/effekt/shared/src/main/scala/effekt/core/Renamer.scala index 40e6d4003..e2d3749ba 100644 --- a/effekt/shared/src/main/scala/effekt/core/Renamer.scala +++ b/effekt/shared/src/main/scala/effekt/core/Renamer.scala @@ -42,11 +42,9 @@ class Renamer(names: Names = Names(Map.empty), prefix: String = "") extends core def withBinding[R](id: Id)(f: => R): R = withBindings(List(id))(f) // free variables are left untouched - override def id: PartialFunction[core.Id, core.Id] = { - id => scope.getOrElse(id, id) - } + override def rewrite(id: Id): Id = scope.getOrElse(id, id) - override def stmt: PartialFunction[Stmt, Stmt] = { + override def rewrite(stmt: Stmt): Stmt = stmt match { case core.Def(id, block, body) => // can be recursive withBinding(id) { core.Def(rewrite(id), rewrite(block), rewrite(body)) } @@ -84,9 +82,11 @@ class Renamer(names: Names = Names(Map.empty), prefix: String = "") extends core case core.Shift(p, k, body) => val resolvedPrompt = rewrite(p) withBinding(k.id) { core.Shift(resolvedPrompt, rewrite(k), rewrite(body)) } + + case other => super.rewrite(other) } - override def block: PartialFunction[Block, Block] = { + override def rewrite(block: BlockLit): BlockLit = block match { case Block.BlockLit(tparams, cparams, vparams, bparams, body) => withBindings(tparams ++ cparams ++ vparams.map(_.id) ++ bparams.map(_.id)) { Block.BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite, diff --git a/effekt/shared/src/main/scala/effekt/core/TestRenamer.scala b/effekt/shared/src/main/scala/effekt/core/TestRenamer.scala index a7893a7cf..4d11682c4 100644 --- a/effekt/shared/src/main/scala/effekt/core/TestRenamer.scala +++ b/effekt/shared/src/main/scala/effekt/core/TestRenamer.scala @@ -68,28 +68,26 @@ class TestRenamer(names: Names = Names(Map.empty), prefix: String = "$", preserv // Top-level items may be mutually recursive. This means that a bound occurrence may precede its binding. // We use a separate pass to collect all top-level ids, so that we can distinguish them from free variables. - override def id: PartialFunction[core.Id, core.Id] = { - case id => - if (builtins.isCoreBuiltin(id)) { - // builtin, do not rename - id - } else { - scopes.collectFirst { - // locally bound variable - case bnds if bnds.contains(id) => bnds(id) - }.getOrElse { - if (toplevelScope.contains(id)) { - // id references a top-level item - toplevelScope(id) - } else { - // free variable, do not rename - id - } + override def rewrite(id: Id): Id = + if (builtins.isCoreBuiltin(id)) { + // builtin, do not rename + id + } else { + scopes.collectFirst { + // locally bound variable + case bnds if bnds.contains(id) => bnds(id) + }.getOrElse { + if (toplevelScope.contains(id)) { + // id references a top-level item + toplevelScope(id) + } else { + // free variable, do not rename + id } } - } + } - override def stmt: PartialFunction[Stmt, Stmt] = { + override def rewrite(stmt: Stmt): Stmt = stmt match { case core.Def(id, block, body) => // can be recursive withBinding(id) { core.Def(rewrite(id), rewrite(block), rewrite(body)) } @@ -127,9 +125,11 @@ class TestRenamer(names: Names = Names(Map.empty), prefix: String = "$", preserv case core.Shift(p, k, body) => val resolvedPrompt = rewrite(p) withBinding(k.id) { core.Shift(resolvedPrompt, rewrite(k), rewrite(body)) } + + case other => super.rewrite(other) } - override def block: PartialFunction[Block, Block] = { + override def rewrite(block: BlockLit): BlockLit = block match { case Block.BlockLit(tparams, cparams, vparams, bparams, body) => withBindings(tparams ++ cparams ++ vparams.map(_.id) ++ bparams.map(_.id)) { Block.BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite, diff --git a/effekt/shared/src/main/scala/effekt/core/Tree.scala b/effekt/shared/src/main/scala/effekt/core/Tree.scala index 53e1010c4..ff8486161 100644 --- a/effekt/shared/src/main/scala/effekt/core/Tree.scala +++ b/effekt/shared/src/main/scala/effekt/core/Tree.scala @@ -414,66 +414,54 @@ object Tree { def all[T](t: IterableOnce[T], f: T => Res): Res = t.iterator.foldLeft(empty) { case (xs, t) => combine(f(t), xs) } - def pure(using Ctx): PartialFunction[Expr, Res] = PartialFunction.empty - def stmt(using Ctx): PartialFunction[Stmt, Res] = PartialFunction.empty - def block(using Ctx): PartialFunction[Block, Res] = PartialFunction.empty - def toplevel(using Ctx): PartialFunction[Toplevel, Res] = PartialFunction.empty - def implementation(using Ctx): PartialFunction[Implementation, Res] = PartialFunction.empty - def operation(using Ctx): PartialFunction[Operation, Res] = PartialFunction.empty - def clause(using Ctx): PartialFunction[(Id, BlockLit), Res] = PartialFunction.empty - def externBody(using Ctx): PartialFunction[ExternBody, Res] = PartialFunction.empty - /** * Hook that can be overridden to perform an action at every node in the tree */ def visit[T](t: T)(visitor: Ctx ?=> T => Res)(using Ctx): Res = visitor(t) - inline def structuralQuery[T](el: T, pf: PartialFunction[T, Res])(using Ctx): Res = visit(el) { t => - if pf.isDefinedAt(el) then pf.apply(el) else queryStructurally(t, empty, combine) + inline def structuralQuery[T](el: T)(using Ctx): Res = visit(el) { t => + queryStructurally(t, empty, combine) } - def query(p: Expr)(using Ctx): Res = structuralQuery(p, pure) - def query(s: Stmt)(using Ctx): Res = structuralQuery(s, stmt) - def query(b: Block)(using Ctx): Res = structuralQuery(b, block) - def query(d: Toplevel)(using Ctx): Res = structuralQuery(d, toplevel) - def query(d: Implementation)(using Ctx): Res = structuralQuery(d, implementation) - def query(d: Operation)(using Ctx): Res = structuralQuery(d, operation) - def query(matchClause: (Id, BlockLit))(using Ctx): Res = - if clause.isDefinedAt(matchClause) then clause.apply(matchClause) else matchClause match { + def query(p: Expr)(using Ctx): Res = structuralQuery(p) + def query(s: Stmt)(using Ctx): Res = structuralQuery(s) + def query(b: Block)(using Ctx): Res = structuralQuery(b) + def query(d: Toplevel)(using Ctx): Res = structuralQuery(d) + def query(d: Implementation)(using Ctx): Res = structuralQuery(d) + def query(d: Operation)(using Ctx): Res = structuralQuery(d) + def query(matchClause: (Id, BlockLit))(using Ctx): Res = matchClause match { case (id, lit) => query(lit) } - def query(b: ExternBody)(using Ctx): Res = structuralQuery(b, externBody) - def query(m: ModuleDecl)(using Ctx) = structuralQuery(m, PartialFunction.empty) + def query(b: ExternBody)(using Ctx): Res = structuralQuery(b) + def query(m: ModuleDecl)(using Ctx) = structuralQuery(m) } class Rewrite extends Structural { - def id: PartialFunction[Id, Id] = PartialFunction.empty - def pure: PartialFunction[Expr, Expr] = PartialFunction.empty - def stmt: PartialFunction[Stmt, Stmt] = PartialFunction.empty - def toplevel: PartialFunction[Toplevel, Toplevel] = PartialFunction.empty - def block: PartialFunction[Block, Block] = PartialFunction.empty - def implementation: PartialFunction[Implementation, Implementation] = PartialFunction.empty - - def rewrite(x: Id): Id = if id.isDefinedAt(x) then id(x) else x - def rewrite(p: Expr): Expr = rewriteStructurally(p, pure) - def rewrite(s: Stmt): Stmt = rewriteStructurally(s, stmt) - def rewrite(b: Block): Block = rewriteStructurally(b, block) - def rewrite(d: Toplevel): Toplevel = rewriteStructurally(d, toplevel) - def rewrite(e: Implementation): Implementation = rewriteStructurally(e, implementation) + def rewrite(x: Id): Id = x + def rewrite(p: Expr): Expr = rewriteStructurally(p) + def rewrite(s: Stmt): Stmt = rewriteStructurally(s) + def rewrite(block: Block): Block = block match { + case b : Block.BlockVar => rewrite(b) + case b : Block.BlockLit => rewrite(b) + case Block.Unbox(pure) => Block.Unbox(rewrite(pure)) + case Block.New(impl) => Block.New(rewrite(impl)) + } + def rewrite(d: Toplevel): Toplevel = rewriteStructurally(d) + def rewrite(e: Implementation): Implementation = rewriteStructurally(e) def rewrite(o: Operation): Operation = rewriteStructurally(o) def rewrite(p: ValueParam): ValueParam = rewriteStructurally(p) def rewrite(p: BlockParam): BlockParam = rewriteStructurally(p) - def rewrite(b: ExternBody): ExternBody= rewriteStructurally(b) - def rewrite(e: Extern): Extern= rewriteStructurally(e) + def rewrite(b: ExternBody): ExternBody = rewriteStructurally(b) + def rewrite(e: Extern): Extern = rewriteStructurally(e) def rewrite(d: Declaration): Declaration = rewriteStructurally(d) def rewrite(c: Constructor): Constructor = rewriteStructurally(c) def rewrite(f: Field): Field = rewriteStructurally(f) - def rewrite(b: BlockLit): BlockLit = if block.isDefinedAt(b) then block(b).asInstanceOf else b match { + def rewrite(b: BlockLit): BlockLit = b match { case BlockLit(tparams, cparams, vparams, bparams, body) => BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite, rewrite(body)) } - def rewrite(b: BlockVar): BlockVar = if block.isDefinedAt(b) then block(b).asInstanceOf else b match { + def rewrite(b: BlockVar): BlockVar = b match { case BlockVar(id, annotatedTpe, annotatedCapt) => BlockVar(rewrite(id), rewrite(annotatedTpe), rewrite(annotatedCapt)) } @@ -496,7 +484,7 @@ object Tree { } } - class TrampolinedRewrite { + class RewriteTrampolined { import Trampoline.done @@ -753,29 +741,28 @@ object Tree { } class RewriteWithContext[Ctx] extends Structural { - def id(using Ctx): PartialFunction[Id, Id] = PartialFunction.empty - def expr(using Ctx): PartialFunction[Expr, Expr] = PartialFunction.empty - def stmt(using Ctx): PartialFunction[Stmt, Stmt] = PartialFunction.empty - def toplevel(using Ctx): PartialFunction[Toplevel, Toplevel] = PartialFunction.empty - def block(using Ctx): PartialFunction[Block, Block] = PartialFunction.empty - def implementation(using Ctx): PartialFunction[Implementation, Implementation] = PartialFunction.empty - - def rewrite(x: Id)(using Ctx): Id = if id.isDefinedAt(x) then id(x) else x - def rewrite(p: Expr)(using Ctx): Expr = rewriteStructurally(p, expr) - def rewrite(s: Stmt)(using Ctx): Stmt = rewriteStructurally(s, stmt) - def rewrite(b: Block)(using Ctx): Block = rewriteStructurally(b, block) - def rewrite(d: Toplevel)(using Ctx): Toplevel = rewriteStructurally(d, toplevel) - def rewrite(e: Implementation)(using Ctx): Implementation = rewriteStructurally(e, implementation) + + def rewrite(x: Id)(using Ctx): Id = x + def rewrite(p: Expr)(using Ctx): Expr = rewriteStructurally(p) + def rewrite(s: Stmt)(using Ctx): Stmt = rewriteStructurally(s) + def rewrite(b: Block)(using Ctx): Block = b match { + case b : Block.BlockVar => rewrite(b) + case b : Block.BlockLit => rewrite(b) + case Block.Unbox(pure) => Block.Unbox(rewrite(pure)) + case Block.New(impl) => Block.New(rewrite(impl)) + } + def rewrite(d: Toplevel)(using Ctx): Toplevel = rewriteStructurally(d) + def rewrite(e: Implementation)(using Ctx): Implementation = rewriteStructurally(e) def rewrite(o: Operation)(using Ctx): Operation = rewriteStructurally(o) def rewrite(p: ValueParam)(using Ctx): ValueParam = rewriteStructurally(p) def rewrite(p: BlockParam)(using Ctx): BlockParam = rewriteStructurally(p) def rewrite(b: ExternBody)(using Ctx): ExternBody= rewrite(b) - def rewrite(b: BlockLit)(using Ctx): BlockLit = if block.isDefinedAt(b) then block(b).asInstanceOf else b match { + def rewrite(b: BlockLit)(using Ctx): BlockLit = b match { case BlockLit(tparams, cparams, vparams, bparams, body) => BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite, rewrite(body)) } - def rewrite(b: BlockVar)(using Ctx): BlockVar = if block.isDefinedAt(b) then block(b).asInstanceOf else b match { + def rewrite(b: BlockVar)(using Ctx): BlockVar = b match { case BlockVar(id, annotatedTpe, annotatedCapt) => BlockVar(rewrite(id), rewrite(annotatedTpe), rewrite(annotatedCapt)) } diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/Deadcode.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Deadcode.scala index 7df135181..2661637ae 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/Deadcode.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Deadcode.scala @@ -4,7 +4,7 @@ package optimizer import util.Trampoline -class Deadcode(reachable: Map[Id, Usage]) extends core.Tree.TrampolinedRewrite { +class Deadcode(reachable: Map[Id, Usage]) extends core.Tree.RewriteTrampolined { private def used(id: Id): Boolean = reachable.get(id).exists(u => u != Usage.Never) diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/DirectStyle.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/DirectStyle.scala index 09a67333c..72f6e68a3 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/DirectStyle.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/DirectStyle.scala @@ -4,7 +4,7 @@ package optimizer object DirectStyle extends Tree.Rewrite { - override def stmt = { + override def rewrite(stmt: Stmt): Stmt = stmt match { // val x = { ... return 42 }; stmt2 // @@ -24,6 +24,7 @@ object DirectStyle extends Tree.Rewrite { else Val(id, rewrittenBinding, rewrittenBody) + case other => super.rewrite(other) } private def canBeDirect(s: Stmt): Boolean = s match { diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/DropBindings.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/DropBindings.scala index 26b149d2a..2bf940a63 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/DropBindings.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/DropBindings.scala @@ -46,14 +46,16 @@ object DropBindings extends Phase[CoreTransformed, CoreTransformed] { private object dropping extends Tree.RewriteWithContext[DropContext] { - override def expr(using DropContext) = { + override def rewrite(expr: Expr)(using DropContext): Expr = expr match { case Expr.ValueVar(id, tpe) if usedOnce(id) && hasDefinition(id) => definitionOf(id) + case other => super.rewrite(other) } - override def stmt(using C: DropContext) = { + override def rewrite(stmt: Stmt)(using C: DropContext): Stmt = stmt match { case Stmt.Let(id, p: Expr, body) if usedOnce(id) => val transformed = rewrite(p) rewrite(body)(using C.updated(id, transformed)) + case other => super.rewrite(other) } } } diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala index 6b6855b43..159a9410e 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala @@ -7,10 +7,12 @@ object RemoveTailResumptions { def apply(m: ModuleDecl): ModuleDecl = removal.rewrite(m) object removal extends Tree.Rewrite { - override def stmt: PartialFunction[Stmt, Stmt] = { + override def rewrite(stmt: Stmt): Stmt = stmt match { case Stmt.Shift(prompt, BlockParam(k, Type.TResume(from, to), capt), body) if tailResumptive(k, body) => removeTailResumption(k, from, body) case Stmt.Shift(prompt, k, body) => Shift(prompt, k, rewrite(body)) + + case other => super.rewrite(other) } }