MXNET

Operators in MXNet

2016-06-13  本文已影响313人  cptn3m0

Operators in MXNet

An operator in MXNet is a class that contains both actual computation logic and auxiliary informations that could aid our system to perform optimizations like in-place updates and auto-derivative. Before continue on this document, it is strongly recommended for you to first understand mshadow library, since all operators compute on tensor-like structure mshadow::TBlob provided by the system during runtime. MXNet's operator interface tries its best to offer users flexibility including:

Operator Interface

The core interface of operator is Forward:

virtual void Forward(const OpContext &ctx,
                     const std::vector<TBlob> &in_data,
                     const std::vector<OpReqType> &req,
                     const std::vector<TBlob> &out_data,
                     const std::vector<TBlob> &aux_states) = 0;

Apart from Forward operator, user could also optionally implement Backward interface defined as follows:

virtual void Backward(const OpContext &ctx,
                      const std::vector<TBlob> &out_grad,
                      const std::vector<TBlob> &in_data,
                      const std::vector<TBlob> &out_data,
                      const std::vector<OpReqType> &req,
                      const std::vector<TBlob> &in_grad,
                      const std::vector<TBlob> &aux_states);

The interface follows the design principle as Forward interface, except that out_grad, in_data and out_data are given and the operator should computes in_grad as results. The name strategy is similar to torch's convention and could be summarized in following figure:

[input/output semantics figure]

Some operator may not need all the out_grad, in_data and out_data. This could be specified by the DeclareBackwardDependency interface in OperatorProperty.

Operator Property

It is possible that one convolution has several implementations and users want to switch among them to achieve best performance. Therefore, we separate the operator semantic interfaces from the implementation interface (Operator class) into OperatorProperty class. The OperatorProperty interface consists of:

Create Operator from Operator Property

As mentioned above OperatorProperty includes all semantical attributes of an operation. It is also in charge of creating Operator pointer for actual computation.

Create Operator

Implement following interface in OperatorProperty:

virtual Operator* CreateOperator(Context ctx) const = 0;

For example:

class ConvolutionOp {
 public:
  void Forward( ... ) { ... }
  void Backward( ... ) { ... }
};
class ConvolutionOpProperty : public OperatorProperty {
 public:
  Operator* CreateOperator(Context ctx) const {
    return new ConvolutionOp;
  }
};

Parameterize Operator

When implementing convolution operator, we need to know the kernal size, the stride size, padding size and so on. These parameters should be passed to the operator before any Forward or Backward interface is called. To do so, user could define a ConvolutionParam structure:

#include <dmlc/parameter.h>
struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
  TShape kernel, stride, pad;
  uint32_t num_filter, num_group, workspace;
  bool no_bias;
};

Put it in ConvolutionOpProperty and pass it to the operator class during construction:

class ConvolutionOp {
 public:
  ConvolutionOp(ConvolutionParam p): param_(p) {}
  void Forward( ... ) { ... }
  void Backward( ... ) { ... }
 private:
  ConvolutionParam param_;
};
class ConvolutionOpProperty : public OperatorProperty {
 public:
  void Init(const vector<pair<string, string>& kwargs) {
    // initialize param_ using kwargs
  }
  Operator* CreateOperator(Context ctx) const {
    return new ConvolutionOp(param_);
  }
 private:
  ConvolutionParam param_;
};

Register Operator to MXNet

Use following macros to register the parameter structure and the operator property class to MXNet system:

DMLC_REGISTER_PARAMETER(ConvolutionParam);
MXNET_REGISTER_OP_PROPERTY(Convolution, ConvolutionOpProperty);

The first argument to the macro is the name string, the second is the property class name.

All in a list

Finally! We almost covered the interface we needed to define a new operator. Let's do a recap in a list:

Enjoy your MXNet trip.

上一篇 下一篇

猜你喜欢

热点阅读