Contents

Implementing an Arrow-based SQL Query Engine with Rust - Part 1

TL;DR:为了学习查询引擎,最近从0在写 sql-query-engine-rs,它是一个用 Rust 写的 Arrow-based SQL Query Engine。本系列文章会详细介绍它的具体实现,会按照对应的 Roadmap 依次讲解,也可以 checkout 对应 tag 查看代码。Most of ideas inspired by risinglight and datafusion

这篇Part 1会详细介绍它的整体架构,以及实现基础的 SQL query: select c1 from t where c2 = 1select c1, count(c1), max(c2) from t group by c1

Roadmap v0.1

Roadmap v0.1 涉及到:

  • catalog
  • csv storage
  • parser
  • binder
  • logical planner
  • executor

它的 milestone 是实现一个基础的 SQL query: select c1 from t where c2 = 1,将 csv 中的数据查询出来。

catalog

首先构建数据库的最基础组件:catalog,它提供数据库表的元数据信息,用于后续 binder,logical planner 的构建,同时它是贯穿整个查询引擎处理流程中的组件,提供了表的元数据信息。

这里为了实现简单,没有引入 database 与 schema 的概念,RootCatalog 下直接包含一个 HashMap 存放所有的 TableCatalog。TableCatalog 则包含所有的 ColumnCatalog。如下定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
pub struct RootCatalog {
    pub tables: HashMap<TableId, TableCatalog>,
}

pub struct TableCatalog {
    pub id: TableId,
    pub name: String,
    pub column_ids: Vec<ColumnId>,
    pub columns: BTreeMap<ColumnId, ColumnCatalog>,
}

pub struct ColumnCatalog {
    pub id: ColumnId,
    pub desc: ColumnDesc,
}

pub struct ColumnDesc {
    pub name: String,
    pub data_type: DataType,
}

csv storage

由于 sql-query-engine-rs 是基于 Apache Arrow 的内存数据格式建立的查询引擎,它本身并不提供数据文件的存储与查询。因此这里的 storage 是建立 Arrow 支持文件之上抽象,并提供一个类似迭代器的方法来分批吐出 chunk 数据。

抽象的 storage trait 采用 GAT(generic associated types),即 trait 内部定义 associated type,来支持不同的 storage,table 和 transaction 类型。如下定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
pub trait Storage: Sync + Send + 'static {
    type TableType: Table;

    fn create_table(&self, id: String, filepath: String) -> Result<(), StorageError>;

    fn get_table(&self, id: String) -> Result<Self::TableType, StorageError>;

    fn get_catalog(&self) -> RootCatalog;
}

pub trait Table: Sync + Send + Clone + 'static {
    type TransactionType: Transaction;

    fn read(&self) -> Result<Self::TransactionType, StorageError>;
}

pub trait Transaction: Sync + Send + 'static {
    fn next_batch(&mut self) -> Result<Option<RecordBatch>, StorageError>;
}

CsvStorage 实现了上面 trait,每次使用时,从 storage 读出一个 table,然后通过 read 方法生成一个新的 Transaction (包含新的 csv reader) 提供用户分批查询数据。

parser

parser 层用于将用户输入的 raw SQL 转换为 AST,提供给 binder 层,进而转换为对查询引擎有意义的表达式。

raw SQL 的解析的过程和编程语言类似,包含词法与语法解析,涉及大量工作,因此,为了实现简单,这里实现直接采用了开源方案 sqlparser-rs 来作为 parser 层。

binder

基于 parser 给到的 AST 后,查询引擎需要知道一个 raw string,具体代表 table 中的哪一列,以及这一列对应的 datatype 等信息。因此,这些信息都需要在 binder 层进行转换完成。

为了实现最简单的 SQL bind,如 select c1 from t where c2 = 1。定义了最基础的 BoundSelect 如下:

1
2
3
4
5
pub struct BoundSelect {
    pub select_list: Vec<BoundExpr>,
    pub from_table: Option<BoundTableRef>,
    pub where_clause: Option<BoundExpr>,
}

Binder 将传入的 query AST 转换为 BoundSelect,因此,如果要想实现丰富的 SQL 语法支持,第一步就需要在 binder 层先定义明确。比如,上面 SQL 中的 c2 = 1,它就代表中一种 BoundExpr。

查询引擎中存在多种类型的表达式,来定义一个 SQL 语句中的某一个块,如下基础的 BoundExpr:

