Spark详解(08) - Spark(3.0)内核解析和源码欣赏
阅读原文时间:2023年07月09日阅读:8

**Spark详解(08) - Spark(3.0)内核解析和源码欣赏
**

源码全流程

HashShuffle流程

优化后的HashShuffle流程

假设前提:每个Executor只有1个CPU core,也就是说,无论这个Executor上分配多少个task线程,同一时间都只能执行一个task线程

SortShuffle流程

bypassShuffle流程

环境准备及提交流程

1)spark-3.0.0-bin-hadoop3.2\bin\spark-submit.cmd => cmd /V /E /C ""%~dp0spark-submit2.cmd" %*"

2)spark-submit2.cmd => set CLASS=org.apache.spark.deploy.SparkSubmit "%~dp0spark-class2.cmd" %CLASS% %*

3)spark-class2.cmd => %SPARK_CMD%

4)在spark-class2.cmd文件中增加打印%SPARK_CMD%语句

echo %SPARK_CMD%

%SPARK_CMD%

5)在spark-3.0.0-bin-hadoop3.2\bin目录上执行cmd命令

6)进入命令行窗口,输入

spark-submit --class org.apache.spark.examples.SparkPi --master local[2] ../examples/jars/spark-examples_2.12-3.0.0.jar 10

7)发现底层执行的命令为

java -cp org.apache.spark.deploy.SparkSubmit

说明:java -cp和 -classpath一样,是指定类运行所依赖其他类的路径。

8)执行java -cp 就会开启JVM虚拟机,在虚拟机上开启SparkSubmit进程,然后开始执行main方法

java -cp =》开启JVM虚拟机 =》开启Process(SparkSubmit)=》程序入口SparkSubmit.main

9)在IDEA中全局查找(ctrl + n):org.apache.spark.deploy.SparkSubmit,找到SparkSubmit的伴生对象,并找到main方法

  1. override def main(args: Array[String]): Unit = {

  2.     val submit = new SparkSubmit() {

  3.         … …

  4.     }

  5. }

程序入口

SparkSubmit.scala

  1. override def main(args: Array[String]): Unit = {

  2.     val submit = new SparkSubmit() {

  3.     … …

  4.         override def doSubmit(args: Array[String]): Unit = {

  5.           super.doSubmit(args)

  6.         }

  7.     }

  8.     submit.doSubmit(args)

  9. }

def doSubmit()方法

  1. def doSubmit(args: Array[String]): Unit = {

  2.     val uninitLog = initializeLogIfNecessary(true, silent = true)

  3.     // 解析参数

  4.     val appArgs = parseArguments(args)

  5.     … …

  6.     appArgs.action match {

  7.         // 提交作业

  8.         case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog)

  9.         case SparkSubmitAction.KILL => kill(appArgs)

  10.         case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)

  11.         case SparkSubmitAction.PRINT_VERSION => printVersion()   

  12.     }

  13. }

解析输入参数

  1. protected def parseArguments(args: Array[String]): SparkSubmitArguments = {

  2.     new SparkSubmitArguments(args)

  3. }

SparkSubmitArguments.scala

  1. private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env)

  2.     extends SparkSubmitArgumentsParser with Logging {

  3.     … …

  4.     parse(args.asJava)

  5.     … …

  6. }

SparkSubmitOptionParser.java

  1. protected final void parse(List args) {

  2.     Pattern eqSeparatedOpt = Pattern.compile("(--[^=]+)=(.+)");

  3.     
     

  4.     int idx = 0;

  5.     for (idx = 0; idx < args.size(); idx++) {

  6.         String arg = args.get(idx);

  7.         String value = null;

  8.     
     

  9.         Matcher m = eqSeparatedOpt.matcher(arg);

  10.         if (m.matches()) {

  11.             arg = m.group(1);

  12.             value = m.group(2);

  13.         }

  14.     
     

  15.         String name = findCliOption(arg, opts);

  16.         if (name != null) {

  17.             if (value == null) {

  18.                 … …

  19.             }

  20.             // handle_的实现类(_ctrl + h)是SparkSubmitArguments.scala中

  21.             if (!handle(name, value)) {

  22.                 break;

  23.             }

  24.             continue;

  25.         }

  26.         … …

  27.     }

  28.     handleExtraArgs(args.subList(idx, args.size()));

  29. }

SparkSubmitArguments.scala

  1. override protected def handle(opt: String, value: String): Boolean = {

  2.     opt match {

  3.         case NAME =>

  4.             name = value

  5.         // protected final String MASTER = "--master";  SparkSubmitOptionParser.java

  6.         case MASTER =>

  7.             master = value

  8.         
     

  9.         case CLASS =>

  10.             mainClass = value

  11.         … …

  12.         case _ =>

  13.             error(s"Unexpected argument '$opt'.")

  14.     }

  15.     action != SparkSubmitAction.PRINT_VERSION

  16. }

  17. private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env)

  18.   extends SparkSubmitArgumentsParser with Logging {

  19.     … …

  20.     var action: SparkSubmitAction = null

  21.     … …

  22.     
     

  23.     private def loadEnvironmentArguments(): Unit = {

  24.         … …

  25.         // Action should be SUBMIT unless otherwise specified

  26.         // action_默认赋值_submit

  27.         action = Option(action).getOrElse(SUBMIT)

  28.     }

  29.     … …

  30. }

选择创建哪种类型的客户端

SparkSubmit.scala

  1. private[spark] class SparkSubmit extends Logging {

  2.     … …

  3.     def doSubmit(args: Array[String]): Unit = {

  4.         val uninitLog = initializeLogIfNecessary(true, silent = true)

  5.         // 解析参数

  6.         val appArgs = parseArguments(args)

  7.         if (appArgs.verbose) {

  8.           logInfo(appArgs.toString)

  9.         }

  10.         appArgs.action match {

  11.           // 提交作业

  12.           case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog)

  13.           case SparkSubmitAction.KILL => kill(appArgs)

  14.           case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)

  15.           case SparkSubmitAction.PRINT_VERSION => printVersion()

  16.         }

  17.     }

  18.     
     

  19.     private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {

  20.     
     

  21.         def doRunMain(): Unit = {

  22.             if (args.proxyUser != null) {

  23.                 … …

  24.             } else {

  25.                 runMain(args, uninitLog)

  26.             }

  27.         }

  28.         if (args.isStandaloneCluster && args.useRest) {

  29.             … …

  30.         } else {

  31.             doRunMain()

  32.         }

  33.     }    

  34.     
     

  35.     private def runMain(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {

  36.         // 选择创建什么应用:__YarnClusterApplication

  37.         val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args)

  38.         … …

  39.         var mainClass: Class[_] = null

  40.         
     

  41.         try {

  42.             mainClass = Utils.classForName(childMainClass)

  43.         } catch {

  44.             … …

  45.         }

  46.         // 反射创建应用:__YarnClusterApplication

  47.         val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) {

  48.             mainClass.getConstructor().newInstance().asInstanceOf[SparkApplication]

  49.         } else {

  50.             new JavaMainApplication(mainClass)

  51.         }

  52.         … …

  53.         try {

  54.             //_启动应用_

  55.             app.start(childArgs.toArray, sparkConf)

  56.         } catch {

  57.         case t: Throwable =>

  58.             throw findCause(t)

  59.         }

  60.     }

  61.     … … 

  62. }

SparkSubmit.scala

  1. private[deploy] def prepareSubmitEnvironment(

  2.       args: SparkSubmitArguments,

  3.       conf: Option[HadoopConfiguration] = None)

  4.       : (Seq[String], Seq[String], SparkConf, String) = {

  5.     var childMainClass = ""

  6.     … …

  7.     // yarn集群模式

  8.     if (isYarnCluster) {

  9. // YARN_CLUSTER_SUBMIT_CLASS="org.apache.spark.deploy.yarn.YarnClusterApplication"

  10.         childMainClass = YARN_CLUSTER_SUBMIT_CLASS

  11.         … …

  12.     }

  13.     … …

  14.     (childArgs, childClasspath, sparkConf, childMainClass)

  15. }

Yarn客户端参数解析

**1)在pom.xml文件中添加依赖spark-yarn
**

  1.     org.apache.spark

  2.     spark-yarn_2.12

  3.     3.0.0

2)在IDEA中全文查找(ctrl+n)org.apache.spark.deploy.yarn.YarnClusterApplication

**3)Yarn客户端参数解析
**

Client.scala

  1. private[spark] class YarnClusterApplication extends SparkApplication {

  2.   override def start(args: Array[String], conf: SparkConf): Unit = {

  3.     … …

  4.     new Client(new ClientArguments(args), conf, null).run()

  5.   }

  6. }

ClientArguments.scala

  1. private[spark] class ClientArguments(args: Array[String]) {

  2.     … …

  3.     parseArgs(args.toList)

  4.     
     

  5.     private def parseArgs(inputArgs: List[String]): Unit = {

  6.         var args = inputArgs

  7.         while (!args.isEmpty) {

  8.             args match {

  9.                 case ("--jar") :: value :: tail =>

  10.                 userJar = value

  11.                 args = tail

  12.         
     

  13.                 case ("--class") :: value :: tail =>

  14.                 userClass = value

  15.                 args = tail

  16.                 … …

  17.                 case _ =>

  18.                 throw new IllegalArgumentException(getUsageMessage(args))

  19.             }

  20.         }

  21.     }

  22.     … …

  23. }

创建Yarn客户端

Client.scala

  1. private[spark] class Client(

  2.     val args: ClientArguments,

  3.     val sparkConf: SparkConf,

  4.     val rpcEnv: RpcEnv)

  5.     extends Logging {

  6.     // 创建__yarnClient

  7.     private val yarnClient = YarnClient.createYarnClient

  8.     … …

  9. }

YarnClient.java

  1. public abstract class YarnClient extends AbstractService {

  2.     @Public

  3.     public static YarnClient createYarnClient() {

  4.         YarnClient client = new YarnClientImpl();

  5.         return client;

  6.     }

  7.     … …

  8. }

YarnClientImpl.java

  1. public class YarnClientImpl extends YarnClient {

  2.     // yarnClient_主要用来和_RM通信

  3.     protected ApplicationClientProtocol rmClient;

  4.     … …

  5.     
     

  6.     public YarnClientImpl() {

  7.         super(YarnClientImpl.class.getName());

  8.     }

  9.     … …

  10. }

Yarn客户端创建并启动ApplicationMaster

Client.scala

  1. private[spark] class YarnClusterApplication extends SparkApplication {

  2.   override def start(args: Array[String], conf: SparkConf): Unit = {

  3.     // SparkSubmit would use yarn cache to distribute files & jars in yarn mode,

  4.     // so remove them from sparkConf here for yarn mode.

  5.     conf.remove(JARS)

  6.     conf.remove(FILES)

  7.     new Client(new ClientArguments(args), conf, null).run()

  8.   }

  9. }

  10. private[spark] class Client(

  11.     val args: ClientArguments,

  12.     val sparkConf: SparkConf,

  13.     val rpcEnv: RpcEnv)

  14.     extends Logging {

  15.     def run(): Unit = {

  16.         this.appId = submitApplication()

  17.         … …

  18.     }

  19.     
     

  20.     def submitApplication(): ApplicationId = {

  21.         var appId: ApplicationId = null

  22.         try {

  23.             launcherBackend.connect()

  24.             yarnClient.init(hadoopConf)

  25.             yarnClient.start()

  26.     
     

  27.             val newApp = yarnClient.createApplication()

  28.             val newAppResponse = newApp.getNewApplicationResponse()

  29.             appId = newAppResponse.getApplicationId()

  30.         
     

  31.             … …

  32.             // 封装提交参数和命令

  33.             val containerContext = createContainerLaunchContext(newAppResponse)

  34.             val appContext = createApplicationSubmissionContext(newApp, containerContext)

  35.         
     

  36.             yarnClient.submitApplication(appContext)

  37.             … …

  38.             appId

  39.         } catch {

  40.             … …

  41.         }

  42.     }

  43. }

  44. // 封装提交参数和命令

  45. private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse)

  46.     : ContainerLaunchContext = {

  47.     … …

  48.     val amClass =

  49.         // 如果是集群模式启动__ApplicationMaster,如果是客户端模式启动__ExecutorLauncher

  50.         if (isClusterMode) {

  51.             Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName

  52.         } else {

  53.             Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName

  54.         }

  55.         
     

  56.     val amArgs =

  57.       Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++

  58.       Seq("--properties-file",

  59.         buildPath(Environment.PWD.$$(), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) ++

  60.       Seq("--dist-cache-conf",

  61.         buildPath(Environment.PWD.$$(), LOCALIZED_CONF_DIR, DIST_CACHE_CONF_FILE))

  62.     // Command for the ApplicationMaster

  63.     val commands = prefixEnv ++

  64.       Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++

  65.       javaOpts ++ amArgs ++

  66.       Seq(

  67.         "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",

  68.         "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")

  69.         
     

  70.     val printableCommands = commands.map(s => if (s == null) "null" else s).toList

  71.     amContainer.setCommands(printableCommands.asJava)

  72.     … …

  73.     val securityManager = new SecurityManager(sparkConf)

  74.     amContainer.setApplicationACLs(

  75.       YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava)

  76.     setupSecurityToken(amContainer)

  77.     amContainer

  78. }

1)在IDEA中全局查找(ctrl + n)org.apache.spark.deploy.yarn.ApplicationMaster**,点击对应的伴生对象
**

