diff --git a/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala b/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala index 64436275..f04b80a8 100644 --- a/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala +++ b/compiler/src/main/scala/play/twirl/compiler/TwirlCompiler.scala @@ -426,7 +426,11 @@ object TwirlCompiler { additionalImports: collection.Seq[String], constructorAnnotations: collection.Seq[String] ): collection.Seq[Any] = { - val (renderCall, f, templateType) = TemplateAsFunctionCompiler.getFunctionMapping(root.params.str, resultType) + val (renderCall, f, templateType) = TemplateAsFunctionCompiler.getFunctionMapping( + root.templateFunctionName.map(_.str).getOrElse("f"), + root.params.str, + resultType + ) // Get the imports that we need to include, filtering out empty imports val imports: Seq[Any] = Seq(additionalImports.map(i => Seq("import ", i, "\n")), formatImports(root.topImports)) @@ -520,7 +524,7 @@ package """ :+ packageName :+ """ /** The maximum time in milliseconds to wait for a compiler response to finish. */ private val Timeout = 10000 - def getFunctionMapping(signature: String, returnType: String): (String, String, String) = synchronized { + def getFunctionMapping(fn: String, signature: String, returnType: String): (String, String, String) = synchronized { def filterType(t: String) = t.replace("_root_.scala.", "Array") @@ -584,7 +588,8 @@ package """ :+ packageName :+ """ (if (params.flatten.isEmpty) "" else ",") + returnType ) - val f = "def f:%s = %s => apply%s".format( + val f = "def %s:%s = %s => apply%s".format( + fn, functionType, params.map(group => "(" + group.map(_.name.toString).mkString(",") + ")").mkString(" => "), params diff --git a/parser/src/main/scala/play/twirl/parser/TreeNodes.scala b/parser/src/main/scala/play/twirl/parser/TreeNodes.scala index bdbe938d..812210a5 100644 --- a/parser/src/main/scala/play/twirl/parser/TreeNodes.scala +++ b/parser/src/main/scala/play/twirl/parser/TreeNodes.scala @@ -15,6 +15,7 @@ object TreeNodes { name: PosString, constructor: Option[Constructor], comment: Option[Comment], + templateFunctionName: Option[PosString], params: PosString, topImports: collection.Seq[Simple], imports: collection.Seq[Simple], diff --git a/parser/src/main/scala/play/twirl/parser/TwirlParser.scala b/parser/src/main/scala/play/twirl/parser/TwirlParser.scala index 663613e8..da2c794f 100644 --- a/parser/src/main/scala/play/twirl/parser/TwirlParser.scala +++ b/parser/src/main/scala/play/twirl/parser/TwirlParser.scala @@ -810,7 +810,7 @@ class TwirlParser(val shouldParseInclusiveDot: Boolean) { if (check("{")) { val (imports, localDefs, templates, mixeds) = templateContent() if (check("}")) - result = Template(templDecl._1, None, None, templDecl._2, Nil, imports, localDefs, templates, mixeds) + result = Template(templDecl._1, None, None, None, templDecl._2, Nil, imports, localDefs, templates, mixeds) } } } @@ -931,6 +931,29 @@ class TwirlParser(val shouldParseInclusiveDot: Boolean) { } else None } + /** + * Parse the template function name, if it exists + */ + private def maybeTemplateFunctionName(): Option[PosString] = { + if (check("@templateFunctionName(")) { + val p = input.offset() + whitespaceNoBreak + val name = stringLiteral("\"", "\\") + if (name != null) { + whitespaceNoBreak + if (!check(")")) { + error("Expected closing parenthesis after template function name") + None + } else { + Some(position(PosString(name), p)) + } + } else { + error("Expected template function name") + None + } + } else None + } + /** * Parse the template arguments, if they exist */ @@ -965,12 +988,14 @@ class TwirlParser(val shouldParseInclusiveDot: Boolean) { } } val args = maybeTemplateArgs() + val templateFunctionName = maybeTemplateFunctionName() val (imports, localDefs, templates, mixeds) = templateContent() val template = Template( PosString(""), constructor, argsComment, + templateFunctionName, args.getOrElse(PosString("()")), topImports, imports,