Spark

Spark源码分析之SparkSubmit的流程

2019-06-15  本文已影响18人  叫我不矜持

准备

本文主要对SparkSubmit的任务提交流程源码进行分析。Spark源码版本为2.3.1。

首先阅读一下启动脚本,看看首先加载的是哪个类,我们看一下spark-submit启动脚本中的具体内容。

spark-submit的shell脚本

可以看到这里加载的类是org.apache.spark.deploy.SparkSubmit,并且把启动相关的参数也带过去了。下面我们跟一下源码看看整个流程是如何运作的...

流程分析

SparkSubmit的main方法如下

//提交任务主类运行
  override def main(args: Array[String]): Unit = {
    // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to
    // be reset before the application starts.
    val uninitLog = initializeLogIfNecessary(true, silent = true)
    //设置参数
    val appArgs = new SparkSubmitArguments(args)
    if (appArgs.verbose) {
      // scalastyle:off println
      printStream.println(appArgs)
      // scalastyle:on println
    }
    appArgs.action match {
      //任务提交匹配 submit
      case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog)
      case SparkSubmitAction.KILL => kill(appArgs)
      case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)
    }
  }

这里我们由于我们是提交作业,所有会走上面的submit(appArgs, uninitLog)方法

private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {
    //以下方法prepareSubmitEnvironment 返回四元组,重点注意childMainClass类 这里以standalone-cluster为例
    val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args)
   
    def doRunMain(): Unit = {... }

    ...
}

可以看到submit方法首先会准备任务提交的环境,调用了prepareSubmitEnvironment,该方法会返回四元组,该方法中会调用doPrepareSubmitEnvironment,这里我们重点注意childMainClass类具体是什么,因为这里涉及到后面启动我们主类的过程。

以下是doPrepareSubmitEnvironment方法的源码...

 private def doPrepareSubmitEnvironment(
      args: SparkSubmitArguments,
      conf: Option[HadoopConfiguration] = None)
      : (Seq[String], Seq[String], SparkConf, String) = {
    // Return values
    val childArgs = new ArrayBuffer[String]()
    val childClasspath = new ArrayBuffer[String]()
    val sparkConf = new SparkConf()
    var childMainClass = ""

   val clusterManager: Int = args.master match {
      case "yarn" => YARN
      case "yarn-client" | "yarn-cluster" =>
        printWarning(s"Master ${args.master} is deprecated since 2.0." +
          " Please use master \"yarn\" with specified deploy mode instead.")
        YARN
      case m if m.startsWith("spark") => STANDALONE
      case m if m.startsWith("mesos") => MESOS
      case m if m.startsWith("k8s") => KUBERNETES
      case m if m.startsWith("local") => LOCAL
      case _ =>
        printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local")
        -1
  }
    // Set the deploy mode; default is client mode
    var deployMode: Int = args.deployMode match {
      case "client" | null => CLIENT
      case "cluster" => CLUSTER
      case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1
    }
   ....
  if (deployMode == CLIENT) {
      childMainClass = args.mainClass
      if (localPrimaryResource != null && isUserJar(localPrimaryResource)) {
        childClasspath += localPrimaryResource
      }
      if (localJars != null) { childClasspath ++= localJars.split(",") }
    }
....
// In standalone cluster mode, use the REST client to submit the application (Spark 1.3+).
    // All Spark parameters are expected to be passed to the client through system properties.
    //standalone-cluster模式
    if (args.isStandaloneCluster) {
      //使用rest风格,这里rest提交是指使用json 格式和http 提交任务 ,spark1.3+支持
      if (args.useRest) {
        childMainClass = REST_CLUSTER_SUBMIT_CLASS
        childArgs += (args.primaryResource, args.mainClass)
      } else {
        //正常提交方式
        // In legacy standalone cluster mode, use Client as a wrapper around the user class
        childMainClass = STANDALONE_CLUSTER_SUBMIT_CLASS
        if (args.supervise) { childArgs += "--supervise" }
        Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) }
        Option(args.driverCores).foreach { c => childArgs += ("--cores", c) }
        childArgs += "launch"
        childArgs += (args.master, args.primaryResource, args.mainClass)
      }
      if (args.childArgs != null) {
        childArgs ++= args.childArgs
      }
    }