ApplicationMaster.scala

  1. def main(args: Array[String]): Unit = {

  2.    
     

  3.     // 1_解析传递过来的参数_

  4.     val amArgs = new ApplicationMasterArguments(args)

  5.     val sparkConf = new SparkConf()

  6.     … …

  7.     val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf))

  8.     // 2_创建_ApplicationMaster对象

  9.     master = new ApplicationMaster(amArgs, sparkConf, yarnConf)

  10.     … …

  11.     ugi.doAs(new PrivilegedExceptionAction[Unit]() {

  12.         // 3_执行_ApplicationMaster

  13.         override def run(): Unit = System.exit(master.run())

  14.     })

  15. }

解析传递过来的参数

ApplicationMasterArguments.scala

  1. class ApplicationMasterArguments(val args: Array[String]) {

  2.     … …

  3.     parseArgs(args.toList)

  4.     
     

  5.     private def parseArgs(inputArgs: List[String]): Unit = {

  6.         val userArgsBuffer = new ArrayBuffer[String]()

  7.         var args = inputArgs

  8.     
     

  9.         while (!args.isEmpty) {

  10.             args match {

  11.                 case ("--jar") :: value :: tail =>

  12.                 userJar = value

  13.                 args = tail

  14.         
     

  15.                 case ("--class") :: value :: tail =>

  16.                 userClass = value

  17.                 args = tail

  18.                 … …

  19.         
     

  20.                 case _ =>

  21.                 printUsageAndExit(1, args)

  22.             }

  23.         }

  24.         … …

  25.     }

  26.     … …

  27. }}

创建RMClient并启动Driver

ApplicationMaster.scala

  1. private[spark] class ApplicationMaster(

  2.     args: ApplicationMasterArguments,

  3.     sparkConf: SparkConf,

  4.     yarnConf: YarnConfiguration) extends Logging {

  5.     … …

  6.     // 1_创建_RMClient

  7.     private val client = new YarnRMClient()

  8.     … …

  9.     final def run(): Int = {

  10.         … …

  11.         if (isClusterMode) {

  12.             runDriver()

  13.         } else {

  14.             runExecutorLauncher()

  15.         }

  16.         … …

  17.     }

  18.     private def runDriver(): Unit = {

  19.         addAmIpFilter(None, System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV))

  20.         // 2_根据输入参数启动_Driver

  21.         userClassThread = startUserApplication()

  22.         val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)

  23.         
     

  24.         try {

  25.             // 3_等待初始化完毕_

  26.             val sc = ThreadUtils.awaitResult(sparkContextPromise.future,

  27.                 Duration(totalWaitTime, TimeUnit.MILLISECONDS))

  28.            // sparkcontext_初始化完毕_

  29.             if (sc != null) {

  30.                 val rpcEnv = sc.env.rpcEnv

  31.                 val userConf = sc.getConf

  32.                 val host = userConf.get(DRIVER_HOST_ADDRESS)

  33.                 val port = userConf.get(DRIVER_PORT)

  34.                 // 4 向__RM注册自己(AM)

  35.                 registerAM(host, port, userConf, sc.ui.map(_.webUrl), appAttemptId)

  36.                 val driverRef = rpcEnv.setupEndpointRef(

  37.                 RpcAddress(host, port),

  38.                 YarnSchedulerBackend.ENDPOINT_NAME)

  39.                 // 5_获取_RM返回的可用资源列表

  40.                 createAllocator(driverRef, userConf, rpcEnv, appAttemptId, distCacheConf)

  41.             } else {

  42.                 … …

  43.             }

  44.             resumeDriver()

  45.             userClassThread.join()

  46.         } catch {

  47.             … …

  48.         } finally {

  49.             resumeDriver()

  50.         }

  51.     }

ApplicationMaster.scala

  1. private def startUserApplication(): Thread = {

  2. … …

  3. // args.userClass来源于__ApplicationMasterArguments.scala

  4.     val mainMethod = userClassLoader.loadClass(args.userClass)

  5.     .getMethod("main", classOf[Array[String]])

  6.     … …

  7.     val userThread = new Thread {

  8.         override def run(): Unit = {

  9.             … …

  10.             if (!Modifier.isStatic(mainMethod.getModifiers)) {

  11.                 logError(s"Could not find static main method in object ${args.userClass}")

  12.                 finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS)

  13.             } else {

  14.                 mainMethod.invoke(null, userArgs.toArray)

  15.                 finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)

  16.                 logDebug("Done running user class")

  17.             }  

  18.             … …  

  19.         }  

  20.     }

  21.     userThread.setContextClassLoader(userClassLoader)

  22.     userThread.setName("Driver")

  23.     userThread.start()

  24.     userThread

  25. }

向RM注册AM

  1. private def registerAM(

  2.     host: String,

  3.     port: Int,

  4.     _sparkConf: SparkConf,

  5.     uiAddress: Option[String],

  6.     appAttempt: ApplicationAttemptId): Unit = {

  7. … …

  8.     client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress)

  9.     registered = true

  10. }

获取RM返回可以资源列表

ApplicationMaster.scala

  1. private def createAllocator(

  2.     driverRef: RpcEndpointRef,

  3.     _sparkConf: SparkConf,

  4.     rpcEnv: RpcEnv,

  5.     appAttemptId: ApplicationAttemptId,

  6.     distCacheConf: SparkConf): Unit = {

  7.     
     

  8.     … …

  9.     // 申请资源 获得资源

  10.     allocator = client.createAllocator(

  11.     yarnConf,

  12.     _sparkConf,

  13.     appAttemptId,

  14.     driverUrl,

  15.     driverRef,

  16.     securityMgr,

  17.     localResources)

  18.     … …

  19.     // 处理资源结果,启动Executor

  20.     allocator.allocateResources()

  21.     … …

  22. }

YarnAllocator.scala

  1. def allocateResources(): Unit = synchronized {

  2.     val progressIndicator = 0.1f

  3.     val allocateResponse = amClient.allocate(progressIndicator)

  4.     // 获取可分配资源

  5.     val allocatedContainers = allocateResponse.getAllocatedContainers()

  6.     allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes)

  7.     // 可分配的资源大于__0

  8.     if (allocatedContainers.size > 0) {

  9.         ……

  10.         // 分配规则

  11.         handleAllocatedContainers(allocatedContainers.asScala)

  12.     }

  13.     … …

  14. }

  15. def handleAllocatedContainers(allocatedContainers: Seq[Container]): Unit = {

  16.     val containersToUse = new ArrayBuffer[Container](allocatedContainers.size)

  17.     // 分配在同一台主机上资源

  18.     val remainingAfterHostMatches = new ArrayBuffer[Container]

  19.     for (allocatedContainer <- allocatedContainers) {

  20.         … …

  21.     }

  22.     // 分配同一个机架上资源

  23.     val remainingAfterRackMatches = new ArrayBuffer[Container]

  24.     if (remainingAfterHostMatches.nonEmpty) {

  25.         … …

  26.     }

  27.     // 分配既不是本地节点也不是机架本地的剩余部分

  28.     val remainingAfterOffRackMatches = new ArrayBuffer[Container]

  29.     for (allocatedContainer <- remainingAfterRackMatches) {

  30.         … …

  31. }

  32.     // 运行已分配容器

  33.     runAllocatedContainers(containersToUse)

  34. }

根据可用资源创建NMClient

YarnAllocator.scala

  1. private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = {

  2.     for (container <- containersToUse) {

  3.         … …

  4.         if (runningExecutors.size() < targetNumExecutors) {

  5.             numExecutorsStarting.incrementAndGet()

  6.             if (launchContainers) {

  7.                 launcherPool.execute(() => {

  8.                     try {

  9.                         new ExecutorRunnable(

  10.                             … …

  11.                         ).run()

  12.                         updateInternalState()

  13.                     } catch {

  14.                         … …

  15.                     }

  16.                 })

  17.             } else {

  18.                 // For test only

  19.                 updateInternalState()

  20.             }

  21.         } else {

  22.             … …

  23.         }

  24.     }

  25. }

ExecutorRunnable.scala

  1. private[yarn] class ExecutorRunnable(… …) extends Logging {

  2.     var rpc: YarnRPC = YarnRPC.create(conf)

  3.     var nmClient: NMClient = _

  4.     
     

  5.     def run(): Unit = {

  6.         logDebug("Starting Executor Container")

  7.         nmClient = NMClient.createNMClient()

  8.         nmClient.init(conf)

  9.         nmClient.start()

  10.         startContainer()

  11.     }

  12.     … …

  13.     def startContainer(): java.util.Map[String, ByteBuffer] = {

  14.         … …

  15.         // 准备命令,封装到__ctx环境中

  16.         val commands = prepareCommand()

  17.         ctx.setCommands(commands.asJava)

  18.         … …

  19.         // 向指定的__NM启动容器对象

  20.         try {

  21.             nmClient.startContainer(container.get, ctx)

  22.         } catch {

  23.             … …

  24.         }

  25.     }

  26.     private def prepareCommand(): List[String] = {

  27.         … …

  28.         YarnSparkHadoopUtil.addOutOfMemoryErrorArgument(javaOpts)

  29.         val commands = prefixEnv ++

  30.         Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++

  31.         javaOpts ++

  32.         Seq("org.apache.spark.executor.YarnCoarseGrainedExecutorBackend",

  33.             "--driver-url", masterAddress,

  34.             "--executor-id", executorId,

  35.             "--hostname", hostname,

  36.             "--cores", executorCores.toString,

  37.             "--app-id", appId,

  38.             "--resourceProfileId", resourceProfileId.toString) ++

  39.         … …

  40.     }

  41. }

Spark中通信框架的发展

  1. Spark早期版本中采用Akka作为内部通信部件。

  2. Spark1.3中引入Netty通信框架,为了解决Shuffle的大数据传输问题使用

  3. Spark1.6中Akka和Netty可以配置使用。Netty完全实现了Akka在Spark中的功能。

  4. Spark2.x系列中,Spark抛弃Akka,使用Netty。

那么Netty为什么可以取代Akka?

首先不容置疑的是Akka可以做到的,Netty也可以做到,但是Netty可以做到,Akka却无法做到,原因是什么?

在软件栈中,Akka相比Netty要高级一点,它专门针对RPC做了很多事情,而Netty相比更加基础一点,可以为不同的应用层通信协议(RPC,FTP,HTTP等)提供支持,在早期的Akka版本,底层的NIO通信就是用的Netty;其次一个优雅的工程师是不会允许一个系统中容纳两套通信框架,恶心!最后,虽然Netty没有Akka协程级的性能优势,但是Netty内部高效的Reactor线程模型,无锁化的串行设计,高效的序列化,零拷贝,内存池等特性也保证了Netty不会存在性能问题。

Endpoint有1个InBox和N个OutBox(N>=1,N取决于当前Endpoint与多少其他的Endpoint进行通信,一个与其通讯的其他Endpoint对应一个OutBox),Endpoint接收到的消息被写入InBox,发送出去的消息写入OutBox并被发送到其他Endpoint的InBox中。

三种通信方式 BIO NIO AIO

1)三种通信模式

BIO:阻塞式IO

NIO:非阻塞式IO

AIO:异步非阻塞式IO

Spark底层采用Netty

Netty:支持NIO和Epoll模式

默认采用NIO

2)举例说明:

比如去饭店吃饭,老板说你前面有4个人,需要等一会:

(1)那你在桌子前一直等着,就是阻塞式IO——BIO。

(2)如果你和老板说,饭先做着,我先去打会篮球。在打篮球的过程中你时不时的回来看一下饭是否做好,就是非阻塞式IO——NIO。

(3)先给老板说,我去打篮球,一个小时后给我送到指定位置,就是异步非阻塞式——AIO。

3)注意:

Linux对AIO支持的不够好,Windows支持AIO很好

Linux采用Epoll方式模仿AIO操作

Spark底层通信原理

  1. RpcEndpoint:RPC通信终端。Spark针对每个节点(Client/Master/Worker)都称之为一个RPC终端,且都实现RpcEndpoint接口,内部根据不同端点的需求,设计不同的消息和不同的业务处理,如果需要发送(询问)则调用Dispatcher。在Spark中,所有的终端都存在生命周期:

  2. RpcEnv:RPC上下文环境,每个RPC终端运行时依赖的上下文环境称为RpcEnv;在当前Spark版本中使用的NettyRpcEnv

  3. Dispatcher:消息调度(分发)器,针对于RPC终端需要发送远程消息或者从远程RPC接收到的消息,分发至对应的指令收件箱(发件箱)。如果指令接收方是自己则存入收件箱,如果指令接收方不是自己,则放入发件箱;

  4. Inbox:指令消息收件箱。一个本地RpcEndpoint对应一个收件箱,Dispatcher在每次向Inbox存入消息时,都将对应EndpointData加入内部ReceiverQueue中,另外Dispatcher创建时会启动一个单独线程进行轮询ReceiverQueue,进行收件箱消息消费;

  5. RpcEndpointRef:RpcEndpointRef是对远程RpcEndpoint的一个引用。当我们需要向一个具体的RpcEndpoint发送消息时,一般我们需要获取到该RpcEndpoint的引用,然后通过该应用发送消息。

  6. OutBox:指令消息发件箱。对于当前RpcEndpoint来说,一个目标RpcEndpoint对应一个发件箱,如果向多个目标RpcEndpoint发送信息,则有多个OutBox。当消息放入Outbox后,紧接着通过TransportClient将消息发送出去。消息放入发件箱以及发送过程是在同一个线程中进行;

  7. RpcAddress:表示远程的RpcEndpointRef的地址,Host + Port。

  8. TransportClient:Netty通信客户端,一个OutBox对应一个TransportClient,TransportClient不断轮询OutBox,根据OutBox消息的receiver信息,请求对应的远程TransportServer;

  9. TransportServer:Netty通信服务端,一个RpcEndpoint对应一个TransportServer,接受远程消息后调用Dispatcher分发消息至对应收发件箱;

