人工不智能PyTorch/JIT

PyTorch Internal:算子注册

2023-09-29  本文已影响0人  A君来了

Overhead

PyTorch 执行 eager 操作时,例如,torch.add(a, b),调度器(c10::Dispatcher)会根据分派键(DispatchKey) 来查找并执行 add op 的 op kernel (理解PyTorch分发机制的内部工作原理)。因此,算子注册过程就是在调度器中定义 op,并将 kernel function 注册到 op 的指定分派键条目中。

Torch Library

torch::Library 是算子注册用的 helper,通过它注册的算子有着相同的命名空间、dispatch key等。

TORCH_LIBRARY(myops, m) {
  m.def("myadd(Tensor self, Tensor other) -> Tensor");
  m.def("mysub(Tensor self, Tensor other) -> Tensor", mysub_func);
  m.impl("myadd", myadd_func);
}

m 就是命名空间为 myops 的 library,它通过 m.def 定义了 myadd 和 mysub 这两个 op 的静态信息 schema。mysub 在定义的同时也将 mysub_func 函数注册到 op,而 myadd 的 op kernel 则是通过 m.impl 单独注册的。由于 TORCH_LIBRARY 宏没有指定 dispatch key,因此,这两个 op kernel 都是 CatchAll 函数。

如果要将 kernel function 注册到指定的 dispatch key,需要用到 TORCH_LIBRARY_IMPL 宏:

TORCH_LIBRARY_IMPL(myops, CUDA, m) {
  m.impl("myadd", myadd_cuda);
  m.impl("mysub", mysub_cuda);
}

所有通过 m 注册的 kernel function 都会注册到 op 的 CUDA key 条目中,它执行的优先级会比 CatchAll 更高。

OperatorDef

OperatorDef 用于描述调度器中 op 的静态信息,它会提供 registerSchema()registerKernel() 方法给 m.def() 和 m.impl() 分别用于注册 op 和 kernel。

Kernel list

通过 m.impl() 注册的 kernel function 会插入到指定 dispatch key 的 kernel list(kernels_)的头部,而调度器则会从列表中的首元素中获取 kernel。也就是说,PyTorch 允许为 op 的同一个 dispatch key 注册多个 kernel,而新 kernel 会覆盖旧 kernel。

class TORCH_API OperatorEntry final {
  ...
  ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
};

const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
  auto kern_it = kernels_.find(dispatch_key);
  if (kern_it != kernels_.end()) {
    TORCH_INTERNAL_ASSERT(!kern_it->second.empty());
    TORCH_INTERNAL_ASSERT(kern_it->second.front().kernel.isValid());
    return &kern_it->second.front();
  }
  return nullptr;
}

End

上一篇下一篇

猜你喜欢

热点阅读