....
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
    if (isYarnCluster) {
      childMainClass = YARN_CLUSTER_SUBMIT_CLASS
      if (args.isPython) {
        childArgs += ("--primary-py-file", args.primaryResource)
        childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
      } else if (args.isR) {
        val mainFile = new Path(args.primaryResource).getName
        childArgs += ("--primary-r-file", mainFile)
        childArgs += ("--class", "org.apache.spark.deploy.RRunner")
      } else {
        if (args.primaryResource != SparkLauncher.NO_RESOURCE) {
          childArgs += ("--jar", args.primaryResource)
        }
        childArgs += ("--class", args.mainClass)
      }
      if (args.childArgs != null) {
        args.childArgs.foreach { arg => childArgs += ("--arg", arg) }
      }
    }

if (isMesosCluster) {
      assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API")
      childMainClass = REST_CLUSTER_SUBMIT_CLASS
      if (args.isPython) {
        // Second argument is main class
        childArgs += (args.primaryResource, "")
        if (args.pyFiles != null) {
          sparkConf.set("spark.submit.pyFiles", args.pyFiles)
        }
      } else if (args.isR) {
        // Second argument is main class
        childArgs += (args.primaryResource, "")
      } else {
        childArgs += (args.primaryResource, args.mainClass)
      }
      if (args.childArgs != null) {
        childArgs ++= args.childArgs
      }
    }

    if (isKubernetesCluster) {
      childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS
      if (args.primaryResource != SparkLauncher.NO_RESOURCE) {
        childArgs ++= Array("--primary-java-resource", args.primaryResource)
      }
      childArgs ++= Array("--main-class", args.mainClass)
      if (args.childArgs != null) {
        args.childArgs.foreach { arg =>
          childArgs += ("--arg", arg)
        }
      }
    }

  .....

  (childArgs, childClasspath, sparkConf, childMainClass)
}

可以看到该方法首先是解析相关的参数,如jar包,mainClass的全限定名,系统配置,校验一些参数,等等,之后的关键点就是根据我们deploy-mode参数来判断是如何运行我们的mainClass,这里主要是通过childMainClass这个参数来决定下一步首先启动哪个类。

childMainClass根据部署模型有不同的值:

之后该方法会把准备好的四元组返回,我们接着看之前的submit方法

private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {
    //以下方法prepareSubmitEnvironment 返回四元组,重点注意childMainClass类 这里以standalone-cluster为例
    val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args)
   
    def doRunMain(): Unit = {... }
    ...
    if (args.isStandaloneCluster && args.useRest) {
      try {
        // scalastyle:off println
        printStream.println("Running Spark using the REST application submission protocol.")
        // scalastyle:on println
        doRunMain()
      } catch {
        // Fail over to use the legacy submission gateway
        case e: SubmitRestConnectionException =>
          printWarning(s"Master endpoint ${args.master} was not a REST server. " +
            "Falling back to legacy submission gateway instead.")
          args.useRest = false
          submit(args, false)
      }
    // In all other modes, just run the main class as prepared
    } else {
      doRunMain()
    }
}

可以看到这里最终会调用doRunMain()方法去进行下一步。

doRunMain的实现如下...

def doRunMain(): Unit = {
      if (args.proxyUser != null) {
        val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser,
          UserGroupInformation.getCurrentUser())
        try {
          proxyUser.doAs(new PrivilegedExceptionAction[Unit]() {
            override def run(): Unit = {
              runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose)
            }
          })
        } catch {
          case e: Exception =>
            // Hadoop's AuthorizationException suppresses the exception's stack trace, which
            // makes the message printed to the output by the JVM not very helpful. Instead,
            // detect exceptions with empty stack traces here, and treat them differently.
            if (e.getStackTrace().length == 0) {
              // scalastyle:off println
              printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
              // scalastyle:on println
              exitFn(1)
            } else {
              throw e
            }
        }
      } else {
        //运行
        runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose)
      }
    }

doRunMain方法中会判断是否需要一个代理用户,然后无论需不需要都会执行runMain方法,我们接下来看看runMain方法是如何实现的。