Executor通信终端

1)在IDEA中全局查找(ctrl + n)org.apache.spark.executor.YarnCoarseGrainedExecutorBackend**,点击对应的伴生对象
**

2)YarnCoarseGrainedExecutorBackend.scala 继承CoarseGrainedExecutorBackend继承RpcEndpoint

  1. // constructor -> onStart -> receive* -> onStop

  2. private[spark] trait RpcEndpoint {

  3.   val rpcEnv: RpcEnv

  4.   final def self: RpcEndpointRef = {

  5.     require(rpcEnv != null, "rpcEnv has not been initialized")

  6.     rpcEnv.endpointRef(this)

  7.   }

  8.   def receive: PartialFunction[Any, Unit] = {

  9.     case _ => throw new SparkException(self + " does not implement 'receive'")

  10.   }

  11.   
     

  12.   def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

  13.     case _ => context.sendFailure(new SparkException(self + " won't reply anything"))

  14.   }

  15.   def onStart(): Unit = {

  16.     // By default, do nothing.

  17.   }

  18.   def onStop(): Unit = {

  19.     // By default, do nothing.

  20.   }

  21. }

  22. private[spark] abstract class RpcEndpointRef(conf: SparkConf)

  23.   extends Serializable with Logging {

  24.   … …

  25.   def send(message: Any): Unit

  26.   def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]

  27.   … …

  28. }

Driver通信终端

ExecutorBackend发送向Driver发送请求后,Driver开始接收消息。全局查找(ctrl + n)SparkContext类

SparkContext.scala

  1. class SparkContext(config: SparkConf) extends Logging {

  2.     … …

  3.     private var _schedulerBackend: SchedulerBackend = _

  4.     … …

  5. }

点击SchedulerBackend进入SchedulerBackend.scala,查找实现类(ctrl+h),找到CoarseGrainedSchedulerBackend.scala,在该类内部创建DriverEndpoint对象。

  1. private[spark]

  2. class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv)

  3.     extends ExecutorAllocationClient with SchedulerBackend with Logging {

  4.   
     

  5.     class DriverEndpoint extends IsolatedRpcEndpoint with Logging {

  6.         override def receive: PartialFunction[Any, Unit] = {

  7.             … …

  8.             // 接收注册成功后的消息

  9.             case LaunchedExecutor(executorId) =>

  10.             executorDataMap.get(executorId).foreach { data =>

  11.                 data.freeCores = data.totalCores

  12.             }

  13.             makeOffers(executorId)

  14.         }

  15.         
     

  16.         // 接收__ask消息,并回复

  17.         override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

  18.           case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls,

  19.               attributes, resources, resourceProfileId) =>

  20.             … …

  21.             context.reply(true)

  22.             … …

  23.         }

  24.         … …

  25.     }

  26.     
     

  27.     val driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint())

  28.     
     

  29.     protected def createDriverEndpoint(): DriverEndpoint = new DriverEndpoint()

  30. }

DriverEndpoint继承IsolatedRpcEndpoint继承RpcEndpoint

  1. // constructor -> onStart -> receive* -> onStop

  2. private[spark] trait RpcEndpoint {

  3.   val rpcEnv: RpcEnv

  4.   final def self: RpcEndpointRef = {

  5.     require(rpcEnv != null, "rpcEnv has not been initialized")

  6.     rpcEnv.endpointRef(this)

  7.   }

  8.   def receive: PartialFunction[Any, Unit] = {

  9.     case _ => throw new SparkException(self + " does not implement 'receive'")

  10.   }

  11.   
     

  12.   def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

  13.     case _ => context.sendFailure(new SparkException(self + " won't reply anything"))

  14.   }

  15.   def onStart(): Unit = {

  16.     // By default, do nothing.

  17.   }

  18.   def onStop(): Unit = {

  19.     // By default, do nothing.

  20.   }

  21. }

  22. private[spark] abstract class RpcEndpointRef(conf: SparkConf)

  23.   extends Serializable with Logging {

  24.   … …

  25.   def send(message: Any): Unit

  26.   def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]

  27.   … …

  28. }

创建RPC通信环境

1)在IDEA中全局查找(ctrl + n)org.apache.spark.executor.YarnCoarseGrainedExecutorBackend**,点击对应的伴生对象
**

2)运行CoarseGrainedExecutorBackend

YarnCoarseGrainedExecutorBackend.scala

  1. private[spark] object YarnCoarseGrainedExecutorBackend extends Logging {

  2.     def main(args: Array[String]): Unit = {

  3.         val createFn: (RpcEnv, CoarseGrainedExecutorBackend.Arguments, SparkEnv, ResourceProfile) =>

  4.         CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env, resourceProfile) =>

  5.         new YarnCoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId,

  6.             arguments.bindAddress, arguments.hostname, arguments.cores, arguments.userClassPath, env,

  7.             arguments.resourcesFileOpt, resourceProfile)

  8.         }

  9.         val backendArgs = CoarseGrainedExecutorBackend.parseArguments(args,

  10.         this.getClass.getCanonicalName.stripSuffix("$"))

  11.         CoarseGrainedExecutorBackend.run(backendArgs, createFn)

  12.         System.exit(0)

  13.     }

  14. }

CoarseGrainedExecutorBackend.scala

  1. def run(

  2.     arguments: Arguments,

  3.     backendCreateFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) =>

  4.         CoarseGrainedExecutorBackend): Unit = {

  5.     SparkHadoopUtil.get.runAsSparkUser { () =>

  6.         // Bootstrap to fetch the driver's Spark properties.

  7.         val executorConf = new SparkConf

  8.         val fetcher = RpcEnv.create(

  9.             "driverPropsFetcher",

  10.             arguments.bindAddress,

  11.             arguments.hostname,

  12.             -1,

  13.             executorConf,

  14.             new SecurityManager(executorConf),

  15.             numUsableCores = 0,

  16.             clientMode = true)

  17.         … …

  18.         driverConf.set(EXECUTOR_ID, arguments.executorId)

  19.         val env = SparkEnv.createExecutorEnv(driverConf, arguments.executorId, arguments.bindAddress,

  20.             arguments.hostname, arguments.cores, cfg.ioEncryptionKey, isLocal = false)

  21.         env.rpcEnv.setupEndpoint("Executor",

  22.             backendCreateFn(env.rpcEnv, arguments, env, cfg.resourceProfile))

  23.         arguments.workerUrl.foreach { url =>

  24.             env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))

  25.         }

  26.         env.rpcEnv.awaitTermination()

  27.     }

  28. }

3)点击create,进入RpcEnv.Scala

  1. def create(

  2.     name: String,

  3.     bindAddress: String,

  4.     advertiseAddress: String,

  5.     port: Int,

  6.     conf: SparkConf,

  7.     securityManager: SecurityManager,

  8.     numUsableCores: Int,

  9.     clientMode: Boolean): RpcEnv = {

  10.     val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,

  11.       numUsableCores, clientMode)

  12.     new NettyRpcEnvFactory().create(config)

  13. }

NettyRpcEnv.scala

  1. private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {

  2.     def create(config: RpcEnvConfig): RpcEnv = {

  3.         … …

  4.         val nettyEnv =

  5.             new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,

  6.                 config.securityManager, config.numUsableCores)

  7.         if (!config.clientMode) {

  8.             val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>

  9.                 nettyEnv.startServer(config.bindAddress, actualPort)

  10.                 (nettyEnv, nettyEnv.address.port)

  11.             }

  12.             try {

  13.                 Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1

  14.             } catch {

  15.                 case NonFatal(e) =>

  16.                 nettyEnv.shutdown()

  17.                 throw e

  18.             }

  19.         }

  20.         nettyEnv

  21.     }

  22. }

创建多个发件箱

NettyRpcEnv.scala

  1. NettyRpcEnv.scala 

  2. private[netty] class NettyRpcEnv(

  3.     val conf: SparkConf,

  4.     javaSerializerInstance: JavaSerializerInstance,

  5.     host: String,

  6.     securityManager: SecurityManager,

  7.     numUsableCores: Int) extends RpcEnv(conf) with Logging {

  8.     … …

  9.     private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()

  10.     … …

  11. }

启动TransportServer

NettyRpcEnv.scala

  1. def startServer(bindAddress: String, port: Int): Unit = {

  2.     … …

  3.     server = transportContext.createServer(bindAddress, port, bootstraps)

  4.     dispatcher.registerRpcEndpoint(

  5.         RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))

  6. }

TransportContext.scala

  1. public TransportServer createServer(

  2.     String host, int port, List bootstraps) {

  3.     return new TransportServer(this, host, port, rpcHandler, bootstraps);

  4. }

TransportServer.java

  1. public TransportServer(

  2.     TransportContext context,

  3.     String hostToBind,

  4.     int portToBind,

  5.     RpcHandler appRpcHandler,

  6.     List bootstraps) {

  7.     … …

  8.     init(hostToBind, portToBind);

  9.     … …

  10. }

  11. private void init(String hostToBind, int portToBind) {

  12.     // 默认是__NIO模式

  13.     IOMode ioMode = IOMode.valueOf(conf.ioMode());

  14.     
     

  15.     EventLoopGroup bossGroup = NettyUtils.createEventLoop(ioMode, 1,

  16.         conf.getModuleName() + "-boss");

  17.     EventLoopGroup workerGroup =  NettyUtils.createEventLoop(ioMode, conf.serverThreads(), conf.getModuleName() + "-server");

  18.     
     

  19.     bootstrap = new ServerBootstrap()

  20.         .group(bossGroup, workerGroup)

  21.         .channel(NettyUtils.getServerChannelClass(ioMode))

  22.         .option(ChannelOption.ALLOCATOR, pooledAllocator)

  23.         .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS)

  24.         .childOption(ChannelOption.ALLOCATOR, pooledAllocator);

  25.     … …

  26. }

NettyUtils.java

  1. public static Class getServerChannelClass(IOMode mode) {

  2.     switch (mode) {

  3.         case NIO:

  4.             return NioServerSocketChannel.class;

  5.         case EPOLL:

  6.             return EpollServerSocketChannel.class;

  7.         default:

  8.             throw new IllegalArgumentException("Unknown io mode: " + mode);

  9.     }

  10. }

注册通信终端RpcEndpoint

NettyRpcEnv.scala

  1. def startServer(bindAddress: String, port: Int): Unit = {

  2.     … …

  3.     server = transportContext.createServer(bindAddress, port, bootstraps)

  4.     dispatcher.registerRpcEndpoint(

  5.         RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))

  6. }

创建TransportClient

Dispatcher.scala

  1. def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {

  2.     … …

  3.     val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)

  4.     … …

  5. }

  6. private[netty] class NettyRpcEndpointRef(… …) extends RpcEndpointRef(conf) {

  7.     … …

  8.     @transient @volatile var client: TransportClient = _

  9.     // 创建__TransportClient

  10.     private[netty] def createClient(address: RpcAddress): TransportClient = {

  11.       clientFactory.createClient(address.host, address.port)

  12.     }

  13.     
     

  14.     private val clientFactory = transportContext.createClientFactory(createClientBootstraps())

  15.     … …

  16. }

收发邮件箱

1)接收邮件箱1个

Dispatcher.scala

  1. def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {

  2.         … …

  3.         var messageLoop: MessageLoop = null

  4.         try {

  5.             messageLoop = endpoint match {

  6.             case e: IsolatedRpcEndpoint =>

  7.                 new DedicatedMessageLoop(name, e, this)

  8.             case _ =>

  9.                 sharedLoop.register(name, endpoint)

  10.                 sharedLoop

  11.             }

  12.             endpoints.put(name, messageLoop)

  13.         } catch {

  14.             … …

  15.         }

  16.     }

  17.     endpointRef

  18. }

DedicatedMessageLoop.scala

  1. private class DedicatedMessageLoop(

  2.     name: String,

  3.     endpoint: IsolatedRpcEndpoint,

  4.     dispatcher: Dispatcher)

  5.   extends MessageLoop(dispatcher) {

  6.     private val inbox = new Inbox(name, endpoint)

  7.     … …

  8. }

Inbox.scala

  1. private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)

  2.   extends Logging {

  3.     … …

  4.     inbox.synchronized {

  5.         messages.add(OnStart)

  6.     }

  7.     … …

  8. }

CoarseGrainedExecutorBackend.scala

  1. // RPC生命周期: constructor -> onStart -> receive* -> onStop

  2. private[spark] class CoarseGrainedExecutorBackend(… …)

  3.   extends IsolatedRpcEndpoint with ExecutorBackend with Logging {

  4.     … …

  5.     override def onStart(): Unit = {

  6.         … …

  7.         rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>

  8.         // This is a very fast action so we can use "ThreadUtils.sameThread"

  9.         driver = Some(ref)

  10. // 1向Driver注册自己

  11.         ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls, extractAttributes, _resources, resourceProfile.id))

  12.         }(ThreadUtils.sameThread).onComplete {

  13. // 2接收Driver返回成功的消息,并给自己发送注册成功消息

  14.         case Success(_) =>

  15.             self.send(RegisteredExecutor)

  16.         case Failure(e) =>

  17.             exitExecutor(1, s"Cannot register with driver: $driverUrl", e, notifyDriver = false)

  18.         }(ThreadUtils.sameThread)

  19.     }

  20.     … …

  21.     override def receive: PartialFunction[Any, Unit] = {

  22. // 3收到注册成功的消息后,创建Executor,并启动__Executor

  23.         case RegisteredExecutor =>

  24.         try {

  25.             // 创建__Executor

  26.             executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false, resources = _resources)

  27.             driver.get.send(LaunchedExecutor(executorId))

  28.         } catch {

  29.             case NonFatal(e) =>

  30.             exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)

  31.         }

  32.         … …

  33.     }

  34. }

