Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions be/src/vec/functions/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@

namespace doris::vectorized {

struct GammaName {
static constexpr auto name = "gamma";
};
using FunctionGamma = FunctionMathUnary<UnaryFunctionPlain<GammaName, std::tgamma>>;

struct AcosName {
static constexpr auto name = "acos";
// https://dev.mysql.com/doc/refman/8.4/en/mathematical-functions.html#function_acos
Expand Down Expand Up @@ -991,5 +996,6 @@ void register_function_math(SimpleFunctionFactory& factory) {
factory.register_function<FunctionMathBinary<LcmImpl<TYPE_LARGEINT>>>();
factory.register_function<FunctionIsNan>();
factory.register_function<FunctionIsInf>();
factory.register_function<FunctionGamma>();
}
} // namespace doris::vectorized
11 changes: 11 additions & 0 deletions be/test/vec/function/function_math_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,4 +900,15 @@ TEST(MathFunctionTest, factorial_test) {
}
}
}

TEST(MathFunctionTest, gamma_test) {
std::string func_name = "gamma";

InputTypeSet input_types = {PrimitiveType::TYPE_DOUBLE};

DataSet data_set = {{{1.0}, 1.0}, {{2.0}, 1.0}, {{3.0}, 2.0}};

static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.FromSecond;
import org.apache.doris.nereids.trees.expressions.functions.scalar.FromUnixtime;
import org.apache.doris.nereids.trees.expressions.functions.scalar.G;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Gamma;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Gcd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GetFormat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GetVariantType;
Expand Down Expand Up @@ -769,6 +770,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(FromIso8601Date.class, "from_iso8601_date"),
scalar(FromUnixtime.class, "from_unixtime"),
scalar(G.class, "g"),
scalar(Gamma.class, "gamma"),
scalar(Gcd.class, "gcd"),
scalar(GetFormat.class, "get_format"),
scalar(GetVariantType.class, "variant_type"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Database;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Date;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncryptKeyRef;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Gamma;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LastQueryId;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf;
Expand All @@ -86,6 +87,7 @@
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
Expand Down Expand Up @@ -185,7 +187,8 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
matches(LastQueryId.class, this::visitLastQueryId),
matches(Nvl.class, this::visitNvl),
matches(NullIf.class, this::visitNullIf),
matches(Match.class, this::visitMatch)
matches(Match.class, this::visitMatch),
matches(Gamma.class, this::visitGamma)
);
}

Expand Down Expand Up @@ -218,6 +221,21 @@ public Expression visitLiteral(Literal literal, ExpressionRewriteContext context
return literal;
}

@Override
public Expression visitGamma(Gamma gamma, ExpressionRewriteContext context) {
gamma = rewriteChildren(gamma, context);
Optional<Expression> checkedExpr = preProcess(gamma);
if (checkedExpr.isPresent()) {
return checkedExpr.get();
}
Expression child = gamma.child();
if (!(child instanceof Literal)) {
return gamma;
}
double value = ((Literal) child).getDouble();
return new DoubleLiteral(org.apache.commons.math3.special.Gamma.gamma(value));
}

@Override
public Expression visitMatch(Match match, ExpressionRewriteContext context) {
match = rewriteChildren(match, context);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'gamma'. This class is generated by GenerateFunction.
*/
public class Gamma extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
);

/**
* constructor with 1 argument.
*/
public Gamma(Expression arg) {
super("gamma", arg);
}

/** constructor for withChildren and reuse signature */
private Gamma(ScalarFunctionParams functionParams) {
super(functionParams);
}

/**
* withChildren.
*/
@Override
public Gamma withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Gamma(getFunctionParams(children));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitGamma(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.FromIso8601Date;
import org.apache.doris.nereids.trees.expressions.functions.scalar.FromUnixtime;
import org.apache.doris.nereids.trees.expressions.functions.scalar.G;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Gamma;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Gcd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GetFormat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.GetVariantType;
Expand Down Expand Up @@ -1401,6 +1402,10 @@ default R visitG(G g, C context) {
return visitScalarFunction(g, context);
}

default R visitGamma(Gamma gamma, C context) {
return visitScalarFunction(gamma, context);
}

default R visitGcd(Gcd gcd, C context) {
return visitScalarFunction(gcd, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_const --
24 1.772453850905516 1

-- !select_null --
\N

-- !select_zero --
NaN

-- !select_neg --
NaN

-- !select_const_no_fold --
24 1.772453850905516 1

-- !select_null_no_fold --
\N

-- !select_zero_no_fold --
NaN

-- !select_neg_no_fold --
NaN

-- !select_table --
1 1 1
2 2 1
3 3 2
4 4 6
5 5 24
6 0.5 1.772453850905516
7 0 Infinity
8 -1 NaN
9 \N \N

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
suite("test_gamma") {
sql "set debug_skip_fold_constant=false"
qt_select_const "select gamma(5), gamma(0.5), gamma(1)"
qt_select_null "select gamma(null)"
qt_select_zero "select gamma(0)"
qt_select_neg "select gamma(-1)"

sql "set debug_skip_fold_constant=true"
qt_select_const_no_fold "select gamma(5), gamma(0.5), gamma(1)"
qt_select_null_no_fold "select gamma(null)"
qt_select_zero_no_fold "select gamma(0)"
qt_select_neg_no_fold "select gamma(-1)"

sql "drop table if exists test_gamma_tbl"
sql """
create table test_gamma_tbl (
k1 int,
v1 double
) distributed by hash(k1) properties("replication_num" = "1");
"""

sql """
insert into test_gamma_tbl values
(1, 1.0),
(2, 2.0),
(3, 3.0),
(4, 4.0),
(5, 5.0),
(6, 0.5),
(7, 0.0),
(8, -1.0),
(9, null);
"""

qt_select_table "select k1, v1, gamma(v1) from test_gamma_tbl order by k1"

sql "drop table test_gamma_tbl"
}