private def runMain(
      childArgs: Seq[String],
      childClasspath: Seq[String],
      sparkConf: SparkConf,
      childMainClass: String,
      verbose: Boolean): Unit = {
....
var mainClass: Class[_] = null

    try {
      //加载类
      mainClass = Utils.classForName(childMainClass)
    } catch { .... }

   //将mainClass 映射成SparkApplication对象
    val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) {
      mainClass.newInstance().asInstanceOf[SparkApplication]
    } else {
      // SPARK-4170
      if (classOf[scala.App].isAssignableFrom(mainClass)) {
        printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.")
      }
      new JavaMainApplication(mainClass)
    }
   ....

   try {
      //调用start方法,这里调用的是ClientApp的start方法
      app.start(childArgs.toArray, sparkConf)
    } catch {
      case t: Throwable =>
        findCause(t) match {
          case SparkUserAppException(exitCode) =>
            System.exit(exitCode)

          case t: Throwable =>
            throw t
        }
    }

}

这里我们只假设以集群模式启动,首先会加载类,将我们的childMainClass加载为字节码对象mainClass ,然后将mainClass 映射成SparkApplication对象,因为我们以集群模式启动,那么上一步返回四元组中的childMainClass的参数为ClientApp的全限定名,而这里会调用app实例的start方法因此,这里最终调用的是ClientApp的start方法。

ClientApp的start方法如下...

override def start(args: Array[String], conf: SparkConf): Unit = {
    val driverArgs = new ClientArguments(args)

    if (!conf.contains("spark.rpc.askTimeout")) {
      conf.set("spark.rpc.askTimeout", "10s")
    }
    Logger.getRootLogger.setLevel(driverArgs.logLevel)

    //创建rpc通信环境
    val rpcEnv =
      RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
    //得到Master的通信邮箱
    val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL).
      map(rpcEnv.setupEndpointRef(_, Master.ENDPOINT_NAME))
    //在rpc中设置提交当前任务的Endpoint,只要设置肯定会运行 new ClientEndpoint 类的 start方法
    rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf))

    rpcEnv.awaitTermination()
  }

可以看到这里和之前我们的master启动流程有些相似。
可以参考我上一篇文章Spark源码分析之Master的启动流程对这一流程加深理解。

首先是准备rpcEnv环境,之后通过master的地址获取masterEndpoints端点相关信息,因为这里运行start方法时会将之前配置的相关参数都传进来,之后就会通过rpcEnv注册相关clientEndPoint端点信息,同时需要注意,这里会把masterEndpoints端点信息也作为构造ClientEndpoint端点的参数,也就是说这个ClientEndpoint会和masterEndpoints通信。

而在我上一篇文章中说过,只要是setupEndpoint方法被调用,一定会调用相关端点的的onStart方法,而这会调用clientEndPoint的onStart方法。

ClientEndPoint类中的onStart方法会匹配launch事件。源码如下

override def onStart(): Unit = {
    driverArgs.cmd match {
      case "launch" =>
          
        val mainClass = "org.apache.spark.deploy.worker.DriverWrapper"
       ....

        //将DriverWrapper 这个类封装到Command中
        val command = new Command(mainClass,
          Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions,
          sys.env, classPathEntries, libraryPathEntries, javaOpts)

       val driverDescription = new DriverDescription(
          driverArgs.jarUrl,
          driverArgs.memory,
          driverArgs.cores,
          driverArgs.supervise,
          command)

        //向Master申请启动Driver,Master中的 receiveAndReply 方法会接收此请求消息
        asyncSendToMasterAndForwardReply[SubmitDriverResponse](
          RequestSubmitDriver(driverDescription))
    ....
}

onStart中匹配我们的launch的过程,这个过程是启动driverWrapper的过程,可以看到上面源码中封装了mainClass ,该参数对应DriverWrapper类的全限定名,之后将mainClass封装到command中,然后封装到driverDescription中,向Master申请启动Driver。

这个过程会向Mster发送消息,是通过rpcEnv来实现发射消息的,而这里就涉及到outbox信箱,会调用postToOutbox方法,向outbox信箱中添加消息,然后通过TransportClient的send或sendRpc方法发送消息。发件箱以及发送过程是在同一个线程中进行。

而细心的同学会注意到这里调用的方法名为SendToMasterAndForwardReply,见名之意,发送消息到master并且期待回应。

下面是rpcEnv来实现向远端发送消息的一个调用流程,最终会通过netty中的TransportClient来写出。

 override def send(message: Any): Unit = {
    require(message != null, "Message is null")
    nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
  }

private[netty] def send(message: RequestMessage): Unit = {
    val remoteAddr = message.receiver.address
    if (remoteAddr == address) {
     ....
    } else {
      // Message to a remote RPC endpoint.
      postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
    }
  }


private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
  if (receiver.client != null) {
      message.sendWith(receiver.client)
    } else {
       ....
   }
}

