Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions effekt/shared/src/main/scala/effekt/core/Renamer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ class Renamer(names: Names = Names(Map.empty), prefix: String = "") extends core
def withBinding[R](id: Id)(f: => R): R = withBindings(List(id))(f)

// free variables are left untouched
override def id: PartialFunction[core.Id, core.Id] = {
id => scope.getOrElse(id, id)
}
override def rewrite(id: Id): Id = scope.getOrElse(id, id)

override def stmt: PartialFunction[Stmt, Stmt] = {
override def rewrite(stmt: Stmt): Stmt = stmt match {
case core.Def(id, block, body) =>
// can be recursive
withBinding(id) { core.Def(rewrite(id), rewrite(block), rewrite(body)) }
Expand Down Expand Up @@ -84,9 +82,11 @@ class Renamer(names: Names = Names(Map.empty), prefix: String = "") extends core
case core.Shift(p, k, body) =>
val resolvedPrompt = rewrite(p)
withBinding(k.id) { core.Shift(resolvedPrompt, rewrite(k), rewrite(body)) }

case other => super.rewrite(other)
}

override def block: PartialFunction[Block, Block] = {
override def rewrite(block: BlockLit): BlockLit = block match {
case Block.BlockLit(tparams, cparams, vparams, bparams, body) =>
withBindings(tparams ++ cparams ++ vparams.map(_.id) ++ bparams.map(_.id)) {
Block.BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite,
Expand Down
40 changes: 20 additions & 20 deletions effekt/shared/src/main/scala/effekt/core/TestRenamer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,26 @@ class TestRenamer(names: Names = Names(Map.empty), prefix: String = "$", preserv

// Top-level items may be mutually recursive. This means that a bound occurrence may precede its binding.
// We use a separate pass to collect all top-level ids, so that we can distinguish them from free variables.
override def id: PartialFunction[core.Id, core.Id] = {
case id =>
if (builtins.isCoreBuiltin(id)) {
// builtin, do not rename
id
} else {
scopes.collectFirst {
// locally bound variable
case bnds if bnds.contains(id) => bnds(id)
}.getOrElse {
if (toplevelScope.contains(id)) {
// id references a top-level item
toplevelScope(id)
} else {
// free variable, do not rename
id
}
override def rewrite(id: Id): Id =
if (builtins.isCoreBuiltin(id)) {
// builtin, do not rename
id
} else {
scopes.collectFirst {
// locally bound variable
case bnds if bnds.contains(id) => bnds(id)
}.getOrElse {
if (toplevelScope.contains(id)) {
// id references a top-level item
toplevelScope(id)
} else {
// free variable, do not rename
id
}
}
}
}

override def stmt: PartialFunction[Stmt, Stmt] = {
override def rewrite(stmt: Stmt): Stmt = stmt match {
case core.Def(id, block, body) =>
// can be recursive
withBinding(id) { core.Def(rewrite(id), rewrite(block), rewrite(body)) }
Expand Down Expand Up @@ -127,9 +125,11 @@ class TestRenamer(names: Names = Names(Map.empty), prefix: String = "$", preserv
case core.Shift(p, k, body) =>
val resolvedPrompt = rewrite(p)
withBinding(k.id) { core.Shift(resolvedPrompt, rewrite(k), rewrite(body)) }

case other => super.rewrite(other)
}

override def block: PartialFunction[Block, Block] = {
override def rewrite(block: BlockLit): BlockLit = block match {
case Block.BlockLit(tparams, cparams, vparams, bparams, body) =>
withBindings(tparams ++ cparams ++ vparams.map(_.id) ++ bparams.map(_.id)) {
Block.BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite,
Expand Down
95 changes: 41 additions & 54 deletions effekt/shared/src/main/scala/effekt/core/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,66 +414,54 @@ object Tree {
def all[T](t: IterableOnce[T], f: T => Res): Res =
t.iterator.foldLeft(empty) { case (xs, t) => combine(f(t), xs) }

def pure(using Ctx): PartialFunction[Expr, Res] = PartialFunction.empty
def stmt(using Ctx): PartialFunction[Stmt, Res] = PartialFunction.empty
def block(using Ctx): PartialFunction[Block, Res] = PartialFunction.empty
def toplevel(using Ctx): PartialFunction[Toplevel, Res] = PartialFunction.empty
def implementation(using Ctx): PartialFunction[Implementation, Res] = PartialFunction.empty
def operation(using Ctx): PartialFunction[Operation, Res] = PartialFunction.empty
def clause(using Ctx): PartialFunction[(Id, BlockLit), Res] = PartialFunction.empty
def externBody(using Ctx): PartialFunction[ExternBody, Res] = PartialFunction.empty

/**
* Hook that can be overridden to perform an action at every node in the tree
*/
def visit[T](t: T)(visitor: Ctx ?=> T => Res)(using Ctx): Res = visitor(t)

inline def structuralQuery[T](el: T, pf: PartialFunction[T, Res])(using Ctx): Res = visit(el) { t =>
if pf.isDefinedAt(el) then pf.apply(el) else queryStructurally(t, empty, combine)
inline def structuralQuery[T](el: T)(using Ctx): Res = visit(el) { t =>
queryStructurally(t, empty, combine)
}

def query(p: Expr)(using Ctx): Res = structuralQuery(p, pure)
def query(s: Stmt)(using Ctx): Res = structuralQuery(s, stmt)
def query(b: Block)(using Ctx): Res = structuralQuery(b, block)
def query(d: Toplevel)(using Ctx): Res = structuralQuery(d, toplevel)
def query(d: Implementation)(using Ctx): Res = structuralQuery(d, implementation)
def query(d: Operation)(using Ctx): Res = structuralQuery(d, operation)
def query(matchClause: (Id, BlockLit))(using Ctx): Res =
if clause.isDefinedAt(matchClause) then clause.apply(matchClause) else matchClause match {
def query(p: Expr)(using Ctx): Res = structuralQuery(p)
def query(s: Stmt)(using Ctx): Res = structuralQuery(s)
def query(b: Block)(using Ctx): Res = structuralQuery(b)
def query(d: Toplevel)(using Ctx): Res = structuralQuery(d)
def query(d: Implementation)(using Ctx): Res = structuralQuery(d)
def query(d: Operation)(using Ctx): Res = structuralQuery(d)
def query(matchClause: (Id, BlockLit))(using Ctx): Res = matchClause match {
case (id, lit) => query(lit)
}
def query(b: ExternBody)(using Ctx): Res = structuralQuery(b, externBody)
def query(m: ModuleDecl)(using Ctx) = structuralQuery(m, PartialFunction.empty)
def query(b: ExternBody)(using Ctx): Res = structuralQuery(b)
def query(m: ModuleDecl)(using Ctx) = structuralQuery(m)
}

class Rewrite extends Structural {
def id: PartialFunction[Id, Id] = PartialFunction.empty
def pure: PartialFunction[Expr, Expr] = PartialFunction.empty
def stmt: PartialFunction[Stmt, Stmt] = PartialFunction.empty
def toplevel: PartialFunction[Toplevel, Toplevel] = PartialFunction.empty
def block: PartialFunction[Block, Block] = PartialFunction.empty
def implementation: PartialFunction[Implementation, Implementation] = PartialFunction.empty

def rewrite(x: Id): Id = if id.isDefinedAt(x) then id(x) else x
def rewrite(p: Expr): Expr = rewriteStructurally(p, pure)
def rewrite(s: Stmt): Stmt = rewriteStructurally(s, stmt)
def rewrite(b: Block): Block = rewriteStructurally(b, block)
def rewrite(d: Toplevel): Toplevel = rewriteStructurally(d, toplevel)
def rewrite(e: Implementation): Implementation = rewriteStructurally(e, implementation)
def rewrite(x: Id): Id = x
def rewrite(p: Expr): Expr = rewriteStructurally(p)
def rewrite(s: Stmt): Stmt = rewriteStructurally(s)
def rewrite(block: Block): Block = block match {
case b : Block.BlockVar => rewrite(b)
case b : Block.BlockLit => rewrite(b)
case Block.Unbox(pure) => Block.Unbox(rewrite(pure))
case Block.New(impl) => Block.New(rewrite(impl))
}
def rewrite(d: Toplevel): Toplevel = rewriteStructurally(d)
def rewrite(e: Implementation): Implementation = rewriteStructurally(e)
def rewrite(o: Operation): Operation = rewriteStructurally(o)
def rewrite(p: ValueParam): ValueParam = rewriteStructurally(p)
def rewrite(p: BlockParam): BlockParam = rewriteStructurally(p)
def rewrite(b: ExternBody): ExternBody= rewriteStructurally(b)
def rewrite(e: Extern): Extern= rewriteStructurally(e)
def rewrite(b: ExternBody): ExternBody = rewriteStructurally(b)
def rewrite(e: Extern): Extern = rewriteStructurally(e)
def rewrite(d: Declaration): Declaration = rewriteStructurally(d)
def rewrite(c: Constructor): Constructor = rewriteStructurally(c)
def rewrite(f: Field): Field = rewriteStructurally(f)

def rewrite(b: BlockLit): BlockLit = if block.isDefinedAt(b) then block(b).asInstanceOf else b match {
def rewrite(b: BlockLit): BlockLit = b match {
case BlockLit(tparams, cparams, vparams, bparams, body) =>
BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite, rewrite(body))
}
def rewrite(b: BlockVar): BlockVar = if block.isDefinedAt(b) then block(b).asInstanceOf else b match {
def rewrite(b: BlockVar): BlockVar = b match {
case BlockVar(id, annotatedTpe, annotatedCapt) => BlockVar(rewrite(id), rewrite(annotatedTpe), rewrite(annotatedCapt))
}

Expand All @@ -496,7 +484,7 @@ object Tree {
}
}

class TrampolinedRewrite {
class RewriteTrampolined {

import Trampoline.done

Expand Down Expand Up @@ -753,29 +741,28 @@ object Tree {
}

class RewriteWithContext[Ctx] extends Structural {
def id(using Ctx): PartialFunction[Id, Id] = PartialFunction.empty
def expr(using Ctx): PartialFunction[Expr, Expr] = PartialFunction.empty
def stmt(using Ctx): PartialFunction[Stmt, Stmt] = PartialFunction.empty
def toplevel(using Ctx): PartialFunction[Toplevel, Toplevel] = PartialFunction.empty
def block(using Ctx): PartialFunction[Block, Block] = PartialFunction.empty
def implementation(using Ctx): PartialFunction[Implementation, Implementation] = PartialFunction.empty

def rewrite(x: Id)(using Ctx): Id = if id.isDefinedAt(x) then id(x) else x
def rewrite(p: Expr)(using Ctx): Expr = rewriteStructurally(p, expr)
def rewrite(s: Stmt)(using Ctx): Stmt = rewriteStructurally(s, stmt)
def rewrite(b: Block)(using Ctx): Block = rewriteStructurally(b, block)
def rewrite(d: Toplevel)(using Ctx): Toplevel = rewriteStructurally(d, toplevel)
def rewrite(e: Implementation)(using Ctx): Implementation = rewriteStructurally(e, implementation)

def rewrite(x: Id)(using Ctx): Id = x
def rewrite(p: Expr)(using Ctx): Expr = rewriteStructurally(p)
def rewrite(s: Stmt)(using Ctx): Stmt = rewriteStructurally(s)
def rewrite(b: Block)(using Ctx): Block = b match {
case b : Block.BlockVar => rewrite(b)
case b : Block.BlockLit => rewrite(b)
case Block.Unbox(pure) => Block.Unbox(rewrite(pure))
case Block.New(impl) => Block.New(rewrite(impl))
}
def rewrite(d: Toplevel)(using Ctx): Toplevel = rewriteStructurally(d)
def rewrite(e: Implementation)(using Ctx): Implementation = rewriteStructurally(e)
def rewrite(o: Operation)(using Ctx): Operation = rewriteStructurally(o)
def rewrite(p: ValueParam)(using Ctx): ValueParam = rewriteStructurally(p)
def rewrite(p: BlockParam)(using Ctx): BlockParam = rewriteStructurally(p)
def rewrite(b: ExternBody)(using Ctx): ExternBody= rewrite(b)

def rewrite(b: BlockLit)(using Ctx): BlockLit = if block.isDefinedAt(b) then block(b).asInstanceOf else b match {
def rewrite(b: BlockLit)(using Ctx): BlockLit = b match {
case BlockLit(tparams, cparams, vparams, bparams, body) =>
BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite, rewrite(body))
}
def rewrite(b: BlockVar)(using Ctx): BlockVar = if block.isDefinedAt(b) then block(b).asInstanceOf else b match {
def rewrite(b: BlockVar)(using Ctx): BlockVar = b match {
case BlockVar(id, annotatedTpe, annotatedCapt) => BlockVar(rewrite(id), rewrite(annotatedTpe), rewrite(annotatedCapt))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package optimizer

import util.Trampoline

class Deadcode(reachable: Map[Id, Usage]) extends core.Tree.TrampolinedRewrite {
class Deadcode(reachable: Map[Id, Usage]) extends core.Tree.RewriteTrampolined {

private def used(id: Id): Boolean = reachable.get(id).exists(u => u != Usage.Never)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package optimizer

object DirectStyle extends Tree.Rewrite {

override def stmt = {
override def rewrite(stmt: Stmt): Stmt = stmt match {

// val x = { ... return 42 }; stmt2
//
Expand All @@ -24,6 +24,7 @@ object DirectStyle extends Tree.Rewrite {
else
Val(id, rewrittenBinding, rewrittenBody)

case other => super.rewrite(other)
}

private def canBeDirect(s: Stmt): Boolean = s match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ object DropBindings extends Phase[CoreTransformed, CoreTransformed] {

private object dropping extends Tree.RewriteWithContext[DropContext] {

override def expr(using DropContext) = {
override def rewrite(expr: Expr)(using DropContext): Expr = expr match {
case Expr.ValueVar(id, tpe) if usedOnce(id) && hasDefinition(id) => definitionOf(id)
case other => super.rewrite(other)
}

override def stmt(using C: DropContext) = {
override def rewrite(stmt: Stmt)(using C: DropContext): Stmt = stmt match {
case Stmt.Let(id, p: Expr, body) if usedOnce(id) =>
val transformed = rewrite(p)
rewrite(body)(using C.updated(id, transformed))
case other => super.rewrite(other)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ object RemoveTailResumptions {
def apply(m: ModuleDecl): ModuleDecl = removal.rewrite(m)

object removal extends Tree.Rewrite {
override def stmt: PartialFunction[Stmt, Stmt] = {
override def rewrite(stmt: Stmt): Stmt = stmt match {
case Stmt.Shift(prompt, BlockParam(k, Type.TResume(from, to), capt), body) if tailResumptive(k, body) =>
removeTailResumption(k, from, body)
case Stmt.Shift(prompt, k, body) => Shift(prompt, k, rewrite(body))

case other => super.rewrite(other)
}
}

Expand Down