kubeflow - mxnet-operator代码分析

2019-07-19  本文已影响0人  陈先生_9e91

kubeflow - mxnet-operator

要求

example

kubeflow/mxnet-operator/examples/v1beta1/train/mx_job_dist_gpu.yaml

apiVersion: "kubeflow.org/v1beta1"
kind: "MXJob"
metadata:
  name: "mxnet-job"
spec:
  jobMode: MXTrain
  mxReplicaSpecs:
    Scheduler:
      replicas: 1
      restartPolicy: Never
      template:
        spec:
          containers:
            - name: mxnet
              image: mxjob/mxnet:gpu
    Server:
      replicas: 1
      restartPolicy: Never
      template:
        spec:
          containers:
            - name: mxnet
              image: mxjob/mxnet:gpu
    Worker:
      replicas: 1
      restartPolicy: Never
      template:
        spec:
          containers:
            - name: mxnet
              image: mxjob/mxnet:gpu
              command: ["python"]
              args: ["/incubator-mxnet/example/image-classification/train_mnist.py","--num-epochs","10","--num-layers","2","--kv-store","dist_device_sync","--gpus","0"]
              resources:
                limits:
                  nvidia.com/gpu: 1

三种角色,SchedulerServerWorker,内容跟常见的deploy没什么太大差别。

CRD

kubeflow/mxnet-operator/manifests/crd-v1beta1.yaml

apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
  name: mxjobs.kubeflow.org
spec:
  group: kubeflow.org
  version: v1beta1
  scope: Namespaced
  names:
    kind: MXJob
    singular: mxjob
    plural: mxjobs
  validation:
    openAPIV3Schema:
      properties:
        spec:
          properties:
            mxReplicaSpecs:
              properties:
                # The validation works when the configuration contains
                # `Worker`, `Server`, `Scheduler`, 
                # `TunerTracker`, `TunerServer`, `Tuner`,
                # Otherwise it will not be validated.
                Scheduler:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1
                      maximum: 1
                Worker:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1
                Server:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1
                TunerTracker:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1
                      maximum: 1
                TunerServer:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1
                Tuner:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1
                      maximum: 1

Scheduler

                Scheduler:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1
                      maximum: 1

Worker

                Worker:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1

Server

                Server:
                  properties:
                    replicas:
                      type: integer
                      minimum: 1

launch

后期可以增加launch,简化分布式job

code

create

kubeflow/mxnet-operator/pkg/controller.v1/mxnet/controller.go

    // Set up an event handler for when mxjob resources change.
    mxJobInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
        AddFunc:    tc.addMXJob,
    })

注册时间处理函数

kubeflow/mxnet-operator/pkg/controller.v1/mxnet/job.go

// When a pod is added, set the defaults and enqueue the current mxjob.
func (tc *MXController) addMXJob(obj interface{}) {
   // Convert from unstructured object.
   mxJob, err := mxJobFromUnstructured(obj)

   // Set default for the new mxjob.
   scheme.Scheme.Default(mxJob)

   msg := fmt.Sprintf("MXJob %s is created.", mxJob.Name)
   logger.Info(msg)

   // Add a created condition.
   err = updateMXJobConditions(mxJob, mxv1.MXJobCreated, mxJobCreatedReason, msg)

   // Convert from mxjob object
   err = unstructuredFromMXJob(obj, mxJob)

   tc.enqueueMXJob(obj)
}

func (tc *MXController) enqueueMXJob(mxjob interface{}) {
   key, err := KeyFunc(mxjob)

   // TODO: we may need add backoff here
   tc.WorkQueue.Add(key)
}

kubeflow/mxnet-operator/pkg/controller.v1/mxnet/controller.go

// runWorker is a long-running function that will continually call the
// processNextWorkItem function in order to read and process a message on the
// workqueue.
func (tc *MXController) runWorker() {
   for tc.processNextWorkItem() {
   }
}

// processNextWorkItem will read a single work item off the workqueue and
// attempt to process it, by calling the syncHandler.
func (tc *MXController) processNextWorkItem() bool {
    key, quit := tc.WorkQueue.Get()
    defer tc.WorkQueue.Done(key)

    mxJob, err := tc.getMXJobFromKey(key.(string))

    // Verify
    err = tc.inspectMXjob(mxJob)

    // Sync MXJob to match the actual state to this desired state.
    forget, err := tc.syncHandler(key.(string))
    if err == nil {
        if forget {
            tc.WorkQueue.Forget(key)
        }
        return true
    }

    tc.WorkQueue.AddRateLimited(key)

    return true
}

