| @@ -17,6 +17,16 @@ | |||
| #include "nnacl/infer/arithmetic_infer.h" | |||
| #include "nnacl/infer/infer_register.h" | |||
| void SetOutputDtypeFormat(const TensorC *input0, const TensorC *input1, TensorC *output) { | |||
| output->format_ = input0->format_; | |||
| output->data_type_ = input0->data_type_; | |||
| // when input0 is const, it is quanted before insert quant trans op, so use input1 data type instead | |||
| if (input0->data_ != NULL || | |||
| ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32))) { | |||
| output->data_type_ = input1->data_type_; | |||
| } | |||
| } | |||
| void UpdateInputShape(const int input_shape0_size, const int input_shape1_size, int *ndim, const int *input_shape0, | |||
| const int *input_shape1, int *in_shape0, int *in_shape1) { | |||
| if (input_shape0_size < input_shape1_size) { | |||
| @@ -71,11 +81,7 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso | |||
| size_t input_shape0_size = input0->shape_size_; | |||
| const int *input_shape1 = input1->shape_; | |||
| size_t input_shape1_size = input1->shape_size_; | |||
| output->format_ = input0->format_; | |||
| output->data_type_ = input0->data_type_; | |||
| if ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32)) { | |||
| output->data_type_ = input1->data_type_; | |||
| } | |||
| SetOutputDtypeFormat(input0, input1, output); | |||
| if (!parameter->infer_flag_) { | |||
| return NNACL_INFER_INVALID; | |||
| @@ -1,4 +1,4 @@ | |||
| video_infer.tflite | |||
| video_infer2.tflite | |||
| mobilenet_v1_0.25_128_quant.tflite | |||
| mobilenet_v1_0.25_160_quant.tflite | |||
| mobilenet_v1_0.25_192_quant.tflite | |||
| @@ -38,6 +38,8 @@ | |||
| #include "ops/op_utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "src/ops/ops_utils.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/converter/converter_context.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| @@ -81,9 +83,9 @@ std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) { | |||
| } | |||
| } // namespace | |||
| int AnfExporter::SetQuantOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<mindspore::Primitive> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | |||
| int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<mindspore::Primitive> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | |||
| auto first_output_index = dst_node->outputIndex[0]; | |||
| auto first_tensor_output = meta_graph->allTensors[first_output_index].get(); | |||
| if (dst_node->quantType == schema::QuantType_PostTraining) { | |||
| @@ -116,82 +118,63 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me | |||
| QuantParamsVector output_quant_params; | |||
| dst_node->quantType = schema::QuantType_QUANT_NONE; | |||
| auto quant_tensor_info_ptr = primitive->GetAttr("quant_params"); | |||
| if (quant_tensor_info_ptr != nullptr) { | |||
| auto quant_param_holder = quant_tensor_info_ptr->cast<QuantParamHolderPtr>(); | |||
| if (quant_param_holder == nullptr) { | |||
| MS_LOG(ERROR) << "quant param is invalid."; | |||
| return RET_ERROR; | |||
| } | |||
| input_quant_params = quant_param_holder->get_input_quant_params(); | |||
| output_quant_params = quant_param_holder->get_output_quant_params(); | |||
| dst_node->quantType = quant_param_holder->quant_type(); | |||
| QuantParamHolderPtr quant_param_holder = nullptr; | |||
| if (quant_tensor_info_ptr == nullptr || | |||
| (quant_param_holder = quant_tensor_info_ptr->cast<QuantParamHolderPtr>()) == nullptr) { | |||
| quant_param_holder = std::make_shared<QuantParamHolder>(dst_node->inputIndex.size(), dst_node->outputIndex.size()); | |||
| } | |||
| // add quant param | |||
| if (!input_quant_params.empty()) { | |||
| for (size_t i = 0; i < input_quant_params.size(); i++) { | |||
| if (i >= dst_node->inputIndex.size()) { | |||
| MS_LOG(INFO) << "node: " << dst_node->name << " input has " << input_quant_params.size() | |||
| << " quant_params; but only " << dst_node->inputIndex.size() << " input"; | |||
| break; | |||
| } | |||
| auto activate_index = dst_node->inputIndex[i]; | |||
| auto tensor_input = meta_graph->allTensors[activate_index].get(); | |||
| if (tensor_input->quantParams.empty()) { | |||
| for (auto input_quant_param : input_quant_params[i]) { | |||
| auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param); | |||
| MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | |||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||
| input_quant_param_ptr->dstDtype = tensor_input->dataType; | |||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||
| } | |||
| input_quant_params = quant_param_holder->get_input_quant_params(); | |||
| output_quant_params = quant_param_holder->get_output_quant_params(); | |||
| dst_node->quantType = quant_param_holder->quant_type(); | |||
| // convert input quant param | |||
| for (size_t i = 0; i < dst_node->inputIndex.size(); i++) { | |||
| if (i >= input_quant_params.size()) { | |||
| MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->inputIndex.size() << ", but only has" | |||
| << input_quant_params.size() << " quant params"; | |||
| break; | |||
| } | |||
| auto activate_index = dst_node->inputIndex[i]; | |||
| auto tensor_input = meta_graph->allTensors[activate_index].get(); | |||
| if (tensor_input->quantParams.empty()) { | |||
| for (auto input_quant_param : input_quant_params[i]) { | |||
| auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param); | |||
| MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | |||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||
| input_quant_param_ptr->dstDtype = tensor_input->dataType; | |||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||
| } | |||
| if (!tensor_input->quantParams.empty()) { | |||
| int bit_num = tensor_input->quantParams.at(0)->numBits; | |||
| if (bit_num != 8 && bit_num != 16) { | |||
| auto status = DoBitPack(bit_num, tensor_input); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "do bit pack failed. " << status; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (!tensor_input->quantParams.empty()) { | |||
| int bit_num = tensor_input->quantParams.at(0)->numBits; | |||
| if (bit_num != 8 && bit_num != 16) { | |||
| auto status = DoBitPack(bit_num, tensor_input); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "do bit pack failed. " << status; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty"; | |||
| } | |||
| // output | |||
| if (output_quant_params.empty()) { | |||
| if (primitive->name() != mindspore::ops::kNameQuantDTypeCast) { | |||
| MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty"; | |||
| } | |||
| } else { | |||
| if (dst_node->outputIndex.size() != output_quant_params.size()) { | |||
| MS_LOG(INFO) << "node: " << dst_node->name << " output has " << output_quant_params.size() | |||
| << " quant_params; but only " << dst_node->outputIndex.size() << " output"; | |||
| return RET_ERROR; | |||
| } | |||
| int output_idx = 0; | |||
| for (const auto &output_quant_param : output_quant_params) { | |||
| auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get(); | |||
| output_idx++; | |||
| for (const auto &channel_quant_param : output_quant_param) { | |||
| if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { | |||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | |||
| std::make_unique<schema::QuantParamT>(channel_quant_param); | |||
| MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | |||
| << " zp: " << output_quant_param_ptr->zeroPoint; | |||
| output_quant_param_ptr->dstDtype = output_tensor->dataType; | |||
| output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); | |||
| } | |||
| int output_idx = 0; | |||
| for (const auto &output_quant_param : output_quant_params) { | |||
| auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get(); | |||
| output_idx++; | |||
| for (const auto &channel_quant_param : output_quant_param) { | |||
| if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { | |||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | |||
| std::make_unique<schema::QuantParamT>(channel_quant_param); | |||
| MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | |||
| << " zp: " << output_quant_param_ptr->zeroPoint; | |||
| output_quant_param_ptr->dstDtype = output_tensor->dataType; | |||
| output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); | |||
| } | |||
| } | |||
| } | |||
| auto status = SetQuantOutputTensorType(meta_graph, primitive, dst_node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "set quant output tensor data type failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -224,6 +207,7 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &m | |||
| tensor->format = schema::Format_NHWC; | |||
| if (!IsContain(subgraph->inputIndices, input)) { | |||
| if (subgraph_index == kMainGraphIndex) { | |||
| TensorDataType::GetInstance()->UpdateGraphInputDType(meta_graphT->inputIndex.size(), tensor->dataType); | |||
| meta_graphT->inputIndex.push_back(input); | |||
| } | |||
| subgraph->inputIndices.push_back(input); | |||
| @@ -262,6 +246,8 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap | |||
| } | |||
| for (unsigned int &i : return_node->inputIndex) { | |||
| if (subgraph_index == kMainGraphIndex) { | |||
| auto &tensor = meta_graphT->allTensors.at(i); | |||
| TensorDataType::GetInstance()->UpdateGraphOutputDType(meta_graphT->outputIndex.size(), tensor->dataType); | |||
| meta_graphT->outputIndex.push_back(i); | |||
| } | |||
| meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i); | |||
| @@ -354,6 +340,13 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||
| MS_LOG(ERROR) << "ConvertQuantParam failed"; | |||
| break; | |||
| } | |||
| auto status = SetPostTrainOutputTensorType(meta_graphT, prim, node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "set quant output tensor data type failed."; | |||
| break; | |||
| } | |||
| meta_graphT->nodes.push_back(std::move(node)); | |||
| meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++); | |||
| } | |||
| @@ -615,7 +608,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| return; | |||
| } | |||
| auto elements = tuple->elements(); | |||
| for (size_t i = 0; i < elements.size(); i++) { | |||
| for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) { | |||
| auto msTensor = new (std::nothrow) schema::TensorT(); | |||
| if (msTensor == nullptr) { | |||
| MS_LOG(ERROR) << "new msTensor failed"; | |||
| @@ -627,8 +620,6 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) | |||
| break; | |||
| } else { | |||
| if (elements.size() == 1) { | |||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||
| @@ -57,9 +57,9 @@ class AnfExporter { | |||
| int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index); | |||
| int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *return_node); | |||
| static int SetQuantOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<mindspore::Primitive> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node); | |||
| static int SetPostTrainOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<mindspore::Primitive> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node); | |||
| static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<mindspore::Primitive> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node); | |||
| @@ -495,13 +495,21 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->src_t == TypeId::kNumberTypeUInt8) { | |||
| if (preTensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| toAddTensor->quantParams.front()->zeroPoint -= 128; | |||
| } else { | |||
| preTensor->quantParams.front()->zeroPoint += 128; | |||
| } | |||
| } else if (prim->dst_t == TypeId::kNumberTypeUInt8) { | |||
| if (preTensor->dataType == TypeId::kNumberTypeInt8) { | |||
| toAddTensor->quantParams.front()->zeroPoint += 128; | |||
| } else { | |||
| preTensor->quantParams.front()->zeroPoint -= 128; | |||
| } | |||
| } | |||
| preTensor->dataType = prim->src_t; | |||
| toAddTensor->dataType = prim->dst_t; | |||
| if (prim->src_t == TypeId::kNumberTypeUInt8 && prim->dst_t == TypeId::kNumberTypeInt8) { | |||
| preTensor->quantParams.front()->zeroPoint += 128; | |||
| } else if (prim->src_t == TypeId::kNumberTypeInt8 && prim->dst_t == TypeId::kNumberTypeUInt8) { | |||
| toAddTensor->quantParams.front()->zeroPoint += 128; | |||
| } | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| @@ -565,13 +573,21 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast(); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (prim->dst_t == TypeId::kNumberTypeUInt8) { | |||
| if (postTensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| postTensor->quantParams.front()->zeroPoint -= 128; | |||
| } else { | |||
| toAddTensor->quantParams.front()->zeroPoint += 128; | |||
| } | |||
| } else if (prim->src_t == TypeId::kNumberTypeUInt8) { | |||
| if (postTensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| toAddTensor->quantParams.front()->zeroPoint -= 128; | |||
| } else { | |||
| postTensor->quantParams.front()->zeroPoint += 128; | |||
| } | |||
| } | |||
| postTensor->dataType = prim->src_t; | |||
| toAddTensor->dataType = prim->dst_t; | |||
| if (prim->src_t == TypeId::kNumberTypeInt8 && prim->dst_t == TypeId::kNumberTypeUInt8) { | |||
| toAddTensor->quantParams.front()->zeroPoint += 128; | |||
| } else if (prim->src_t == TypeId::kNumberTypeUInt8 && prim->dst_t == TypeId::kNumberTypeInt8) { | |||
| postTensor->quantParams.front()->zeroPoint += 128; | |||
| } | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| @@ -446,5 +446,20 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | |||
| } | |||
| return RET_OK; | |||
| } | |||
| size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) { | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (train_flag && | |||
| (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) { | |||
| return 1; | |||
| } | |||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||
| return tuple->elements().size(); | |||
| } else { | |||
| return 1; | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -26,6 +26,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -401,6 +402,8 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) | |||
| } | |||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); | |||
| size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag = false); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H | |||
| @@ -77,18 +77,29 @@ class TensorDataType { | |||
| static TensorDataType tensor_data_type; | |||
| return &tensor_data_type; | |||
| } | |||
| void UpdateTensorType(int32_t index, int32_t type) { tensor_data_type_map_[index] = type; } | |||
| int32_t GetTensorType(int32_t index) const { | |||
| if (tensor_data_type_map_.find(index) == tensor_data_type_map_.end()) { | |||
| void UpdateGraphInputDType(int32_t index, int32_t dtype) { graph_input_data_type_map_[index] = dtype; } | |||
| int32_t GetGraphInputDType(int32_t index) const { | |||
| if (graph_input_data_type_map_.find(index) == graph_input_data_type_map_.end()) { | |||
| return TypeId::kTypeUnknown; | |||
| } | |||
| return graph_input_data_type_map_.at(index); | |||
| } | |||
| void UpdateGraphOutputDType(int32_t index, int32_t dtype) { graph_output_data_type_map_[index] = dtype; } | |||
| int32_t GetGraphOutputDType(int32_t index) const { | |||
| if (graph_output_data_type_map_.find(index) == graph_output_data_type_map_.end()) { | |||
| return TypeId::kTypeUnknown; | |||
| } | |||
| return tensor_data_type_map_.at(index); | |||
| return graph_output_data_type_map_.at(index); | |||
| } | |||
| private: | |||
| TensorDataType() {} | |||
| virtual ~TensorDataType() = default; | |||
| std::map<int32_t, int32_t> tensor_data_type_map_; | |||
| std::map<int32_t, int32_t> graph_input_data_type_map_; | |||
| std::map<int32_t, int32_t> graph_output_data_type_map_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ | |||
| #include "tools/converter/converter_context.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -38,18 +39,17 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||
| return status; | |||
| } | |||
| status = DoNodeInoutDTypeTrans(graph); | |||
| status = DoModelOutputDTypeTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; | |||
| MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| status = DoModelOutputDTypeTrans(graph); | |||
| status = DoNodeInoutDTypeTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; | |||
| MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -61,15 +61,23 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_LOG(ERROR) << "Invalid inputDataType: " << this->input_data_dtype; | |||
| return RET_ERROR; | |||
| } | |||
| for (auto graph_in_idx : graph_in_idxes) { | |||
| for (size_t i = 0; i < graph_in_idxes.size(); i++) { | |||
| size_t graph_in_idx = graph_in_idxes.at(i); | |||
| MS_ASSERT(graph_in_idx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graph_in_idx); | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| if (!quant::TensorQuantParamsInited(*tensor)) { | |||
| continue; | |||
| } | |||
| if (this->input_data_dtype == TypeId::kTypeUnknown) { | |||
| if (tensor->dataType != TensorDataType::GetInstance()->GetGraphInputDType(i)) { | |||
| MS_LOG(ERROR) << "Change graph input dtype is not allowed."; | |||
| return RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| int32_t tensor_data_type = this->input_data_dtype != TypeId::kTypeUnknown | |||
| ? this->input_data_dtype | |||
| : TensorDataType::GetInstance()->GetTensorType(graph_in_idx); | |||
| int32_t tensor_data_type = this->input_data_dtype; | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto node_name = (*iter)->name; | |||
| for (size_t input_indexidx = 0; input_indexidx < (*iter)->inputIndex.size(); input_indexidx++) { | |||
| @@ -77,7 +85,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS status = RET_OK; | |||
| // insert dtype cast node between input tensor and input node | |||
| if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) { | |||
| if (tensor_data_type != tensor->dataType) { | |||
| iter = | |||
| InsertDTypeTransNode(graph, iter, kBefore, input_indexidx, tensor_data_type, tensor->dataType, &status); | |||
| } | |||
| @@ -101,15 +109,23 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| return RET_ERROR; | |||
| } | |||
| auto &graph_out_idxes = graph->outputIndex; | |||
| for (auto graph_out_idx : graph_out_idxes) { | |||
| for (size_t i = 0; i < graph_out_idxes.size(); i++) { | |||
| size_t graph_out_idx = graph_out_idxes.at(i); | |||
| MS_ASSERT(graph_out_idx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graph_out_idx); | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| if (!quant::TensorQuantParamsInited(*tensor)) { | |||
| continue; | |||
| } | |||
| int32_t tensor_data_type = this->output_data_dtype != TypeId::kTypeUnknown | |||
| ? this->output_data_dtype | |||
| : TensorDataType::GetInstance()->GetTensorType(graph_out_idx); | |||
| if (this->output_data_dtype == TypeId::kTypeUnknown) { | |||
| if (tensor->dataType != TensorDataType::GetInstance()->GetGraphOutputDType(i)) { | |||
| MS_LOG(ERROR) << "Change graph output dtype is not allowed."; | |||
| return RET_ERROR; | |||
| } | |||
| continue; | |||
| } | |||
| int32_t tensor_data_type = this->output_data_dtype; | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto node_name = (*iter)->name; | |||
| MS_ASSERT(node != nullptr); | |||
| @@ -117,7 +133,7 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graph_out_idx) { | |||
| // insert transNode | |||
| STATUS status = RET_OK; | |||
| if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) { | |||
| if (tensor_data_type != tensor->dataType) { | |||
| iter = | |||
| InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensor_data_type, &status); | |||
| } | |||
| @@ -136,27 +152,34 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS DTypeTransPass::InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter) { | |||
| auto node_name = (**iter)->name; | |||
| auto status = RET_OK; | |||
| // insert fp32 to int8 before | |||
| // insert fp32/uint8 to int8 before | |||
| for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) { | |||
| auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i)); | |||
| if (pre_tensor->dataType == kNumberTypeFloat32 && !pre_tensor->quantParams.empty() && | |||
| pre_tensor->quantParams.front()->inited) { | |||
| *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeFloat32, kNumberTypeInt8, &status); | |||
| // insert quant cast op for tensor which should be int8 | |||
| if ((pre_tensor->dataType == kNumberTypeFloat32 || pre_tensor->dataType == kNumberTypeUInt8) && | |||
| quant::TensorQuantParamsInited(*pre_tensor)) { | |||
| if (!pre_tensor->data.empty()) { | |||
| MS_LOG(ERROR) << "tensor with float data should be quantized at tensor_quant_pass."; | |||
| return RET_ERROR; | |||
| } | |||
| *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, pre_tensor->dataType, kNumberTypeInt8, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; | |||
| MS_LOG(ERROR) << "Insert float32 or uint8 to int8 node after before " << node_name.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| // insert int8 to fp32 after | |||
| // insert int8 to fp32/uint8 after | |||
| for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) { | |||
| auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i)); | |||
| if (post_tensor->dataType == kNumberTypeFloat32 && !post_tensor->quantParams.empty() && | |||
| post_tensor->quantParams.front()->inited) { | |||
| *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status); | |||
| // insert quant cast op for tensor which should be int8 | |||
| // e.g: reshape's shape tensor don't need insert quant op so its quant param isn't inited | |||
| if ((post_tensor->dataType == kNumberTypeFloat32 || post_tensor->dataType == kNumberTypeUInt8) && | |||
| quant::TensorQuantParamsInited(*post_tensor)) { | |||
| *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, post_tensor->dataType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed"; | |||
| MS_LOG(ERROR) << "Insert int8 to float32 or uint8 node after " << node_name.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| @@ -170,8 +193,7 @@ STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraph | |||
| // insert int8 to fp32 before | |||
| for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) { | |||
| auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i)); | |||
| if (pre_tensor->dataType == kNumberTypeInt8 && !pre_tensor->quantParams.empty() && | |||
| pre_tensor->quantParams.front()->inited) { | |||
| if (pre_tensor->dataType == kNumberTypeInt8 && quant::TensorQuantParamsInited(*pre_tensor)) { | |||
| *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed"; | |||
| @@ -183,8 +205,7 @@ STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraph | |||
| // insert fp32 to int8 after | |||
| for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) { | |||
| auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i)); | |||
| if (post_tensor->dataType == kNumberTypeInt8 && !post_tensor->quantParams.empty() && | |||
| post_tensor->quantParams.front()->inited) { | |||
| if (post_tensor->dataType == kNumberTypeInt8 && quant::TensorQuantParamsInited(*post_tensor)) { | |||
| *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; | |||
| @@ -200,8 +221,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto node_name = (*iter)->name; | |||
| if ((*iter)->inputIndex.empty()) { | |||
| MS_LOG(ERROR) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least"; | |||
| return RET_ERROR; | |||
| MS_LOG(WARNING) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least"; | |||
| continue; | |||
| } | |||
| if ((*iter)->primitive->value.type == schema::PrimitiveType_QuantDTypeCast || | |||
| @@ -270,6 +291,10 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte | |||
| trans_node->name = "uint8toint8_" + tile_name + std::to_string(id_++); | |||
| } else if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeUInt8) { | |||
| trans_node->name = "int8touint8_" + tile_name + std::to_string(id_++); | |||
| } else if (input_data_type == TypeId::kNumberTypeUInt8 && output_data_type == TypeId::kNumberTypeFloat32) { | |||
| trans_node->name = "uint8toft32_" + tile_name + std::to_string(id_++); | |||
| } else if (input_data_type == TypeId::kNumberTypeFloat32 && output_data_type == TypeId::kNumberTypeUInt8) { | |||
| trans_node->name = "ft32touint8_" + tile_name + std::to_string(id_++); | |||
| } | |||
| trans_node->primitive->value.value = quant_dtype_cast_param; | |||
| int insert_num = 0; | |||
| @@ -25,29 +25,15 @@ | |||
| namespace mindspore::lite { | |||
| namespace { | |||
| STATUS PreHandleQuantDtypeCast(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| if (node == nullptr || node->primitive == nullptr) { | |||
| MS_LOG(ERROR) << " node or node->primitive is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) { | |||
| auto attr = node->primitive->value.AsQuantDTypeCast(); | |||
| auto &inputTensor = graph->allTensors.at(node->inputIndex.front()); | |||
| inputTensor->dataType = attr->src_t; | |||
| auto &outputTensor = graph->allTensors.at(node->outputIndex.front()); | |||
| outputTensor->dataType = attr->dst_t; | |||
| if (attr->src_t == TypeId::kNumberTypeUInt8) { | |||
| attr->src_t = TypeId::kNumberTypeInt8; | |||
| } | |||
| if (attr->dst_t == TypeId::kNumberTypeUInt8) { | |||
| attr->dst_t = TypeId::kNumberTypeInt8; | |||
| } | |||
| } | |||
| bool TensorNeedQuant(const std::unique_ptr<TensorT> &tensor) { | |||
| if (!quant::TensorQuantParamsInited(*tensor)) { | |||
| return false; | |||
| } | |||
| return RET_OK; | |||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||
| tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) { | |||
| return false; | |||
| } | |||
| return !tensor->data.empty(); | |||
| } | |||
| STATUS ComputeDataToInt8(const std::unique_ptr<TensorT> &tensor, int32_t index) { | |||
| @@ -73,7 +59,6 @@ STATUS ComputeDataToInt8(const std::unique_ptr<TensorT> &tensor, int32_t index) | |||
| weightQauntParam->zeroPoint -= 128; | |||
| tensor->quantParams.clear(); | |||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | |||
| TensorDataType::GetInstance()->UpdateTensorType(index, TypeId::kNumberTypeUInt8); | |||
| } | |||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||
| if (tensor->data.empty()) { | |||
| @@ -174,23 +159,15 @@ STATUS ComputeQuantTensorPerChannel(TensorT *tensor, const int &tensor_index, co | |||
| STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto status = PreHandleQuantDtypeCast(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "pre adjust failed."; | |||
| return status; | |||
| } | |||
| int32_t index = 0; | |||
| auto status = RET_OK; | |||
| for (auto &tensor : graph->allTensors) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| index++; | |||
| continue; | |||
| } | |||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||
| tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) { | |||
| if (!TensorNeedQuant(tensor)) { | |||
| index++; | |||
| continue; | |||
| } | |||
| if (tensor->quantParams.size() != 1) { // perchannel | |||
| if (tensor->quantParams.size() > 1) { // perchannel | |||
| status = ComputeQuantTensorPerChannel(tensor.get(), index, *graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "compute tensor to int8 prechannel failed."; | |||
| @@ -201,8 +178,8 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||
| } | |||
| // perlayer | |||
| auto &quantParam = tensor->quantParams.front(); | |||
| if (quantParam->dstDtype == TypeId::kNumberTypeUInt8 || quantParam->dstDtype == TypeId::kNumberTypeFloat32 || | |||
| quantParam->dstDtype == TypeId::kNumberTypeFloat) { | |||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8 || quantParam->dstDtype == TypeId::kNumberTypeUInt8 || | |||
| quantParam->dstDtype == TypeId::kNumberTypeFloat32 || quantParam->dstDtype == TypeId::kNumberTypeFloat) { | |||
| status = ComputeDataToInt8(tensor, index); | |||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | |||
| // quant bias data | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "mindspore/core/ir/dtype/type_id.h" | |||
| namespace mindspore::lite { | |||
| static constexpr size_t kInputIndex = 0; | |||
| static constexpr size_t kWeightIndex = 1; | |||
| STATUS QuantDtypeCastQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { | |||
| auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0)); | |||
| if (!input_tensor->quantParams.empty() && input_tensor->quantParams.front()->inited) { | |||
| input_tensor->quantParams.front()->dstDtype = input_tensor->dataType; | |||
| } | |||
| auto &output_tensor = graph->allTensors.at(node.outputIndex.at(0)); | |||
| if (!output_tensor->quantParams.empty() && output_tensor->quantParams.front()->inited) { | |||
| output_tensor->quantParams.front()->dstDtype = output_tensor->dataType; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H | |||
| #include "tools/converter/quantizer/quant_helper/quant_node_helper.h" | |||
| namespace mindspore::lite { | |||
| class QuantDtypeCastQuantParamPropogator : public QuantParamPropogator { | |||
| public: | |||
| STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H | |||
| @@ -25,6 +25,7 @@ | |||
| #include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h" | |||
| #include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h" | |||
| #include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h" | |||
| #include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h" | |||
| namespace mindspore::lite { | |||
| void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) { | |||
| @@ -100,6 +101,7 @@ QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_ty | |||
| QuantHelperRegister::QuantHelperRegister() { | |||
| auto base_propogator = std::make_shared<QuantParamPropogator>(); | |||
| auto base_determiner = std::make_shared<QuantTypeDeterminer>(); | |||
| auto quant_dtype_cast_propogator = std::make_shared<QuantDtypeCastQuantParamPropogator>(); | |||
| auto bias_add_propogator = std::make_shared<BiasAddQuantParamPropogator>(); | |||
| auto carry_data_propogator = std::make_shared<CarryDataQuantParamPropogator>(); | |||
| auto carry_data_determiner = std::make_shared<CarryDataQuantTypeDeterminer>(); | |||
| @@ -127,7 +129,7 @@ QuantHelperRegister::QuantHelperRegister() { | |||
| register_map_[schema::PrimitiveType_MatMul] = new QuantNodeHelper(conv_propogator, conv_determiner); | |||
| register_map_[schema::PrimitiveType_QuantDTypeCast] = | |||
| new QuantNodeHelper(base_propogator, default_quant_all_determiner); | |||
| new QuantNodeHelper(quant_dtype_cast_propogator, default_quant_all_determiner); | |||
| register_map_[schema::PrimitiveType_DetectionPostProcess] = | |||
| new QuantNodeHelper(base_propogator, only_need_inputs_determiner); | |||
| @@ -248,6 +248,25 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) { | |||
| } | |||
| return quant_params_holder; | |||
| } | |||
| bool QuantParamEqual(const schema::QuantParamT &quant_param1, const schema::QuantParamT &quant_param2) { | |||
| return quant_param1.inited == quant_param2.inited && quant_param1.scale == quant_param2.scale && | |||
| quant_param1.zeroPoint == quant_param2.zeroPoint && quant_param1.min == quant_param2.min && | |||
| quant_param1.max == quant_param2.max && quant_param1.narrowRange == quant_param2.narrowRange && | |||
| quant_param1.numBits == quant_param2.numBits && quant_param1.inited == quant_param2.inited && | |||
| quant_param1.varCorr == quant_param2.varCorr && quant_param1.meanCorr == quant_param2.meanCorr; | |||
| } | |||
| bool TensorQuantParamsInited(const schema::TensorT &tensor) { | |||
| if (tensor.quantParams.empty()) { | |||
| return false; | |||
| } | |||
| for (auto &quant_param : tensor.quantParams) { | |||
| if (!quant_param->inited) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max, | |||
| int quant_min, int num_bits) { | |||
| @@ -175,6 +175,10 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan | |||
| }(); | |||
| } | |||
| bool QuantParamEqual(const schema::QuantParamT &quant_param1, const schema::QuantParamT &quant_param2); | |||
| bool TensorQuantParamsInited(const schema::TensorT &tensor); | |||
| template <typename T> | |||
| STATUS DoPerChannelQuant(const tensor::TensorPtr &weight, const QuantType &quant_type, | |||
| std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min, | |||
| @@ -28,6 +28,7 @@ | |||
| #include "tools/common/tensor_util.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "tools/converter/quant_param_holder.h" | |||
| using float16 = Eigen::half; | |||
| @@ -1416,6 +1417,9 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu | |||
| auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm_node}); | |||
| MS_ASSERT(cnode != nullptr); | |||
| cnode->set_fullname_with_scope(cnode_name); | |||
| auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1); | |||
| auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| trans_insert_prim->AddAttr("quant_params", quant_params_holder); | |||
| return cnode; | |||
| } | |||
| @@ -21,20 +21,11 @@ | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/node_util.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| size_t GetCNodeOutputsSize(std::shared_ptr<AnfNode> anf_node) { | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||
| return tuple->elements().size(); | |||
| } else { | |||
| return 1; | |||
| } | |||
| } | |||
| int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { | |||
| auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>(); | |||
| std::vector<schema::QuantParamT> quants; | |||
| @@ -112,27 +103,6 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t | |||
| return lite::RET_OK; | |||
| } | |||
| void CheckQuantParams(const PrimitivePtr &prim) { | |||
| auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>(); | |||
| auto input_quant_params = quant_param_holder->get_input_quant_params(); | |||
| bool is_quant = false; | |||
| for (size_t i = 0; i < input_quant_params.size(); ++i) { | |||
| if (!input_quant_params.at(i).empty() && input_quant_params.at(i).at(0).inited) { | |||
| is_quant = true; | |||
| break; | |||
| } | |||
| } | |||
| auto output_quant_params = quant_param_holder->get_output_quant_params(); | |||
| for (size_t i = 0; i < output_quant_params.size(); ++i) { | |||
| if (!output_quant_params.at(i).empty() && output_quant_params.at(i).at(0).inited) { | |||
| is_quant = true; | |||
| } | |||
| } | |||
| if (!is_quant) { | |||
| prim->EraseAttr("quant_params"); | |||
| } | |||
| } | |||
| int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| auto narrow_range = prim->GetAttr("narrow_range"); | |||
| bool narrow_range_param = false; | |||
| @@ -170,7 +140,6 @@ int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &i | |||
| MS_LOG(ERROR) << "compute output quant param failed."; | |||
| return status; | |||
| } | |||
| CheckQuantParams(prim); | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace | |||
| @@ -236,7 +205,8 @@ int MindirAdjustPass::ComputeQuantParams(std::shared_ptr<AnfNode> anf_node) { | |||
| auto inputs = cnode->inputs(); | |||
| inputs.erase(inputs.begin()); | |||
| auto quant_param_holder = std::make_shared<lite::QuantParamHolder>(inputs.size(), GetCNodeOutputsSize(anf_node)); | |||
| auto quant_param_holder = | |||
| std::make_shared<lite::QuantParamHolder>(inputs.size(), lite::GetCNodeOutputsSize(anf_node, train_flag_)); | |||
| primitive->AddAttr("quant_params", quant_param_holder); | |||
| if (ConvertQuantParam(primitive, inputs) != lite::RET_OK) { | |||
| @@ -25,7 +25,6 @@ | |||
| #include "ops/fusion/slice_fusion.h" | |||
| #include "ops/op_utils.h" | |||
| #include "ops/strided_slice.h" | |||
| #include "tools/converter/quant_param_holder.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -93,9 +92,6 @@ AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_gr | |||
| std::string trans_name = | |||
| before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; | |||
| auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name); | |||
| auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(1, 1); | |||
| auto trans_insert_prim = GetValueNode<PrimitivePtr>(trans_insert_node->input(0)); | |||
| trans_insert_prim->AddAttr("quant_params", quant_params_holder); | |||
| return trans_insert_node; | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include "ops/transpose.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/quant_param_holder.h" | |||
| using mindspore::lite::converter::FmkType_CAFFE; | |||
| using mindspore::lite::converter::FmkType_MS; | |||
| @@ -92,6 +93,7 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu | |||
| } | |||
| auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm"); | |||
| auto prim = std::make_shared<ops::Transpose>(); | |||
| prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>(1, 1)); | |||
| auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node}); | |||
| if (!weight_node->has_default()) { | |||
| MS_LOG(DEBUG) << "Weight parameter should has default parameter."; | |||