技术内幕|StarRocks 标量函数与聚合函数

本文发表于: &{ new Date(1680019200000).toLocaleDateString() }

StarRocks 函数就像预设于数据库中的公式,允许用户调用现有的函数以完成特定功能。函数可以很方便地实现业务逻辑的重用,因此正确使用函数会让读者在编写 SQL 语句时起到事半功倍的效果。

StarRocks 提供了多种内置函数,包括标量函数、聚合函数、窗口函数、Table 函数和 Lambda 函数等,可帮助用户更加便捷地处理表中的数据。此外,StarRocks 还允许用户自定义函数以适应实际的业务操作。本文将以标量函数和聚合函数为例,介绍 StarRocks 常见的两种函数实现原理,希望读者能够借鉴其设计思路,并按需实现所需的函数。同时,我们也欢迎社区小伙伴一起贡献力量,共同完善 StarRocks 的功能,具体的函数任务认领方式请见文末。

 

01 如何为 StarRocks 添加标量函数

1-1 标量函数介绍

标量函数用于处理单行数据,接受一个或多个参数作为输入,并返回一个值作为结果。StarRocks 常见的标量函数有 abs、floor、ceil 等。

1-2 标量函数的实现原理

首先,我们来了解函数签名,函数签名用来唯一标识函数,描述函数的 ID、名字、返回类型、输入参数的类型等基本信息。

标量函数的函数签名定义在 gensrc/script/functions.py,在编译阶段我们会根据 Python 文件中的内容生成对应的 Java 和 C++ 代码,供 FE 和 BE 使用。每个函数签名在 Python 文件中通过一个特定的数组来描述,数组的内容有如下两种格式:

[<function_id>, <function_name>, <return_type>, [<arg_type>...], <be_scalar_function>]
or
[<function_id>, <function_name>, <return_type>, [<arg_type>...], <be_scalar_function>, <be_prepare_function>, <be_close_function>]

其中基本信息如下:

  • function_id:函数唯一标识,是唯一一串数字,function_id 遵循如下约定,前两位表示 function_type,中间两位表示 function_group,余下的表示具体的 sub_function,后面我们会举例说明
  • function_name:函数名称
  • return_type:返回值类型
  • arg_type:入参类型,如果有多个入参,需要在数组中描述每个入参的类型
  • be_scalar_function:BE 中负责实现该函数计算逻辑的函数
  • be_prepare_function/be_close_function:可选参数,有些函数在执行的过程中可能会传递一些状态,be_prepare_function 和 be_close_function 就是 BE 中负责实现创建状态和回收状态的函数

为了支持多种数据类型作为输入,需要为每种类型单独创建函数签名。以下以 abs 函数为例,该函数用于计算绝对值,需要描述以下五个信息:

  • function_id:10代表它们都属于 math function,04代表它们都属于 abs 这个 function group,余下的数字用来区分具体的 sub-function
  • function_name:函数名称都是 abs
  • return_type:返回值类型,同入参类型一致
  • arg_type:该函数只接受一个入参,所以第四项的数组中只有一个元素。
  • be_eval_function:BE 中实现计算逻辑的函数,StarRocks 针对每种数据类型做了特殊处理,所以每个签名中的函数名也不一样

对于 abs 函数而言,由于不需要传递状态,因此不需要 be_prepare_function 和 be_close_function 这两个选项。请注意,这两个选项在某些情况下可能会用到,具体用法将在后面的示例中介绍。

[10040, "abs", "DOUBLE", ["DOUBLE"], "MathFunctions::abs_double"],
    [10041, "abs", "FLOAT", ["FLOAT"], "MathFunctions::abs_float"],
    [10042, "abs", "LARGEINT", ["LARGEINT"], "MathFunctions::abs_largeint"],
    [10043, "abs", "LARGEINT", ["BIGINT"], "MathFunctions::abs_bigint"],
    [10044, "abs", "BIGINT", ["INT"], "MathFunctions::abs_int"],
    [10045, "abs", "INT", ["SMALLINT"], "MathFunctions::abs_smallint"],
    [10046, "abs", "SMALLINT", ["TINYINT"], "MathFunctions::abs_tinyint"],
    [10047, "abs", "DECIMALV2", ["DECIMALV2"], "MathFunctions::abs_decimalv2val"],
    [100470, "abs", "DECIMAL32", ["DECIMAL32"], "MathFunctions::abs_decimal32"],
    [100471, "abs", "DECIMAL64", ["DECIMAL64"], "MathFunctions::abs_decimal64"],
    [100472, "abs", "DECIMAL128", ["DECIMAL128"], "MathFunctions::abs_decimal128"],

