图解Spark Graphx基于connectedComponents函数实现连通图底层原理
阅读原文时间:2023年09月04日阅读:2

原创/朱季谦

第一次写这么长的graphx源码解读,还是比较晦涩,有较多不足之处,争取改进。

一、连通图说明

连通图是指图中的任意两个顶点之间都存在路径相连而组成的一个子图。

用一个图来说明,例如,下面这个叫graph的大图里,存在两个连通图。

左边是一个连接图,该子图里每个顶点都存在路径相连,包括了顶点:{(5L, "Eve"), (7L, "Grace"), (1L, "Alice"), (2L, "Bob"), (3L, "Charlie")}。

右边同样是一个连接图,该子图里每个顶点都存在路径相连,包括了顶点:{(8L, "Henry"),(9L, "Ivy"),(6L, "Frank")}。

在现实生活里,这两个子图就相当某个社区里的关系网,在Spark Graphx里,经常需要处理这类关系网的操作,那么,在一个图里如何得到各个子图的数据呢?

这时,就可以使用到Spark Graphx的connectedComponents函数,网上关于它的介绍,基本都是说它是Graphx三大图算法之一的连通组件。

连通组件是指图中的一组顶点,每个顶点之间都存在路径互相关联,也就是前面提到图中的子图概念。

通俗解释,就是通过这个函数,可以将每个顶点都关联到连通图里的最小顶点,例如,前面提到的子图{(8L, "Henry"),(9L, "Ivy"),(6L, "Frank")},在通过connectedComponents函数处理后,就可以得到每个顶点关联到该子网的最小顶点ID。该子图里的最小顶点ID是6L,那么,可以处理成以下数据{(8L,6L),(9L,6L),(6L,6L)}。既然属于同一个子图的各个顶点都关联到一个共同的最小顶点,不就意味着,通过该最小顶点,是可以按照分组的操作,将同一个最小顶点的数据都分组到一块,这样,就能提取出同一个子图的顶点集合了。

二、案例说明

基于以上的图顶点和边数据,创建一个Graphx图——

val conf = new SparkConf().setMaster("local[*]").setAppName("graphx")
val ss = SparkSession.builder().config(conf).getOrCreate()

// 创建顶点RDD
val vertices = ss.sparkContext.parallelize(Seq(
  (1L, "Alice"),
  (2L, "Bob"),
  (3L, "Charlie"),
  (5L, "Eve"),
  (6L, "Frank"),
  (7L, "Grace"),
  (8L, "Henry"),
  (9L, "Ivy")
))

// 创建边RDD
val edges = ss.sparkContext.parallelize(Seq(
  Edge(5L, 7L, "friend"),
  Edge(5L, 1L, "friend"),
  Edge(1L, 2L, "friend"),
  Edge(2L, 3L, "friend"),
  Edge(6L, 9L, "friend"),
  Edge(9L, 8L, "friend")
))

//创建一个Graph图
val graph = Graph(vertices, edges, null)

调用图graph的connectedComponents函数,顺便打印一下效果,可以看到,左边子图{(5L, "Eve"), (7L, "Grace"), (1L, "Alice"), (2L, "Bob"), (3L, "Charlie")}里的各个顶点都关联到了最小顶点1,右边子图{(8L, "Henry"),(9L, "Ivy"),(6L, "Frank")}里的各个顶点都关联到了最小顶点6。

val cc = graph.connectedComponents()
cc.vertices.foreach(println)

打印的结果——
(2,1)
(6,6)
(7,1)
(1,1)
(9,6)
(8,6)
(3,1)
(5,1)

注意一点,connectedComponents是可以传参的,传入的数字,是代表各个顶点最高可以连通迭代到多少步去寻找所在子图里的最小顶点。

举个例子,可能就能明白了,假如,给connectedComponents传参为1,那么代码执行打印后,如下——

val cc = graph.connectedComponents(1)
cc.vertices.foreach(println)

打印的结果——
(2,1)
(5,1)
(8,8)
(7,5)
(1,1)
(9,6)
(6,6)
(3,2)

你会发现,各个顶点的连通组件即关联所在子图的最小顶点,大多都变了,这是因为设置参数为1 后,各个顶点沿着边去迭代寻找连通组件时,只能迭代一步,相当本顶点只能走到一度邻居顶点,然后将本顶点和邻居顶点比较,谁最小,最小的当作连通组件。

以下图说明,就是顶点(7L, "Grace")迭代一步去寻找最小顶点做连通组件,只能迭代到顶点(5L, "Eve"),没法迭代到 (1L, "Alice"),这时顶点(7L, "Grace")就会拿自身与顶点(5L, "Eve")比较,发现5L更小,就会用5L当作自己的连通组件做关联,即(7,5)。

