DAGScheduler源码分析(stage划分算法)

2019-02-19  本文已影响0人  有一束阳光叫温暖

DAGScheduler的stage划分算法总结:会从触发action操作的那个rdd开始反向解析,首先会为最后一个rdd创建一个stage,反向解析的时候,遇到窄依赖就把当前的rdd加入到Stage,遇到宽依赖就断开,将宽依赖的那个rdd创建一个新的stage,那个rdd就是这个stage最后一个rdd。依此类推,遍历所有RDD为止。

stage划分源码分析

// DAGScheduler的job调度的核心入口
  // stage划分总结
  // 1、从finalStage推断
  // 2、通过宽依赖,来进行新的stage的划分
  // 3、使用递归,优先提交父stage
  private[scheduler] def handleJobSubmitted(jobId: Int,
      finalRDD: RDD[_],
      func: (TaskContext, Iterator[_]) => _,
      partitions: Array[Int],
      allowLocal: Boolean,
      callSite: CallSite,
      listener: JobListener,
      properties: Properties = null)
  {
    // 使用触发job的最后一个rdd,创建finalStage
    var finalStage: Stage = null
    try {
      // New stage creation may throw an exception if, for example, jobs are run on a
      // HadoopRDD whose underlying HDFS files have been deleted.
      // 创建一个stage对象
      // 并且将stage加入DAGScheduler内部的内存缓存中
      finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite)
    } catch {
      case e: Exception =>
        logWarning("Creating new stage failed due to exception - job: " + jobId, e)
        listener.jobFailed(e)
        return
    }
    // 用finalStage,创建一个job
    // 也就是说,这个job最后一个stage,就是finalStage
    if (finalStage != null) {
      val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
      clearCacheLocs()
      logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format(
        job.jobId, callSite.shortForm, partitions.length, allowLocal))
      logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
      logInfo("Parents of final stage: " + finalStage.parents)
      logInfo("Missing parents: " + getMissingParentStages(finalStage))
      val shouldRunLocally =
        localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
      val jobSubmissionTime = clock.getTimeMillis()
      if (shouldRunLocally) {
        // Compute very short actions like first() or take() with no parent stages locally.
        listenerBus.post(
          SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
        runLocally(job)
      } else {
        // 将job加入内存缓存中
        jobIdToActiveJob(jobId) = job
        activeJobs += job
        finalStage.resultOfJob = Some(job)
        val stageIds = jobIdToStageIds(jobId).toArray
        val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
        listenerBus.post(
          SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
        // 使用submitStage提交finalStage
        // 这个方法会提交第一个stage
        submitStage(finalStage)
      }
    }
    // 提交等待的stage
    submitWaitingStages()
  }

private def submitStage(stage: Stage) {
    val jobId = activeJobForStage(stage)
    if (jobId.isDefined) {
      logDebug("submitStage(" + stage + ")")
      if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
        // 使用getMissingParentStages(stage)去获取当前stage的父stage
        val missing = getMissingParentStages(stage).sortBy(_.id)
        logDebug("missing: " + missing)
        // 反复递归调用,直到最初的stage,没有父stage.
        // 此时就会首先提交这个第一stage,stage0,其余都在waitingStage中
        if (missing == Nil) {
          logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
          submitMissingTasks(stage, jobId.get)
        } else {
          // 递归调用submit()方法,去提交父stage
          for (parent <- missing) {
            submitStage(parent)
          }
          // 并且将当前stage,放入waitingStage等待执行的stage的队列中
          waitingStages += stage
        }
      }
    } else {
      abortStage(stage, "No active job for stage " + stage.id)
    }
  }
/**
    * 获取父stage
    * 这个方法的意思是,对于一个stage
    * 如果它的最后一个rdd的所有依赖,都是窄依赖,不会创建任何新的stage
    * 但是如果只要发现这个stage的rdd宽依赖了某个rdd,那么就用宽的那个rdd,创建一个新的stage
    * 然后立即将新的stage返回
    * @param stage
    * @return
    */
  private def getMissingParentStages(stage: Stage): List[Stage] = {
    val missing = new HashSet[Stage]
    val visited = new HashSet[RDD[_]]
    // We are manually maintaining a stack here to prevent StackOverflowError
    // caused by recursively visiting
    val waitingForVisit = new Stack[RDD[_]]
    def visit(rdd: RDD[_]) {
      if (!visited(rdd)) {
        visited += rdd
        if (getCacheLocs(rdd).contains(Nil)) {
          // 遍历rdd依赖
          for (dep <- rdd.dependencies) {
            dep match {
                // 如何是宽依赖,那么使用宽依赖那个rdd,创建一个stage,并且会将ishuffleMap设置为true
              case shufDep: ShuffleDependency[_, _, _] =>
                val mapStage = getShuffleMapStage(shufDep, stage.jobId)
                if (!mapStage.isAvailable) {
                  missing += mapStage
                }
                // 如果是窄依赖,那么将依赖的rdd放入栈中
              case narrowDep: NarrowDependency[_] =>
                waitingForVisit.push(narrowDep.rdd)
            }
          }
        }
      }
    }
    // 首先往栈里面,推入了stage最后一个rdd
    waitingForVisit.push(stage.rdd)
    while (!waitingForVisit.isEmpty) {
      // 对stage的最后一个rdd,调用自己内部定义的visit
      visit(waitingForVisit.pop())
    }
    missing.toList
  }

源码分析总结:

1、从finalStage推断
2、通过宽依赖,来进行新的stage的划分
3、使用递归,优先提交父stage

上一篇下一篇

猜你喜欢

热点阅读