在 StarRocks 的编译和执行阶段,都会使用函数签名来确定函数的输入输出和执行逻辑。具体流程如下:

  1. 在编译阶段,根据 gensrc/script/functions.py 中的内容生成代码供 FE 和 BE 使用。
  2. Java 代码在fe/fe-core/target/generated-sources/build/com/starrocks/builtins/VectorizedBuiltinFunctions.java,FunctionSet[1]保存了所有的函数签名,初始化阶段会调用VectorizedBuiltinFunctions::initBuiltins 来添加标量函数的函数签名。SQL analyze 阶段,会利用 FunctionSet 提供的信息进行校验,如果找不到函数签名会直接返回错误,这部分实现在 ExpressionAnalyzer.Visitor [2]的 visitFunctionCall[3]方法中。
  • C++ 代码在./gensrc/build/gen_C++/opcode/builtin_functions.cpp,BE 标量函数的函数签名保存在 BuiltinFunctions::_fn_tables[4] ,生成的代码用于初始化_fn_tables。在 SQL 执行阶段,VectorizedFunctionCallExpr 会根据 fid(函数唯一标识)从 _fn_tables 中找到执行该函数所需要的信息,包括输入参数的个数,执行函数的函数指针(ScalarFunction),以及执行前后的 PrepareFunction 和 CloseFunction,这部分定义在 FunctionDescriptor[5]

在 BE 实现函数的计算逻辑

这部分此处不做赘述,根据函数的功能实现相关的逻辑即可。

1-3 添加标量函数示例

接下来我们以 sha2 函数为例,介绍引入新函数的具体流程。sha2 函数的功能如下图,其详细信息可以参考官方文档[6]中的介绍。


 


生成函数签名

首先,需要在 gensrc/script/functions.py 中新增签名。

[120160, "sha2", "VARCHAR", ["VARCHAR", "INT"], "EncryptionFunctions::sha2", "EncryptionFunctions::sha2_prepare", "EncryptionFunctions::sha2_close"],

如上述代码所示,sha2 函数输入需要两个参数,根据第二个参数来决定使用哪种加密算法,如果第二个参数本身是个常数,那么不需要每次执行的时候都去判断。我们可以把这部分“状态”保存起来,所以函数签名中除了前文所述的五个基本信息之外,还增加了

EncryptionFunctions::sha2_prepare 和 EncryptionFunctions::sha2_close,用来实现状态的创建和回收。

实现函数的计算逻辑
 

sha2 属于加密函数的一种,所以我们直接在 EncryptionFunctions [7] 中增加相应的方法即可。具体代码如下:

/*
     * Called by sha2 to the corresponding part
     */
    DEFINE_VECTORIZED_FN(sha224);
    DEFINE_VECTORIZED_FN(sha256);
    DEFINE_VECTORIZED_FN(sha384);
    DEFINE_VECTORIZED_FN(sha512);
    DEFINE_VECTORIZED_FN(invalid_sha);
    /**
     * @param: [json_string, tagged_value]
     * @paramType: [BinaryColumn, BinaryColumn]
     * @return: Int32Column
     */
    DEFINE_VECTORIZED_FN(sha2);
    static Status sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope);
    static Status sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope);

其中,实现标量函数的计算逻辑主要分布在 PrepareFuntionScalarFunctionCloseFunction 三个函数中。

PrepareFunction

Prepare 阶段主要是针对第二个参数进行特殊处理,如果是常数,可以把实现对应加密算法的函数指针保存起来,后面的 ScalarFunction 中可以直接调用。加密算法的函数指针保存在 EncryptionFunctions::SHA2Ctx 中,通过 FunctionContext::set_function_state 保存在上下文中。具体代码如下:

Status EncryptionFunctions::sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
    if (scope != FunctionContext::FRAGMENT_LOCAL) {
        return Status::OK();
    }

    if (!context->is_notnull_constant_column(1)) {
        return Status::OK();
    }

    ColumnPtr column = context->get_constant_column(1);
    auto hash_length = ColumnHelper::get_const_value<TYPE_INT>(column);

    ScalarFunction function;
    if (hash_length == 224) {
        function = &EncryptionFunctions::sha224;
    } else if (hash_length == 256 || hash_length == 0) {
        function = &EncryptionFunctions::sha256;
    } else if (hash_length == 384) {
        function = &EncryptionFunctions::sha384;
    } else if (hash_length == 512) {
        function = &EncryptionFunctions::sha512;
    } else {
        function = EncryptionFunctions::invalid_sha;
    }

    auto fc = new EncryptionFunctions::SHA2Ctx();
    fc->function = function;
    context->set_function_state(scope, fc);
    return Status::OK();
}