当然,实际底层的源码实现,并非是通过迭代多少步去寻找最小顶点,它的实现方式更精妙,站在原地就可以收集到所能迭代最大次数范围内的最小顶点。

如果connectedComponents没有设置参数,就会默认最大迭代次数是Int.MaxValue,2 的 31 次方 - 1即2147483647

在实际业务当中,可以通过设置参数来避免在过大规模的子图里做耗时过长的迭代操作

接下来,就可以通过连通组件做分组,将具有共同连通组件的顶点分组到一块,这样就知道哪些顶点属于同一子图了。

val cc = graph.connectedComponents()
val group = cc.vertices.map{
  case (verticeId, minId) => (minId, verticeId)
}.groupByKey()

group.foreach(println)

打印结果——
(1,CompactBuffer(1, 2, 3, 5, 7))
(6,CompactBuffer(8, 9, 6))

基于这个函数,就可以得到哪些顶点在一张关系网里了。

三、connectedComponents源码解析

先来看一下connectedComponents函数源码,在ConnectedComponents单例对象里,可以看到,如果没有传参的话,默认迭代次数是Int.MaxValue,如果传参的话,就使用参数的maxIterations做迭代次数——

/**
*无参数
*/
def connectedComponents(): Graph[VertexId, ED] = {
  ConnectedComponents.run(graph)
}

def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = {
    run(graph, Int.MaxValue)
}

/**
*有参数
*/
def connectedComponents(maxIterations: Int): Graph[VertexId, ED] = {
    ConnectedComponents.run(graph, maxIterations)
}

在run方法里,主要是做了一些函数和常量的准备工作,然后将这些函数和常量传给单例对象Pregel的apply方法。apply是单例对象的特殊方法,就像Java类里的构造方法一样,创建对象时可以直接被调用。Pregel(ccGraph, initialMessage,maxIterations, EdgeDirection.Either)(……)最后调用的就是Pregel里的apply方法。

def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED],
                                    maxIterations: Int): Graph[VertexId, ED] = {
  require(maxIterations > 0, s"Maximum of iterations must be greater than 0," +
    s" but got ${maxIterations}")
  //step1 初始化图,将各顶点id设置为顶点属性,图顶点结构(vid,vid)
  val ccGraph = graph.mapVertices { case (vid, _) => vid }

  //step2 处理图里的每一个三元组边对象,该对象edge包含了源顶点(srcId,srcAttr)和目标顶点(dstId,dstAttr)的信息,及边属性attr,即(srcId,srcAttr,dstId,dstAttr,attr)
  def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = {
    //如果源顶点属性小于目标顶点属性
    if (edge.srcAttr < edge.dstAttr) {
      //保存(目标顶点id,源顶点属性),这里的源顶点属性等于源顶点id,其实保存的是(目标顶点id,源顶点id)
      Iterator((edge.dstId, edge.srcAttr))
       //如果源顶点属性大于目标顶点属性
    } else if (edge.srcAttr > edge.dstAttr) {
      //保存(源顶点id,目标顶点id)
      Iterator((edge.srcId, edge.dstAttr))
    } else {
      //如果两个顶点属性相同,说明已经在同一个子网里,不需要处理
      Iterator.empty
    }
  }
  //step3 设置一个初始最大值,用于在初始化阶段,比较每个顶点的属性,这样顶点属性值在最初阶段就相当是最小顶点
  val initialMessage = Long.MaxValue

  //step4 将上面设置的常量和函数当作参数传给Pregel,其中EdgeDirection.Either表示处理包括出度和入度的顶点。
  val pregelGraph = Pregel(ccGraph, initialMessage,
    maxIterations, EdgeDirection.Either)(
    //将最初顶点的属性attr与initialMessage比较,相当是子图的0次迭代寻找最小顶点
    vprog = (id, attr, msg) => math.min(attr, msg),
    //上面定义的sendMessage方法
    sendMsg = sendMessage,
    //处理各个顶点收到的消息,然后将最小的顶点保存
    mergeMsg = (a, b) => math.min(a, b))
  ccGraph.unpersist()
  pregelGraph
}

step1 初始化图,将各顶点id设置为顶点属性,图顶点结构(vid,vid)——

 val ccGraph = graph.mapVertices { case (vid, _) => vid }

写一个简单的代码验证一下即可知道得到的ccGraph处理后顶点是否为(vid,vid)结构了。

// 创建顶点RDD
val vertices = ss.sparkContext.parallelize(Seq(
  (1L, "Alice"),
  (2L, "Bob"),
  (3L, "Charlie"),
  (5L, "Eve"),
  (6L, "Frank"),
  (7L, "Grace"),
  (8L, "Henry"),
  (9L, "Ivy")
))

