Merge pull request !4451 from yankai10/mergetags/v0.7.0-beta
| @@ -27,7 +27,7 @@ | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "src/ir/primitive_value.h" | #include "src/ir/primitive_value.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #if 0 | #if 0 | ||||
| @@ -159,7 +159,7 @@ void MinnieBuildGraph::FbTest(const GraphDef *graph_def) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| int AnfImporter::Import() { | |||||
| int AnfImporter::Import(const schema::QuantType &quantType) { | |||||
| ConverterConstTensor(); | ConverterConstTensor(); | ||||
| auto ret = ConverterCNode(); | auto ret = ConverterCNode(); | ||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfImporter { | class AnfImporter { | ||||
| @@ -29,7 +30,7 @@ class AnfImporter { | |||||
| virtual ~AnfImporter() = default; | virtual ~AnfImporter() = default; | ||||
| virtual int Import(); | |||||
| virtual int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE); | |||||
| virtual FuncGraphPtr GetResult() = 0; | virtual FuncGraphPtr GetResult() = 0; | ||||
| @@ -1,5 +1,6 @@ | |||||
| /** | /** | ||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * This is the C++ adaptation and derivative work of Myia | |||||
| * (https://github.com/mila-iqia/myia/). | |||||
| * | * | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | * Copyright 2019 Huawei Technologies Co., Ltd | ||||
| * | * | ||||
| @@ -17,101 +18,218 @@ | |||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_conv_populater.h" | #include "src/common/anf_importer/anf_populater/anf_conv_populater.h" | ||||
| #include <mindspore/lite/src/ir/tensor.h> | |||||
| #include <memory> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| int group = GetValue<int>(prim->GetAttr("group")); | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (group > 1) { | |||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| void AnfConvPopulater::PopulaterConv2DMultiGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) { | |||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | } else { | ||||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||||
| attr->group = group; | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| attr->channelOut = GetValue<int>(prim->GetAttr("out_channel")); | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| void AnfConvPopulater::PopulaterConv2DSingleGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) { | |||||
| auto attr = std::make_unique<schema::Conv2DT>(); | |||||
| attr->group = group; | |||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| attr->channelOut = GetValue<int>(prim->GetAttr("out_channel")); | |||||
| auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, | |||||
| float *mMin, float *mMax) { | |||||
| constexpr float qmin = 0; | |||||
| constexpr float qmax = 255; | |||||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | |||||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | |||||
| } | |||||
| void AnfConvPopulater::PopulaterQuantParam( | |||||
| const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| auto narrow_range = prim->GetAttr("narrow_range"); | |||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | |||||
| auto num_bits = prim->GetAttr("num_bits"); | |||||
| int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits); | |||||
| std::vector<schema::QuantParamT> quants; | |||||
| schema::QuantParamT quantParam; | |||||
| auto mean = prim->GetAttr("mean"); | |||||
| auto std_dev = prim->GetAttr("std_dev"); | |||||
| if (mean != nullptr && std_dev != nullptr) { | |||||
| auto meanQuantOaram = GetValue<double>(mean); | |||||
| double stddevQuantOaram = GetValue<double>(std_dev); | |||||
| float mMin = 0.0; | |||||
| float mMax = 0.0; | |||||
| CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); | |||||
| quantParam.min = mMin; | |||||
| quantParam.max = mMax; | |||||
| } else { | |||||
| auto inputMin = prim->GetAttr("input_minq"); | |||||
| auto inputMax = prim->GetAttr("input_maxq"); | |||||
| auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(inputMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(inputMaxPtr->Data()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| } | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| vecQuantParam->emplace_back(quants); | |||||
| quants.clear(); | |||||
| int biasQuantSize = 0; | |||||
| auto filterMin = prim->GetAttr("filter_minq"); | |||||
| auto filterMax = prim->GetAttr("filter_maxq"); | |||||
| if (filterMin != nullptr && filterMax != nullptr) { | |||||
| auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(filterMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(filterMaxPtr->Data()); | |||||
| biasQuantSize = filterMinPtr->DataSize(); | |||||
| for (int i = 0; i < biasQuantSize; ++i) { | |||||
| quantParam.min = *(minBuf++); | |||||
| quantParam.max = *(maxBuf++); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| } | } | ||||
| primitive->value.type = schema::PrimitiveType_Conv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| vecQuantParam->emplace_back(quants); | |||||
| } | } | ||||
| quants.clear(); | |||||
| for (int i = 0; i < biasQuantSize; ++i) { | |||||
| quantParam.min = 0.0; | |||||
| quantParam.max = 0.0; | |||||
| quantParam.zeroPoint = 0; | |||||
| quantParam.scale = | |||||
| vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; | |||||
| quants.emplace_back(quantParam); | |||||
| } | |||||
| vecQuantParam->emplace_back(quants); | |||||
| quants.clear(); | |||||
| auto outputMin = prim->GetAttr("output_minq"); | |||||
| auto outputMax = prim->GetAttr("output_maxq"); | |||||
| if (outputMin != nullptr && outputMax != nullptr) { | |||||
| auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(outputMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| vecQuantParam->emplace_back(quants); | |||||
| } | |||||
| } | |||||
| int AnfConvPopulater::Populate(const PrimitivePtr &prim, | |||||
| PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | MS_ASSERT(primitiveTValuePtr != nullptr); | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| int group = GetValue<int>(prim->GetAttr("group")); | |||||
| if (group > 1) { | |||||
| PopulaterConv2DMultiGroup(prim, primitive, group); | |||||
| } else { | |||||
| PopulaterConv2DSingleGroup(prim, primitive, group); | |||||
| } | |||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | primitiveTValuePtr->SetPrimitiveT(primitive.release()); | ||||
| if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecQuantParam; | |||||
| PopulaterQuantParam(prim, &vecQuantParam); | |||||
| primitiveTValuePtr->SetInputQuantParam(vecQuantParam); | |||||
| } | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater()); | AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater()); | ||||
| @@ -1,5 +1,6 @@ | |||||
| /** | /** | ||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * This is the C++ adaptation and derivative work of Myia | |||||
| * (https://github.com/mila-iqia/myia/). | |||||
| * | * | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | * Copyright 2019 Huawei Technologies Co., Ltd | ||||
| * | * | ||||
| @@ -18,8 +19,9 @@ | |||||
| #ifndef MINDSPORE_ANF_CONV_PARSER_H | #ifndef MINDSPORE_ANF_CONV_PARSER_H | ||||
| #define MINDSPORE_ANF_CONV_PARSER_H | #define MINDSPORE_ANF_CONV_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfConvPopulater : public AnfNodePopulater { | class AnfConvPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| @@ -27,6 +29,18 @@ class AnfConvPopulater : public AnfNodePopulater { | |||||
| ~AnfConvPopulater() override = default; | ~AnfConvPopulater() override = default; | ||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | ||||
| const std::vector<AnfNodePtr> &inputs) override; | const std::vector<AnfNodePtr> &inputs) override; | ||||
| private: | |||||
| void PopulaterConv2DMultiGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group); | |||||
| void PopulaterConv2DSingleGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group); | |||||
| void PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, | |||||
| float *mMax); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -14,15 +14,113 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" | #include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" | ||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, | |||||
| const double &stdDev, float *mMin, | |||||
| float *mMax) { | |||||
| constexpr float qmin = 0; | |||||
| constexpr float qmax = 255; | |||||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | |||||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | |||||
| } | |||||
| void AnfDepwiseconv2DPopulater::PopulaterQuantParam( | |||||
| const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| auto narrow_range = prim->GetAttr("narrow_range"); | |||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | |||||
| auto num_bits = prim->GetAttr("num_bits"); | |||||
| int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits); | |||||
| std::vector<schema::QuantParamT> quants; | |||||
| schema::QuantParamT quantParam; | |||||
| auto mean = prim->GetAttr("mean"); | |||||
| auto std_dev = prim->GetAttr("std_dev"); | |||||
| if (mean != nullptr && std_dev != nullptr) { | |||||
| auto meanQuantOaram = GetValue<double>(mean); | |||||
| double stddevQuantOaram = GetValue<double>(std_dev); | |||||
| float mMin = 0.0; | |||||
| float mMax = 0.0; | |||||
| CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); | |||||
| quantParam.min = mMin; | |||||
| quantParam.max = mMax; | |||||
| } else { | |||||
| auto inputMin = prim->GetAttr("input_minq"); | |||||
| auto inputMax = prim->GetAttr("input_maxq"); | |||||
| auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(inputMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(inputMaxPtr->Data()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| } | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| vecQuantParam->emplace_back(quants); | |||||
| quants.clear(); | |||||
| int biasQuantSize = 0; | |||||
| auto filterMin = prim->GetAttr("filter_minq"); | |||||
| auto filterMax = prim->GetAttr("filter_maxq"); | |||||
| if (filterMin != nullptr && filterMax != nullptr) { | |||||
| auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(filterMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(filterMaxPtr->Data()); | |||||
| biasQuantSize = filterMinPtr->DataSize(); | |||||
| for (int i = 0; i < biasQuantSize; ++i) { | |||||
| quantParam.min = *(minBuf++); | |||||
| quantParam.max = *(maxBuf++); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| } | |||||
| vecQuantParam->emplace_back(quants); | |||||
| } | |||||
| quants.clear(); | |||||
| for (int i = 0; i < biasQuantSize; ++i) { | |||||
| quantParam.min = 0.0; | |||||
| quantParam.max = 0.0; | |||||
| quantParam.zeroPoint = 0; | |||||
| quantParam.scale = | |||||
| vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; | |||||
| quants.emplace_back(quantParam); | |||||
| } | |||||
| vecQuantParam->emplace_back(quants); | |||||
| quants.clear(); | |||||
| auto outputMin = prim->GetAttr("output_minq"); | |||||
| auto outputMax = prim->GetAttr("output_maxq"); | |||||
| if (outputMin != nullptr && outputMax != nullptr) { | |||||
| auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(outputMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| vecQuantParam->emplace_back(quants); | |||||
| } | |||||
| } | |||||
| int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, | |||||
| PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | const std::vector<AnfNodePtr> &inputs) { | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | ||||
| @@ -36,9 +134,9 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads")); | auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads")); | ||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | attr->padRight = pad_list[3]; | ||||
| auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation")); | ||||
| @@ -73,10 +171,13 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu | |||||
| auto abstractBase = paramNode->abstract(); | auto abstractBase = paramNode->abstract(); | ||||
| MS_ASSERT(abstractBase != nullptr); | MS_ASSERT(abstractBase != nullptr); | ||||
| if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | ||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| auto abstractTensor = | |||||
| utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| MS_ASSERT(abstractTensor != nullptr); | MS_ASSERT(abstractTensor != nullptr); | ||||
| if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | ||||
| auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| auto dims = | |||||
| utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape()) | |||||
| ->shape(); | |||||
| attr->channelIn = dims[kAnfPopulaterOne]; | attr->channelIn = dims[kAnfPopulaterOne]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -86,8 +187,16 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | MS_ASSERT(primitiveTValuePtr != nullptr); | ||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | primitiveTValuePtr->SetPrimitiveT(primitive.release()); | ||||
| if (primitiveTValuePtr->GetQuantType()) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecQuantParam; | |||||
| PopulaterQuantParam(prim, &vecQuantParam); | |||||
| primitiveTValuePtr->SetInputQuantParam(vecQuantParam); | |||||
| } | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dPopulater( | |||||
| "DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater( | |||||
| "DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,8 +15,9 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | ||||
| #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| @@ -24,6 +25,11 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | |||||
| ~AnfDepwiseconv2DPopulater() override = default; | ~AnfDepwiseconv2DPopulater() override = default; | ||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | ||||
| const std::vector<AnfNodePtr> &inputs) override; | const std::vector<AnfNodePtr> &inputs) override; | ||||
| private: | |||||
| void PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, | |||||
| float *mMax); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -14,14 +14,98 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" | #include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" | ||||
| #include <vector> | |||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include <vector> | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, | |||||
| float *mMin, float *mMax) { | |||||
| constexpr float qmin = 0; | |||||
| constexpr float qmax = 255; | |||||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | |||||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | |||||
| } | |||||
| void AnfMatmulPopulater::PopulaterQuantParam( | |||||
| const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| auto narrow_range = prim->GetAttr("narrow_range"); | |||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | |||||
| auto num_bits = prim->GetAttr("num_bits"); | |||||
| int32_t numbitsRangeQuantParam = GetValue<int32_t>(num_bits); | |||||
| std::vector<schema::QuantParamT> quants; | |||||
| schema::QuantParamT quantParam; | |||||
| auto mean = prim->GetAttr("mean"); | |||||
| auto std_dev = prim->GetAttr("std_dev"); | |||||
| if (mean != nullptr && std_dev != nullptr) { | |||||
| auto meanQuantOaram = GetValue<double>(mean); | |||||
| double stddevQuantOaram = GetValue<double>(std_dev); | |||||
| float mMin = 0.0; | |||||
| float mMax = 0.0; | |||||
| CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); | |||||
| quantParam.min = mMin; | |||||
| quantParam.max = mMax; | |||||
| } else { | |||||
| auto inputMin = prim->GetAttr("input_minq"); | |||||
| auto inputMax = prim->GetAttr("input_maxq"); | |||||
| auto inputMinPtr = inputMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto inputMaxPtr = inputMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(inputMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(inputMaxPtr->Data()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| } | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| vecQuantParam->emplace_back(quants); | |||||
| quants.clear(); | |||||
| auto filterMin = prim->GetAttr("filter_minq"); | |||||
| auto filterMax = prim->GetAttr("filter_maxq"); | |||||
| if (filterMin != nullptr && filterMax != nullptr) { | |||||
| auto filterMinPtr = filterMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto filterMaxPtr = filterMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(filterMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(filterMaxPtr->Data()); | |||||
| for (int i = 0; i < filterMinPtr->DataSize(); ++i) { | |||||
| quantParam.min = *(minBuf++); | |||||
| quantParam.max = *(maxBuf++); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| } | |||||
| vecQuantParam->emplace_back(quants); | |||||
| } | |||||
| quants.clear(); | |||||
| auto outputMin = prim->GetAttr("output_minq"); | |||||
| auto outputMax = prim->GetAttr("output_maxq"); | |||||
| if (outputMin != nullptr && outputMax != nullptr) { | |||||
| auto outputMinPtr = outputMin->cast<lite::tensor::TensorPtr>(); | |||||
| auto outputMaxPtr = outputMax->cast<lite::tensor::TensorPtr>(); | |||||
| float *minBuf = static_cast<float *>(outputMinPtr->Data()); | |||||
| float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | |||||
| quantParam.min = *minBuf; | |||||
| quantParam.max = *maxBuf; | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | |||||
| vecQuantParam->emplace_back(quants); | |||||
| } | |||||
| } | |||||
| int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, | |||||
| PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | const std::vector<AnfNodePtr> &inputs) { | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| auto attr = std::make_unique<schema::MatMulT>(); | auto attr = std::make_unique<schema::MatMulT>(); | ||||
| @@ -32,8 +116,16 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | MS_ASSERT(primitiveTValuePtr != nullptr); | ||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | primitiveTValuePtr->SetPrimitiveT(primitive.release()); | ||||
| if (primitiveTValuePtr->GetQuantType()) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecQuantParam; | |||||
| PopulaterQuantParam(prim, &vecQuantParam); | |||||
| primitiveTValuePtr->SetInputQuantParam(vecQuantParam); | |||||
| } | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", new AnfMatmulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", | |||||
| new AnfMatmulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", | |||||
| new AnfMatmulPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -24,6 +24,11 @@ class AnfMatmulPopulater : public AnfNodePopulater { | |||||
| ~AnfMatmulPopulater() override = default; | ~AnfMatmulPopulater() override = default; | ||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | ||||
| const std::vector<AnfNodePtr> &inputs) override; | const std::vector<AnfNodePtr> &inputs) override; | ||||
| private: | |||||
| void PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, | |||||
| float *mMax); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -28,18 +28,18 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "schema/inner/model_generated.h" | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | #include "google/protobuf/io/zero_copy_stream_impl.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "schema/inner/model_generated.h" | |||||
| #include "securec/include/securec.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | #include "src/ir/tensor.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "securec/include/securec.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| using string = std::string; | using string = std::string; | ||||
| using int32 = int32_t; | using int32 = int32_t; | ||||
| @@ -60,16 +60,24 @@ enum ParseForm : int { | |||||
| }; | }; | ||||
| static std::map<std::string, ParseForm> kParseTypeSwitchMap{ | static std::map<std::string, ParseForm> kParseTypeSwitchMap{ | ||||
| {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; | |||||
| {"type", FORM_PARSE_TYPE}, | |||||
| {"scalar", FORM_PARSE_SCALAR}, | |||||
| {"tensor", FORM_PARSE_TENSOR}}; | |||||
| static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | ||||
| {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, | |||||
| {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, | |||||
| {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, | |||||
| {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, | |||||
| {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, | |||||
| {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, | |||||
| {onnx::TensorProto_DataType_STRING, kObjectTypeString}, | |||||
| {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, | |||||
| {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, | |||||
| {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, | |||||
| {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, | |||||
| {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, | |||||
| {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, | |||||
| {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, | |||||
| {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, | |||||
| {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, | |||||
| {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, | |||||
| {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, | |||||
| {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, | |||||
| {onnx::TensorProto_DataType_STRING, kObjectTypeString}, | |||||
| }; | }; | ||||
| #if 0 | #if 0 | ||||
| @@ -189,15 +197,16 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ | |||||
| if (attr_tensor.type##_data_size() == 1) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } else { \ | |||||
| MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ | |||||
| } \ | |||||
| return {}; \ | |||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| ValuePtr ParseAttrInScalar_##type##_##valuetype( \ | |||||
| const onnx::TensorProto &attr_tensor) { \ | |||||
| if (attr_tensor.type##_data_size() == 1) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } else { \ | |||||
| MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ | |||||
| } \ | |||||
| return {}; \ | |||||
| } | } | ||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | ||||
| @@ -643,20 +652,21 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc | |||||
| } | } | ||||
| #else | #else | ||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ | |||||
| const onnx::TensorProto &attr_tensor) { \ | |||||
| MS_EXCEPTION_IF_NULL(prim); \ | |||||
| std::vector<ValuePtr> attr_value_vec; \ | |||||
| for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ | |||||
| attr_value_vec.push_back(MakeValue<valuetype>(value)); \ | |||||
| } \ | |||||
| if (attr_value_vec.size() == 1) { \ | |||||
| prim->AddAttr(attr_name, attr_value_vec[0]); \ | |||||
| } else { \ | |||||
| prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ | |||||
| } \ | |||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| void ParseAttrInScalar_##type##_##valuetype( \ | |||||
| const PrimitivePtr &prim, const std::string &attr_name, \ | |||||
| const onnx::TensorProto &attr_tensor) { \ | |||||
| MS_EXCEPTION_IF_NULL(prim); \ | |||||
| std::vector<ValuePtr> attr_value_vec; \ | |||||
| for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ | |||||
| attr_value_vec.push_back(MakeValue<valuetype>(value)); \ | |||||
| } \ | |||||
| if (attr_value_vec.size() == 1) { \ | |||||
| prim->AddAttr(attr_name, attr_value_vec[0]); \ | |||||
| } else { \ | |||||
| prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ | |||||
| } \ | |||||
| } | } | ||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | ||||
| @@ -667,8 +677,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | ||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | ||||
| bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | |||||
| const onnx::ValueInfoProto &value_proto) { | |||||
| bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( | |||||
| const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!value_proto.has_type() || !value_proto.has_name()) { | if (!value_proto.has_type() || !value_proto.has_name()) { | ||||
| MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; | MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; | ||||
| @@ -691,24 +701,30 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod | |||||
| shape.push_back(tensor_shape.dim(i).dim_value()); | shape.push_back(tensor_shape.dim(i).dim_value()); | ||||
| } | } | ||||
| if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { | |||||
| if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == | |||||
| kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; | MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| auto type_ptr = | |||||
| TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); | |||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| node->set_abstract(abstract_tensor); | node->set_abstract(abstract_tensor); | ||||
| if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | ||||
| tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||||
| tensor::Tensor *tensor_info = new tensor::Tensor( | |||||
| kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | MS_EXCEPTION_IF_NULL(tensor_info); | ||||
| tensor_info->MallocData(); | tensor_info->MallocData(); | ||||
| const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; | |||||
| const onnx::TensorProto initialize_proto = | |||||
| default_para_map_[value_proto.name()]; | |||||
| std::string initial_data = initialize_proto.raw_data(); | std::string initial_data = initialize_proto.raw_data(); | ||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | ||||
| MS_EXCEPTION_IF_NULL(tensor_data_buf); | MS_EXCEPTION_IF_NULL(tensor_data_buf); | ||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), | |||||
| initial_data.data(), initial_data.size()); | |||||
| if (EOK != ret) { | if (EOK != ret) { | ||||
| MS_LOG(ERROR) << "memcpy_s error"; | MS_LOG(ERROR) << "memcpy_s error"; | ||||
| return false; | return false; | ||||
| @@ -724,15 +740,18 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto) { | |||||
| bool AnfImporterFromProtobuf::ImportParametersForGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); | |||||
| MS_LOG(INFO) << "Parameters had default paramerer size is: " | |||||
| << importProto.initializer_size(); | |||||
| for (int i = 0; i < importProto.initializer_size(); ++i) { | for (int i = 0; i < importProto.initializer_size(); ++i) { | ||||
| const onnx::TensorProto &initializer_proto = importProto.initializer(i); | const onnx::TensorProto &initializer_proto = importProto.initializer(i); | ||||
| if (!initializer_proto.has_name()) { | if (!initializer_proto.has_name()) { | ||||
| MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; | |||||
| MS_LOG(ERROR) | |||||
| << "initializer vector of onnx GraphProto has no name at index: " | |||||
| << i; | |||||
| return false; | return false; | ||||
| } | } | ||||
| default_para_map_[initializer_proto.name()] = initializer_proto; | default_para_map_[initializer_proto.name()] = initializer_proto; | ||||
| @@ -741,7 +760,8 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu | |||||
| MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); | MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); | ||||
| for (int i = 0; i < importProto.input_size(); ++i) { | for (int i = 0; i < importProto.input_size(); ++i) { | ||||
| const onnx::ValueInfoProto &input_proto = importProto.input(i); | const onnx::ValueInfoProto &input_proto = importProto.input(i); | ||||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { | |||||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), | |||||
| input_proto)) { | |||||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -749,20 +769,25 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm( | |||||
| const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == | |||||
| kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" | |||||
| << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| prim->AddAttr(attr_name, | |||||
| TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( | |||||
| const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| switch (attr_tensor_type) { | switch (attr_tensor_type) { | ||||
| @@ -796,20 +821,59 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &pr | |||||
| break; | break; | ||||
| } | } | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " | |||||
| << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( | |||||
| const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; | |||||
| return false; | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | |||||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||||
| std::vector<int> shape; | |||||
| auto ret = EOK; | |||||
| if (attr_tensor.dims_size() != 0) { | |||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | |||||
| shape.push_back(attr_tensor.dims(i)); | |||||
| } | |||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>( | |||||
| kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| tensor_info->MallocData(); | |||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | |||||
| ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue(tensor_info)); | |||||
| } else { | |||||
| if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) { | |||||
| size_t data_size = sizeof(double); | |||||
| double attr_value = 0.0; | |||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue<double>(attr_value)); | |||||
| } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) { | |||||
| size_t data_size = sizeof(int64_t); | |||||
| int32_t attr_value = 0; | |||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue<int32_t>(attr_value)); | |||||
| } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) { | |||||
| size_t data_size = sizeof(bool); | |||||
| bool attr_value = false; | |||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue<bool>(attr_value)); | |||||
| } | |||||
| } | |||||
| return ret == EOK; | |||||
| } | } | ||||
| bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||||
| bool AnfImporterFromProtobuf::GetAttrValueForCNode( | |||||
| const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| const std::string &attr_name = attr_proto.name(); | const std::string &attr_name = attr_proto.name(); | ||||
| if (!attr_proto.has_ref_attr_name()) { | if (!attr_proto.has_ref_attr_name()) { | ||||
| @@ -833,18 +897,20 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( | |||||
| const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| std::vector<int> shape; | std::vector<int> shape; | ||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | for (int i = 0; i < attr_tensor.dims_size(); ++i) { | ||||
| shape.push_back(attr_tensor.dims(i)); | shape.push_back(attr_tensor.dims(i)); | ||||
| } | } | ||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>( | |||||
| kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| tensor_info->MallocData(); | tensor_info->MallocData(); | ||||
| const std::string &tensor_buf = attr_tensor.raw_data(); | const std::string &tensor_buf = attr_tensor.raw_data(); | ||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | ||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| if (EOK != ret) { | if (EOK != ret) { | ||||
| MS_LOG(ERROR) << "memcpy_s error"; | MS_LOG(ERROR) << "memcpy_s error"; | ||||
| return false; | return false; | ||||
| @@ -852,14 +918,15 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val | |||||
| auto new_value_node = NewValueNode(MakeValue(tensor_info)); | auto new_value_node = NewValueNode(MakeValue(tensor_info)); | ||||
| MS_EXCEPTION_IF_NULL(new_value_node); | MS_EXCEPTION_IF_NULL(new_value_node); | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); | auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| new_value_node->set_abstract(abstract_tensor); | new_value_node->set_abstract(abstract_tensor); | ||||
| anfnode_build_map_[value_node_name] = new_value_node; | anfnode_build_map_[value_node_name] = new_value_node; | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( | |||||
| const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| ValuePtr value_ptr = nullptr; | ValuePtr value_ptr = nullptr; | ||||
| switch (attr_tensor_type) { | switch (attr_tensor_type) { | ||||
| @@ -871,7 +938,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val | |||||
| if (add_data.size() == 1) { | if (add_data.size() == 1) { | ||||
| value_ptr = MakeValue(add_data[0]); | value_ptr = MakeValue(add_data[0]); | ||||
| } else if (!add_data.empty()) { | } else if (!add_data.empty()) { | ||||
| value_ptr = MakeValue<std::vector<int32>>(add_data); | |||||
| value_ptr = MakeValue<std::vector<int32> >(add_data); | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -884,7 +951,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val | |||||
| if (add_data.size() == 1) { | if (add_data.size() == 1) { | ||||
| value_ptr = MakeValue(add_data[0]); | value_ptr = MakeValue(add_data[0]); | ||||
| } else if (!add_data.empty()) { | } else if (!add_data.empty()) { | ||||
| value_ptr = MakeValue<std::vector<float>>(add_data); | |||||
| value_ptr = MakeValue<std::vector<float> >(add_data); | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -894,7 +961,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val | |||||
| break; | break; | ||||
| } | } | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " | |||||
| << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto new_value_node = NewValueNode(value_ptr); | auto new_value_node = NewValueNode(value_ptr); | ||||
| @@ -905,23 +973,28 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm( | |||||
| const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == | |||||
| kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) | |||||
| << "Obtain ValueNode attr in type-form has not support input type: " | |||||
| << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||||
| auto new_value_node = | |||||
| NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| abstract::AbstractTypePtr abs_type = | |||||
| std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||||
| new_value_node->set_abstract(abs_type); | new_value_node->set_abstract(abs_type); | ||||
| anfnode_build_map_[value_node_name] = new_value_node; | anfnode_build_map_[value_node_name] = new_value_node; | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, | |||||
| const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::GetAttrValueForValueNode( | |||||
| const std::string &ref_attr_name, const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| switch (kParseTypeSwitchMap[ref_attr_name]) { | switch (kParseTypeSwitchMap[ref_attr_name]) { | ||||
| case FORM_PARSE_SCALAR: { | case FORM_PARSE_SCALAR: { | ||||
| return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); | return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); | ||||
| @@ -933,12 +1006,14 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_at | |||||
| return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); | return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); | ||||
| } | } | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; | |||||
| MS_LOG(ERROR) | |||||
| << "parse ValueNode value don't support input of ref_attr_name"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||||
| bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( | |||||
| const onnx::NodeProto &node_proto) { | |||||
| const std::string &value_node_name = node_proto.output(0); | const std::string &value_node_name = node_proto.output(0); | ||||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | ||||
| if (!attr_proto.has_ref_attr_name()) { | if (!attr_proto.has_ref_attr_name()) { | ||||
| @@ -951,20 +1026,23 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto & | |||||
| return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); | return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); | ||||
| } | } | ||||
| abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { | |||||
| abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode( | |||||
| const onnx::AttributeProto &attr_proto) { | |||||
| std::vector<int> shape_vec; | std::vector<int> shape_vec; | ||||
| const onnx::TensorProto &attr_tensor = attr_proto.t(); | const onnx::TensorProto &attr_tensor = attr_proto.t(); | ||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | for (int i = 0; i < attr_tensor.dims_size(); ++i) { | ||||
| shape_vec.push_back(attr_tensor.dims(i)); | shape_vec.push_back(attr_tensor.dims(i)); | ||||
| } | } | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); | auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); | |||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); | |||||
| MS_EXCEPTION_IF_NULL(abstract_tensor); | MS_EXCEPTION_IF_NULL(abstract_tensor); | ||||
| return abstract_tensor; | return abstract_tensor; | ||||
| } | } | ||||
| CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::NodeProto &node_proto) { | |||||
| CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | |||||
| const schema::QuantType &quantType) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| if (!node_proto.has_op_type()) { | if (!node_proto.has_op_type()) { | ||||
| MS_LOG(ERROR) << "Get CNode op_type failed!"; | MS_LOG(ERROR) << "Get CNode op_type failed!"; | ||||
| @@ -1004,20 +1082,24 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||||
| for (int i = 0; i < node_proto.input_size(); ++i) { | for (int i = 0; i < node_proto.input_size(); ++i) { | ||||
| const std::string &input_name = node_proto.input(i); | const std::string &input_name = node_proto.input(i); | ||||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | ||||
| MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; | |||||
| MS_LOG(ERROR) << node_name << " input " << i << input_name | |||||
| << "can't find in nodes have parsed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| inputs.push_back(anfnode_build_map_[input_name]); | inputs.push_back(anfnode_build_map_[input_name]); | ||||
| } | } | ||||
| std::string opType = prim->name(); | std::string opType = prim->name(); | ||||
| auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| auto node_parser = | |||||
| AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| if (node_parser == nullptr) { | if (node_parser == nullptr) { | ||||
| MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto primitiveT = std::make_unique<schema::PrimitiveT>(); | auto primitiveT = std::make_unique<schema::PrimitiveT>(); | ||||
| // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); | // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); | ||||
| std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = std::make_shared<PrimitiveTValue>(primitiveT.release()); | |||||
| std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = | |||||
| std::make_shared<PrimitiveTValue>(primitiveT.release()); | |||||
| primitiveTValuePtr->SetQuantType(quantType); | |||||
| node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); | node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); | ||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | MS_ASSERT(primitiveTValuePtr != nullptr); | ||||
| inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr)); | inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr)); | ||||
| @@ -1048,8 +1130,9 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out | |||||
| return cnode_ptr; | return cnode_ptr; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||||
| bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const CNodePtr &cnode_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | MS_EXCEPTION_IF_NULL(cnode_ptr); | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| @@ -1064,7 +1147,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||||
| elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | ||||
| } | } | ||||
| auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | ||||
| maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| maketuple_ptr->set_abstract( | |||||
| std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| inputs.clear(); | inputs.clear(); | ||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | inputs.push_back(NewValueNode(prim::kPrimReturn)); | ||||
| inputs.push_back(maketuple_ptr); | inputs.push_back(maketuple_ptr); | ||||
| @@ -1077,11 +1161,14 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||||
| const onnx::TypeProto &output_typeproto = output_node.type(); | const onnx::TypeProto &output_typeproto = output_node.type(); | ||||
| int output_type = output_typeproto.tensor_type().elem_type(); | int output_type = output_typeproto.tensor_type().elem_type(); | ||||
| std::vector<int> output_shape; | std::vector<int> output_shape; | ||||
| for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { | |||||
| output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); | |||||
| for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); | |||||
| ++i) { | |||||
| output_shape.push_back( | |||||
| output_typeproto.tensor_type().shape().dim(i).dim_value()); | |||||
| } | } | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); | |||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); | |||||
| inputs.clear(); | inputs.clear(); | ||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | inputs.push_back(NewValueNode(prim::kPrimReturn)); | ||||
| @@ -1095,8 +1182,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto) { | |||||
| bool AnfImporterFromProtobuf::ImportNodesForGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | ||||
| CNodePtr cnode_ptr = nullptr; | CNodePtr cnode_ptr = nullptr; | ||||
| @@ -1110,7 +1198,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); | |||||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); | |||||
| if (cnode_ptr == nullptr) { | if (cnode_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | ||||
| return false; | return false; | ||||
| @@ -1122,7 +1210,9 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc | |||||
| } | } | ||||
| #endif | #endif | ||||
| bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { | |||||
| bool AnfImporterFromProtobuf::BuildFuncGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | ||||
| MS_EXCEPTION_IF_NULL(debug_info_ptr); | MS_EXCEPTION_IF_NULL(debug_info_ptr); | ||||
| @@ -1135,10 +1225,11 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph | |||||
| if (!ImportParametersForGraph(outputFuncGraph, importProto)) { | if (!ImportParametersForGraph(outputFuncGraph, importProto)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| return ImportNodesForGraph(outputFuncGraph, importProto); | |||||
| return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | |||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||||
| bool AnfImporterFromProtobuf::ParseModelConfigureInfo( | |||||
| const onnx::ModelProto &model_proto) { | |||||
| if (!model_proto.has_producer_name()) { | if (!model_proto.has_producer_name()) { | ||||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | ||||
| return false; | return false; | ||||
| @@ -1159,14 +1250,14 @@ bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mo | |||||
| return true; | return true; | ||||
| } | } | ||||
| int AnfImporterFromProtobuf::Import() { | |||||
| int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||||
| FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>(); | FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>(); | ||||
| MS_EXCEPTION_IF_NULL(dstGraph); | MS_EXCEPTION_IF_NULL(dstGraph); | ||||
| if (!ParseModelConfigureInfo(*onnx_model_)) { | if (!ParseModelConfigureInfo(*onnx_model_)) { | ||||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | ||||
| } | } | ||||
| const onnx::GraphProto &graphBuild = onnx_model_->graph(); | const onnx::GraphProto &graphBuild = onnx_model_->graph(); | ||||
| if (!BuildFuncGraph(dstGraph, graphBuild)) { | |||||
| if (!BuildFuncGraph(dstGraph, graphBuild, quantType)) { | |||||
| MS_LOG(ERROR) << "Build funcgraph failed!"; | MS_LOG(ERROR) << "Build funcgraph failed!"; | ||||
| func_graph_ = nullptr; | func_graph_ = nullptr; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -1176,7 +1267,8 @@ int AnfImporterFromProtobuf::Import() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | |||||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary( | |||||
| const std::string &model_path) { | |||||
| std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | ||||
| if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { | if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { | ||||
| MS_LOG(ERROR) << "open file failed."; | MS_LOG(ERROR) << "open file failed."; | ||||
| @@ -17,20 +17,21 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | ||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | ||||
| #include <string> | |||||
| #include <map> | #include <map> | ||||
| #include <string> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | |||||
| #include "src/common/anf_importer/anf_importer.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "src/common/anf_importer/anf_importer.h" | |||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfImporterFromProtobuf : public AnfImporter { | class AnfImporterFromProtobuf : public AnfImporter { | ||||
| public: | public: | ||||
| explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) | |||||
| : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} | |||||
| explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, | |||||
| FuncGraphPtr func_graph) | |||||
| : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} | |||||
| ~AnfImporterFromProtobuf() override = default; | ~AnfImporterFromProtobuf() override = default; | ||||
| @@ -38,15 +39,17 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||||
| FuncGraphPtr GetResult() override; | FuncGraphPtr GetResult() override; | ||||
| int Import() override; | |||||
| int Import(const schema::QuantType &quantType = | |||||
| schema::QuantType_QUANT_NONE) override; | |||||
| private: | private: | ||||
| void ConverterConstTensor() override {}; | |||||
| int ConverterCNode() override {}; | |||||
| void AddReturnCNode() override {}; | |||||
| void ConverterConstTensor() override{}; | |||||
| int ConverterCNode() override{}; | |||||
| void AddReturnCNode() override{}; | |||||
| bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | ||||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, | bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, | ||||
| const onnx::GraphProto &importProto); | |||||
| const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType); | |||||
| #if 0 | #if 0 | ||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | ||||
| const onnx::GraphProto &importProto); | const onnx::GraphProto &importProto); | ||||
| @@ -78,31 +81,46 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> | std::unordered_map<std::string, abstract::AbstractTensorPtr> | ||||
| GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | ||||
| #else | #else | ||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | |||||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); | |||||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto); | |||||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType); | |||||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, | |||||
| const onnx::ValueInfoProto &value_proto); | |||||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::NodeProto &node_proto, | |||||
| const schema::QuantType &quantType); | |||||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, | |||||
| const CNodePtr &cnode_ptr); | const CNodePtr &cnode_ptr); | ||||
| bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); | |||||
| bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| bool GetAttrValueForCNode(const PrimitivePtr &prim, | |||||
| const onnx::AttributeProto &attr_proto); | |||||
| bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, | |||||
| const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, | |||||
| const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, | |||||
| const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); | bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); | ||||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, | |||||
| bool ObtainValueNodeInScalarForm(const string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| bool GetAttrValueForValueNode(const string &ref_attr_name, | |||||
| const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | |||||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| abstract::AbstractTensorPtr GetAbstractForCNode( | |||||
| const onnx::AttributeProto &attr_proto); | |||||
| #endif | #endif | ||||
| private: | private: | ||||
| std::string producer_name_; | std::string producer_name_; | ||||
| int model_version_{}; | int model_version_{}; | ||||
| @@ -115,4 +133,3 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ | ||||
| @@ -46,6 +46,9 @@ class PrimitiveTValue : public Value { | |||||
| } | } | ||||
| } | } | ||||
| void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) { | |||||
| } | |||||
| void AddInputQuantParam(schema::QuantParamT quant_param) { | void AddInputQuantParam(schema::QuantParamT quant_param) { | ||||
| this->input_quant_param_.emplace_back(quant_param); | this->input_quant_param_.emplace_back(quant_param); | ||||
| } | } | ||||
| @@ -73,6 +73,9 @@ class Tensor : public mindspore::tensor::MetaTensor { | |||||
| size_t Size() const { | size_t Size() const { | ||||
| size_t size = 0; | size_t size = 0; | ||||
| switch (this->data_type_) { | switch (this->data_type_) { | ||||
| case kNumberTypeFloat64: | |||||
| size = sizeof(double); | |||||
| break; | |||||
| case kNumberTypeFloat: | case kNumberTypeFloat: | ||||
| case kNumberTypeFloat32: | case kNumberTypeFloat32: | ||||
| size = sizeof(float); | size = sizeof(float); | ||||
| @@ -71,7 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| FuncGraphPtr graph = nullptr; | FuncGraphPtr graph = nullptr; | ||||
| if (flag->fmk == converter::FmkType_MS) { | if (flag->fmk == converter::FmkType_MS) { | ||||
| MS_ASSERT(nullptr != modelImporter); | MS_ASSERT(nullptr != modelImporter); | ||||
| modelImporter->Import(); | |||||
| modelImporter->Import(flag->quantType); | |||||
| graph = modelImporter->GetResult(); | graph = modelImporter->GetResult(); | ||||
| } else { | } else { | ||||
| MS_ASSERT(nullptr != modelParser); | MS_ASSERT(nullptr != modelParser); | ||||