ExecutorBackend发送向Driver发送请求后,Driver开始接收消息。全局查找(ctrl + n)SparkContext类

SparkContext.scala

  1. class SparkContext(config: SparkConf) extends Logging {

  2.     … …

  3.     private var _schedulerBackend: SchedulerBackend = _

  4.     … …

  5. }

点击SchedulerBackend进入SchedulerBackend.scala,查找实现类(ctrl+h),找到CoarseGrainedSchedulerBackend.scala

  1. private[spark]

  2. class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv)

  3.     extends ExecutorAllocationClient with SchedulerBackend with Logging {

  4.   
     

  5.     class DriverEndpoint extends IsolatedRpcEndpoint with Logging {

  6.         override def receive: PartialFunction[Any, Unit] = {

  7.             … …

  8.             // 接收注册成功后的消息

  9.             case LaunchedExecutor(executorId) =>

  10.             executorDataMap.get(executorId).foreach { data =>

  11.                 data.freeCores = data.totalCores

  12.             }

  13.             makeOffers(executorId)

  14.         }

  15.         
     

  16.         // 接收__ask消息,并回复

  17.         override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

  18.           case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls,

  19.               attributes, resources, resourceProfileId) =>

  20.             … …

  21.             context.reply(true)

  22.             … …

  23.         }

  24.         … …

  25.     }

  26.     
     

  27.     val driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint())

  28.     
     

  29.     protected def createDriverEndpoint(): DriverEndpoint = new DriverEndpoint()

  30. }

SparkContext初始化完毕,通知执行后续代码

1)进入到ApplicationMaster

ApplicationMaster.scala

  1. private[spark] class ApplicationMaster(

  2.     args: ApplicationMasterArguments,

  3.     sparkConf: SparkConf,

  4.     yarnConf: YarnConfiguration) extends Logging {

  5.     
     

  6.     private def runDriver(): Unit = {

  7.         addAmIpFilter(None, System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV))

  8.         userClassThread = startUserApplication()

  9.     
     

  10.         val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)

  11.         try {

  12.             val sc = ThreadUtils.awaitResult(sparkContextPromise.future,

  13.                 Duration(totalWaitTime, TimeUnit.MILLISECONDS))

  14.             if (sc != null) {

  15.                 val rpcEnv = sc.env.rpcEnv

  16.         
     

  17.                 val userConf = sc.getConf

  18.                 val host = userConf.get(DRIVER_HOST_ADDRESS)

  19.                 val port = userConf.get(DRIVER_PORT)

  20.                 registerAM(host, port, userConf, sc.ui.map(_.webUrl), appAttemptId)

  21.         
     

  22.                 val driverRef = rpcEnv.setupEndpointRef(

  23.                 RpcAddress(host, port),

  24.                 YarnSchedulerBackend.ENDPOINT_NAME)

  25.                 createAllocator(driverRef, userConf, rpcEnv, appAttemptId, distCacheConf)

  26.             } else {

  27.                 … …

  28.             }

  29.             // 执行程序

  30.             resumeDriver()

  31.             userClassThread.join()

  32.         } catch {

  33.             … …

  34.         } finally {

  35.             resumeDriver()

  36.         }

  37.     }

  38.     … …

  39.     private def resumeDriver(): Unit = {

  40.         sparkContextPromise.synchronized {

  41.             sparkContextPromise.notify()

  42.         }

  43.     }

  44. }

接收代码继续执行消息

在SparkContext.scala文件中查找_taskScheduler.postStartHook(),点击postStartHook,查找实现类(ctrl + h)

  1. private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnScheduler(sc) {

  2.     logInfo("Created YarnClusterScheduler")

  3.     
     

  4.     override def postStartHook(): Unit = {

  5.         ApplicationMaster.sparkContextInitialized(sc)

  6.         super.postStartHook()

  7.         logInfo("YarnClusterScheduler.postStartHook done")

  8.     }

  9. }

点击super.postStartHook()

TaskSchedulerImpl.scala

  1. override def postStartHook(): Unit = {

  2.     waitBackendReady()

  3. }

  4. private def waitBackendReady(): Unit = {

  5.     if (backend.isReady) {

  6.         return

  7.     }

  8.     while (!backend.isReady) { 

  9.         if (sc.stopped.get) {

  10.             throw new IllegalStateException("Spark context stopped while waiting for backend")

  11.         }

  12.         synchronized {

  13.             this.wait(100)

  14.         }

  15.     }

  16. }

任务的执行

任务切分和任务调度原理

Stage任务划分

Task任务调度执行

本地化调度

任务分配原则:根据每个Task的优先位置,确定Task的Locality(本地化)级别,本地化一共有五种,优先级由高到低顺序:

移动数据不如移动计算。

名称

解析

PROCESS_LOCAL

进程本地化,task和数据在同一个Executor中,性能最好。

NODE_LOCAL

节点本地化,task和数据在同一个节点中,但是task和数据不在同一个Executor中,数据需要在进程间进行传输。

RACK_LOCAL

机架本地化,task和数据在同一个机架的两个节点上,数据需要通过网络在节点之间进行传输。

NO_PREF

对于task来说,从哪里获取都一样,没有好坏之分。

ANY

task和数据可以在集群的任何地方,而且不在一个机架中,性能最差。

失败重试与黑名单机制

除了选择合适的Task调度运行外,还需要监控Task的执行状态,前面也提到,与外部打交道的是SchedulerBackend,Task被提交到Executor启动执行后,Executor会将执行状态上报给SchedulerBackend,SchedulerBackend则告诉TaskScheduler,TaskScheduler找到该Task对应的TaskSetManager,并通知到该TaskSetManager,这样TaskSetManager就知道Task的失败与成功状态,对于失败的Task,会记录它失败的次数,如果失败次数还没有超过最大重试次数,那么就把它放回待调度的Task池子中,否则整个Application失败。

在记录Task失败次数过程中,会记录它上一次失败所在的Executor Id和Host,这样下次再调度这个Task时,会使用黑名单机制,避免它被调度到上一次失败的节点上,起到一定的容错作用。黑名单记录Task上一次失败所在的Executor Id和Host,以及其对应的"拉黑"时间,"拉黑"时间是指这段时间内不要再往这个节点上调度这个Task了。

0)在WordCount程序中查看源码

  1. import org.apache.spark.rdd.RDD

  2. import org.apache.spark.{SparkConf, SparkContext}

  3. object WordCount {

  4.     def main(args: Array[String]): Unit = {

  5.         val conf: SparkConf = new SparkConf().setAppName("WC").setMaster("local[*]")

  6.         val sc: SparkContext = new SparkContext(conf)

  7.         // 2 读取数据 hello atguigu spark spark

  8.         val lineRDD: RDD[String] = sc.textFile("input")

  9.         // 3 一行变多行

  10.         val wordRDD: RDD[String] = lineRDD.flatMap((x: String) => x.split(" "))

  11.         // 4 变换结构 一行变一行

  12.         val wordToOneRDD: RDD[(String, Int)] = wordRDD.map((x: String) => (x, 1))

  13.         // 5  聚合key相同的单词

  14.         val wordToSumRDD: RDD[(String, Int)] = wordToOneRDD.reduceByKey((v1, v2) => v1 + v2)

  15.         // 6 收集打印

  16.         wordToSumRDD.collect().foreach(println)

  17.         //7 关闭资源

  18.         sc.stop()

  19.     }

  20. }

1)在WordCount代码中点击collect

RDD.scala

  1. def collect(): Array[T] = withScope {

  2.     val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)

  3.     Array.concat(results: _*)

  4. }

SparkContext.scala

  1. def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {

  2.     runJob(rdd, func, 0 until rdd.partitions.length)

  3. }

  4. def runJob[T, U: ClassTag](

  5.     rdd: RDD[T],

  6.     func: Iterator[T] => U,

  7.     partitions: Seq[Int]): Array[U] = {

  8.     val cleanedFunc = clean(func)

  9.     runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions)

  10. }

  11. def runJob[T, U: ClassTag](

  12.     rdd: RDD[T],

  13.     func: (TaskContext, Iterator[T]) => U,

  14.     partitions: Seq[Int]): Array[U] = {

  15.     val results = new Array[U](partitions.size)

  16.     runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res)

  17.     results

  18. }

  19. def runJob[T, U: ClassTag](

  20.     rdd: RDD[T],

  21.     func: (TaskContext, Iterator[T]) => U,

  22.     partitions: Seq[Int],

  23.     resultHandler: (Int, U) => Unit): Unit = {

  24.     … …

  25.     dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)

  26.     … …

  27. }

DAGScheduler.scala

  1. def runJob[T, U](

  2.     rdd: RDD[T],

  3.     func: (TaskContext, Iterator[T]) => U,

  4.     partitions: Seq[Int],

  5.     callSite: CallSite,

  6.     resultHandler: (Int, U) => Unit,

  7.     properties: Properties): Unit = {  

  8.     … … 

  9.     val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)

  10.     … …

  11. }

  12. def submitJob[T, U](

  13.     rdd: RDD[T],

  14.     func: (TaskContext, Iterator[T]) => U,

  15.     partitions: Seq[Int],

  16.     callSite: CallSite,

  17.     resultHandler: (Int, U) => Unit,

  18.     properties: Properties): JobWaiter[U] = {

  19.     … …

  20.     val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler)

  21.     eventProcessLoop.post(JobSubmitted(

  22.         jobId, rdd, func2, partitions.toArray, callSite, waiter,

  23.         Utils.cloneProperties(properties)))

  24.     waiter

  25. }

EventLoop.scala

  1. def post(event: E): Unit = {

  2.     if (!stopped.get) {

  3.         if (eventThread.isAlive) {

  4.             eventQueue.put(event)

  5.         } else {

  6.             … …

  7.         }

  8.     }

  9. }

  10. private[spark] val eventThread = new Thread(name) {

  11.     override def run(): Unit = {

  12.         while (!stopped.get) {

  13.             val event = eventQueue.take()

  14.             try {

  15.                 onReceive(event)

  16.             } catch {

  17.                 … …

  18.             }

  19.         }

  20.     }

  21. }

查找onReceive实现类(ctrl + h)

DAGScheduler.scala

  1. private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)

  2.     extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {

  3.     … …

  4.     override def onReceive(event: DAGSchedulerEvent): Unit = {

  5.         val timerContext = timer.time()

  6.         try {

  7.             doOnReceive(event)

  8.         } finally {

  9.             timerContext.stop()

  10.         }

  11.     }

  12.     
     

  13.     private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {

  14.         case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>

  15.         dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

  16.         … …

  17.     }

  18.     … …

  19.     private[scheduler] def handleJobSubmitted(jobId: Int,

  20.         finalRDD: RDD[_],

  21.         func: (TaskContext, Iterator[_]) => _,

  22.         partitions: Array[Int],

  23.         callSite: CallSite,

  24.         listener: JobListener,

  25.         properties: Properties): Unit = {

  26.           
     

  27.         var finalStage: ResultStage = null

  28.         finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)

  29.         … …

  30.     }

  31.     private def createResultStage(

  32.         rdd: RDD[_],

  33.         func: (TaskContext, Iterator[_]) => _,

  34.         partitions: Array[Int],

  35.         jobId: Int,

  36.         callSite: CallSite): ResultStage = {

  37.         … …

  38.         val parents = getOrCreateParentStages(rdd, jobId)

  39.         val id = nextStageId.getAndIncrement()

  40.         
     

  41.         val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)

  42.         stageIdToStage(id) = stage

  43.         updateJobIdStageIdMaps(jobId, stage)

  44.         stage

  45.     }

  46.     private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {

  47.         getShuffleDependencies(rdd).map { shuffleDep =>

  48.             getOrCreateShuffleMapStage(shuffleDep, firstJobId)

  49.         }.toList

  50.     }

  51.     private[scheduler] def getShuffleDependencies(

  52.         rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {

  53.         
     

  54.             val parents = new HashSet[ShuffleDependency[_, _, _]]

  55.             val visited = new HashSet[RDD[_]]

  56.             val waitingForVisit = new ListBuffer[RDD[_]]

  57.             waitingForVisit += rdd

  58.             while (waitingForVisit.nonEmpty) {

  59.             val toVisit = waitingForVisit.remove(0)

  60.             
     

  61.             if (!visited(toVisit)) {

  62.                 visited += toVisit

  63.                 toVisit.dependencies.foreach {

  64.                 case shuffleDep: ShuffleDependency[_, _, _] =>

  65.                     parents += shuffleDep

  66.                 case dependency =>

  67.                     waitingForVisit.prepend(dependency.rdd)

  68.                 }

  69.             }

  70.         }

  71.         parents

  72.     }

  73.     private def getOrCreateShuffleMapStage(

  74.         shuffleDep: ShuffleDependency[_, _, _],

  75.         firstJobId: Int): ShuffleMapStage = {

  76.         
     

  77.         shuffleIdToMapStage.get(shuffleDep.shuffleId) match {

  78.             case Some(stage) =>

  79.                 stage

  80.             
     

  81.             case None =>

  82.                 getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>

  83.                 
     

  84.                     if (!shuffleIdToMapStage.contains(dep.shuffleId)) {

  85.                         createShuffleMapStage(dep, firstJobId)

  86.                     }

  87.                 }

  88.                 // Finally, create a stage for the given shuffle dependency.

  89.                 createShuffleMapStage(shuffleDep, firstJobId)

  90.         }

  91.     }

  92.     def createShuffleMapStage[K, V, C](

  93.         shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = {

  94.         … …

  95.         val rdd = shuffleDep.rdd

  96.         val numTasks = rdd.partitions.length

  97.         val parents = getOrCreateParentStages(rdd, jobId)

  98.         val id = nextStageId.getAndIncrement()        

  99.         
     

  100.         val stage = new ShuffleMapStage(

  101.             id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)

  102.         … …      

  103.     }    

  104.     … …

  105. }

