在之前的文章中,我介绍了flink广播状态,从而了解了flink广播状态实际上就是将一个流广播到下游所有算子之中。在本文中我将介绍spark中类似的概念,为了方便理解,先放张spark应用程序架构图。
实际上,如果我们在main函数中定义了一个变量,并且在rdd算子中使用到了这个变量,那么spark已经将这个变量广播到了每个task之中。看一下下列代码:
def main(args: Array[String]): Unit = {val conf = new SparkConf().setAppName("broadcast-learn").setMaster("local[4]")val sc = new SparkContext(conf)val userActions: RDD[(Int, String)] = sc.makeRDD(List((1, "LOGIN"), (2, "BUY"), (4, "LOGIN"), (5, "LOGOUT"), (1, "BUY")), 3)// driver运行到这一行时,会在自己的内存创建keyUsers这个对象val keyUsers: List[Int] = List(1, 2, 3)val resultRDD = userActions.map((userAction) => {if (keyUsers.contains(userAction._1)) {print(userAction._1 + ", " + userAction._2+"\n")userAction}})resultRDD.collect()
}
在这段代码里面,我们使用keyUsers
这个变量对userActions
这个RDD进行过滤;因为RDD时分散在多个task中的,因此keyUsers
变量会被driver广播到这个stage上的多个task上。实际上就是完成了变量的广播。
上面的代码已经完成了变量的广播,那为什么还需要广播变量呢?因为在上面的代码中具有一个痛点:假设说一个executor运行了10个task,那么这个executor上每个task都需要从driver那里拉取一次keyUsers
,这造成了很大的网络负载,而且这个内存负载完全是可以避免的,因为这十个task使用的是完全相同的变量!所以我们为什么不让每个executor去driver拉取一次,然后让运行在其上的所有task去使用这个变量呢?(附加内容:executor实际上就是一个jvm进程,而task就是运行在executor上面的线程;因此广播变量的实现可以是executor启动一个线程去拉取广播变量,存放在堆内存中;因为所有线程共享堆内存,因此直接从堆内存中读变量值就好了。)
以上的想法就是spark共享变量的本质:
def main(args: Array[String]): Unit = {val conf = new SparkConf().setAppName("broadcast-learn").setMaster("local[4]")val sc = new SparkContext(conf)val userActions: RDD[(Int, String)] = sc.makeRDD(List((1, "LOGIN"), (2, "BUY"), (4, "LOGIN"), (5, "LOGOUT"), (1, "BUY")), 3)// driver运行到这一行时,会在自己的内存创建keyUsers这个对象val keyUsers = sc.broadcast(List(1, 2, 3))val resultRDD = userActions.map((userAction) => {// 去除共享变量中的值val keyUsersValue = keyUsers.valueif (keyUsersValue.contains(userAction._1)) {print(userAction._1 + ", " + userAction._2+"\n")userAction}})resultRDD.collect()
}
上面的累加变量实现了一件事情:把变量向所有task广播转换为向所有executor广播,大大降低了网络开销。广播变量是每个executor存放一份,而累加器是全局存放一份(存放在driver上),非常适合进行全局计数等场景。
def main(args: Array[String]): Unit = {val conf = new SparkConf().setAppName("broadcast-learn").setMaster("local[4]")val sc = new SparkContext(conf)val userActions: RDD[(Int, String)] = sc.makeRDD(List((1, "LOGIN"), (2, "BUY"), (4, "LOGIN"), (5, "LOGOUT"), (1, "BUY")), 3)// driver运行到这一行时,会在自己的内存创建keyUsers这个对象val userActionCount = sc.longAccumulator("user action count")val resultRDD = userActions.map((_) => {userActionCount.add(1)})resultRDD.collect()println(userActionCount.value)
}