// 创建边RDD
val edges = ss.sparkContext.parallelize(Seq(
  Edge(5L, 7L, "friend"),
  Edge(5L, 1L, "friend"),
  Edge(1L, 2L, "friend"),
  Edge(2L, 3L, "friend"),
  Edge(6L, 9L, "friend"),
  Edge(9L, 8L, "friend")
))

//创建一个Graph图
val graph = Graph(vertices, edges, null)
graph.mapVertices{case  (vid,_) => vid}.vertices.foreach(println)

打印结果——
(2,2)
(5,5)
(3,3)
(6,6)
(7,7)
(8,8)
(1,1)
(9,9)

可见,ccGraph的图顶点已经被处理成(vid,vid),即(顶点id, 顶点属性),方便用于在sendMessage方法做属性判断处理。

step2 sendMessage处理图里的每一个三元组边对象

前面处理的ccGraph顶点数据变成(顶点id, 顶点属性)就是为了放在这里做处理,这里的if (edge.srcAttr < edge.dstAttr) 相当是if (edge.srcId < edge.dstId)。

这个方法是基于边的三元组做处理,将同一边的源顶点和目标顶点比较,筛选出两个顶点最小的顶点,然后针对最大的顶点,保留(最大顶点,最小顶点属性)这样的数据。

  def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = {
    //如果源顶点属性小于目标顶点属性
    if (edge.srcAttr < edge.dstAttr) {
      //保存(目标顶点id,源顶点属性),这里的源顶点属性等于源顶点id,其实保存的是(目标顶点id,源顶点id)
      Iterator((edge.dstId, edge.srcAttr))
       //如果源顶点属性大于目标顶点属性
    } else if (edge.srcAttr > edge.dstAttr) {
      //保存(源顶点id,目标顶点id)
      Iterator((edge.srcId, edge.dstAttr))
    } else {
      //如果两个顶点属性相同,说明已经在同一个子网里,不需要处理
      Iterator.empty
    }
  }

这个方法的作用,就是找出同一条边上哪个顶点最小,例如下图中,2L比3L小,那么2L是这条边上最小的顶点,将以最大点关联最小点的方式(edge.dstId, edge.srcAttr)即(3L,2L)保存下来。最后会将(3L,2L)中的_.2也就是2L发送给顶点(3L,3L),而顶点(3L,3L)后续需要做的事情是,是将这一轮收到的消息即最小顶点2L与现在的属性3L值通过math.min(a, b)做比较,保留最小顶点当作属性值,即变成了(3L,2L)。

可见,在子图里,每一轮迭代后,各个顶点的属性值都可能会被更新接收到的最小顶点值,这就是连通组件迭代的精妙。

这个方法会在后面的Pregel对象里用到。

step3 设置一个初始最大值,用于比较后初始化每个顶点最初的属性值

val initialMessage = Long.MaxValue需要与vprog = (id, attr, msg) => math.min(attr, msg)结合来看,相当在0次迭代时,将顶点(id,attr)的属性值与initialMessage做比较,理论上,肯定是attr比较小,就意味着初始化时,顶点关联的最小顶点就是attr,在这里,就相当关联的最小顶点是它本身,相当于子图做了0次迭代处理。

step4 执行Pregel的构造函数apply方法

可以看到,前面创建的ccGraph,initialMessage,maxIterations(最大迭代次数),EdgeDirection.Either都当作参数传给了Pregel。

val pregelGraph = Pregel(ccGraph, initialMessage,
    maxIterations, EdgeDirection.Either)(
    //将最初顶点的属性attr与initialMessage比较,相当是子图的0次迭代寻找最小顶点
    vprog = (id, attr, msg) => math.min(attr, msg),
    //上面定义的sendMessage方法
    sendMsg = sendMessage,
    //处理各个顶点收到的消息,然后将最小的顶点保存
    mergeMsg = (a, b) => math.min(a, b))

该Pregel对象底层主要就是对一系列的三元组边的源顶点和目标顶点做比较,将两顶点最小的顶点值发送给该条边最大的顶点,最大的顶点收到消息后,会比较当前属性与收到的最小顶点值比较,然后保留最小值。这样,每一轮迭代,可能关联的属性值都会一直变化,不断保留历史最小顶点值,直到迭代完成。最后,就可以实现通过connectedComponents得到每个顶点都关联到最小顶点的数据。

三、Pregel源码解析

Pregel是一个图处理模型和计算框架,核心思想是将一系列顶点之间的消息做传递和状态更新操作,并以迭代的方式进行计算。让我们继续深入看一下它的底层实现。

以下是保留主要核心代码的函数——

