技术内幕|StarRocks 标量函数与聚合函数
本文发表于: &{ new Date(1680019200000).toLocaleDateString() }
StarRocks 函数就像预设于数据库中的公式,允许用户调用现有的函数以完成特定功能。函数可以很方便地实现业务逻辑的重用,因此正确使用函数会让读者在编写 SQL 语句时起到事半功倍的效果。
StarRocks 提供了多种内置函数,包括标量函数、聚合函数、窗口函数、Table 函数和 Lambda 函数等,可帮助用户更加便捷地处理表中的数据。此外,StarRocks 还允许用户自定义函数以适应实际的业务操作。本文将以标量函数和聚合函数为例,介绍 StarRocks 常见的两种函数实现原理,希望读者能够借鉴其设计思路,并按需实现所需的函数。同时,我们也欢迎社区小伙伴一起贡献力量,共同完善 StarRocks 的功能,具体的函数任务认领方式请见文末。
01 如何为 StarRocks 添加标量函数
1-1 标量函数介绍
标量函数用于处理单行数据,接受一个或多个参数作为输入,并返回一个值作为结果。StarRocks 常见的标量函数有 abs、floor、ceil 等。
1-2 标量函数的实现原理
首先,我们来了解函数签名,函数签名用来唯一标识函数,描述函数的 ID、名字、返回类型、输入参数的类型等基本信息。
标量函数的函数签名定义在 gensrc/script/functions.py,在编译阶段我们会根据 Python 文件中的内容生成对应的 Java 和 C++ 代码,供 FE 和 BE 使用。每个函数签名在 Python 文件中通过一个特定的数组来描述,数组的内容有如下两种格式:
[<function_id>, <function_name>, <return_type>, [<arg_type>...], <be_scalar_function>]
or
[<function_id>, <function_name>, <return_type>, [<arg_type>...], <be_scalar_function>, <be_prepare_function>, <be_close_function>]
其中基本信息如下:
- function_id:函数唯一标识,是唯一一串数字,function_id 遵循如下约定,前两位表示 function_type,中间两位表示 function_group,余下的表示具体的 sub_function,后面我们会举例说明
- function_name:函数名称
- return_type:返回值类型
- arg_type:入参类型,如果有多个入参,需要在数组中描述每个入参的类型
- be_scalar_function:BE 中负责实现该函数计算逻辑的函数
- be_prepare_function/be_close_function:可选参数,有些函数在执行的过程中可能会传递一些状态,be_prepare_function 和 be_close_function 就是 BE 中负责实现创建状态和回收状态的函数
为了支持多种数据类型作为输入,需要为每种类型单独创建函数签名。以下以 abs 函数为例,该函数用于计算绝对值,需要描述以下五个信息:
- function_id:10代表它们都属于 math function,04代表它们都属于 abs 这个 function group,余下的数字用来区分具体的 sub-function
- function_name:函数名称都是 abs
- return_type:返回值类型,同入参类型一致
- arg_type:该函数只接受一个入参,所以第四项的数组中只有一个元素。
- be_eval_function:BE 中实现计算逻辑的函数,StarRocks 针对每种数据类型做了特殊处理,所以每个签名中的函数名也不一样
对于 abs 函数而言,由于不需要传递状态,因此不需要 be_prepare_function 和 be_close_function 这两个选项。请注意,这两个选项在某些情况下可能会用到,具体用法将在后面的示例中介绍。
[10040, "abs", "DOUBLE", ["DOUBLE"], "MathFunctions::abs_double"],
[10041, "abs", "FLOAT", ["FLOAT"], "MathFunctions::abs_float"],
[10042, "abs", "LARGEINT", ["LARGEINT"], "MathFunctions::abs_largeint"],
[10043, "abs", "LARGEINT", ["BIGINT"], "MathFunctions::abs_bigint"],
[10044, "abs", "BIGINT", ["INT"], "MathFunctions::abs_int"],
[10045, "abs", "INT", ["SMALLINT"], "MathFunctions::abs_smallint"],
[10046, "abs", "SMALLINT", ["TINYINT"], "MathFunctions::abs_tinyint"],
[10047, "abs", "DECIMALV2", ["DECIMALV2"], "MathFunctions::abs_decimalv2val"],
[100470, "abs", "DECIMAL32", ["DECIMAL32"], "MathFunctions::abs_decimal32"],
[100471, "abs", "DECIMAL64", ["DECIMAL64"], "MathFunctions::abs_decimal64"],
[100472, "abs", "DECIMAL128", ["DECIMAL128"], "MathFunctions::abs_decimal128"],
在 StarRocks 的编译和执行阶段,都会使用函数签名来确定函数的输入输出和执行逻辑。具体流程如下:
- 在编译阶段,根据 gensrc/script/functions.py 中的内容生成代码供 FE 和 BE 使用。
- Java 代码在fe/fe-core/target/generated-sources/build/com/starrocks/builtins/VectorizedBuiltinFunctions.java,FunctionSet[1]保存了所有的函数签名,初始化阶段会调用VectorizedBuiltinFunctions::initBuiltins 来添加标量函数的函数签名。SQL analyze 阶段,会利用 FunctionSet 提供的信息进行校验,如果找不到函数签名会直接返回错误,这部分实现在 ExpressionAnalyzer.Visitor [2]的 visitFunctionCall[3]方法中。
- C++ 代码在./gensrc/build/gen_C++/opcode/builtin_functions.cpp,BE 标量函数的函数签名保存在 BuiltinFunctions::_fn_tables[4] ,生成的代码用于初始化_fn_tables。在 SQL 执行阶段,VectorizedFunctionCallExpr 会根据 fid(函数唯一标识)从 _fn_tables 中找到执行该函数所需要的信息,包括输入参数的个数,执行函数的函数指针(ScalarFunction),以及执行前后的 PrepareFunction 和 CloseFunction,这部分定义在 FunctionDescriptor[5]
在 BE 实现函数的计算逻辑
这部分此处不做赘述,根据函数的功能实现相关的逻辑即可。
1-3 添加标量函数示例
接下来我们以 sha2 函数为例,介绍引入新函数的具体流程。sha2 函数的功能如下图,其详细信息可以参考官方文档[6]中的介绍。
生成函数签名
首先,需要在 gensrc/script/functions.py 中新增签名。
[120160, "sha2", "VARCHAR", ["VARCHAR", "INT"], "EncryptionFunctions::sha2", "EncryptionFunctions::sha2_prepare", "EncryptionFunctions::sha2_close"],
如上述代码所示,sha2 函数输入需要两个参数,根据第二个参数来决定使用哪种加密算法,如果第二个参数本身是个常数,那么不需要每次执行的时候都去判断。我们可以把这部分“状态”保存起来,所以函数签名中除了前文所述的五个基本信息之外,还增加了
EncryptionFunctions::sha2_prepare 和 EncryptionFunctions::sha2_close,用来实现状态的创建和回收。
实现函数的计算逻辑
sha2 属于加密函数的一种,所以我们直接在 EncryptionFunctions [7] 中增加相应的方法即可。具体代码如下:
/*
* Called by sha2 to the corresponding part
*/
DEFINE_VECTORIZED_FN(sha224);
DEFINE_VECTORIZED_FN(sha256);
DEFINE_VECTORIZED_FN(sha384);
DEFINE_VECTORIZED_FN(sha512);
DEFINE_VECTORIZED_FN(invalid_sha);
/**
* @param: [json_string, tagged_value]
* @paramType: [BinaryColumn, BinaryColumn]
* @return: Int32Column
*/
DEFINE_VECTORIZED_FN(sha2);
static Status sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope);
static Status sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope);
其中,实现标量函数的计算逻辑主要分布在 PrepareFuntion、ScalarFunction、CloseFunction 三个函数中。
PrepareFunction
Prepare 阶段主要是针对第二个参数进行特殊处理,如果是常数,可以把实现对应加密算法的函数指针保存起来,后面的 ScalarFunction 中可以直接调用。加密算法的函数指针保存在 EncryptionFunctions::SHA2Ctx 中,通过 FunctionContext::set_function_state 保存在上下文中。具体代码如下:
Status EncryptionFunctions::sha2_prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
if (scope != FunctionContext::FRAGMENT_LOCAL) {
return Status::OK();
}
if (!context->is_notnull_constant_column(1)) {
return Status::OK();
}
ColumnPtr column = context->get_constant_column(1);
auto hash_length = ColumnHelper::get_const_value<TYPE_INT>(column);
ScalarFunction function;
if (hash_length == 224) {
function = &EncryptionFunctions::sha224;
} else if (hash_length == 256 || hash_length == 0) {
function = &EncryptionFunctions::sha256;
} else if (hash_length == 384) {
function = &EncryptionFunctions::sha384;
} else if (hash_length == 512) {
function = &EncryptionFunctions::sha512;
} else {
function = EncryptionFunctions::invalid_sha;
}
auto fc = new EncryptionFunctions::SHA2Ctx();
fc->function = function;
context->set_function_state(scope, fc);
return Status::OK();
}
ScalarFunction
ScalarFunction 主要实现 sha2 的计算逻辑,如果第二个参数是常数,那么 PrepareFunction 中保存的 function_state 就可以派上用场了。具体代码如下:
StatusOr<ColumnPtr> EncryptionFunctions::sha2(FunctionContext* ctx, const Columns& columns) {
if (!ctx->is_notnull_constant_column(1)) {
auto src_viewer = ColumnViewer<TYPE_VARCHAR>(columns[0]);
auto length_viewer = ColumnViewer<TYPE_INT>(columns[1]);
auto size = columns[0]->size();
ColumnBuilder<TYPE_VARCHAR> result(size);
for (int row = 0; row < size; row++) {
if (src_viewer.is_null(row) || length_viewer.is_null(row)) {
result.append_null();
continue;
}
auto src_value = src_viewer.value(row);
auto length = length_viewer.value(row);
if (length == 224) {
SHA224Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else if (length == 0 || length == 256) {
SHA256Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else if (length == 384) {
SHA384Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else if (length == 512) {
SHA512Digest digest;
digest.update(src_value.data, src_value.size);
digest.digest();
result.append(Slice(digest.hex().c_str(), digest.hex().size()));
} else {
result.append_null();
}
}
return result.build(ColumnHelper::is_all_const(columns));
}
auto ctc = reinterpret_cast<SHA2Ctx*>(ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
return ctc->function(ctx, columns);
}
CloseFunction
CloseFunction 主要用来回收资源。函数执行中所依赖的 function state,在执行结束之后不再被需要,那么可以在这个阶段释放内存。具体代码如下:
Status EncryptionFunctions::sha2_close(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
if (scope == FunctionContext::FRAGMENT_LOCAL) {
auto fc = reinterpret_cast<SHA2Ctx*>(context->get_function_state(scope));
delete fc;
}
return Status::OK();
}
- 增加对应的单元测试
具体细节可参考 EntryptionFunctionTest [8] 即可。代码示例如下:
TEST_P(ShaTestFixture, test_sha2) {
auto [str, len, expected] = GetParam();
std::unique_ptr<FunctionContext> ctx(FunctionContext::create_test_context());
Columns columns;
auto plain = BinaryColumn::create();
plain->append(str);
ColumnPtr hash_length =
len == -1 ? ColumnHelper::create_const_null_column(1) : ColumnHelper::create_const_column<TYPE_INT>(len, 1);
if (str == "NULL") {
columns.emplace_back(ColumnHelper::create_const_null_column(1));
} else {
columns.emplace_back(plain);
}
columns.emplace_back(hash_length);
ctx->set_constant_columns(columns);
ASSERT_TRUE(EncryptionFunctions::sha2_prepare(ctx.get(), FunctionContext::FunctionStateScope::FRAGMENT_LOCAL).ok());
if (len != -1) {
ASSERT_NE(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
} else {
ASSERT_EQ(nullptr, ctx->get_function_state(FunctionContext::FRAGMENT_LOCAL));
}
ColumnPtr result = EncryptionFunctions::sha2(ctx.get(), columns).value();
if (expected == "NULL") {
std::cerr << result->debug_string() << std::endl;
EXPECT_TRUE(result->is_null(0));
} else {
auto v = ColumnHelper::cast_to<TYPE_VARCHAR>(result);
EXPECT_EQ(expected, v->get_data()[0].to_string());
}
ASSERT_TRUE(EncryptionFunctions::sha2_close(ctx.get(),
FunctionContext::FunctionContext::FunctionStateScope::FRAGMENT_LOCAL)
.ok());
}
完整的改动可以参考 PR:https://github.com/StarRocks/starrocks/pull/1264/files
02 如何为 StarRocks 添加聚合函数
2-1 聚合函数介绍
聚合函数用于处理多行数据,接受多行数据作为输入,经过计算后返回一行结果。StarRocks 常见的聚合函数有 count、sum、avg、min、max 等。
2-2 聚合函数的实现原理
在查询执行阶段,Pipeline 引擎的聚合算子通过 Aggregator 完成聚合计算,聚合算子的实现原理可参见文末《StarRocks 聚合算子源码解析》[9],本文主要关注聚合函数的实现原理。
Aggregator 在 prepare 阶段会根据函数名找到对应的 AggregateFunction 并保存下来,AggregateFunction 是最重要的抽象,封装了聚合计算过程中需要的各个接口,每个聚合函数都需要继承 AggregateFunction 实现自己的逻辑。计算的中间结果保存在 AggDataPtr 中,AggDataPrt 是一个指针,指向描述中间结果的数据结构。每种聚合函数的中间结果都不相同,比如求和函数,只需要保存 sum 即可,而平均值函数,除了保存 sum 之外,还需要记录 count。
在AggregateFunction提供的接口中,我们需要重点关注以下几个:
// 逐行读取数据,不断更新 state 中保存的中间结果。
void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, size_t row_num)
// 通常用在多阶段聚合中,读取已经算好的部分中间结果,合并计算,更新 state 中的数据。
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num)
// 多阶段的聚合可能会通过多个节点执行,计算的中间结果需要跨网络传输,这个方法用来实现序列化的逻辑。
void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to)
// 把中间结果转成最终对用户返回的结果。比如求和函数,直接返回中间结果保存的 sum 即可,而平均值函数,需要返回 sum/count。
void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to)
// 重置 state 的状态,比如在 window aggregate 中,我们会用一个的 state 保存中间结果,每次遇到新的 group时,需要通过 reset 重置,然后才能进行接下来的计算。
void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state)
除了上述内容之外,为了减少函数调用的开销,AggregateFunction 还封装了批量操作的接口,具体的细节这里就不展开讲解了,可以参考 be/src/exprs/agg/aggregate.h。
2-3 添加聚合函数示例
接下来我们以 ANY_VALUE 为例,介绍添加聚合函数的流程,这个函数实现的功能比较简单,可以参考官方文档 [10] 说明:
在 FE 创建函数签名
FE 通过
AggregateFunction [11] 来描述聚合函数,所有的聚合函数都会注册在 FunctionSet 中,初始化阶段在 FunctionSet 的 initAggregateBuiltins [12] 内增加对应的函数即可。具体代码如下:
// ANY_VALUE
addBuiltin(AggregateFunction.createBuiltin(ANY_VALUE,
Lists.newArrayList(t), t, t, true, false, false));
在 BE 实现函数的计算逻辑
此处重点是如何描述中间结果,以及如何实现 AggregateFunction 的核心接口。ANY_VALUE 的语义很简单,在每个 group 中选择一行返回。中间结果通过 AnyValueAggregateData 描述,只需要记录当前是否已经有结果以及对应的数据是什么即可,AnyValueAggregateData 为每种数据类型进行了特化,实现上几乎一致。具体代码如下:
template <LogicalType LT>
struct AnyValueAggregateData {
using T = AggDataValueType<LT>;
T result;
bool has_value = false;
void reset() {
result = T{};
has_value = false;
}
};
具体的计算逻辑非常简单,这部分通过 AnyValueElement 实现。具体代码如下:
template <LogicalType LT, typename State>
struct AnyValueElement {
using RefType = AggDataRefType<LT>;
void operator()(State& state, RefType right) const {
if (UNLIKELY(!state.has_value)) {
AggDataTypeTraits<LT>::assign_value(state.result, right);
state.has_value = true;
}
}
};
最后利用 AnyValueElement 实现 AggregateFunction 所需要的接口即可,具体代码如下:
template <LogicalType LT, typename State, class OP, typename T = RunTimeC++Type<LT>, typename = guard::Guard>
class AnyValueAggregateFunction final
: public AggregateFunctionBatchHelper<State, AnyValueAggregateFunction<LT, State, OP, T>> {
public:
using InputColumnType = RunTimeColumnType<LT>;
void reset(FunctionContext* ctx, const Columns& args, AggDataPtr state) const override {
this->data(state).reset();
}
void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
size_t row_num) const override {
DCHECK(!columns[0]->is_nullable());
const auto& column = down_cast<const InputColumnType&>(*columns[0]);
OP()(this->data(state), AggDataTypeTraits<LT>::get_row_ref(column, row_num));
}
void update_batch_single_state(FunctionContext* ctx, size_t chunk_size, const Column** columns,
AggDataPtr __restrict state) const override {
update(ctx, columns, state, 0);
}
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
DCHECK(!column->is_nullable());
const auto& input_column = down_cast<const InputColumnType&>(*column);
OP()(this->data(state), AggDataTypeTraits<LT>::get_row_ref(input_column, row_num));
}
void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
DCHECK(!to->is_nullable());
AggDataTypeTraits<LT>::append_value(down_cast<InputColumnType*>(to), this->data(state).result);
}
void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size,
ColumnPtr* dst) const override {
*dst = src[0];
}
void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
DCHECK(!to->is_nullable());
AggDataTypeTraits<LT>::append_value(down_cast<InputColumnType*>(to), this->data(state).result);
}
void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start,
size_t end) const override {
DCHECK_GT(end, start);
InputColumnType* column = down_cast<InputColumnType*>(dst);
for (size_t i = start; i < end; ++i) {
AggDataTypeTraits<LT>::append_value(column, this->data(state).result);
}
}
std::string get_name() const override { return "any_value"; }
};
完整的实现细节参见:be/src/exprs/agg/any_value.h
在 AggregateFactory 中注册
这一步是为了让 AggregateFactory 可以根据函数名找到对应的函数,函数的创建通过MakeAnyValueAggregateFunction 实现,相关的改动可以在aggregate_factory.hpp[13]中 grep MakeAnyValueAggregateFunction 看到,比较简单,这里不再过多赘述,具体示例如下:
template <LogicalType LT>
AggregateFunctionPtr AggregateFactory::MakeAnyValueAggregateFunction() {
return std::make_shared<
AnyValueAggregateFunction<LT, AnyValueAggregateData<LT>, AnyValueElement<LT, AnyValueAggregateData<LT>>>>();
}
添加单元测试
可以参见 test/exprs/agg/aggregate_test.cpp [14] 添加单测,比如:
TEST_F(AggregateTest, test_any_value) {
const AggregateFunction* func = get_aggregate_function("any_value", TYPE_SMALLINT, TYPE_SMALLINT, false);
test_non_deterministic_agg_function<int16_t, int16_t>(ctx, func);
func = get_aggregate_function("any_value", TYPE_INT, TYPE_INT, false);
test_non_deterministic_agg_function<int32_t, int32_t>(ctx, func);
func = get_aggregate_function("any_value", TYPE_BIGINT, TYPE_BIGINT, false);
test_non_deterministic_agg_function<int64_t, int64_t>(ctx, func);
func = get_aggregate_function("any_value", TYPE_LARGEINT, TYPE_LARGEINT, false);
test_non_deterministic_agg_function<int128_t, int128_t>(ctx, func);
func = get_aggregate_function("any_value", TYPE_FLOAT, TYPE_FLOAT, false);
test_non_deterministic_agg_function<float, float>(ctx, func);
func = get_aggregate_function("any_value", TYPE_DOUBLE, TYPE_DOUBLE, false);
test_non_deterministic_agg_function<double, double>(ctx, func);
func = get_aggregate_function("any_value", TYPE_VARCHAR, TYPE_VARCHAR, false);
test_non_deterministic_agg_function<Slice, Slice>(ctx, func);
func = get_aggregate_function("any_value", TYPE_DECIMALV2, TYPE_DECIMALV2, false);
test_non_deterministic_agg_function<DecimalV2Value, DecimalV2Value>(ctx, func);
func = get_aggregate_function("any_value", TYPE_DATETIME, TYPE_DATETIME, false);
test_non_deterministic_agg_function<TimestampValue, TimestampValue>(ctx, func);
func = get_aggregate_function("any_value", TYPE_DATE, TYPE_DATE, false);
test_non_deterministic_agg_function<DateValue, DateValue>(ctx, func);
}
完整的改动见 PR:https://github.com/StarRocks/starrocks/pull/2073
03 总结
本文介绍了 StarRocks 中标量函数和聚合函数的实现原理,并以 sha2 标量函数和 ANY_VALUE 聚合函数为例,说明了如何添加标量函数和新增聚合函数。标量函数定义在 be/src/exprs/ 目录下。若想查看某个函数的实现,可以在函数签名中找到对应的 be function,然后在该目录下使用 grep 进行查找。此外,StarRocks 还实现了多种聚合函数,具体实现可在 be/src/exprs/agg 目录下查找。最后,如果你在阅读完本文后对 StarRocks 函数的实现原理以及如何添加新的函数还有很多疑问,欢迎报名参加 4/6(星期四)的 <StarRocks 源码实验室直播>,以进一步学习。同时,我们也欢迎你领取函数任务,并通过实践学习如何为 StarRocks 添加新的函数!