DAGScheduler.scala

  1. private[scheduler] def handleJobSubmitted(jobId: Int,

  2.     finalRDD: RDD[_],

  3.     func: (TaskContext, Iterator[_]) => _,

  4.     partitions: Array[Int],

  5.     callSite: CallSite,

  6.     listener: JobListener,

  7.     properties: Properties): Unit = {

  8.     
     

  9.     var finalStage: ResultStage = null

  10.     try {

  11.         finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)

  12.     } catch {

  13.         … …

  14.     }

  15.     
     

  16.     val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)    

  17.     … …

  18.     submitStage(finalStage)

  19. }

  20. private def submitStage(stage: Stage): Unit = {

  21.     val jobId = activeJobForStage(stage)

  22.     
     

  23.     if (jobId.isDefined) {        

  24.         if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {

  25.             val missing = getMissingParentStages(stage).sortBy(_.id)

  26.             if (missing.isEmpty) {

  27.                 submitMissingTasks(stage, jobId.get)

  28.             } else {

  29.                 for (parent <- missing) {

  30.                     submitStage(parent)

  31.                 }

  32.                 waitingStages += stage

  33.             }

  34.         }

  35.     } else {

  36.         abortStage(stage, "No active job for stage " + stage.id, None)

  37.     }

  38. }

  39. private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {

  40.     val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()

  41.     … …

  42.     val tasks: Seq[Task[_]] = try {

  43.         val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()

  44.         
     

  45.         stage match {

  46.         case stage: ShuffleMapStage =>        

  47.             stage.pendingPartitions.clear()

  48.             partitionsToCompute.map { id =>

  49.             val locs = taskIdToLocations(id)

  50.             val part = partitions(id)

  51.             stage.pendingPartitions += id            

  52.             new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,

  53.                 taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),

  54.                 Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())

  55.             }

  56.         case stage: ResultStage =>

  57.             partitionsToCompute.map { id =>

  58.             val p: Int = stage.partitions(id)

  59.             val part = partitions(p)

  60.             val locs = taskIdToLocations(id)

  61.             new ResultTask(stage.id, stage.latestInfo.attemptNumber,

  62.                 taskBinary, part, locs, id, properties, serializedTaskMetrics,

  63.                 Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,

  64.                 stage.rdd.isBarrier())

  65.             }

  66.         }

  67.     } catch {

  68.         … …

  69.     }

  70. }

Stage.scala

  1. private[scheduler] abstract class Stage(… …)

  2.     extends Logging {

  3.     … …

  4.     def findMissingPartitions(): Seq[Int]

  5.     … …

  6. }

全局查找(ctrl + h)findMissingPartitions实现类。

ShuffleMapStage.scala

  1. private[spark] class ShuffleMapStage(… …)

  2.     extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {

  3.     private[this] var _mapStageJobs: List[ActiveJob] = Nil

  4.     … …

  5.     override def findMissingPartitions(): Seq[Int] = {

  6.         mapOutputTrackerMaster

  7.         .findMissingPartitions(shuffleDep.shuffleId)

  8.         .getOrElse(0 until numPartitions)

  9.     }

  10. }

ResultStage.scala

  1. private[spark] class ResultStage(… …)

  2.     extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) {

  3.     … …

  4.     override def findMissingPartitions(): Seq[Int] = {

  5.         val job = activeJob.get(0 until job.numPartitions).filter(id => !job.finished(id))

  6.     }

  7.     … …

  8. }

1)提交任务

DAGScheduler.scala

  1. private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {

  2.     … …

  3.     if (tasks.nonEmpty) {

  4.         taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties))

  5.     } else {

  6.         markStageAsFinished(stage, None)

  7.     
     

  8.         stage match {

  9.         case stage: ShuffleMapStage =>

  10.             markMapStageJobsAsFinished(stage)

  11.         case stage : ResultStage =>

  12.             logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})")

  13.         }

  14.         submitWaitingChildStages(stage)

  15.     }

  16. }

TaskScheduler.scala

def submitTasks(taskSet: TaskSet): Unit

全局查找submitTasks的实现类TaskSchedulerImpl

TaskSchedulerImpl.scala

  1. override def submitTasks(taskSet: TaskSet): Unit = {

  2.     val tasks = taskSet.tasks

  3.     this.synchronized {

  4.         val manager = createTaskSetManager(taskSet, maxTaskFailures)

  5.         val stage = taskSet.stageId

  6.         val stageTaskSets =

  7.         taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])

  8.         … …

  9.         stageTaskSets(taskSet.stageAttemptId) = manager

  10.         // 向队列里面设置任务

  11.         schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) 

  12.         … …

  13.     }

  14.     // 取任务

  15.     backend.reviveOffers()

  16. }

2)FIFO和公平调度器

点击schedulableBuilder,查找schedulableBuilder初始化赋值的地方

  1. private var schedulableBuilder: SchedulableBuilder = null

  2. def initialize(backend: SchedulerBackend): Unit = {

  3.     this.backend = backend

  4.     schedulableBuilder = {

  5.         schedulingMode match {

  6.         case SchedulingMode.FIFO =>

  7.             new FIFOSchedulableBuilder(rootPool)

  8.         case SchedulingMode.FAIR =>

  9.             new FairSchedulableBuilder(rootPool, conf)

  10.         case _ =>

  11.             throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " +

  12.             s"$schedulingMode")

  13.         }

  14.     }

  15.     schedulableBuilder.buildPools()

  16. }

点击schedulingMode,default scheduler is FIFO

  1. private val schedulingModeConf = conf.get(SCHEDULER_MODE)

  2. val schedulingMode: SchedulingMode =

  3.     … …

  4.       SchedulingMode.withName(schedulingModeConf.toUpperCase(Locale.ROOT))

  5.     … …

  6. }

  7. private[spark] val SCHEDULER_MODE =

  8. ConfigBuilder("spark.scheduler.mode")

  9.     .version("0.8.0")

  10.     .stringConf

  11.     .createWithDefault(SchedulingMode.FIFO.toString)

3)读取任务

SchedulerBackend.scala

  1. private[spark] trait SchedulerBackend {

  2.     … …

  3.     def reviveOffers(): Unit

  4.     … …

  5. }

全局查找reviveOffers实现类CoarseGrainedSchedulerBackend

CoarseGrainedSchedulerBackend.scala

  1. override def reviveOffers(): Unit = {

  2.     // 自己给自己发消息

  3.     driverEndpoint.send(ReviveOffers)

  4. }

  5. // 自己接收到消息

  6. override def receive: PartialFunction[Any, Unit] = {

  7.     … …

  8.     case ReviveOffers =>

  9.         makeOffers()

  10.     … …

  11. }

  12. private def makeOffers(): Unit = {

  13.     val taskDescs = withLock {

  14.         … …

  15.         // 取任务

  16.         scheduler.resourceOffers(workOffers)

  17.     }

  18.     if (taskDescs.nonEmpty) {

  19.         launchTasks(taskDescs)

  20.     }

  21. }

TaskSchedulerImpl.scala

  1. def resourceOffers(offers: IndexedSeq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {

  2.     … …

  3.     val sortedTaskSets = rootPool.getSortedTaskSetQueue.filterNot(_.isZombie)

  4.       
     

  5.     for (taskSet <- sortedTaskSets) {

  6.         val availableSlots = availableCpus.map(c => c / CPUS_PER_TASK).sum

  7.         if (taskSet.isBarrier && availableSlots < taskSet.numTasks) {

  8.     
     

  9.         } else {

  10.             var launchedAnyTask = false

  11.             val addressesWithDescs = ArrayBuffer[(String, TaskDescription)]()

  12.             for (currentMaxLocality <- taskSet.myLocalityLevels) {

  13.                 var launchedTaskAtCurrentMaxLocality = false

  14.                 do {

  15.                     launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet,

  16.                         currentMaxLocality, shuffledOffers, availableCpus,

  17.                         availableResources, tasks, addressesWithDescs)

  18.                     launchedAnyTask |= launchedTaskAtCurrentMaxLocality

  19.                 } while (launchedTaskAtCurrentMaxLocality)

  20.             }

  21.             … …

  22.         }

  23.     }

  24.     … …

  25.     return tasks

  26. }

Pool.scala

  1. override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = {

  2.     val sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]

  3.     val sortedSchedulableQueue =

  4.         schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator)

  5.     
     

  6.     for (schedulable <- sortedSchedulableQueue) {

  7.         sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue

  8.     }

  9.     sortedTaskSetQueue

  10. }

  11. private val taskSetSchedulingAlgorithm: SchedulingAlgorithm = {

  12.     schedulingMode match {

  13.         case SchedulingMode.FAIR =>

  14.             new FairSchedulingAlgorithm()

  15.         case SchedulingMode.FIFO =>

  16.             new FIFOSchedulingAlgorithm()

  17.         case _ =>

  18.             … …

  19.     }

  20. }

4)FIFO和公平调度器规则

SchedulingAlgorithm.scala

  1. private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {

  2.     override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {

  3.         val priority1 = s1.priority

  4.         val priority2 = s2.priority

  5.         var res = math.signum(priority1 - priority2)

  6.         if (res == 0) {

  7.             val stageId1 = s1.stageId

  8.             val stageId2 = s2.stageId

  9.             res = math.signum(stageId1 - stageId2)

  10.         }

  11.         res < 0

  12.     }

  13. }

  14. private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {

  15.     override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {

  16.         val minShare1 = s1.minShare

  17.         val minShare2 = s2.minShare

  18.         val runningTasks1 = s1.runningTasks

  19.         val runningTasks2 = s2.runningTasks

  20.         val s1Needy = runningTasks1 < minShare1

  21.         val s2Needy = runningTasks2 < minShare2

  22.         val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0)

  23.         val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0)

  24.         val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble

  25.         val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble

  26.         … …

  27.     }

  28. }

5)发送给Executor端执行任务

CoarseGrainedSchedulerBackend.scala

  1. private def makeOffers(): Unit = {

  2.     val taskDescs = withLock {

  3.         … …

  4.         // 取任务

  5.         scheduler.resourceOffers(workOffers)

  6.     }

  7.     if (taskDescs.nonEmpty) {

  8.         launchTasks(taskDescs)

  9.     }

  10. }

  11. private def launchTasks(tasks: Seq[Seq[TaskDescription]]): Unit = {

  12.     for (task <- tasks.flatten) {

  13.         val serializedTask = TaskDescription.encode(task)

  14.         if (serializedTask.limit() >= maxRpcMessageSize) {

  15.             … …

  16.         }

  17.         else {

  18.             … …

  19.             // 序列化任务发往__Executor远程终端

  20.             executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))

  21.         }

  22.     }

  23. }

在CoarseGrainedExecutorBackend.scala中接收数据LaunchTask

  1. override def receive: PartialFunction[Any, Unit] = {

  2.     … …

  3.     case LaunchTask(data) =>

  4.       if (executor == null) {

  5.         exitExecutor(1, "Received LaunchTask command but executor was null")

  6.       } else {

  7.         val taskDesc = TaskDescription.decode(data.value)

  8.         logInfo("Got assigned task " + taskDesc.taskId)

  9.         taskResources(taskDesc.taskId) = taskDesc.resources

  10.         executor.launchTask(this, taskDesc)

  11.       }

  12.       … …

  13. }

Executor.scala

  1. def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {

  2.     val tr = new TaskRunner(context, taskDescription)

  3.     runningTasks.put(taskDescription.taskId, tr)

  4.     threadPool.execute(tr)

  5. }

Shuffle

Spark最初版本HashShuffle

Spark0.8.1版本以后优化后的HashShuffle

Spark1.1版本加入SortShuffle,默认是HashShuffle

Spark1.2版本默认是SortShuffle,但是可配置HashShuffle

Spark2.0版本删除HashShuffle只有SortShuffle

Shuffle一定会有落盘。

如果shuffle过程中落盘数据量减少,那么可以提高性能。

算子如果存在预聚合功能,可以提高shuffle的性能。

未优化的HashShuffle

优化后的HashShuffle

优化的HashShuffle过程就是启用合并机制,合并机制就是复用buffer,开启合并机制的配置是spark.shuffle.consolidateFiles。该参数默认值为false,将其设置为true即可开启优化机制。通常来说,如果我们使用HashShuffleManager,那么都建议开启这个选项。

官网参数说明:http://spark.apache.org/docs/0.8.1/configuration.html

SortShuffle

在该模式下,数据会先写入一个数据结构,reduceByKey写入Map,一边通过Map局部聚合,一边写入内存。Join算子写入ArrayList直接写入内存中。然后需要判断是否达到阈值,如果达到就会将内存数据结构的数据写入到磁盘,清空内存数据结构。

在溢写磁盘前,先根据key进行排序,排序过后的数据,会分批写入到磁盘文件中。默认批次为10000条,数据会以每批一万条写入到磁盘文件。写入磁盘文件通过缓冲区溢写的方式,每次溢写都会产生一个磁盘文件,也就是说一个Task过程会产生多个临时文件。

