spark源码分析, 任务反序列化及执行
阅读原文时间:2023年07月10日阅读:1

1 ==> 接受消息,org.apache.spark.executor.CoarseGrainedExecutorBackend#receive

case LaunchTask(data) =>  
  if (executor == null) {  
    exitExecutor(1, "Received LaunchTask command but executor was null")  
  } else {  
    val taskDesc = TaskDescription.decode(data.value)  
    logInfo("Got assigned task " + taskDesc.taskId)  
    executor.launchTask(this, taskDesc)  
  }

2. ==> org.apache.spark.executor.Executor#launchTask

// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
}

3. ==>org.apache.spark.executor.Executor.TaskRunner#run

override def run(): Unit = {
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)

//下载依赖
updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
//反序列化得到真正的 task
task = ser.deserialize[Task[Any]](taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)

val value = Utils.tryWithSafeFinally {  
          val res = task.run(  
            taskAttemptId = taskId,  
            attemptNumber = taskDescription.attemptNumber,  
            metricsSystem = env.metricsSystem)  
          threwException = false  
          res  
        } {  
          val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)  
          val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()  
        }  
//处理执行结果  
val resultSer = env.serializer.newInstance()  
val beforeSerialization = System.currentTimeMillis()  
val valueBytes = resultSer.serialize(value)  
val afterSerialization = System.currentTimeMillis()       

// Note: accumulator updates must be collected after TaskMetrics is updated  
val accumUpdates = task.collectAccumulatorUpdates()  
// TODO: do not serialize value twice  
val directResult = new DirectTaskResult(valueBytes, accumUpdates)  
val serializedDirectResult = ser.serialize(directResult)  
val resultSize = serializedDirectResult.limit()

// directSend = sending directly back to the driver  
val serializedResult: ByteBuffer = {  
  if (maxResultSize > 0 && resultSize > maxResultSize) {  
    logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +  
      s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +  
      s"dropping it.")  
    ser.serialize(new IndirectTaskResult\[Any\](TaskResultBlockId(taskId), resultSize))  
  } else if (resultSize > maxDirectResultSize) {  
    val blockId = TaskResultBlockId(taskId)  
    env.blockManager.putBytes(  
      blockId,  
      new ChunkedByteBuffer(serializedDirectResult.duplicate()),  
      StorageLevel.MEMORY\_AND\_DISK\_SER)  
    logInfo(  
      s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")  
    ser.serialize(new IndirectTaskResult\[Any\](blockId, resultSize))  
  } else {  
    logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")  
    serializedDirectResult  
  }  
}

setTaskFinishedAndClearInterruptStatus()  
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)  

}

==> org.apache.spark.executor.Executor#updateDependencies

/**
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars) {
val localName = new URI(name).getPath.split("/").last
val currentTimeStamp = currentJars.get(name)
.orElse(currentJars.get(localName))
.getOrElse(-1L)
if (currentTimeStamp < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
if (!urlClassLoader.getURLs().contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
}
}
}
}

==> org.apache.spark.scheduler.Task#run

final def run(
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)

val taskContext = new TaskContextImpl(  
  stageId,  
  stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal  
  partitionId,  
  taskAttemptId,  
  attemptNumber,  
  taskMemoryManager,  
  localProperties,  
  metricsSystem,  
  metrics)

context = if (isBarrier) {  
  new BarrierTaskContext(taskContext)  
} else {  
  taskContext  
}

TaskContext.setTaskContext(context)  
taskThread = Thread.currentThread()

if (\_reasonIfKilled != null) {  
  kill(interruptThread = false, \_reasonIfKilled)  
}

new CallerContext(  
  "TASK",  
  SparkEnv.get.conf.get(APP\_CALLER\_CONTEXT),  
  appId,  
  appAttemptId,  
  jobId,  
  Option(stageId),  
  Option(stageAttemptId),  
  Option(taskAttemptId),  
  Option(attemptNumber)).setCurrentContext()

try {  
//这个类只是一个模板类或者抽象类, 具体实现类分为ResultTask, ShuffleMapTask 两种  
  runTask(context)  
}  

}

==>org.apache.spark.scheduler.ShuffleMapTask#runTask

ShuffleMapTask将rdd的元素,切分为多个bucket, 基于ShuffleDependency指定的partitioner,默认就是HashPartitioner

ShuffleMapTask 核心方法是 RDD.iterator[底层调用 compute 方法(fn(context,index,partition))],

执行完成rdd之后,rdd或返回处理过后的partition数据,这些数据通过shuffleWriter在经过HashPartitioner写入对应的分区中

// ShuffleMapTask将rdd的元素,切分为多个bucket
// 基于ShuffleDependency指定的partitioner,默认就是HashPartitioner
private[spark] class ShuffleMapTask(

// ShuffleMapTask的 runTask 有 MapStatus返回值
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L

// 对task要处理的数据,做反序列化操作

val ser = SparkEnv.get.closureSerializer.newInstance()  
//获得 RDD  
val (rdd, dep) = ser.deserialize\[(RDD\[\_\], ShuffleDependency\[\_, \_, \_\])\](  
  ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)  
\_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime  
\_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {  
  threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime  
} else 0L

var writer: ShuffleWriter\[Any, Any\] = null  
try {  
  // 拿到shuffleManager  
  val manager = SparkEnv.get.shuffleManager  
  // 拿到shuffleWriter  
  writer = manager.getWriter\[Any, Any\](dep.shuffleHandle, partitionId, context)

  // 核心逻辑,调用rdd的iterator方法,并且传入了当前要处理的partition  
  // 执行完成rdd之后,rdd或返回处理过后的partition数据,这些数据通过shuffleWriter  
  // 在经过HashPartitioner写入对应的分区中

  writer.write(rdd.iterator(partition, context).asInstanceOf\[Iterator\[\_ <: Product2\[Any, Any\]\]\])

  // 返回结果 MapStatus ,里面封装了ShuffleMapTask存储在哪里,其实就是BlockManager相关信息  
  writer.stop(success = true).get  
}  

}

}

==> org.apache.spark.scheduler.ResultTask#runTask

override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L

//直接调用用户自定义函数
func(context, rdd.iterator(partition, context))
}

==> org.apache.spark.rdd.RDD#iterator

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
   //结果不需要存储
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}

==> org.apache.spark.rdd.RDD#computeOrReadCheckpoint

/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointedAndMaterialized) {
firstParent[T].iterator(split, context)
} else {
//核心方法, 此方法为虚方法,具体实现由具体 RDD 子类实现,如 MapPartitionsRDD,JdbcRDD等
compute(split, context)
}
}

demo:

class MapPartitionsRDD[U: ClassTag, T: ClassTag](
var prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false,
isFromBarrier: Boolean = false,
isOrderSensitive: Boolean = false)
extends RDD[U](prev) {

override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))

}

class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging {
override def getPartitions: Array[Partition] = {
// bounds are inclusive, hence the + 1 here and - 1 on end
val length = BigInt(1) + upperBound - lowerBound
(0 until numPartitions).map { i =>
val start = lowerBound + ((i * length) / numPartitions)
val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
new JdbcPartition(i, start.toLong, end.toLong)
}.toArray
}

override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
{
context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

val url = conn.getMetaData.getURL

val rs = stmt.executeQuery()

override def getNext(): T = {  
  if (rs.next()) {  
    mapRow(rs)  
  } else {  
    finished = true  
    null.asInstanceOf\[T\]  
  }  
}

override def close() {

}  

}
}