1
2
3
4
5
6
7
pub enum BoundExpr {
    Constant(ScalarValue),
    ColumnRef(BoundColumnRef),
    InputRef(BoundInputRef),
    BinaryOp(BoundBinaryOp),
    TypeCast(BoundTypeCast),
}

上面的 BoundExpr 依次解释为:

  • Constant:表示一个常量,比如 1,‘a’,true,false,null,等等。在 c2 = 1 中,1就是一个 constant 表达式。
  • ColumnRef:表示一个列,比如 c1c2。它内部会包括 ColumnCatalog。
  • InputRef:它是一个特殊的表达式,代表executor阶段从最终的 RecordBatch 中读取第几个 index 的数据。它会在 rewriter 阶段将所有的 ColumnRef 转换为 InputRef。
  • BinaryOp:表示一个二元运算,比如上面的 c2 = 1,它就是一个 operator 为 Eq 的 BinaryOp 表达式。它的 left 与 right 同样也是 BoundExpr。
  • TypeCast:表示一个类型转换,比如在 BinaryOp 中,左右为不同的类型,但它们可以通过 implicit cast 转换为相同类型。这时,构造 BinaryOp 时,就可以加入 TypeCast 表达式。

logical planner

到这一层,我们已经有了一个查询的所有必要信息,可以开始构建 PlanNode,来初始化 LogicalPlanNode 和 PhysicalPlanNode。比如:Project、Filter、TableScan。

declarative macro

在 PlanNode 的初始化中,会大量用到 Rust 的 declarative macro 来生成大量模板代码,即通过将一个 macro name 传入另外一个 macro,来实现宏复用。如下代码中的 impl_downcast_utility 宏,通过 for_all_plan_nodes 从外到内展开宏,实现了所有 PlanNode 的 as_xxx 方法 (如 as_logical_project)。这里也记录下idea来源

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
pub trait PlanNode: WithPlanNodeType + PlanTreeNode + Debug + Downcast {
    fn schema(&self) -> Vec<ColumnCatalog> {
        vec![]
    }
}
impl_downcast!(PlanNode);

macro_rules! for_all_plan_nodes {
    ($macro:ident) => {
        $macro! {
            Dummy,
            LogicalTableScan,
            LogicalProject,
            LogicalFilter,
            PhysicalTableScan,
            PhysicalProject,
            PhysicalFilter
        }
    };
}

macro_rules! impl_downcast_utility {
    ($($node_name:ident),*) => {
        impl dyn PlanNode {
            $(
                paste! {
                    #[allow(dead_code)]
                    pub fn [<as_ $node_name:snake>] (&self) -> std::result::Result<&$node_name, ()> {
                        self.downcast_ref::<$node_name>().ok_or_else(|| ())
                    }
                }
            )*
        }
    }
}
for_all_plan_nodes! { impl_downcast_utility }

rewriter

有了 LogicalPlanNode 和 PhysicalPlanNode 定义后,我们需要构建出它们,并组装成一个 PlanTree 来代表整个查询计划,提供 executor 执行。

PlanTree 的构建是通过一个个 PlanNode 组装起来的,除了最底层的 TableScan,每个 PlanNode 都会包含 children,来代表它下层的 PlanNode。如下 trait 定义:

1
2
3
4
5
6
7
pub trait PlanTreeNode {
    /// Get the child plan nodes.
    fn children(&self) -> Vec<PlanRef>;

    /// Clone the node with new children for rewriting plan node.
    fn clone_with_children(&self, children: Vec<PlanRef>) -> PlanRef;
}

构建 LogicalPlanNode 比较简单,只需要将 BoundSelect 中的信息,手动依次组装到 LogicalPlanTree 中即可,注意,一个 PlanTree 在代码中表示为 PlanRef,即Arc<dyn PlanNode>,因此需要一些 downcast 来获取具体的 PlanNode。

构建 PhysicalPlanNode 时,引入了 visitor pattern 来实现 PlanNode rewriter 和 BoundExpr rewriter。

首先讨论 PlanNode rewriter,如下定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
macro_rules! def_rewriter {
    ($($node_name:ident),*) => {
        pub trait PlanRewriter {
            paste! {
                fn rewrite(&mut self, plan: PlanRef) -> PlanRef {
                    match plan.node_type() {
                        $(
                            PlanNodeType::$node_name => self.[<rewrite_$node_name:snake>](plan.downcast_ref::<$node_name>().unwrap()),
                        )*
                    }
                }

                $(
                    fn [<rewrite_$node_name:snake>](&mut self, plan: &$node_name) -> PlanRef {
                        let new_children = plan
                            .children()
                            .into_iter()
                            .map(|child| self.rewrite(child.clone()))
                            .collect_vec();
                        plan.clone_with_children(new_children)
                    }
                )*
            }
        }
    };
}