最后在每个Task中,将所有的临时文件合并,这就是merge过程,此过程将所有临时文件读取出来,一次写入到最终文件。意味着一个Task的所有数据都在这一个文件中。同时单独写一份索引文件,标识下游各个Task的数据在文件中的索引,start offset和end offset。

bypassShuffle

bypassShuffle和SortShuffle的区别就是不对数据排序。

bypass运行机制的触发条件如下:

1)shuffle reduce task数量小于等于spark.shuffle.sort.bypassMergeThreshold参数的值,默认为200。

2)不是聚合类的shuffle算子(比如reduceByKey不行)。

shuffleWriterProcessor(写处理器)

DAGScheduler.scala

  1. private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {

  2.     … …

  3.     val tasks: Seq[Task[_]] = try {

  4.         val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()

  5.         
     

  6.         stage match {

  7.         // shuffle_写过程_

  8.         case stage: ShuffleMapStage =>        

  9.             stage.pendingPartitions.clear()

  10.             partitionsToCompute.map { id =>

  11.             val locs = taskIdToLocations(id)

  12.             val part = partitions(id)

  13.             stage.pendingPartitions += id            

  14.             new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,

  15.                 taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),

  16.                 Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())

  17.             }

  18.         // shuffle_读过程_

  19.         case stage: ResultStage =>

  20.         … …

  21.         }

  22.     } catch {

  23.         … …

  24.     }

  25. }

Task.scala

  1. private[spark] abstract class Task[T](… …) extends Serializable {

  2.    
     

  3.     final def run(… …): T = {      

  4.         runTask(context)

  5.     }

  6. }

Ctrl+h查找runTask 实现类ShuffleMapTask.scala

  1. private[spark] class ShuffleMapTask(… …)

  2.   extends Task[MapStatus](… …){

  3.   
     

  4.     override def runTask(context: TaskContext): MapStatus = {

  5.         … …

  6.         dep.shuffleWriterProcessor.write(rdd, dep, mapId, context, partition)

  7.     }

  8. }

ShuffleWriteProcessor.scala

  1. def write(… …): MapStatus = {

  2.     
     

  3.     var writer: ShuffleWriter[Any, Any] = null

  4.     try {

  5.         val manager = SparkEnv.get.shuffleManager

  6.         writer = manager.getWriter[Any, Any](

  7.         dep.shuffleHandle,

  8.         mapId,

  9.         context,

  10.         createMetricsReporter(context))

  11.         
     

  12.         writer.write(

  13.         rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

  14.         writer.stop(success = true).get

  15.     } catch {

  16.         … …

  17.     }

  18. }

查找(ctrl + h)ShuffleManager的实现类,SortShuffleManager

SortShuffleManager.scala

  1. override def getWriter[K, V]( handle: ShuffleHandle,

  2.       mapId: Long,

  3.       context: TaskContext,

  4.       metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] =

  5.     … …

  6.     handle match {

  7.     case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>

  8.         new UnsafeShuffleWriter(… …)

  9.     case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>

  10.         new BypassMergeSortShuffleWriter(… …)

  11.     case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>

  12.         new SortShuffleWriter(… …)

  13.     }

  14. }

因为getWriter的第一个输入参数是dep.shuffleHandle,点击dep.shuffleHandle

Dependency.scala

val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(shuffleId, this)

ShuffleManager.scala

def registerShuffle[K, V, C](shuffleId: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle

使用BypassShuffle条件

BypassMergeSortShuffleHandle使用条件:

**1)不能使用预聚合
**

**2)如果下游的分区数量小于等于200(可配置)
**

处理器

写对象

判断条件

SerializedShuffleHandle

UnsafeShuffleWriter

1.序列化规则支持重定位操作(java序列化不支持,Kryo支持)

2.不能使用预聚合

3.如果下游的分区数量小于或等于1677216

BypassMergeSortShuffleHandle

BypassMergeSortShuffleWriter

1.不能使用预聚合

2.如果下游的分区数量小于等于200(可配置)

BaseShuffleHandle

SortShuffleWriter

其他情况

查找(ctrl + h)registerShuffle 实现类,SortShuffleManager.scala

  1. override def registerShuffle[K, V, C](

  2.     shuffleId: Int,

  3.     dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {

  4.     //_使用_BypassShuffle条件:不能使用预聚合功能;默认下游分区数据不能大于__200

  5.     if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {

  6.         new BypassMergeSortShuffleHandle[K, V](

  7.             shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])

  8.     } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {

  9.         new SerializedShuffleHandle[K, V](

  10.             shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])

  11.     } else {

  12.         new BaseShuffleHandle(shuffleId, dependency)

  13.     }

  14. }

点击shouldBypassMergeSort

SortShuffleWriter.scala

  1. private[spark] object SortShuffleWriter {

  2.     def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {

  3.         // 是否有__map阶段预聚合(支持预聚合不能用)

  4.         if (dep.mapSideCombine) {

  5.             false

  6.         } else {

  7. // SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD = 200分区

  8.             val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)

  9.             // 如果下游分区器的数量,小于__200(可配置),可以使用__bypass

  10.             dep.partitioner.numPartitions <= bypassMergeThreshold

  11.         }

  12.     }

  13. }

使用SerializedShuffle条件

SerializedShuffleHandle使用条件:

**1)序列化规则支持重定位操作(java序列化不支持,Kryo支持)
**

**2)不能使用预聚合
**

**3)如果下游的分区数量小于或等于1677216
**

点击canUseSerializedShuffle

SortShuffleManager.scala

  1. def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {

  2.     val shufId = dependency.shuffleId

  3.     val numPartitions = dependency.partitioner.numPartitions

  4.     
     

  5.     // 是否支持将两个独立的序列化对象 重定位,聚合到一起

  6.     // 1_默认的_java序列化不支持;Kryo序列化支持重定位(可以用)

  7.     if (!dependency.serializer.supportsRelocationOfSerializedObjects) {

  8.         false

  9.     } else if (dependency.mapSideCombine) { // 2_支持预聚合也不能用_

  10.         false

  11.     } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {//3_如果下游分区的数量大于_16777216,也不能用

  12.         false

  13.     } else {

  14.         true

  15.     }

  16. }

使用BaseShuffle

点击SortShuffleWriter

SortShuffleWriter.scala

  1. override def write(records: Iterator[Product2[K, V]]): Unit = {

  2.     // 判断是否有预聚合功能,支持会有__aggregator和排序规则

  3.     sorter = if (dep.mapSideCombine) {

  4.         new ExternalSorter[K, V, C](

  5.         context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)

  6.     } else {

  7.         new ExternalSorter[K, V, V](

  8.         context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)

  9.     }

  10.     // 插入数据

  11.     sorter.insertAll(records)

  12.     
     

  13.     val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(

  14.         dep.shuffleId, mapId, dep.partitioner.numPartitions)

  15.     // 插入数据

  16.     sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)

  17.     
     

  18.     val partitionLengths = mapOutputWriter.commitAllPartitions()

  19.     
     

  20.     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)

  21. }

插入数据(缓存+溢写)

ExternalSorter.scala

  1. def insertAll(records: Iterator[Product2[K, V]]): Unit = {

  2.     val shouldCombine = aggregator.isDefined

  3.     // 判断是否支持预聚合,支持预聚合,采用__map结构,不支持预聚合采用buffer结构

  4.     if (shouldCombine) {

  5.         val mergeValue = aggregator.get.mergeValue

  6.         val createCombiner = aggregator.get.createCombiner

  7.         var kv: Product2[K, V] = null

  8.         val update = (hadValue: Boolean, oldValue: C) => {

  9.             if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)

  10.         }

  11.         
     

  12.         while (records.hasNext) {

  13.             addElementsRead()

  14.             kv = records.next()

  15.             // 如果支持预聚合,在__map阶段聚合,将相同key,的value聚合

  16.             map.changeValue((getPartition(kv._1), kv._1), update)

  17.             // 是否能够溢写

  18.             maybeSpillCollection(usingMap = true)

  19.         }

  20.     } else {

  21.         while (records.hasNext) {

  22.             addElementsRead()

  23.             val kv = records.next()

  24.             // 如果不支持预聚合,value不需要聚合 (key,(value1,value2))

  25.             buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])

  26.             maybeSpillCollection(usingMap = false)

  27.         }

  28.     }

  29. }

  30. private def maybeSpillCollection(usingMap: Boolean): Unit = {

  31.     var estimatedSize = 0L

  32.     if (usingMap) {

  33.         estimatedSize = map.estimateSize()

  34.         if (maybeSpill(map, estimatedSize)) {

  35.             map = new PartitionedAppendOnlyMap[K, C]

  36.         }

  37.     } else {

  38.         estimatedSize = buffer.estimateSize()

  39.         if (maybeSpill(buffer, estimatedSize)) {

  40.             buffer = new PartitionedPairBuffer[K, C]

  41.         }

  42.     }

  43.     
     

  44.     if (estimatedSize > _peakMemoryUsedBytes) {

  45.         _peakMemoryUsedBytes = estimatedSize

  46.     }

  47. }

Spillable.scala

  1. protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {

  2.     var shouldSpill = false

  3.     // myMemoryThreshold_默认值内存门槛是_5m

  4.     if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {

  5.     
     

  6.         val amountToRequest = 2 * currentMemory - myMemoryThreshold

  7.         // 申请内存

  8.         val granted = acquireMemory(amountToRequest)

  9.         myMemoryThreshold += granted

  10.         // 当前内存大于(尝试申请的内存_+门槛),就需要溢写了_

  11.         shouldSpill = currentMemory >= myMemoryThreshold

  12.     }

  13.     // 强制溢写 读取数据的值 超过了Int的最大值

  14.     shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold

  15.     
     

  16.     if (shouldSpill) {

  17.         _spillCount += 1

  18.         logSpillage(currentMemory)

  19.         // 溢写

  20.         spill(collection)

  21.         _elementsRead = 0

  22.         _memoryBytesSpilled += currentMemory

  23.         // 释放内存

  24.         releaseMemory()

  25.     }

  26.     shouldSpill

  27. }

  28. protected def spill(collection: C): Unit

查找(ctrl +h)spill 的实现类ExternalSorter

ExternalSorter.scala

  1. override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {

  2.     val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)

  3.     val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)

  4.     spills += spillFile

  5. }

  6. private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)

  7.      : SpilledFile = {

  8.    // 创建临时文件

  9.    val (blockId, file) = diskBlockManager.createTempShuffleBlock()

  10.    var objectsWritten: Long = 0

  11.    val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics

  12.    // 溢写文件前,__fileBufferSize缓冲区大小默认__32m

  13.    val writer: DiskBlockObjectWriter =

  14.      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)

  15.    … …

  16.    SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)

  17. }

merge合并

来到SortShuffleWriter.scala

  1. override def write(records: Iterator[Product2[K, V]]): Unit = {

  2.     sorter = if (dep.mapSideCombine) {

  3.         new ExternalSorter[K, V, C](

  4.         context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)

  5.     } else {

  6.         new ExternalSorter[K, V, V](

  7.         context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)

  8.     }

  9.     sorter.insertAll(records)

  10.     
     

  11.     val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(

  12.         dep.shuffleId, mapId, dep.partitioner.numPartitions)

  13.     // 合并

  14.     sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)

  15.     
     

  16.     val partitionLengths = mapOutputWriter.commitAllPartitions()

  17.     
     

  18.     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)

  19. }

ExternalSorter.scala

  1. def writePartitionedMapOutput(

  2.     shuffleId: Int,

  3.     mapId: Long,

  4.     mapOutputWriter: ShuffleMapOutputWriter): Unit = {

  5.     var nextPartitionId = 0

  6.     // 如果溢写文件为空,只对内存中数据处理

  7.     if (spills.isEmpty) {

  8.         // Case where we only have in-memory data

  9.         … …

  10.     } else {

  11.         // We must perform merge-sort; get an iterator by partition and write everything directly.

  12.         //_如果溢写文件不为空,需要将多个溢写文件合并_

  13.         for ((id, elements) <- this.partitionedIterator) {

  14.             val blockId = ShuffleBlockId(shuffleId, mapId, id)

  15.             var partitionWriter: ShufflePartitionWriter = null

  16.             var partitionPairsWriter: ShufflePartitionPairsWriter = null

  17.             … …

  18.             } {

  19.                 if (partitionPairsWriter != null) {

  20.                     partitionPairsWriter.close()

  21.                 }

  22.             }

  23.             nextPartitionId = id + 1

  24.         }

  25.     }

  26.     … …

  27. }

  28. def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {

  29.     val usingMap = aggregator.isDefined

  30.     val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer

  31.     
     

  32.     if (spills.isEmpty) {

  33.         if (ordering.isEmpty) {      

  34.             groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))

  35.         } else {

  36.             groupByPartition(destructiveIterator(

  37.             collection.partitionedDestructiveSortedIterator(Some(keyComparator))))

  38.         }

  39.     } else {

  40.         // 合并溢写文件和内存中数据

  41.         merge(spills, destructiveIterator(

  42.         collection.partitionedDestructiveSortedIterator(comparator)))

  43.     }

  44. }

  45. private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])

  46.     : Iterator[(Int, Iterator[Product2[K, C]])] = {

  47.     
     

  48.     val readers = spills.map(new SpillReader(_))

  49.     val inMemBuffered = inMemory.buffered

  50.     
     

  51.     (0 until numPartitions).iterator.map { p =>

  52.         val inMemIterator = new IteratorForPartition(p, inMemBuffered)

  53.         val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)

  54.         if (aggregator.isDefined) {

  55.             (p, mergeWithAggregation(

  56.             iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))

  57.         } else if (ordering.isDefined) {

  58.         // 归并排序

  59.             (p, mergeSort(iterators, ordering.get))

  60.         } else {

  61.             (p, iterators.iterator.flatten)

  62.         }

  63.     }

  64. }

