ArmNN onnx parser添加quant dequant

2023-04-08  本文已影响0人  i_1312

ArmNN 主要处理的模型是Tflite模型,对onnx模型的支持不好,但提供了ONNX 模型解析的基础框架,支持了少量的ONNX算子,想要支持更多的ONNX 算子,就必须自己去添加了,这里添加了QuantizeLayer和DequantizeLayer

armnn::TensorInfo ToTensorInfo(const std::string &name, std::vector<unsigned int> &shape, int data_type)函数添加新的数据类型

            case onnx::TensorProto::INT8: 
            {
                type = DataType::QAsymmS8;
                break;
            }

头文件中添加函数

void ParseQuantize(const onnx::NodeProto &nodeProto);
void ParseDequantize(const onnx::NodeProto &nodeProto);

添加实现函数

void OnnxParserImpl::ParseQuantize(const onnx::NodeProto &node)
    {

        CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 3);
        CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);

        IConnectableLayer *const layer = m_Network->AddQuantizeLayer(node.name().c_str());
        ARMNN_ASSERT(layer != nullptr);

        onnx::TensorProto onnxTensor = *m_TensorsInfo[node.input(1)].m_tensor;

        auto srcFloatDataPtr1 = onnxTensor.float_data().data();
        float scale = (*srcFloatDataPtr1);

        onnx::TensorProto onnxTensor2 = *m_TensorsInfo[node.input(2)].m_tensor;
        // auto srcFloatDataPtr2 = onnxTensor2.float_data().data();
        // float zeropint = (*srcFloatDataPtr2);
        // std::cout << zeropint << std::endl;
        auto srcData = reinterpret_cast<const int32_t *>(onnxTensor2.raw_data().c_str());
        int32_t zeropint = srcData[0];

        auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});

        outputInfo[0].SetQuantizationScale(scale);
        outputInfo[0].SetQuantizationOffset(zeropint);

        layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);

        // register the input connection slots for the layer, connections are made after all layers have been created
        // only the tensors for the inputs are relevant, exclude the const tensors
        RegisterInputSlots(layer, {node.input(0)});

        // register the output connection slots for the layer, connections are made after all layers have been created
        RegisterOutputSlots(layer, {node.output(0)});
    }

    void OnnxParserImpl::ParseDequantize(const onnx::NodeProto &node)
    {

        CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 3);
        CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);

        IConnectableLayer *const layer = m_Network->AddDequantizeLayer(node.name().c_str());
        ARMNN_ASSERT(layer != nullptr);

        auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});

        layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);

        // register the input connection slots for the layer, connections are made after all layers have been created
        // only the tensors for the inputs are relevant, exclude the const tensors
        RegisterInputSlots(layer, {node.input(0)});

        // register the output connection slots for the layer, connections are made after all layers have been created
        RegisterOutputSlots(layer, {node.output(0)});
    }

std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions中添加函数的映射

        {"QuantizeLinear", &OnnxParserImpl::ParseQuantize},
        {"DequantizeLinear", &OnnxParserImpl::ParseDequantize},
上一篇 下一篇

猜你喜欢

热点阅读