for_all_plan_nodes! { def_rewriter }


// expand 后的代码如下
pub trait PlanRewriter {
            fn rewrite(&mut self, plan: PlanRef) -> PlanRef {
                match plan.node_type() {
                    PlanNodeType::Dummy => {
                        self.rewrite_dummy(plan.downcast_ref::<Dummy>().unwrap())
                    }
                    PlanNodeType::LogicalTableScan => self.rewrite_logical_table_scan(
                        plan.downcast_ref::<LogicalTableScan>().unwrap(),
                    ),
                    PlanNodeType::LogicalProject => {
                        self.rewrite_logical_project(plan.downcast_ref::<LogicalProject>().unwrap())
                    }
                    PlanNodeType::LogicalFilter => {
                        self.rewrite_logical_filter(plan.downcast_ref::<LogicalFilter>().unwrap())
                    }
                    PlanNodeType::PhysicalTableScan => self.rewrite_physical_table_scan(
                        plan.downcast_ref::<PhysicalTableScan>().unwrap(),
                    ),
                    PlanNodeType::PhysicalProject => self
                        .rewrite_physical_project(plan.downcast_ref::<PhysicalProject>().unwrap()),
                    PlanNodeType::PhysicalFilter => {
                        self.rewrite_physical_filter(plan.downcast_ref::<PhysicalFilter>().unwrap())
                    }
                }
            }
            fn rewrite_dummy(&mut self, plan: &Dummy) -> PlanRef {
                //...
            }
            fn rewrite_logical_table_scan(&mut self, plan: &LogicalTableScan) -> PlanRef {
                let new_children = plan
                    .children()
                    .into_iter()
                    .map(|child| self.rewrite(child.clone()))
                    .collect_vec();
                plan.clone_with_children(new_children)
            }
            fn rewrite_logical_project(&mut self, plan: &LogicalProject) -> PlanRef {
                //...
            }
            fn rewrite_logical_filter(&mut self, plan: &LogicalFilter) -> PlanRef {
                // ...
            }
            fn rewrite_physical_table_scan(&mut self, plan: &PhysicalTableScan) -> PlanRef {
                // ...
            }
            fn rewrite_physical_project(&mut self, plan: &PhysicalProject) -> PlanRef {
                // ...
            }
            fn rewrite_physical_filter(&mut self, plan: &PhysicalFilter) -> PlanRef {
                // ...
            }
        }

这个 trait 包含一些默认方法,来 rewrite 所有的 PlanNode,对每个 PlanNode 的 rewrite 方法都是一样的,即先执行所有 children 的 rewrite 方法,获取 new_children,然后再调用 clone_with_children 来生成新的 PlanNode。

通过这种方式,提供了一个统一的框架代码,即不论输入 LogicalPlanTree 或 PhysicalPlanTree,都可以通过 PlanRewriter 来进行 rewrite。如果需要自定义 rewrite 逻辑以实现特殊功能,可以覆写 PlanRewriter 相应的方法。比如:PhysicalRewriter 和 InputRefRewriter。

  • PhysicalRewriter:只覆写了 rewrite_logical_xxx 方法,将一个 LogicalPlanTree 转换为一个 PhysicalPlanTree。
  • InputRefRewriter:只覆写了 rewrite_logical_xxx 方法,从最底层依次往上 rewrite ColumnRef 表达式为 InputRef。

除了 PlanNode rewriter 外,还有 BoundExpr rewriter,如下定义:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
pub trait ExprRewriter {
    fn rewrite_expr(&self, expr: &mut BoundExpr) {
        match expr {
            BoundExpr::Constant(_) => self.rewrite_constant(expr),
            BoundExpr::ColumnRef(_) => self.rewrite_column_ref(expr),
            BoundExpr::InputRef(_) => self.rewrite_input_ref(expr),
            BoundExpr::BinaryOp(_) => self.rewrite_binary_op(expr),
            BoundExpr::TypeCast(_) => self.rewrite_type_cast(expr),
        }
    }

    fn rewrite_constant(&self, _: &mut BoundExpr) {}

    fn rewrite_column_ref(&self, _: &mut BoundExpr) {}

    fn rewrite_input_ref(&self, _: &mut BoundExpr) {}