写磁盘

来到SortShuffleWriter.scala

  1. override def write(records: Iterator[Product2[K, V]]): Unit = {

  2.     sorter = if (dep.mapSideCombine) {

  3.         new ExternalSorter[K, V, C](

  4.         context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)

  5.     } else {

  6.         new ExternalSorter[K, V, V](

  7.         context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)

  8.     }

  9.     sorter.insertAll(records)

  10.     
     

  11.     val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(

  12.         dep.shuffleId, mapId, dep.partitioner.numPartitions)

  13.     // 合并

  14.     sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)

  15.     // 写磁盘

  16.     val partitionLengths = mapOutputWriter.commitAllPartitions()

  17.     
     

  18.     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)

  19. }

  20. 查找(ctrl + h)commitAllPartitions实现类,来到LocalDiskShuffleMapOutputWriter.java

  21. public long[] commitAllPartitions() throws IOException {

  22.     if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) {

  23.         … …

  24.     }

  25.     cleanUp();

  26.     File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;

  27.     blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);

  28.     return partitionLengths;

  29. }

查找(ctrl + h)commitAllPartitions实现类,来到LocalDiskShuffleMapOutputWriter.java

  1. public long[] commitAllPartitions() throws IOException {

  2.     if (outputFileChannel != null && outputFileChannel.position() != bytesWrittenToMergedFile) {

  3.         … …

  4.     }

  5.     cleanUp();

  6.     File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;

  7.     blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);

  8.     return partitionLengths;

  9. }

IndexShuffleBlockResolver.scala

  1. def writeIndexFileAndCommit(

  2.     shuffleId: Int,

  3.     mapId: Long,

  4.     lengths: Array[Long],

  5.     dataTmp: File): Unit = {

  6.     
     

  7.     val indexFile = getIndexFile(shuffleId, mapId)

  8.     val indexTmp = Utils.tempFileWith(indexFile)

  9.     try {

  10.         val dataFile = getDataFile(shuffleId, mapId)

  11.     
     

  12.         synchronized {

  13.             val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)

  14.             if (existingLengths != null) {

  15.                 System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)

  16.                 if (dataTmp != null && dataTmp.exists()) {

  17.                     dataTmp.delete()

  18.                 }

  19.             } else {

  20.                 val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))

  21.                 Utils.tryWithSafeFinally {

  22.                     var offset = 0L

  23.                     out.writeLong(offset)

  24.                     for (length <- lengths) {

  25.                         offset += length

  26.                         out.writeLong(offset)

  27.                     }

  28.                 } {

  29.                     out.close()

  30.                 }

  31.         
     

  32.                 if (indexFile.exists()) {

  33.                     indexFile.delete()

  34.                 }

  35.                 if (dataFile.exists()) {

  36.                     dataFile.delete()

  37.                 }

  38.                 if (!indexTmp.renameTo(indexFile)) {

  39.                     throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)

  40.                 }

  41.                 if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {

  42.                     throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)

  43.                 }

  44.             }

  45.         }

  46.     } finally {

  47.         … …

  48.     }

  49. }

DAGScheduler.scala

  1. private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {

  2.     … …

  3.     val tasks: Seq[Task[_]] = try {

  4.         val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()

  5.         
     

  6.         stage match {

  7.         case stage: ShuffleMapStage =>        

  8.         … …

  9.         case stage: ResultStage =>

  10.             partitionsToCompute.map { id =>

  11.             val p: Int = stage.partitions(id)

  12.             val part = partitions(p)

  13.             val locs = taskIdToLocations(id)

  14.             new ResultTask(stage.id, stage.latestInfo.attemptNumber,

  15.                 taskBinary, part, locs, id, properties, serializedTaskMetrics,

  16.                 Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,

  17.                 stage.rdd.isBarrier())

  18.             }

  19.         }

  20.     } catch {

  21.         … …

  22.     }

  23. }

ResultTask.scala

  1. private[spark] class ResultTask[T, U](… …)

  2.   extends Task[U](… …)

  3.   with Serializable {

  4.   
     

  5.     override def runTask(context: TaskContext): U = {

  6.         func(context, rdd.iterator(partition, context))

  7.     }  

  8. }

RDD.scala

  1. final def iterator(split: Partition, context: TaskContext): Iterator[T] = {

  2.     if (storageLevel != StorageLevel.NONE) {

  3.         getOrCompute(split, context)

  4.     } else {

  5.         computeOrReadCheckpoint(split, context)

  6.     }

  7. }

  8. private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {

  9.     … …

  10.     computeOrReadCheckpoint(partition, context) 

  11.     … …

  12. }

  13. def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] ={

  14.     if (isCheckpointedAndMaterialized) {

  15.         firstParent[T].iterator(split, context)

  16.     } else {

  17.         compute(split, context)

  18.     }

  19. }

  20. def compute(split: Partition, context: TaskContext): Iterator[T]

全局查找compute,由于我们是ShuffledRDD,所以点击ShuffledRDD.scala,搜索compute

  1. override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {

  2.     val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]

  3.     val metrics = context.taskMetrics().createTempShuffleReadMetrics()

  4.     SparkEnv.get.shuffleManager.getReader(

  5.     dep.shuffleHandle, split.index, split.index + 1, context, metrics)

  6.     .read()

  7.     .asInstanceOf[Iterator[(K, C)]]

  8. }

ShuffleManager.scala文件

def getReader[K, C](… …): ShuffleReader[K, C]

查找(ctrl + h)getReader 的实现类,SortShuffleManager.scala

  1. override def getReader[K, C](… …): ShuffleReader[K, C] = {

  2.     val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(

  3.         handle.shuffleId, startPartition, endPartition)

  4.     new BlockStoreShuffleReader(

  5.         handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,

  6.         shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))

  7. }

在BlockStoreShuffleReader.scala文件中查找read方法

  1. override def read(): Iterator[Product2[K, C]] = {

  2.     val wrappedStreams = new ShuffleBlockFetcherIterator(

  3.         … …

  4.         // 读缓冲区大小 默认 48m

  5.         SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,

  6.         SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),

  7.         … …

  8. }

Spark内存管理

概念

Spark支持堆内内存也支持堆外内存

1)堆内内存:程序在运行时动态地申请某个大小的内存空间

2)堆外内存:直接向操作系统进行申请的内存,不受JVM控制

堆内内存和对外内存优缺点

1)堆外内存,相比于堆内内存有几个优势:

(1)减少了垃圾回收的工作,因为垃圾回收会暂停其他的工作

(2)加快了复制的速度。因为堆内在Flush到远程时,会先序列化,然后在发送;而堆外内存本身是序列化的相当于省略掉了这个工作。

说明:堆外内存是序列化的,其占用的内存大小可直接计算。堆内内存是非序列化的对象,其占用的内存是通过周期性地采样近似估算而得,即并不是每次新增的数据项都会计算一次占用的内存大小,这种方法降低了时间开销但是有可能误差较大,导致某一时刻的实际内存有可能远远超出预期。此外,在被Spark标记为释放的对象实例,很有可能在实际上并没有被JVM回收,导致实际可用的内存小于Spark记录的可用内存。所以 Spark并不能准确记录实际可用的堆内内存,从而也就无法完全避免内存溢出OOM的异常。

2)堆外内存,相比于堆内内存有几个缺点:

(1)堆外内存难以控制,如果内存泄漏,那么很难排查

(2)堆外内存相对来说,不适合存储很复杂的对象。一般简单的对象或者扁平化的比较适合。

如何配置

1)堆内内存大小设置:–executor-memory 或 spark.executor.memory

2)在默认情况下堆外内存并不启用,spark.memory.offHeap.enabled 参数启用,并由 spark.memory.offHeap.size 参数设定堆外空间的大小。

官网配置地址:http://spark.apache.org/docs/3.0.0/configuration.html

堆内内存包括:存储(Storage)内存执行(Execution)内存、**其他内存
**

静态内存管理

Spark最初采用的静态内存管理机制下,存储内存、执行内存和其他内存的大小在Spark应用程序运行期间均为固定的,但用户可以应用程序启动前进行配置,堆内内存的分配如图所示:

可以看到,可用的堆内内存的大小需要按照下列方式计算:

可用的存储内存 = systemMaxMemory * spark.storage.memoryFraction * spark.storage.safety Fraction

可用的执行内存 = systemMaxMemory * spark.shuffle.memoryFraction * spark.shuffle.safety Fraction

其中systemMaxMemory取决于当前JVM堆内内存的大小,最后可用的执行内存或者存储内存要在此基础上与各自的memoryFraction 参数和safetyFraction 参数相乘得出。上述计算公式中的两个 safetyFraction 参数,其意义在于在逻辑上预留出 1-safetyFraction 这么一块保险区域,降低因实际内存超出当前预设范围而导致 OOM 的风险(上文提到,对于非序列化对象的内存采样估算会产生误差)。值得注意的是,这个预留的保险区域仅仅是一种逻辑上的规划,在具体使用时 Spark 并没有区别对待,和"其它内存"一样交给了 JVM 去管理。

Storage内存和Execution内存都有预留空间,目的是防止OOM,因为Spark堆内内存大小的记录是不准确的,需要留出保险区域。

堆外的空间分配较为简单,只有存储内存和执行内存,如下图所示。可用的执行内存和存储内存占用的空间大小直接由参数spark.memory.storageFraction 决定,由于堆外内存占用的空间可以被精确计算,所以无需再设定保险区域。

静态内存管理机制实现起来较为简单,但如果用户不熟悉Spark的存储机制,或没有根据具体的数据规模和计算任务或做相应的配置,很容易造成"一半海水,一半火焰"的局面,即存储内存和执行内存中的一方剩余大量的空间,而另一方却早早被占满,不得不淘汰或移出旧的内容以存储新的内容。由于新的内存管理机制的出现,这种方式目前已经很少有开发者使用,出于兼容旧版本的应用程序的目的,Spark 仍然保留了它的实现。

统一内存管理

Spark1.6 之后引入的统一内存管理机制,与静态内存管理的区别在于存储内存和执行内存共享同一块空间,可以动态占用对方的空闲区域,统一内存管理的堆内内存结构如图所示:

统一内存管理的堆外内存结构如下图所示:

其中最重要的优化在于动态占用机制,其规则如下:

  1. 设定基本的存储内存和执行内存区域(spark.storage.storageFraction参数),该设定确定了双方各自拥有的空间的范围;

  2. 双方的空间都不足时,则存储到硬盘;若己方空间不足而对方空余时,可借用对方的空间;(存储空间不足是指不足以放下一个完整的Block)

  3. 执行内存的空间被对方占用后,可让对方将占用的部分转存到硬盘,然后"归还"借用的空间;

  4. 存储内存的空间被对方占用后,无法让对方"归还",因为需要考虑 Shuffle过程中的很多因素,实现起来较为复杂。

统一内存管理的动态占用机制如图所示:

凭借统一内存管理机制,Spark在一定程度上提高了堆内和堆外内存资源的利用率,降低了开发者维护Spark内存的难度,但并不意味着开发者可以高枕无忧。如果存储内存的空间太大或者说缓存的数据过多,反而会导致频繁的全量垃圾回收,降低任务执行时的性能,因为缓存的RDD数据通常都是长期驻留内存的。所以要想充分发挥Spark的性能,需要开发者进一步了解存储内存和执行内存各自的管理方式和实现原理。

4.2.3 内存空间分配

全局查找(ctrl + n)SparkEnv,并找到create方法

SparkEnv.scala

  1. private def create(

  2.     conf: SparkConf,

  3.     executorId: String,

  4.     bindAddress: String,

  5.     advertiseAddress: String,

  6.     port: Option[Int],

  7.     isLocal: Boolean,

  8.     numUsableCores: Int,

  9.     ioEncryptionKey: Option[Array[Byte]],

  10.     listenerBus: LiveListenerBus = null,

  11.     mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {

  12.     
     

  13.     … …

  14.     val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores)

  15.     … …

  16. }

UnifiedMemoryManager.scala

  1. def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = {

  2.     // 获取最大的可用内存为总内存的__0.6

  3.     val maxMemory = getMaxMemory(conf)

  4.     // 最大可用内存的__0.5 MEMORY_STORAGE_FRACTION=0.5

  5.     new UnifiedMemoryManager(

  6.         conf,

  7.         maxHeapMemory = maxMemory,

  8.         onHeapStorageRegionSize =

  9.         (maxMemory * conf.get(config.MEMORY_STORAGE_FRACTION)).toLong,

  10.         numCores = numCores)

  11. }

  12. private def getMaxMemory(conf: SparkConf): Long = {

  13.     // 获取系统内存

  14.     val systemMemory = conf.get(TEST_MEMORY)

  15.     
     

  16.     // 获取系统预留内存,默认300m(RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024)

  17.     val reservedMemory = conf.getLong(TEST_RESERVED_MEMORY.key,

  18.         if (conf.contains(IS_TESTING)) 0 else RESERVED_SYSTEM_MEMORY_BYTES)

  19.     val minSystemMemory = (reservedMemory * 1.5).ceil.toLong

  20.     
     

  21.     if (systemMemory < minSystemMemory) {

  22.         throw new IllegalArgumentException(s"System memory $systemMemory must " +

  23.         s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " +

  24.         s"option or ${config.DRIVER_MEMORY.key} in Spark configuration.")

  25.     }

  26.     if (conf.contains(config.EXECUTOR_MEMORY)) {

  27.         val executorMemory = conf.getSizeAsBytes(config.EXECUTOR_MEMORY.key)

  28.         if (executorMemory < minSystemMemory) {

  29.         throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " +

  30.             s"$minSystemMemory. Please increase executor memory using the " +

  31.             s"--executor-memory option or ${config.EXECUTOR_MEMORY.key} in Spark configuration.")

  32.         }

  33.     }

  34.     val usableMemory = systemMemory - reservedMemory

  35.     val memoryFraction = conf.get(config.MEMORY_FRACTION)

  36.     (usableMemory * memoryFraction).toLong

  37. }