// inspectMXjob make sure a MXjob has all the necessary MXReplicaSpecs members for a special jobMode.
// if not it return err
func (tc *MXController) inspectMXjob(mxjob *mxv1.MXJob) error {
    if mxjob.Spec.JobMode == mxv1.MXTrain {
        // Must have MXReplicaTypeScheduler, MXReplicaTypeServer, MXReplicaTypeWorker, shouldn't have MXReplicaTypeTuner
        if _, ok := mxjob.Spec.MXReplicaSpecs[mxv1.MXReplicaTypeScheduler]; !ok {
            return errWrongJobMode
        }
        if _, ok := mxjob.Spec.MXReplicaSpecs[mxv1.MXReplicaTypeServer]; !ok {
            return errWrongJobMode
        }
        if _, ok := mxjob.Spec.MXReplicaSpecs[mxv1.MXReplicaTypeWorker]; !ok {
            return errWrongJobMode
        }
    } 
    return nil
}
// syncMXJob syncs the mxjob with the given key if it has had its expectations fulfilled, meaning
// it did not expect to see any more of its pods/services created or deleted.
// This function is not meant to be invoked concurrently with the same key.
func (tc *MXController) syncMXJob(key string) (bool, error) {
   namespace, name, err := cache.SplitMetaNamespaceKey(key)

   sharedMXJob, err := tc.getMXJobFromName(namespace, name)

   mxjob := sharedMXJob.DeepCopy()
   mxjobNeedsSync := tc.satisfiedExpectations(mxjob)

   // Set default for the new mxjob.
   scheme.Scheme.Default(mxjob)

   if mxjobNeedsSync && mxjob.DeletionTimestamp == nil {
      reconcileMXJobsErr = tc.reconcileMXJobs(mxjob)
   }

   return true, err
}

Job所需的pod,svc是否满足预期;如果满足满足预期,就调整Job?

// reconcileMXJobs checks and updates replicas for each given MXReplicaSpec.
// It will requeue the mxjob in case of an error while creating/deleting pods/services.
func (tc *MXController) reconcileMXJobs(mxjob *mxv1.MXJob) error {
   mxjobKey, err := KeyFunc(mxjob)

   logger.Infof("Reconcile MXJobs %s", mxjob.Name)

   pods, err := tc.GetPodsForJob(mxjob)

   services, err := tc.GetServicesForJob(mxjob)

   var failureMessage string
   mxJobExceedsLimit := false

   // If the MXJob is terminated, delete all pods and services.
   if isSucceeded(mxjob.Status) || isFailed(mxjob.Status) || mxJobExceedsLimit {
      err := tc.deletePodsAndServices(mxjob, pods)

      err := tc.cleanupMXJob(mxjob)

      if tc.Config.EnableGangScheduling {
         err := tc.DeletePodGroup(mxjob)
      }

      return tc.updateStatusHandler(mxjob)
   }

   if tc.Config.EnableGangScheduling {
      minAvailableReplicas := getTotalReplicas(mxjob)
      _, err := tc.SyncPodGroup(mxjob, minAvailableReplicas)
   }

   // Save the current state of the replicas
   replicasStatus := make(map[string]v1.PodPhase)

   // Diff current active pods/services with replicas.
   for rtype, spec := range mxjob.Spec.MXReplicaSpecs {
      err = tc.reconcilePods(mxjob, pods, rtype, spec, replicasStatus)

      err = tc.reconcileServices(mxjob, services, rtype, spec)
   }

   // TODO(CPH): Add check here, no need to update the mxjob if the status hasn't changed since last time.
   return tc.updateStatusHandler(mxjob)
}

kubeflow/mxnet-operator/pkg/controller.v1/mxnet/pod.go

