在一个算子前后添加quant和dequant算子

2023-04-09  本文已影响0人  i_1312
    void OnnxParserImpl::ParseActivation(const onnx::NodeProto &node, const armnn::ActivationFunction func)
    {
        CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
        CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);

        VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));

        ActivationDescriptor desc;
        desc.m_Function = func;

        if (func == ActivationFunction::BoundedReLu)
        {
            if (node.input_size() == 1 && node.attribute_size() > 0)
            {
                desc.m_A = ReadOptionalNodeFloatAttribute(node, "max", std::numeric_limits<float>::max());
                desc.m_B = ReadOptionalNodeFloatAttribute(node, "min", std::numeric_limits<float>::lowest());
            }
            else
            {
                desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
                desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
            }
        }

        ///   添加dequant算子
        IConnectableLayer *const layer_dequant = m_Network->AddDequantizeLayer((node.name() + "dequant").c_str());
        ARMNN_ASSERT(layer_dequant != nullptr);
        auto outputInfo_dequant = ComputeOutputInfo({node.output(0)}, layer_dequant, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
        layer_dequant->GetOutputSlot(0).SetTensorInfo(outputInfo_dequant[0]);
        // register the input connection slots for the layer, connections are made after all layers have been created
        // only the tensors for the inputs are relevant, exclude the const tensors
        RegisterInputSlots(layer_dequant, {node.input(0)});
        // register the output connection slots for the layer, connections are made after all layers have been created
        RegisterOutputSlots(layer_dequant, {node.input(0) + "dequant"});


        IConnectableLayer *const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
        ARMNN_ASSERT(layer != nullptr);

        auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
        // activation output
        // if (m_jsonRoot.isMember(node.output(0)))
        // {
        //     Json::Value prop = m_jsonRoot[node.output(0)];
        //     std::string type = prop["type"].asString();
        //     float scale = prop["scale"].asFloat();
        //     int zeropoint = prop["zeropoint"].asInt();
        //     outputInfo[0].SetDataType(mapDataType(type));
        //     outputInfo[0].SetQuantizationScale(scale);
        //     outputInfo[0].SetQuantizationOffset(zeropoint);
        // }

   

        layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);

        // register the input connection slots for the layer, connections are made after all layers have been created
        // only the tensors for the inputs are relevant, exclude the const tensors
        RegisterInputSlots(layer, {node.input(0) + "dequant"}); // 这是给layer注册输入

        // register the output connection slots for the layer, connections are made after all layers have been created
        RegisterOutputSlots(layer, {node.output(0) + "quant"}); // 这是给layer注册输出

        IConnectableLayer *const quant_layer = m_Network->AddQuantizeLayer(node.name().c_str());
        ARMNN_ASSERT(layer != nullptr);
        auto outputInfo_quant = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
        outputInfo_quant[0].SetDataType(mapDataType("int8"));
        outputInfo_quant[0].SetQuantizationScale(0.0002f);
        outputInfo_quant[0].SetQuantizationOffset(23);
        quant_layer->GetOutputSlot(0).SetTensorInfo(outputInfo_quant[0]);
        RegisterInputSlots(quant_layer, {node.output(0) + "quant"});
        RegisterOutputSlots(quant_layer, {node.output(0)});
    }


主要的一个思路在于

判断输入输出的类型是否相同
如果输入是浮点,但输出是量化,那么在这个op之前添加一个quant
如果输入时量化,输出时浮点,那么在op之前添加dequant 的算子

需要去内部检查这个算子的量化参数是否正确:
1 该有的量化参数都需要有,比如conv 输出、输出、权重量化类型相同,bias也有量化参数
2 遇到不正确的量化参数, 需要提示给用户,

上一篇下一篇

猜你喜欢

热点阅读