config\package.scala

  1. private[spark] val MEMORY_FRACTION = ConfigBuilder("spark.memory.fraction")

  2.     … …

  3.     .createWithDefault(0.6)

点击UnifiedMemoryManager.apply()中的UnifiedMemoryManager

  1. private[spark] class UnifiedMemoryManager(

  2.     conf: SparkConf,

  3.     val maxHeapMemory: Long,

  4.     onHeapStorageRegionSize: Long,

  5.     numCores: Int)

  6.   extends MemoryManager(

  7.     conf,

  8.     numCores,

  9.     onHeapStorageRegionSize,

  10.     maxHeapMemory - onHeapStorageRegionSize) {// 执行内存__0.6 -0.3 = 0.3

  11.     
     

  12. }

点击MemoryManager

MemoryManager.scala

  1. private[spark] abstract class MemoryManager(

  2.     conf: SparkConf,

  3.     numCores: Int,

  4.     onHeapStorageMemory: Long,

  5.     onHeapExecutionMemory: Long) extends Logging {// 执行内存__0.6 -0.3 = 0.3

  6.     … …

  7.     // 堆内存储内存

  8.     protected val onHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.ON_HEAP)

  9.     // 堆外存储内存

  10.     protected val offHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.OFF_HEAP)

  11.     // 堆内执行内存

  12.     protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.ON_HEAP)

  13.     // 堆外执行内存

  14.     protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.OFF_HEAP)

  15.     protected[this] val maxOffHeapMemory = conf.get(MEMORY_OFFHEAP_SIZE)

  16.     // 堆外内存__MEMORY_STORAGE_FRACTION = 0.5

  17.     protected[this] val offHeapStorageMemory =

  18.         (maxOffHeapMemory * conf.get(MEMORY_STORAGE_FRACTION)).toLong

  19.     … …

  20. }

RDD的持久化机制

弹性分布式数据集(RDD)作为 Spark 最根本的数据抽象,是只读的分区记录(Partition)的集合,只能基于在稳定物理存储中的数据集上创建,或者在其他已有的RDD上执行转换(Transformation)操作产生一个新的RDD。转换后的RDD与原始的RDD之间产生的依赖关系,构成了血统(Lineage)。凭借血统,Spark 保证了每一个RDD都可以被重新恢复。但RDD的所有转换都是惰性的,即只有当一个返回结果给Driver的行动(Action)发生时,Spark才会创建任务读取RDD,然后真正触发转换的执行。

Task在启动之初读取一个分区时,会先判断这个分区是否已经被持久化,如果没有则需要检查Checkpoint 或按照血统重新计算。所以如果一个 RDD 上要执行多次行动,可以在第一次行动中使用 persist或cache 方法,在内存或磁盘中持久化或缓存这个RDD,从而在后面的行动时提升计算速度。

事实上,cache 方法是使用默认的 MEMORY_ONLY 的存储级别将 RDD 持久化到内存,故缓存是一种特殊的持久化。 堆内和堆外存储内存的设计,便可以对缓存RDD时使用的内存做统一的规划和管理。

RDD的持久化由 Spark的Storage模块负责,实现了RDD与物理存储的解耦合。Storage模块负责管理Spark在计算过程中产生的数据,将那些在内存或磁盘、在本地或远程存取数据的功能封装了起来。在具体实现时Driver端和 Executor 端的Storage模块构成了主从式的架构,即Driver端的BlockManager为Master,Executor端的BlockManager 为 Slave。

Storage模块在逻辑上以Block为基本存储单位,RDD的每个Partition经过处理后唯一对应一个 Block(BlockId 的格式为rdd_RDD-ID_PARTITION-ID )。Driver端的Master负责整个Spark应用程序的Block的元数据信息的管理和维护,而Executor端的Slave需要将Block的更新等状态上报到Master,同时接收Master 的命令,例如新增或删除一个RDD。

在对RDD持久化时,Spark规定了MEMORY_ONLY、MEMORY_AND_DISK 等7种不同的存储级别,而存储级别是以下5个变量的组合:

  1. class StorageLevel private(

  2. private var _useDisk: Boolean, //_磁盘_

  3. private var _useMemory: Boolean, //_这里其实是指堆内内存_

  4. private var _useOffHeap: Boolean, //_堆外内存_

  5. private var _deserialized: Boolean, //_是否为非序列化_

  6. private var _replication: Int = 1 //_副本个数_

  7. )

Spark中7种存储级别如下:

持久化级别

含义

MEMORY_ONLY

以非序列化的Java对象的方式持久化在JVM内存中。如果内存无法完全存储RDD所有的partition,那么那些没有持久化的partition就会在下一次需要使用它们的时候,重新被计算

MEMORY_AND_DISK

同上,但是当某些partition无法存储在内存中时,会持久化到磁盘中。下次需要使用这些partition时,需要从磁盘上读取

MEMORY_ONLY_SER

同MEMORY_ONLY,但是会使用Java序列化方式,将Java对象序列化后进行持久化。可以减少内存开销,但是需要进行反序列化,因此会加大CPU开销

MEMORY_AND_DISK_SER

同MEMORY_AND_DISK,但是使用序列化方式持久化Java对象

DISK_ONLY

使用非序列化Java对象的方式持久化,完全存储到磁盘上

MEMORY_ONLY_2

MEMORY_AND_DISK_2

等等

如果是尾部加了2的持久化级别,表示将持久化数据复用一份,保存到其他节点,从而在数据丢失时,不需要再次计算,只需要使用备份数据即可

通过对数据结构的分析,可以看出存储级别从三个维度定义了RDD的 Partition(同时也就是Block)的存储方式:

存储位置:磁盘/堆内内存/堆外内存。如MEMORY_AND_DISK是同时在磁盘和堆内内存上存储,实现了冗余备份。OFF_HEAP 则是只在堆外内存存储,目前选择堆外内存时不能同时存储到其他位置。

存储形式:Block 缓存到存储内存后,是否为非序列化的形式。如 MEMORY_ONLY是非序列化方式存储,OFF_HEAP 是序列化方式存储。

副本数量:大于1时需要远程冗余备份到其他节点。如DISK_ONLY_2需要远程备份1个副本。

  1. **RDD的缓存过程
    **

RDD 在缓存到存储内存之前,Partition中的数据一般以迭代器(Iterator)的数据结构来访问,这是Scala语言中一种遍历数据集合的方法。通过Iterator可以获取分区中每一条序列化或者非序列化的数据项(Record),这些Record的对象实例在逻辑上占用了JVM堆内内存的other部分的空间,同一Partition的不同 Record 的存储空间并不连续。

RDD 在缓存到存储内存之后,Partition 被转换成Block,Record在堆内或堆外存储内存中占用一块连续的空间。将Partition由不连续的存储空间转换为连续存储空间的过程,Spark称之为"展开"(Unroll)。

Block 有序列化和非序列化两种存储格式,具体以哪种方式取决于该 RDD 的存储级别。非序列化的Block以一种 DeserializedMemoryEntry 的数据结构定义,用一个数组存储所有的对象实例,序列化的Block则以SerializedMemoryEntry的数据结构定义,用字节缓冲区(ByteBuffer)来存储二进制数据。每个 Executor 的 Storage模块用一个链式Map结构(LinkedHashMap)来管理堆内和堆外存储内存中所有的Block对象的实例,对这个LinkedHashMap新增和删除间接记录了内存的申请和释放。

因为不能保证存储空间可以一次容纳 Iterator 中的所有数据,当前的计算任务在 Unroll 时要向 MemoryManager 申请足够的Unroll空间来临时占位,空间不足则Unroll失败,空间足够时可以继续进行。

对于序列化的Partition,其所需的Unroll空间可以直接累加计算,一次申请。

对于非序列化的 Partition 则要在遍历 Record 的过程中依次申请,即每读取一条 Record,采样估算其所需的Unroll空间并进行申请,空间不足时可以中断,释放已占用的Unroll空间。

如果最终Unroll成功,当前Partition所占用的Unroll空间被转换为正常的缓存 RDD的存储空间,如下图所示。

在静态内存管理时,Spark 在存储内存中专门划分了一块 Unroll 空间,其大小是固定的,统一内存管理时则没有对 Unroll 空间进行特别区分,当存储空间不足时会根据动态占用机制进行处理。

淘汰与落盘

由于同一个Executor的所有的计算任务共享有限的存储内存空间,当有新的 Block 需要缓存但是剩余空间不足且无法动态占用时,就要对LinkedHashMap中的旧Block进行淘汰(Eviction),而被淘汰的Block如果其存储级别中同时包含存储到磁盘的要求,则要对其进行落盘(Drop),否则直接删除该Block。

存储内存的淘汰规则为:

被淘汰的旧Block要与新Block的MemoryMode相同,即同属于堆外或堆内内存;

新旧Block不能属于同一个RDD,避免循环淘汰;

旧Block所属RDD不能处于被读状态,避免引发一致性问题;

遍历LinkedHashMap中Block,按照最近最少使用(LRU)的顺序淘汰,直到满足新Block所需的空间。其中LRU是LinkedHashMap的特性。

落盘的流程则比较简单,如果其存储级别符合_useDisk为true的条件,再根据其_deserialized判断是否是非序列化的形式,若是则对其进行序列化,最后将数据存储到磁盘,在Storage模块中更新其信息。

执行内存主要用来存储任务在执行Shuffle时占用的内存,Shuffle是按照一定规则对RDD数据重新分区的过程,我们来看Shuffle的Write和Read两阶段对执行内存的使用:

**1)Shuffle Write
**

若在map端选择普通的排序方式,会采用ExternalSorter进行外排,在内存中存储数据时主要占用堆内执行空间。

若在map端选择 Tungsten 的排序方式,则采用ShuffleExternalSorter直接对以序列化形式存储的数据排序,在内存中存储数据时可以占用堆外或堆内执行空间,取决于用户是否开启了堆外内存以及堆外执行内存是否足够。

**2)Shuffle Read
**

在对reduce端的数据进行聚合时,要将数据交给Aggregator处理,在内存中存储数据时占用堆内执行空间。

如果需要进行最终结果排序,则要将再次将数据交给ExternalSorter 处理,占用堆内执行空间。

在ExternalSorter和Aggregator中,Spark会使用一种叫AppendOnlyMap的哈希表在堆内执行内存中存储数据,但在 Shuffle 过程中所有数据并不能都保存到该哈希表中,当这个哈希表占用的内存会进行周期性地采样估算,当其大到一定程度,无法再从MemoryManager 申请到新的执行内存时,Spark就会将其全部内容存储到磁盘文件中,这个过程被称为溢存(Spill),溢存到磁盘的文件最后会被归并(Merge)。

Shuffle Write 阶段中用到的Tungsten是Databricks公司提出的对Spark优化内存和CPU使用的计划(钨丝计划),解决了一些JVM在性能上的限制和弊端。Spark会根据Shuffle的情况来自动选择是否采用Tungsten排序。

Tungsten 采用的页式内存管理机制建立在MemoryManager之上,即 Tungsten 对执行内存的使用进行了一步的抽象,这样在 Shuffle 过程中无需关心数据具体存储在堆内还是堆外。

每个内存页用一个MemoryBlock来定义,并用 Object obj 和 long offset 这两个变量统一标识一个内存页在系统内存中的地址。

堆内的MemoryBlock是以long型数组的形式分配的内存,其obj的值为是这个数组的对象引用,offset是long型数组的在JVM中的初始偏移地址,两者配合使用可以定位这个数组在堆内的绝对地址;堆外的 MemoryBlock是直接申请到的内存块,其obj为null,offset是这个内存块在系统内存中的64位绝对地址。Spark用MemoryBlock巧妙地将堆内和堆外内存页统一抽象封装,并用页表(pageTable)管理每个Task申请到的内存页。

Tungsten 页式管理下的所有内存用64位的逻辑地址表示,由页号和页内偏移量组成:

页号:占13位,唯一标识一个内存页,Spark在申请内存页之前要先申请空闲页号。

页内偏移量:占51位,是在使用内存页存储数据时,数据在页内的偏移地址。

有了统一的寻址方式,Spark 可以用64位逻辑地址的指针定位到堆内或堆外的内存,整个Shuffle Write排序的过程只需要对指针进行排序,并且无需反序列化,整个过程非常高效,对于内存访问效率和CPU使用效率带来了明显的提升。

Spark的存储内存和执行内存有着截然不同的管理方式:对于存储内存来说,Spark用一个LinkedHashMap来集中管理所有的Block,Block由需要缓存的 RDD的Partition转化而成;而对于执行内存,Spark用AppendOnlyMap来存储 Shuffle过程中的数据,在Tungsten排序中甚至抽象成为页式内存管理,开辟了全新的JVM内存管理机制。