// reconcilePods checks and updates pods for each given MXReplicaSpec.
// It will requeue the mxjob in case of an error while creating/deleting pods.
func (tc *MXController) reconcilePods(
   mxjob *mxv1.MXJob,
   pods []*v1.Pod,
   rtype mxv1.MXReplicaType,
   spec *mxv1.MXReplicaSpec, rstatus map[string]v1.PodPhase) error {

   // Convert MXReplicaType to lower string.
   rt := strings.ToLower(string(rtype))
   // Get all pods for the type rt.
   pods, err := tc.FilterPodsForReplicaType(pods, rt)

   replicas := int(*spec.Replicas)

   podSlices := tc.GetPodSlices(pods, replicas, logger)
   for index, podSlice := range podSlices {
      if len(podSlice) > 1 {
         logger.Warningf("We have too many pods for %s %d", rt, index)
      } else if len(podSlice) == 0 {
         logger.Infof("Need to create new pod: %s-%d", rt, index)
          
         err = tc.createNewPod(mxjob, rt, strconv.Itoa(index), spec)
          
      } else {
         // Check the status of the current pod.
         pod := podSlice[0]
         // Get the exit code of the mxnet container.
         var exitCode int32 = 0xbeef // magic number
          
         for _, status := range pod.Status.ContainerStatuses {
            state := status.State
            if status.Name == mxv1.DefaultContainerName && state.Terminated != nil {
               exitCode = state.Terminated.ExitCode
               logger.Infof("Pod: %v.%v exited with code %v", pod.Namespace, pod.Name, exitCode)
         }

         // Check if the pod is retryable.
         if spec.RestartPolicy == mxv1.RestartPolicyExitCode {
            if pod.Status.Phase == v1.PodFailed && train_util.IsRetryableExitCode(exitCode) {
               logger.Infof("Need to restart the pod: %v.%v", pod.Namespace, pod.Name)
               
                err := tc.PodControl.DeletePod(pod.Namespace, pod.Name, mxjob)
            }
         }

         // Check whether scheduler is exited without error.
         if rtype == mxv1.MXReplicaTypeScheduler && exitCode == 0 {
            schedulerCompleted = true
         }
         updateMXJobReplicaStatuses(mxjob, rtype, pod)
      }
   }

   return tc.updateStatusSingle(mxjob, rtype, replicas, restart, schedulerCompleted)
}
// createNewPod creates a new pod for the given index and type.
func (tc *MXController) createNewPod(mxjob *mxv1.MXJob, rt, index string, spec *mxv1.MXReplicaSpec) error {
    mxjobKey, err := KeyFunc(mxjob)

    expectationPodsKey := jobcontroller.GenExpectationPodsKey(mxjobKey, rt)
    err = tc.Expectations.ExpectCreations(expectationPodsKey, 1)

    // Create OwnerReference.
    controllerRef := tc.GenOwnerReference(mxjob)

    // Set type and index for the worker.
    labels := tc.GenLabels(mxjob.Name)
    labels[mxReplicaTypeLabel] = rt
    labels[mxReplicaIndexLabel] = index

    podTemplate := spec.Template.DeepCopy()

    // Set name for the template.
    podTemplate.Name = jobcontroller.GenGeneralName(mxjob.Name, rt, index)

    for key, value := range labels {
        podTemplate.Labels[key] = value
    }

    err := setClusterSpec(podTemplate, mxjob, rt, index)

    setRestartPolicy(podTemplate, spec)

    err = tc.PodControl.CreatePodsWithControllerRef(mxjob.Namespace, podTemplate, mxjob, controllerRef)
    return nil
}

根据用户指定的MXReplicaSpec创建pod;除此以外,为了多机cluster通信,会增加cluster相关配置

func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, mxjob *mxv1.MXJob, rt, index string) error {

    // Generate MX_CONFIG JSON.
    mxConfigData, err := genMXConfig(mxjob, rt, index)

    // Generate MX_CONFIG JSON Str.
    mxConfigJson, err := json.Marshal(mxConfigData)

    // Add MX_CONFIG environment variable.
    for i := range podTemplateSpec.Spec.Containers {

        c := &podTemplateSpec.Spec.Containers[i]

        c.Env = append(c.Env, v1.EnvVar{
            Name:  mxConfig,
            Value: string(mxConfigJson),
        })

        c.Env = append(c.Env, v1.EnvVar{
            Name:  "DMLC_PS_ROOT_PORT",
            Value: strconv.Itoa(getConfigAddr(&mxConfigData, mxv1.MXReplicaTypeScheduler, 0).Port),
        })

        c.Env = append(c.Env, v1.EnvVar{
            Name:  "DMLC_PS_ROOT_URI",
            Value: getConfigAddr(&mxConfigData, mxv1.MXReplicaTypeScheduler, 0).Url,
        })

        c.Env = append(c.Env, v1.EnvVar{
            Name:  "DMLC_NUM_SERVER",
            Value: strconv.Itoa(getConfigReplica(&mxConfigData, mxv1.MXReplicaTypeServer)),
        })

        c.Env = append(c.Env, v1.EnvVar{
            Name:  "DMLC_NUM_WORKER",
            Value: strconv.Itoa(getConfigReplica(&mxConfigData, mxv1.MXReplicaTypeWorker)),
        })

        c.Env = append(c.Env, v1.EnvVar{
            Name:  "DMLC_ROLE",
            Value: mxConfigData.Task.Type,
        })

        c.Env = append(c.Env, v1.EnvVar{
            Name:  "DMLC_USE_KUBERNETES",
            Value: strconv.Itoa(1),
        })
    }
    return nil
}

设置分布式mxnet所需的环境变量

上一篇 下一篇

猜你喜欢

热点阅读