    fn rewrite_type_cast(&self, _: &mut BoundExpr) {}

    fn rewrite_binary_op(&self, expr: &mut BoundExpr) {
        match expr {
            BoundExpr::BinaryOp(e) => {
                self.rewrite_expr(&mut e.left);
                self.rewrite_expr(&mut e.right);
            }
            _ => unreachable!(),
        }
    }
}

比如,InputRefRewriter 需要实现 ExprRewriter,覆写 rewrite_column_ref 方法,从而在 visit LogicalPlanTree 时,将 ColumnRef 转换为 InputRef。

到这一步结束,我们已经可以获取到一个 PhysicalPlanTree,下一步就是 executor 如何构建 volcano 执行模型。

executor

执行引擎的核心是一个 Vectorized Volcano Model。因此,我们仍然需要 visitor pattern 来遍历 PlanTree。仿照上面的 PlanRewriter 实现了类似的 PlanVisitor:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
macro_rules! def_rewriter {
    ($($node_name:ident),*) => {
        pub trait PlanVisitor<R> {
            paste! {
                fn visit(&mut self, plan: PlanRef) -> Option<R> {
                    match plan.node_type() {
                        $(
                            PlanNodeType::$node_name => self.[<visit_$node_name:snake>](plan.downcast_ref::<$node_name>().unwrap()),
                        )*
                    }
                }

                $(
                    fn [<visit_$node_name:snake>](&mut self, _plan: &$node_name) -> Option<R> {
                        unimplemented!("The {} is not implemented visitor yet", stringify!($node_name))
                    }
                )*
            }
        }
    };
}

for_all_plan_nodes! { def_rewriter }

构造 executor 的入口是 ExecutorBuilder,它实现了 PlanVisitor,将输入的 PhysicalPlanTree 转换为 Executor 组装成的 DAG,进而去执行。

对我们要实现的 SQL:select c1 from t where c2 = 1,只需要 project、filter 和 table_scan 这三个 executor。下面依次会介绍用到的技术点。

futures-async-stream

首先,对一个执行算子来说,它本质是一个迭代器,你可以通过 next_batch 去获取数据,但为了更加清晰的代码逻辑,采用了和RisingLight里一样的 futures-async-stream 来包裹执行逻辑。从最外层抽象来说,所有算子会组装成一个 stream,从DAG的最顶层依次往下抽取数据,中间经过各种算子的执行逻辑,最终返回到顶层。

从 TableScanExecutor 开始,它是最底层的 executor,输入的是 storage,通过 transaction.next_batch() 来获取数据,因此,它的核心逻辑如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
#[try_stream(boxed, ok = RecordBatch, error = ExecutorError)]
pub async fn execute(self) {
    let table_id = self.plan.logical().table_id();
    let table = self.storage.get_table(table_id)?;
    let mut tx = table.read()?;
    loop {
        match tx.next_batch() {
            Ok(batch) => {
                if let Some(batch) = batch {
                    yield batch;
                } else {
                    break;
                }
            }
            Err(err) => return Err(err.into()),
        }
    }
}

evaluator

有了 TableScanExecutor 后,需要选取某一列数据,需要实现 ProjectExecutor。它将下游给的 RecordBatch 输入到 BoundExpr 中进行计算。如下几种 BoundExpr:

  • InputRef:取出 RecordBatch 中的某一列数据
  • Constant:构造一个常量 column
  • TypeCast:将一个 column data 转换为指定类型
  • BinaryOp:对 left 和 right expr 计算 BinaryOperator,比如:left + rightleft = right

对于SQL:select c1 from t where c2 = 1,ProjectExecutor 只需要 eval InputRef 获取对于 index 的 column 就可以了,同时构造新的 RecordBatch 返回出去。如下代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
#[try_stream(boxed, ok = RecordBatch, error = ExecutorError)]
pub async fn execute(self) {
    #[for_await]
    for batch in self.child {
        let batch = batch?;
        let columns = self
            .exprs
            .iter()
            .map(|e| e.eval_column(&batch))
            .try_collect();
        let fields = self.exprs.iter().map(|e| e.eval_field(&batch)).collect();
        let schema = SchemaRef::new(Schema::new_with_metadata(
            fields,
            batch.schema().metadata().clone(),
        ));
        yield RecordBatch::try_new(schema, columns?)?;
    }
}

array_compute

这一步,需要处理SQL:select c1 from t where c2 = 1 中的 where 表达式。因此,引入了 arrow 的 compute 模块,它提供了许多 array 之间的计算方法,比如:add、multiply、eq 等。