ScalarFunction

ScalarFunction 主要实现 sha2 的计算逻辑,如果第二个参数是常数,那么 PrepareFunction 中保存的 function_state 就可以派上用场了。具体代码如下:

StatusOr<ColumnPtr> EncryptionFunctions::sha2(FunctionContext* ctx, const Columns& columns) {
    if (!ctx->is_notnull_constant_column(1)) {
        auto src_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]);
        auto length_viewer = ColumnViewer<TYPE_INT>(columns[1]);

        auto size = columns[0]->size();
        ColumnBuilder<TYPE_VARCHAR> result(size);

        for (int row = 0; row < size; row++) {
            if (src_viewer.is_null(row) || length_viewer.is_null(row)) {
                result.append_null();
                continue;
            }

            auto src_value = src_viewer.value(row);
            auto length = length_viewer.value(row);

            if (length == 224) {
                SHA224Digest digest;
                digest.update(src_value.data, src_value.size);
                digest.digest();
                result.append(Slice(digest.hex().c_str(), digest.hex().size()));
            } else if (length == 0 || length == 256) {
                SHA256Digest digest;
                digest.update(src_value.data, src_value.size);
                digest.digest();
                result.append(Slice(digest.hex().c_str(), digest.hex().size()));
            } else if (length == 384) {
                SHA384Digest digest;
                digest.update(src_value.data, src_value.size);
                digest.digest();
                result.append(Slice(digest.hex().c_str(), digest.hex().size()));
            } else if (length == 512) {
                SHA512Digest digest;
                digest.update(src_value.data, src_value.size);
                digest.digest();
                result.append(Slice(digest.hex().c_str(), digest.hex().size()));
            } else {
                result.append_null();
            }
        }

        return result.build(ColumnHelper::is_all_const(columns));
    }

    auto ctc = reinterpret_cast<SHA2Ctx*>(ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
    return ctc->function(ctx, columns);
}

CloseFunction

CloseFunction 主要用来回收资源。函数执行中所依赖的 function state,在执行结束之后不再被需要,那么可以在这个阶段释放内存。具体代码如下:

Status EncryptionFunctions::sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
    if (scope == FunctionContext::FRAGMENT_LOCAL) {
        auto fc = reinterpret_cast<SHA2Ctx*>(context->get_function_state(scope));
        delete fc;
    }

    return Status::OK();
}
  • 增加对应的单元测试
     

具体细节可参考 EntryptionFunctionTest [8] 即可。代码示例如下:

TEST_P(ShaTestFixture, test_sha2) {
    auto [str, len, expected] = GetParam();

    std::unique_ptr<FunctionContext> ctx(FunctionContext::create_test_context());
    Columns columns;

    auto plain = BinaryColumn::create();
    plain->append(str);

    ColumnPtr hash_length =
            len == -1 ? ColumnHelper::create_const_null_column(1) : ColumnHelper::create_const_column<TYPE_INT>(len, 1);

    if (str == "NULL") {
        columns.emplace_back(ColumnHelper::create_const_null_column(1));
    } else {
        columns.emplace_back(plain);
    }
    columns.emplace_back(hash_length);

    ctx->set_constant_columns(columns);
    ASSERT_TRUE(EncryptionFunctions::sha2_prepare(ctx.get(), FunctionContext::FunctionStateScope::FRAGMENT_LOCAL).ok());

    if (len != -1) {
        ASSERT_NE(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
    } else {
        ASSERT_EQ(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
    }

    ColumnPtr result = EncryptionFunctions::sha2(ctx.get(), columns).value();
    if (expected == "NULL") {
        std::cerr << result->debug_string() << std::endl;
        EXPECT_TRUE(result->is_null(0));
    } else {
        auto v = ColumnHelper::cast_to<TYPE_VARCHAR>(result);
        EXPECT_EQ(expected, v->get_data()[0].to_string());
    }

    ASSERT_TRUE(EncryptionFunctions::sha2_close(ctx.get(),
                                                FunctionContext::FunctionContext::FunctionStateScope::FRAGMENT_LOCAL)
                        .ok());
}

完整的改动可以参考 PR:https://github.com/StarRocks/starrocks/pull/1264/files

 

02 如何为 StarRocks 添加聚合函数

2-1 聚合函数介绍

聚合函数用于处理多行数据,接受多行数据作为输入,经过计算后返回一行结果。StarRocks 常见的聚合函数有 count、sum、avg、min、max 等。

2-2 聚合函数的实现原理

在查询执行阶段,Pipeline 引擎的聚合算子通过 Aggregator 完成聚合计算,聚合算子的实现原理可参见文末《StarRocks 聚合算子源码解析》[9],本文主要关注聚合函数的实现原理。

Aggregator 在 prepare 阶段会根据函数名找到对应的 AggregateFunction 并保存下来,AggregateFunction 是最重要的抽象,封装了聚合计算过程中需要的各个接口,每个聚合函数都需要继承 AggregateFunction 实现自己的逻辑。计算的中间结果保存在 AggDataPtr 中,AggDataPrt 是一个指针,指向描述中间结果的数据结构。每种聚合函数的中间结果都不相同,比如求和函数,只需要保存 sum  即可,而平均值函数,除了保存 sum 之外,还需要记录 count。

AggregateFunction提供的接口中,我们需要重点关注以下几个:

// 逐行读取数据,不断更新 state 中保存的中间结果。
void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, size_t row_num)

// 通常用在多阶段聚合中,读取已经算好的部分中间结果,合并计算,更新 state 中的数据。
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num)