override def sendWith(client: TransportClient): Unit = {
    this.client = client
    this.requestId = client.sendRpc(content, this)
  }

public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
    long startTime = System.currentTimeMillis();
    if (logger.isTraceEnabled()) {
      logger.trace("Sending RPC to {}", getRemoteAddress(channel));
    }

    long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
    handler.addRpcRequest(requestId, callback);

    channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)))
        .addListener(future -> {
          if (future.isSuccess()) {
            long timeTaken = System.currentTimeMillis() - startTime;
            if (logger.isTraceEnabled()) {
              logger.trace("Sending request {} to {} took {} ms", requestId,
                getRemoteAddress(channel), timeTaken);
            }
          } else {
            String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
              getRemoteAddress(channel), future.cause());
            logger.error(errorMsg, future.cause());
            handler.removeRpcRequest(requestId);
            channel.close();
            try {
              callback.onFailure(new IOException(errorMsg, future.cause()));
            } catch (Exception e) {
              logger.error("Uncaught exception in RPC response callback handler!", e);
            }
          }
        });

    return requestId;
  }

之后,Master端会触发receiveAndReply函数,匹配RequestSubmitDriver样例类,完成模式匹配执行后续流程。

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
  case RequestSubmitDriver(description) =>
      //判断Master状态
      if (state != RecoveryState.ALIVE) {
        val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
          "Can only accept driver submissions in ALIVE state."
        context.reply(SubmitDriverResponse(self, false, None, msg))
      } else {
        logInfo("Driver submitted " + description.command.mainClass)
        val driver = createDriver(description)
        persistenceEngine.addDriver(driver)
        waitingDrivers += driver
        drivers.add(driver)
        schedule()

        // TODO: It might be good to instead have the submission client poll the master to determine
        //       the current status of the driver. For now it's simply "fire and forget".

        context.reply(SubmitDriverResponse(self, true, Some(driver.id),
          s"Driver successfully submitted as ${driver.id}"))
      }

....

}

可以看到这里首先将Driver信息封装成DriverInfo,然后添加待调度列表waitingDrivers中,然后调用通用的schedule函数。

 private def schedule(): Unit = {
    //判断Master状态
    if (state != RecoveryState.ALIVE) {
      return
    }
    // Drivers take strict precedence over executors 这里是打散worker
    val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
    //可用的worker数量
    val numWorkersAlive = shuffledAliveWorkers.size
    var curPos = 0
    for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers
      // We assign workers to each waiting driver in a round-robin fashion. For each driver, we
      // start from the last worker that was assigned a driver, and continue onwards until we have
      // explored all alive workers.
      var launched = false
      var numWorkersVisited = 0
      while (numWorkersVisited < numWorkersAlive && !launched) {
        //拿到curPos位置的worker
        val worker = shuffledAliveWorkers(curPos)
        numWorkersVisited += 1
        if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
          //这里是启动Driver,启动Driver之后会为当前的application 申请资源
          launchDriver(worker, driver)
          waitingDrivers -= driver
          launched = true
        }
        //curPos 就是一直加一的往后取 Worker  ,一直找到满足资源的worker
        curPos = (curPos + 1) % numWorkersAlive
      }
    }
    startExecutorsOnWorkers()
  }

由于waitingDrivers不为空,则会走LaunchDriver的流程,当前的application申请资源,这时会向worker发送消息,触发Worker的receive方法。

override def receive: PartialFunction[Any, Unit] = synchronized {
....

 /*
     * Driver启动就是DriverWrapper类启动,DriverWrapper的启动就是在Worker中创建一个Driver 进程,
     * 之后就是启动DriverWrapper的main方法
     */
    case LaunchDriver(driverId, driverDesc) =>
      logInfo(s"Asked to launch driver $driverId")
      val driver = new DriverRunner(
        conf,
        driverId,
        workDir,
        sparkHome,
        driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)),
        self,
        workerUri,
        securityMgr)
      drivers(driverId) = driver
      //启动Driver,会初始化 org.apache.spark.deploy.worker.DriverWrapper ,运行main方法
      driver.start()

      coresUsed += driverDesc.cores
      memoryUsed += driverDesc.mem