比如,对 where 表达式用到的 binary_op 计算,如下代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
pub fn binary_op(
    left: &ArrayRef,
    right: &ArrayRef,
    op: &BinaryOperator,
) -> Result<ArrayRef, ExecutorError> {
    match op {
        BinaryOperator::Plus => arithmetic_op!(left, right, add),
        BinaryOperator::Minus => arithmetic_op!(left, right, subtract),
        BinaryOperator::Multiply => arithmetic_op!(left, right, multiply),
        BinaryOperator::Divide => arithmetic_op!(left, right, divide),
        BinaryOperator::Gt => Ok(Arc::new(gt_dyn(left, right)?)),
        BinaryOperator::Lt => Ok(Arc::new(lt_dyn(left, right)?)),
        BinaryOperator::GtEq => Ok(Arc::new(gt_eq_dyn(left, right)?)),
        BinaryOperator::LtEq => Ok(Arc::new(lt_eq_dyn(left, right)?)),
        BinaryOperator::Eq => Ok(Arc::new(eq_dyn(left, right)?)),
        _ => todo!(),
    }
}

同时,arrow 提供了 filter_record_batch 方法,我们只需要传入一个 boolean array 作为 predicate,就可以过滤出 match 的 rows 组织成新的 RecordBatch。

summary

至此,select c1 from t where c2 = 1 这条SQL经过我们的查询引擎,可以将数据查询出来了。

同时,一个简单的查询引擎包含的模块也已经初始化完毕,后续直接在具体的 module 中加入相应特性即可,比如 aggregation、join 等。

Roadmap v0.2

Roadmap v0.2 主要新增功能为:

  • testing framework
  • aggregation operators
  • interactive mode

它的 milestone 是搭建一个 ec2 测试框架加入现有支持的 SQL;同时支持 simple aggregation SQL query: select sum(c1), count(c2), max(c2) from t 和 hash aggregation SQL query: select c1, count(c1), max(c2) from t group by c1

testing framework

为了更方便测试,保证已有功能的正确性,引入 e2e 测试框架 sqllogictest-rs,它来源于 risinglight 社区 port 的 sqllogictest

引入测试的第一步,先重构 src 下的代码为 library 和 binary,即拆分原先 main.rs 中的代码为 lib.rs、db.rs 和 main.rs。

接着新增 sqllogicaltest 的 workspace,它的主要逻辑是,针对每个 slt 的 script 文件,都会读取 csv 生成一个新的 database,然后执行与验证 slt 中的 SQL。

其中特别的部分是,引入 cargo-nextest Custom test harnesses 来实现 cargo-nextest 的自定义测试集成。比如,sqllogicaltest 当前场景下,针对 slt 的每个 script 文件单独生成一个 test,最终 report 也会纳入 cargo-nextest 中。如下展示:

1
2
3
4
5
6
7
Starting 19 tests across 4 binaries
PASS [   0.032s]             sql-query-engine-rs executor::evaluator::evaluator_test::test_eval_column_for_input_ref
PASS [   0.048s]             sql-query-engine-rs storage::csv::tests::test_csv_storage_works
...
PASS [   0.044s] sqllogictest-test::sqllogictest select
PASS [   0.045s] sqllogictest-test::sqllogictest filter
Summary [   0.094s] 19 tests run: 19 passed, 0 skipped

同时,新增了 memory storage 更方便的构造UT测试数据。

simple aggregation

simple aggregation 含义是,没有 group_by 表达式,直接对数据进行 aggregation。

构建它的流程和 Roadmap 0.1 的流程一致,需要从 BoundExpr -> Logical/Physical PlanNode -> Planner -> Executor。其中还涉及到了 agg_expr 的 InputRef resolve 和 aggregate accumulator。

下面会介绍几点不同的部分。

ExprVisitor

Planner 在构建 LogicalPlanTree 时,需要决定是否插入 LogicalAgg PlanNode,判断的依据是 BoundSelect Tree 中是否包含 AggExpr。细想一下,这是一个通用的功能,在一个 BoundTree 中找出特定的 BoundExpr。

而对于一个 tree 的操作,可以通过 visitor 模型来遍历,这里的 find_aggregate_exprs 也是 visitor 一种特殊情况。