def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
   (graph: Graph[VD, ED],
    initialMsg: A,
    maxIterations: Int = Int.MaxValue,
    activeDirection: EdgeDirection = EdgeDirection.Either)
   (vprog: (VertexId, VD, A) => VD,
    sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
    mergeMsg: (A, A) => A)
  : Graph[VD, ED] =
{
  ......
  //step1
  var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
  ......
  //step2
  var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
  ......
  //step3
  var activeMessages = messages.count()
  var prevG: Graph[VD, ED] = null
  var i = 0
  //step4
  while (activeMessages > 0 && i < maxIterations) {
    prevG = g
    g = g.joinVertices(messages)(vprog)
    val oldMessages = messages
    messages = GraphXUtils.mapReduceTriplets(
      g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
    activeMessages = messages.count()
    i += 1
  }

  g
}

这行 var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))代码,需要联系到前面传过来的参数,它的真实面目其实是这样的——

var g = graph.mapVertices((vid, vdata) => {
      (id, attr, initialMsg) => math.min(attr, initialMsg)
})

也就是前面step3里提到的,这里相当做了0次迭代,将attr当作顶点id关联的最小顶点,初始化后,attr其实是顶点id本身。

var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)这行代码中,主要定义了一个函数sendMsg和调用了aggregateMessagesWithActiveSet方法。

private[graphx] def mapReduceTriplets[VD: ClassTag, ED: ClassTag, A: ClassTag](
    g: Graph[VD, ED],
    mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
    reduceFunc: (A, A) => A,
    activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = {
  def sendMsg(ctx: EdgeContext[VD, ED, A]) {
    mapFunc(ctx.toEdgeTriplet).foreach { kv =>
      val id = kv._1
      val msg = kv._2
      if (id == ctx.srcId) {
        ctx.sendToSrc(msg)
      } else {
        assert(id == ctx.dstId)
        ctx.sendToDst(msg)
      }
    }
  }
  g.aggregateMessagesWithActiveSet(
    sendMsg, reduceFunc, TripletFields.All, activeSetOpt)
}

函数sendMsg里需要看懂一点是,这里的mapFunc(ctx.toEdgeTriplet)正是调用了前面定义的ConnectedComponents里的sendMessage方法,因此,这个方法恢复原样,是这样的——

    def sendMsg(ctx: EdgeContext[VD, ED, A]) {
      (ctx.toEdgeTriplet => {
        case edge =>
        if (edge.srcAttr < edge.dstAttr) {
          Iterator((edge.dstId, edge.srcAttr))
        } else if (edge.srcAttr > edge.dstAttr) {
          Iterator((edge.srcId, edge.dstAttr))
        } else {
          Iterator.empty
        }
      }).foreach { kv =>
        val id = kv._1
        val msg = kv._2
        if (id == ctx.srcId) {
          ctx.sendToSrc(msg)
        } else {
          assert(id == ctx.dstId)
          ctx.sendToDst(msg)
        }
      }
    }

这个方法的作用,就是找出同一条边上哪个顶点最小,例如下图中,2L比3L小,那么2L是这条边上最小的顶点,将以最大点关联最小点的方式(edge.dstId, edge.srcAttr)即(3L,2L)保存下来。最后会将(3L,2L)中的_.2也就是2L发送给顶点(3L,3L),而顶点(3L,3L)后续需要做的事情是,是将这一轮收到的消息即最小顶点2L与现在的属性3L值通过math.min(a, b)做比较,保留最小顶点当作属性值,即变成了(3L,2L)。

剩下aggregateMessagesWithActiveSet就是做聚合了,sendMsg就是上面的获取最小顶点后发送给顶点的操作,reduceFunc对应的是 mergeMsg = (a, b) => math.min(a, b)),保留历史最小顶点当作该顶点属性。

g.aggregateMessagesWithActiveSet(
  sendMsg, reduceFunc, TripletFields.All, activeSetOpt)

最后这个while遍历,如果设置了迭代次数,迭代次数就会传至给maxIterations,activeMessages表示还有多少顶点需要处理。

  while (activeMessages > 0 && i < maxIterations) {
    prevG = g
    g = g.joinVertices(messages)(vprog)
    val oldMessages = messages
    messages = GraphXUtils.mapReduceTriplets(
      g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
    activeMessages = messages.count()
    i += 1
  }

这个方法,就是不断做迭代,不断更新各个顶点属性对应的最小顶点,直到迭代出子图里的最小顶点。

很精妙的一点设计是,每个顶点只需要不断迭代,以三元组边为维度,互相将最小顶点发送给属性值(顶点保留的上一轮最小顶点所做的属性)较大的顶点,顶点只需要保留收到的消息里最小的顶点更新为属性值即可。

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器