....

}

Worker的receive方法中,当Worker遇到LaunchDriver指令时,创建并启动一个DriverRunner,DriverRunner启动一个线程,异步的处理Driver启动工作。这里说启动的Driver就是刚才说的org.apache.spark.deploy.worker.DriverWrapper

private[worker] def start() = {
    new Thread("DriverRunner for " + driverId) {
      override def run() {
          ....
        try {
          ....
          // prepare driver jars and run driver
          //这里的方法prepareAndRunDriver 中最后会启动Driver ,将DriverWrapper 包装类启动
          val exitCode = prepareAndRunDriver()

          // set final state depending on if forcibly killed and process exit code
          finalState = if (exitCode == 0) {
            Some(DriverState.FINISHED)
          } else if (killed) {
            Some(DriverState.KILLED)
          } else {
            Some(DriverState.FAILED)
          }
        } catch {
          ...
        } finally {
         ...
        }
        worker.send(DriverStateChanged(driverId, finalState.get, finalException))
      }
    }.start()
  }

可以看到上面在DriverRunner中是开辟线程异步的处理Driver启动工作,不会阻塞主进程的执行,而prepareAndRunDriver方法中最终调用 runDriver..

 private def runDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean): Int = {
    builder.directory(baseDir)
    //初始化启动Driver 也就是启动DrivarWrapper
    def initialize(process: Process): Unit = {
      // Redirect stdout and stderr to files
      val stdout = new File(baseDir, "stdout")
      CommandUtils.redirectStream(process.getInputStream, stdout)

      val stderr = new File(baseDir, "stderr")
      val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"")
      val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40)
      Files.append(header, stderr, StandardCharsets.UTF_8)
      CommandUtils.redirectStream(process.getErrorStream, stderr)
    }
    runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
  }

runDriver中主要先做了一些初始化工作,接着就开始启动driver了。

上述Driver启动工作主要分为以下几步:

下面我们直接看DriverWrapper的实现

def main(args: Array[String]) {
    args.toList match {
      /*
       * IMPORTANT: Spark 1.3 provides a stable application submission gateway that is both
       * backward and forward compatible across future Spark versions. Because this gateway
       * uses this class to launch the driver, the ordering and semantics of the arguments
       * here must also remain consistent across versions.
       */
      //下面的mainClass就是我们真正提交的application
      case workerUrl :: userJar :: mainClass :: extraArgs =>
        val conf = new SparkConf()
        val host: String = Utils.localHostName()
        val port: Int = sys.props.getOrElse("spark.driver.port", "0").toInt
        val rpcEnv = RpcEnv.create("Driver", host, port, conf, new SecurityManager(conf))
        logInfo(s"Driver address: ${rpcEnv.address}")
        rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl))

        val currentLoader = Thread.currentThread.getContextClassLoader
        val userJarUrl = new File(userJar).toURI().toURL()
        val loader =
          if (sys.props.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
            new ChildFirstURLClassLoader(Array(userJarUrl), currentLoader)
          } else {
            new MutableURLClassLoader(Array(userJarUrl), currentLoader)
          }
        Thread.currentThread.setContextClassLoader(loader)
        setupDependencies(loader, userJar)

        // Delegate to supplied main class
        val clazz = Utils.classForName(mainClass)
        //得到提交application的主方法
        val mainMethod = clazz.getMethod("main", classOf[Array[String]])

        /**
          * 启动提交的application 中的main 方法。
          * 这里启动application,会先创建SparkConf和SparkContext
          *   SparkContext中 362行try块中会创建TaskScheduler(492)
          */
        mainMethod.invoke(null, extraArgs.toArray[String])

        rpcEnv.shutdown()

DriverWrapper,会创建了一个RpcEndpoint与RpcEnv,RpcEndpoint为WorkerWatcher,主要目的为监控Worker节点是否正常,如果出现异常就直接退出,然后当前的ClassLoader加载userJar,同时执行userMainClass,在执行用户的main方法后关闭workerWatcher。

以上就是SparkSubmit的流程,下一篇我会对SparkContext的源码进行解析。

欢迎关注...

上一篇 下一篇

猜你喜欢

热点阅读