因此,采用与 ExprRewriter 类似的代码,事先构造好 visit path。但是这里提供了一个 pre_visit 的 hook 用于自定义逻辑,比如提供给 find_aggregate_exprs。如下代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
pub trait ExprVisitor {
    fn pre_visit(&mut self, _: &BoundExpr) {}

    fn visit_expr(&mut self, expr: &BoundExpr) {
        self.pre_visit(expr);
        match expr {
            BoundExpr::Constant(expr) => self.visit_constant(expr),
            BoundExpr::ColumnRef(expr) => self.visit_column_ref(expr),
            BoundExpr::InputRef(expr) => self.visit_input_ref(expr),
            BoundExpr::BinaryOp(expr) => self.visit_binary_op(expr),
            BoundExpr::TypeCast(expr) => self.visit_type_cast(expr),
            BoundExpr::AggFunc(expr) => self.visit_agg_func(expr),
        }
    }

    // ...
}

find_aggregate_exprs 方法利用的 ExprFinder 参考于 datafusion,从输入的 BoundExprs 中返回具体类型的 exprs。如下代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// Visitor that find expressions that match a particular predicate
struct ExprFinder<'a, F>
where
    F: Fn(&BoundExpr) -> bool,
{
    test_fn: &'a F,
    exprs: Vec<BoundExpr>,
}

impl<'a, F> ExprFinder<'a, F>
where
    F: Fn(&BoundExpr) -> bool,
{
    fn new(test_fn: &'a F) -> Self {
        Self {
            test_fn,
            exprs: Vec::new(),
        }
    }
}

impl<'a, F> ExprVisitor for ExprFinder<'a, F>
where
    F: Fn(&BoundExpr) -> bool,
{
    fn pre_visit(&mut self, expr: &BoundExpr) {
        if (self.test_fn)(expr) && !self.exprs.contains(expr) {
            self.exprs.push(expr.clone());
        }
    }
}

/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
/// provided test. The returned `Expr`'s are deduplicated and returned in order
/// of appearance (depth first).
fn find_exprs_in_expr<F>(expr: &BoundExpr, test_fn: &F) -> Vec<BoundExpr>
where
    F: Fn(&BoundExpr) -> bool,
{
    let mut finder = ExprFinder::new(test_fn);
    finder.visit_expr(expr);
    finder.exprs
}

InputRefRewriter

之前的 InputRefRewriter 还未遇到 nested BoundExpr 的情况,但是对于 agg_expr 的情况,比如 sum (a+1) 这种情况,需要 resolve a 在 RecordBatch 中的 index。因此,需要对之前 InputRefRewriter 的 ExprRewriter 进行重构。

最初逻辑:只针对 ColumnRef 去 bindings 中找到相同 expr 对应 index 并替换成 input_ref。

改变后的逻辑:

  1. 不论 BoundExpr 的类型,都会先从 bindings 中搜索相同 expr 对应的 index。
  2. 如果未搜寻到,则继续递归搜寻 BoundExpr 的 nested expr。

如下代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
impl InputRefRewriter {
    fn rewrite_internal(&self, expr: &mut BoundExpr) {
        // Find input expr in bindings.
        if let Some(idx) = self.bindings.iter().position(|e| *e == expr.clone()) {
            *expr = BoundExpr::InputRef(BoundInputRef {
                index: idx,
                return_type: expr.return_type().unwrap(),
            });
            return;
        }

        // If not found in bindings, expand nested expr and then continuity rewrite_expr.
        match expr {
            BoundExpr::BinaryOp(e) => {
                self.rewrite_expr(e.left.as_mut());
                self.rewrite_expr(e.right.as_mut());
            }
            BoundExpr::TypeCast(e) => self.rewrite_expr(e.expr.as_mut()),
            BoundExpr::AggFunc(e) => {
                for arg in &mut e.exprs {
                    self.rewrite_expr(arg);
                }
            }
            _ => unreachable!(
                "unexpected expr type {:?} for InputRefRewriter, binding: {:?}",
                expr, self.bindings
            ),
        }
    }
}

impl ExprRewriter for InputRefRewriter {
    fn rewrite_column_ref(&self, expr: &mut BoundExpr) {
        self.rewrite_internal(expr);
    }

    fn rewrite_type_cast(&self, expr: &mut BoundExpr) {
        self.rewrite_internal(expr);
    }

    fn rewrite_binary_op(&self, expr: &mut BoundExpr) {
        self.rewrite_internal(expr);
    }

    fn rewrite_agg_func(&self, expr: &mut BoundExpr) {
        self.rewrite_internal(expr);
    }
}

Accumulator