// 多阶段的聚合可能会通过多个节点执行,计算的中间结果需要跨网络传输,这个方法用来实现序列化的逻辑。
void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) 

// 把中间结果转成最终对用户返回的结果。比如求和函数,直接返回中间结果保存的 sum 即可,而平均值函数,需要返回 sum/count。
void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to)

// 重置 state 的状态,比如在 window aggregate 中,我们会用一个的 state 保存中间结果,每次遇到新的 group时,需要通过 reset 重置,然后才能进行接下来的计算。
void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state)

除了上述内容之外,为了减少函数调用的开销,AggregateFunction 还封装了批量操作的接口,具体的细节这里就不展开讲解了,可以参考 be/src/exprs/agg/aggregate.h

2-3 添加聚合函数示例

接下来我们以 ANY_VALUE 为例,介绍添加聚合函数的流程,这个函数实现的功能比较简单,可以参考官方文档 [10] 说明:


 

在 FE 创建函数签名

FE 通过

AggregateFunction [11] 来描述聚合函数,所有的聚合函数都会注册在 FunctionSet 中,初始化阶段在 FunctionSet 的 initAggregateBuiltins [12] 内增加对应的函数即可。具体代码如下:

// ANY_VALUE
    addBuiltin(AggregateFunction.createBuiltin(ANY_VALUE,
            Lists.newArrayList(t), t, t, true, false, false));

在 BE 实现函数的计算逻辑

此处重点是如何描述中间结果,以及如何实现 AggregateFunction 的核心接口。ANY_VALUE 的语义很简单,在每个 group 中选择一行返回。中间结果通过 AnyValueAggregateData 描述,只需要记录当前是否已经有结果以及对应的数据是什么即可,AnyValueAggregateData 为每种数据类型进行了特化,实现上几乎一致。具体代码如下:

template <LogicalType LT>
struct AnyValueAggregateData {
    using T = AggDataValueType<LT>;

    T result;
    bool has_value = false;

    void reset() {
        result = T{};
        has_value = false;
    }
};

具体的计算逻辑非常简单,这部分通过 AnyValueElement 实现。具体代码如下:

template <LogicalType LT, typename State>
struct AnyValueElement {
    using RefType = AggDataRefType<LT>;

    void operator()(State& state, RefType right) const {
        if (UNLIKELY(!state.has_value)) {
            AggDataTypeTraits<LT>::assign_value(state.result, right);
            state.has_value = true;
        }
    }
};

最后利用 AnyValueElement 实现 AggregateFunction 所需要的接口即可,具体代码如下:

