学习TreeNode之前,我们先了解下InternalRow。
对于我们一般接触到的数据库关系表来说,我们对于数据库中的数据操作都是按照“行”为单位的。在spark sql内部实现中,InternalRow是用来表示这一行行数据的类。看下源码中的解释,InternalRow作为一个抽象类,包numFields 和 update 方法,以及各列数据对应的 get 与 set 方法,但具体的实现逻辑体现在不同的子类中
/**
* An abstract class for row used internally in Spark SQL, which only contains the columns as
* internal types.
一个抽象类,用于表示spark SQL内部行,只包含内部类型的多个列(其实就是表示一行行数据的类)
*/
详细代码这里就不贴了,整理下一些重要接口的功能含义好了,注意InternalRow中都是根据下标来访问和操作列元素的 。
InternalRow实现类包括,BaseGenericinternalRow、UnsafeRow 和 JoinedRow 3 个直接子类
接下来正式开始进行TreeNode的学习
TreeNode是Spark SQL中所有树结构的基类,定义了一系列通用的集合操作和树遍历的操作接口。我们先看下TreeNode的代码
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with TreePatternBits {}
首先TreeNode是一个抽象类,一个泛型类;这里TreeNode[BaseType <: TreeNode[BaseType]]这种书写方式,不知道大家会不会很陌生,反正我一开始看的时候,觉得不知道咋回事,那么我们来一起理解写,这个具体是什么含义:
另外,TreeNode还继承了Product接口,对于product接口相关使用介绍,请看这篇文章(scala之product特质理解_大家都叫我船长的博客-CSDN博客),看完应该就明白了。
接下来,开始详细看看TreeNode一些重要方法:
/*** Returns a Seq of the children of this node.* Children should not change. Immutability required for containsChild optimization*/def children: Seq[BaseType]
lazy val containsChild: Set[TreeNode[_]] = children.toSet
def fastEquals(other: TreeNode[_]): Boolean = {this.eq(other) || this == other}
def find(f: BaseType => Boolean): Option[BaseType] = if (f(this)) {Some(this)} else {children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) }}
def foreach(f: BaseType => Unit): Unit = {f(this)children.foreach(_.foreach(f))}
def foreachUp(f: BaseType => Unit): Unit = {children.foreach(_.foreachUp(f))f(this)}
def map[A](f: BaseType => A): Seq[A] = {val ret = new collection.mutable.ArrayBuffer[A]()foreach(ret += f(_))ret.toSeq}
def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = {val ret = new collection.mutable.ArrayBuffer[A]()foreach(ret ++= f(_)) //f返回的结果必须是一个集合ret.toSeq}
def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = {val ret = new collection.mutable.ArrayBuffer[B]()val lifted = pf.liftforeach(node => lifted(node).foreach(ret.+=))ret.toSeq}
def collectLeaves(): Seq[BaseType] = {this.collect { case p if p.children.isEmpty => p }}
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = {val lifted = pf.liftlifted(this).orElse {children.foldLeft(Option.empty[B]) { (l, r) => l.orElse(r.collectFirst(pf)) }}}
mapProductIterator其实功能和productIterator.map(f).toArray一致
protected def mapProductIterator[B: ClassTag](f: Any => B): Array[B] = {val arr = Array.ofDim[B](productArity)var i = 0while (i < arr.length) {arr(i) = f(productElement(i))i += 1}arr}
inal def withNewChildren(newChildren: Seq[BaseType]): BaseType = {val childrenIndexedSeq = asIndexedSeq(children)val newChildrenIndexedSeq = asIndexedSeq(newChildren)assert(newChildrenIndexedSeq.size == childrenIndexedSeq.size, "Incorrect number of children")if (childrenIndexedSeq.isEmpty ||childrenFastEquals(newChildrenIndexedSeq, childrenIndexedSeq)) {this} else {CurrentOrigin.withOrigin(origin) {val res = withNewChildrenInternal(newChildrenIndexedSeq)res.copyTagsFrom(this)res}}}
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = {transformDown(rule)}
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {transformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)}def transformDownWithPruning(cond: TreePatternBits => Boolean,ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType]): BaseType = {if (!cond.apply(this) || isRuleIneffective(ruleId)) {return this}val afterRule = CurrentOrigin.withOrigin(origin) {// 如果 this 是 BaseType 或其子类,则对 this 应用 rule 再返回应用 rule 后的结果,否则返回 thisrule.applyOrElse(this, identity[BaseType])}// Check if unchanged and then possibly return old copy to avoid gc churn.if (this fastEquals afterRule) {// 如果应用了 rule 后节点无变化,则递归将 rule 应用于 childrenval rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))if (this eq rewritten_plan) {markRuleAsIneffective(ruleId)this} else {rewritten_plan}} else {// If the transform function replaces this node with a new one, carry over the tags.// 如果应用了 rule 后节点有变化,则本节点换成变化后的节点(children 不变),再将 rule 递归应用于子节点。也就是从根节点往下来应用 rule 替换节点afterRule.copyTagsFrom(this)afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))}}
transformWithPruning,底层调用transformDownWithPruning(功能是返回此节点的副本,其中“规则”已递归应用于树。当“规则”不适用于给定节点时,它将保持不变)
def transformWithPruning(cond: TreePatternBits => Boolean,ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType]): BaseType = {transformDownWithPruning(cond, ruleId)(rule)}def transformDownWithPruning(cond: TreePatternBits => Boolean,ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType]): BaseType = {if (!cond.apply(this) || isRuleIneffective(ruleId)) {return this}val afterRule = CurrentOrigin.withOrigin(origin) {// 如果 this 是 BaseType 或其子类,则对 this 应用 rule 再返回应用 rule 后的结果,否则返回 thisrule.applyOrElse(this, identity[BaseType])}// Check if unchanged and then possibly return old copy to avoid gc churn.if (this fastEquals afterRule) {// 如果应用了 rule 后节点无变化,则递归将 rule 应用于 childrenval rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))if (this eq rewritten_plan) {markRuleAsIneffective(ruleId)this} else {rewritten_plan}} else {// If the transform function replaces this node with a new one, carry over the tags.// 如果应用了 rule 后节点有变化,则本节点换成变化后的节点(children 不变),再将 rule 递归应用于子节点。也就是从根节点往下来应用 rule 替换节点afterRule.copyTagsFrom(this)afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))}}
transformUp 用后序遍历方式将规则作用于所有节点,调用transformUpWithPruning
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {transformUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)}def transformUpWithPruning(cond: TreePatternBits => Boolean,ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType]): BaseType = {if (!cond.apply(this) || isRuleIneffective(ruleId)) {return this}val afterRuleOnChildren = mapChildren(_.transformUpWithPruning(cond, ruleId)(rule))val newNode = if (this fastEquals afterRuleOnChildren) {CurrentOrigin.withOrigin(origin) {rule.applyOrElse(this, identity[BaseType])}} else {CurrentOrigin.withOrigin(origin) {rule.applyOrElse(afterRuleOnChildren, identity[BaseType])}}if (this eq newNode) {markRuleAsIneffective(ruleId)this} else {// If the transform function replaces this node with a new one, carry over the tags.newNode.copyTagsFrom(this)newNode}}
f
应用于所有子节点后该节点的 copy。def mapChildren(f: BaseType => BaseType): BaseType = {if (containsChild.nonEmpty) {withNewChildren(children.map(f))} else {this}}
上面罗列的方法,基本就是TreeNode常用的,还有一些不常用的非核心的,这里就不一一介绍了,大家有兴趣的可以自己去看看源码。
另外TreeNode有两个子类,分别是Expression和QueryPlan,这篇文章我们就先讲到这里,后面会对这两个子类也会进行一一介绍的。
有兴趣的可以关注我,后面一起学习sparkSql源码,另外文章中有错误的地方,感谢指出哈。