对于聚合函数,需要存在一个累加器,在遍历每次 chunk data 的时候,进行结果的累加。

比如,sum 函数,需要一个 SumAccumulator,它存储着一个计算总数 result,在每次遍历数据时,将当前 chunk 的总数累加到 result 中。因此,基于此我们定义 accumulator trait:

1
2
3
4
5
6
7
pub trait Accumulator: Send + Sync {
    /// updates the accumulator's state from a vector of arrays.
    fn update_batch(&mut self, array: &ArrayRef) -> Result<(), ExecutorError>;

    /// returns its value based on its current state.
    fn evaluate(&self) -> Result<ScalarValue, ExecutorError>;
}

仔细发现,我们的计算层完全是依赖于 arrow,因此对两个 ArrayRef 的 sum 函数,也是直接使用 arrow 的 compute 库。

在具体实现中,参考了 datafusion 的宏实现,对 ArrayRef 进行类型转换。

最后一个步骤,新增 SimpleAggExecutor 与 Accumulator 的组合起来。它的基本思路和其它 executor 一样,区别在于,聚合函数 executor 需要将 child 的所有 data chunk 遍历完后,才能计算出最终的汇总结果。

在每次遍历 child chunk data 时候,用 accumulator 进行累加。当所有数据遍历完后,调用 accumulator 的 evaluate 函数,获取最终的结果。如下代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#[try_stream(boxed, ok = RecordBatch, error = ExecutorError)]
pub async fn execute(self) {
    let mut accs = create_accumulators(&self.agg_funcs);
    let agg_funcs = self.cast_agg_funcs();

    let mut agg_fileds: Option<Vec<Field>> = None;

    #[for_await]
    for batch in self.child {
        let batch = batch?;
        // only support one epxrssion in aggregation, not supported example: `sum(distinct a)`
        let columns: Result<Vec<_>, ExecutorError> = agg_funcs
            .iter()
            .map(|agg| agg.exprs[0].eval_column(&batch))
            .try_collect();

        // build new schema for aggregation result
        if agg_fileds.is_none() {
            agg_fileds = Some(
                agg_funcs
                    .iter()
                    .map(|agg| {
                        let inner_name = agg.exprs[0].eval_field(&batch).name().clone();
                        let new_name = format!("{}({})", agg.func, inner_name);
                        Field::new(new_name.as_str(), agg.return_type.clone(), false)
                    })
                    .collect(),
            );
        }
        let columns = columns?;
        for (acc, column) in accs.iter_mut().zip_eq(columns.iter()) {
            acc.update_batch(column)?;
        }
    }

    let mut columns: Vec<ArrayRef> = Vec::new();
    for acc in accs.iter() {
        let res = acc.evaluate()?;
        columns.push(build_scalar_value_array(&res, 1));
    }

    let schema = SchemaRef::new(Schema::new(agg_fileds.unwrap()));
    yield RecordBatch::try_new(schema, columns)?;
}

hash aggregation

首先,我们实现的 aggregation 的前提都是,所有数据都是放在 memory 中,为了简单没有考虑 external hashing aggregation,如果要实现的话,可以参考 15445 external hashing aggregation

对于 aggregation 的实现,一般还会提到 sort aggregation。但大多数情况下,hash aggregation 会比它更加高效。有种情况下:group keys 很多时候,遇到了内存问题,会退化成 sort aggregation 来实现。参考 spark hash vs sort aggregation

下面我们开始介绍它的具体实现,整体的实现步骤和 simple aggregation (no group by exprs) 的流程一样,只不过 hash aggregation 需要在每个步骤中加入 group by exprs 的处理。

开始,需要创建一个新的 PlanNode,PhysicalHashAgg 来代表 LogicalAgg 的另外一种实现,与 PhysicalSimpleAgg 区别。

由于有了两种 PhysicalAgg 的实现,因此需要在 PhysicalRewriter 来决定选用哪一种实现。同时,由于引入了 group by exprs,也需要在 InputRefRewriter 中对 group by exprs 进行 InputRef 的解析。

上面两步骤做完后,一个SQL select c1, count(c1), max(c2) from t group by c1 就可以转成对应的 PhysicalPlanTree 了。下一步,需要给 PhysicalHashAgg 配套上对应的 executor 来实现 hash aggregation,也是实现中最重要的部分。

由于 executor 的计算单元是 Arrow RecordBatch,它是多个 colums 组装而成的集合。而我们在 hashing 的过程中是针对于 row 来进行的,因此整个逻辑中,有个行列数据的转换需要考虑到。