template <LogicalType LT, typename State, class OP, typename T = RunTimeC++Type<LT>, typename = guard::Guard>
class AnyValueAggregateFunction final
        : public AggregateFunctionBatchHelper<State, AnyValueAggregateFunction<LT, State, OP, T>> {
public:
    using InputColumnType = RunTimeColumnType<LT>;

    void reset(FunctionContext* ctx, const Columns& args, AggDataPtr state) const override {
        this->data(state).reset();
    }

    void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
                size_t row_num) const override {
        DCHECK(!columns[0]->is_nullable());
        const auto& column = down_cast<const InputColumnType&>(*columns[0]);
        OP()(this->data(state), AggDataTypeTraits<LT>::get_row_ref(column, row_num));
    }

    void update_batch_single_state(FunctionContext* ctx, size_t chunk_size, const Column** columns,
                                   AggDataPtr __restrict state) const override {
        update(ctx, columns, state, 0);
    }

    void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
        DCHECK(!column->is_nullable());
        const auto& input_column = down_cast<const InputColumnType&>(*column);
        OP()(this->data(state), AggDataTypeTraits<LT>::get_row_ref(input_column, row_num));
    }

    void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
        DCHECK(!to->is_nullable());
        AggDataTypeTraits<LT>::append_value(down_cast<InputColumnType*>(to), this->data(state).result);
    }

    void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size,
                                     ColumnPtr* dst) const override {
        *dst = src[0];
    }

    void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
        DCHECK(!to->is_nullable());
        AggDataTypeTraits<LT>::append_value(down_cast<InputColumnType*>(to), this->data(state).result);
    }

    void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start,
                    size_t end) const override {
        DCHECK_GT(end, start);
        InputColumnType* column = down_cast<InputColumnType*>(dst);
        for (size_t i = start; i < end; ++i) {
            AggDataTypeTraits<LT>::append_value(column, this->data(state).result);
        }
    }

    std::string get_name() const override { return "any_value"; }
};

完整的实现细节参见:be/src/exprs/agg/any_value.h

在 AggregateFactory 中注册

这一步是为了让 AggregateFactory 可以根据函数名找到对应的函数,函数的创建通过MakeAnyValueAggregateFunction 实现,相关的改动可以在aggregate_factory.hpp[13]中 grep MakeAnyValueAggregateFunction 看到,比较简单,这里不再过多赘述,具体示例如下:

template <LogicalType LT>
AggregateFunctionPtr AggregateFactory::MakeAnyValueAggregateFunction() {
    return std::make_shared<
            AnyValueAggregateFunction<LT, AnyValueAggregateData<LT>, AnyValueElement<LT, AnyValueAggregateData<LT>>>>();
}

添加单元测试
可以参见 test/exprs/agg/aggregate_test.cpp [14]  添加单测,比如:

TEST_F(AggregateTest, test_any_value) {
    const AggregateFunction* func = get_aggregate_function("any_value", TYPE_SMALLINT, TYPE_SMALLINT, false);
    test_non_deterministic_agg_function<int16_t, int16_t>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_INT, TYPE_INT, false);
    test_non_deterministic_agg_function<int32_t, int32_t>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_BIGINT, TYPE_BIGINT, false);
    test_non_deterministic_agg_function<int64_t, int64_t>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_LARGEINT, TYPE_LARGEINT, false);
    test_non_deterministic_agg_function<int128_t, int128_t>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_FLOAT, TYPE_FLOAT, false);
    test_non_deterministic_agg_function<float, float>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_DOUBLE, TYPE_DOUBLE, false);
    test_non_deterministic_agg_function<double, double>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_VARCHAR, TYPE_VARCHAR, false);
    test_non_deterministic_agg_function<Slice, Slice>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_DECIMALV2, TYPE_DECIMALV2, false);
    test_non_deterministic_agg_function<DecimalV2Value, DecimalV2Value>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_DATETIME, TYPE_DATETIME, false);
    test_non_deterministic_agg_function<TimestampValue, TimestampValue>(ctx, func);

    func = get_aggregate_function("any_value", TYPE_DATE, TYPE_DATE, false);
    test_non_deterministic_agg_function<DateValue, DateValue>(ctx, func);
}

完整的改动见 PR:https://github.com/StarRocks/starrocks/pull/2073

03 总结

本文介绍了 StarRocks 中标量函数和聚合函数的实现原理,并以 sha2 标量函数和 ANY_VALUE 聚合函数为例,说明了如何添加标量函数和新增聚合函数。标量函数定义在 be/src/exprs/ 目录下。若想查看某个函数的实现,可以在函数签名中找到对应的 be function,然后在该目录下使用 grep 进行查找。此外,StarRocks 还实现了多种聚合函数,具体实现可在 be/src/exprs/agg 目录下查找。最后,如果你在阅读完本文后对 StarRocks 函数的实现原理以及如何添加新的函数还有很多疑问,欢迎报名参加 4/6(星期四)的 <StarRocks 源码实验室直播>,以进一步学习。同时,我们也欢迎你领取函数任务,并通过实践学习如何为 StarRocks 添加新的函数!