具体的实现步骤分为:

构造schema

1.构造最终结果 RecordBatch 的 schema

这部分实现比较简单,直接计算出 group_by 和 agg_funcs exprs 对应的 fields。注意,我们默认是将 group_by 对应的 fields 放在 agg_funcs 之前的,这个约定也贯穿在后续计算中。

eval所需数据

2.1 evaluate agg exprs 聚合函数所需计算的 column data,比如对 Sum(a+1),获取到的 column data 为 column a + 1 所得到的数据

2.2 evaluate group by exprs 获取后续计算 hashing 的 column data

构造hashmap

3.1 将2.2中获取到的 group by column data 计算出每一行的 hash value

这一步骤参考 arrow-datafusion 中的 hash_utils.rs 中的 create_hashes 方法,它可接收多个 column data,并计算出每一行的 hash value。如下使用代码,两个 arr 中的 0.12 组成的一行数据的 hash 是相同的:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
let f64_arr = Arc::new(Float64Array::from_iter_values(vec![
    0.12, 0.12, 1f64, 444.7,
]));
let f64_arr_2 = Arc::new(Float64Array::from_iter_values(vec![
    0.12, 0.12, 1f64, 444.7,
]));

let random_state = RandomState::with_seeds(0, 0, 0, 0);
let hashes_buff = &mut vec![0; f64_arr.len()];

let hashes = create_hashes(&[f64_arr, f64_arr_2], &random_state, hashes_buff)?;
assert_eq!(hashes.len(), 4);
assert_eq!(hashes.clone(), hashes_buff.clone());
assert_eq!(
    hashes_buff,
    &[
        15550289857534363376,
        15550289857534363376,
        13221404211197939868,
        9886939767832447622
    ]
);

3.2 构造计算所需的 hashmap

  • hash->accumulator map:给同一个 group key 创建一系列 accumulators (取决于agg_funcs个数),用于计算聚合结果
  • hash->group row indices map:它是标记 column data 中哪些行是同一个 group 的数据,用于后续 accumulator 取出自己 group 的 column data
  • hash->group keys:它记录了所有的 group keys,用于后续与聚合结果拼接在一起

执行accumulator

利用到3.2中准备的 hash->group row indices map,将同一个 group 中的 data 取出执行 acc.update_batch 计算。

注意,由于存在多个 accumulators,所有这里的计算逻辑是一个双层 for 循环,即对所有的 group 组合,都执行所有的聚合函数累加操作。如下代码:

1
2
3
4
5
6
7
8
9
for (hash, mut idx_builder) in group_hash_2_row_indices {
    let indices = idx_builder.finish();
    let accs = group_hash_2_accs.get_mut(&hash).unwrap();
    for (acc, column) in accs.iter_mut().zip_eq(columns.iter()) {
        // take one group rows from a column
        let new_array = compute::take(column.as_ref(), &indices, None)?;
        acc.update_batch(&new_array)?;
    }
}

组装RecordBatch所需data

经过上面 accumulator 计算后,我们获取到的是对一个 group key 的聚合结果,换句话说,它是针对行的结果。因此,我们需要转换为 columnar data 用于构造 RecordBatch。

思路是,对每个 column 都构造一个 array_builder,再去依次遍历 group_values 和 group_hash_2_accs,相当于将遍历行数据时候,对每一个列 array_builder 执行 append 操作,也是一个场景的行转列操作。如下代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
for hash in group_hashs {
    let group_values = group_hash_2_keys.get(&hash).unwrap();
    for (idx, group_key) in group_values.iter().enumerate() {
        append_scalar_value_for_builder(group_key, &mut builders[idx])?;
    }

    for (idx, acc) in group_hash_2_accs.get(&hash).unwrap().iter().enumerate() {
        append_scalar_value_for_builder(
            &acc.evaluate()?,
            &mut builders[idx + group_values.len()],
        )?;
    }
}

完成RecordBatch

获取上一步的 column builder 的结果,加上第一步构造的 schema,最终就可以拼接成 HashAggregation 的输出 RecordBatch 了。

1
2
3
let columns = builders.iter_mut().map(|b| b.finish()).collect::<Vec<_>>();
let schema = SchemaRef::new(Schema::new(fields));
yield RecordBatch::try_new(schema, columns)?;

Summary

至此,我们的 sql-query-engine 已经有了一个基础的查询功能,后续会给它加入更多算子,为优化器部分准备。