Merge pull request !5040 from yeyunpeng2020/mastertags/v1.0.0
| @@ -165,14 +165,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||
| for (const auto &cnode : cnodes) { | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto primT = primitiveT_value->GetPrimitiveT(); | |||
| if (primitiveT_value->Type() == schema::PrimitiveType_TupleGetItem || | |||
| primitiveT_value->Type() == schema::PrimitiveType_MakeTuple) { | |||
| auto primT = primitive_c->GetPrimitiveT(); | |||
| if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem || | |||
| primitive_c->Type() == schema::PrimitiveType_MakeTuple) { | |||
| continue; | |||
| } | |||
| RemoveIfMakeTuple(cnode); | |||
| @@ -196,7 +196,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { | |||
| return nullptr; | |||
| } | |||
| SetOpOutputNode(cnode, meta_graphT, node.get()); | |||
| ret = ConvertQuantParam(meta_graphT, primitiveT_value, node); | |||
| ret = ConvertQuantParam(meta_graphT, primitive_c, node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertQuantParam failed"; | |||
| return nullptr; | |||
| @@ -62,12 +62,12 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| for (auto &cnode : cnodes) { | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return; | |||
| } | |||
| auto primT = primitiveT_value->GetPrimitiveT(); | |||
| auto primT = primitive_c->GetPrimitiveT(); | |||
| if (primT == nullptr) { | |||
| MS_LOG(ERROR) << "PrimitiveT is nullptr"; | |||
| return; | |||
| @@ -75,7 +75,7 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) { | |||
| if (primT->value.type == schema::PrimitiveType_TupleGetItem || | |||
| primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) { | |||
| delete primT; | |||
| primitiveT_value->SetPrimitiveT(nullptr); | |||
| primitive_c->SetPrimitiveT(nullptr); | |||
| } | |||
| } | |||
| } | |||
| @@ -534,7 +534,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||
| return RET_OK; | |||
| } | |||
| STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, | |||
| STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, | |||
| bool perchanel, bool depthwise) { | |||
| // const vector<int> dims = filter->dims; | |||
| // perlayer | |||
| @@ -552,7 +552,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P | |||
| MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num, | |||
| auto status = QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, | |||
| perchanel, depthwise); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | |||
| @@ -573,8 +573,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P | |||
| return RET_OK; | |||
| } | |||
| STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitiveT_value) { | |||
| if (primitiveT_value == nullptr || bias == nullptr) { | |||
| STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c) { | |||
| if (primitive_c == nullptr || bias == nullptr) { | |||
| MS_LOG(ERROR) << "null pointer!"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| @@ -583,7 +583,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||
| auto bias_default_param = bias_parameter_ptr->default_param(); | |||
| auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param); | |||
| auto active_weight_quant_params = primitiveT_value->GetInputQuantParams(); | |||
| auto active_weight_quant_params = primitive_c->GetInputQuantParams(); | |||
| if (active_weight_quant_params.size() != 2) { | |||
| MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); | |||
| return RET_ERROR; | |||
| @@ -627,7 +627,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi | |||
| quant_param.inited = true; | |||
| quant_params.emplace_back(quant_param); | |||
| } | |||
| primitiveT_value->AddInputQuantParam(quant_params); | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| // quant bias data | |||
| int32_t *quant_datas = new (std::nothrow) int32_t[shape_size]; | |||
| if (quant_datas == nullptr) { | |||
| @@ -683,18 +683,18 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| MS_LOG(INFO) << cnode_name << " can not do quant"; | |||
| continue; | |||
| } | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| continue; | |||
| } | |||
| if (input_scale.find(cnode) == input_scale.end()) { | |||
| primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); | |||
| primitive_c->SetQuantType(schema::QuantType_QUANT_NONE); | |||
| continue; | |||
| } | |||
| primitiveT_value->ClearInputOutputQuantParam(); | |||
| primitive_c->ClearInputOutputQuantParam(); | |||
| auto op_name = cnode->fullname_with_scope(); | |||
| auto op_type = (schema::PrimitiveType)primitiveT_value->Type(); | |||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | |||
| MS_LOG(INFO) << "OpName: " << op_name; | |||
| if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && | |||
| op_type != PrimitiveType_FullConnection) { | |||
| @@ -715,35 +715,35 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||
| if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { | |||
| MS_LOG(DEBUG) << "this parameter do quant"; | |||
| DoWeightQuant(input_node, primitiveT_value, false, false); | |||
| DoWeightQuant(input_node, primitive_c, false, false); | |||
| } else { | |||
| MS_LOG(DEBUG) << "this parameter no need to do quant"; | |||
| } | |||
| continue; | |||
| } | |||
| auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node); | |||
| auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| if (input_cnode_primitiveT_value == nullptr) { | |||
| auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| if (input_cnode_primitive_c == nullptr) { | |||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||
| << " PrimitiveC is null"; | |||
| continue; | |||
| } | |||
| if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) { | |||
| for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) { | |||
| primitiveT_value->AddInputQuantParam(quant_param); | |||
| if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) { | |||
| for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) { | |||
| primitive_c->AddInputQuantParam(quant_param); | |||
| } | |||
| } else { | |||
| // do input quant | |||
| double scale = input_scale[cnode]; | |||
| int32_t zp = input_zero_point[cnode]; | |||
| DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value); | |||
| DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c); | |||
| } | |||
| } | |||
| } else { | |||
| // do input quant | |||
| double scale = input_scale[cnode]; | |||
| int32_t convInputzeropoint = input_zero_point[cnode]; | |||
| DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value); | |||
| DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); | |||
| // do weight quant | |||
| auto weight = cnode->input(2); | |||
| bool depthwise = op_type == PrimitiveType_DepthwiseConv2D; | |||
| @@ -751,18 +751,18 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| if (op_type == PrimitiveType_FullConnection) { | |||
| perchannel = false; | |||
| } | |||
| DoWeightQuant(weight, primitiveT_value, perchannel, depthwise); | |||
| DoWeightQuant(weight, primitive_c, perchannel, depthwise); | |||
| // do bias quant | |||
| if (cnode->inputs().size() == 4) { | |||
| auto bias = cnode->input(3); | |||
| DoBiasQuant(bias, primitiveT_value); | |||
| DoBiasQuant(bias, primitive_c); | |||
| } | |||
| } | |||
| // do output quant | |||
| double OutputScale = output_scale[cnode]; | |||
| int32_t OutputZeropoint = output_zeropoint[cnode]; | |||
| DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitiveT_value); | |||
| primitiveT_value->SetQuantType(schema::QuantType_PostTraining); | |||
| DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitive_c); | |||
| primitive_c->SetQuantType(schema::QuantType_PostTraining); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -95,10 +95,10 @@ class PostTrainingQuantizer : public Quantizer { | |||
| STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); | |||
| STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); | |||
| STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, bool perchannel, | |||
| STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel, | |||
| bool depthwise); | |||
| STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitiveT_value); | |||
| STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c); | |||
| }; | |||
| struct DivergInfo; | |||
| @@ -44,17 +44,17 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| bool first = true; | |||
| for (auto &cnode : cnodes) { | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| auto curnode_quant_type = schema::QuantType_QUANT_NONE; | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); | |||
| } else { | |||
| curnode_quant_type = primitiveT_value->GetQuantType(); | |||
| curnode_quant_type = primitive_c->GetQuantType(); | |||
| } | |||
| if (first) { | |||
| if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | |||
| auto value_node = | |||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); | |||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front()); | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | |||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | |||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | |||
| @@ -72,24 +72,24 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| continue; | |||
| } | |||
| auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node); | |||
| auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| if (input_cnode_primitiveT_value == nullptr) { | |||
| auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| if (input_cnode_primitive_c == nullptr) { | |||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||
| << " PrimitiveC is null"; | |||
| continue; | |||
| } | |||
| auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType(); | |||
| auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); | |||
| if (curnode_quant_type != input_cnode_quant_type) { | |||
| ValueNodePtr value_node = nullptr; | |||
| if (curnode_quant_type == schema::QuantType_PostTraining && | |||
| input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | |||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, | |||
| primitiveT_value->GetInputQuantParams().front()); | |||
| primitive_c->GetInputQuantParams().front()); | |||
| } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | |||
| input_cnode_quant_type == schema::QuantType_PostTraining) { | |||
| value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, | |||
| input_cnode_primitiveT_value->GetInputQuantParams().front()); | |||
| input_cnode_primitive_c->GetInputQuantParams().front()); | |||
| } | |||
| if (value_node == nullptr) { | |||
| MS_LOG(WARNING) << "value_node is null! " | |||
| @@ -87,13 +87,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||
| } | |||
| auto cnode = std::dynamic_pointer_cast<CNode>(node); | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| auto type = (schema::PrimitiveType)primitiveT_value->Type(); | |||
| auto type = (schema::PrimitiveType)primitive_c->Type(); | |||
| MS_LOG(INFO) << "Primitive type: " << type; | |||
| static const std::vector<schema::PrimitiveType> uint8OpList = { | |||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, | |||
| @@ -279,7 +279,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl | |||
| return RET_OK; | |||
| } | |||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, QuantType quantType, | |||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | |||
| int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { | |||
| auto dims = weight->tensor_shape(); | |||
| if (per_channel) { | |||
| @@ -450,7 +450,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||
| MS_LOG(ERROR) << "quant_params empty"; | |||
| return RET_ERROR; | |||
| } | |||
| primitiveT_value->AddInputQuantParam(quant_params); | |||
| primitive_c->AddInputQuantParam(quant_params); | |||
| return RET_OK; | |||
| } | |||
| @@ -118,7 +118,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan | |||
| }(); | |||
| } | |||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, QuantType quantType, | |||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | |||
| int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false, | |||
| bool depth_wise = false); | |||
| @@ -135,6 +135,26 @@ void FreeInputTensor(std::vector<Tensor *> *input_tensor) { | |||
| } | |||
| return; | |||
| } | |||
| schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| return nullptr; | |||
| } | |||
| auto *lite_primitive = primitive_c->GetPrimitiveT(); | |||
| if (lite_primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive in primitive_c is nullptr"; | |||
| return nullptr; | |||
| } | |||
| flatbuffers::FlatBufferBuilder builder(1024); | |||
| auto offset = schema::Primitive::Pack(builder, lite_primitive); | |||
| builder.Finish(offset); | |||
| auto buf = builder.GetBufferPointer(); | |||
| auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf); | |||
| return const_cast<schema::Primitive *>(primitive); | |||
| } | |||
| const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| CheckIfFuncGraphIsNull(func_graph); | |||
| @@ -155,10 +175,16 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| } | |||
| MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); | |||
| auto output_nums = GetOutputTensorNum(input_cnode); | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| std::vector<Tensor *> output_tensors{output_nums, new Tensor()}; | |||
| primitiveT_value->InferShape(input_tensors, output_tensors); | |||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get()); | |||
| auto scheam_primitive = PackPrimitiveT(input_cnode); | |||
| auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive); | |||
| if (lite_primitive == nullptr) { | |||
| MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr"; | |||
| FreeInputTensor(&input_tensors); | |||
| return nullptr; | |||
| } | |||
| lite_primitive->InferShape(input_tensors, output_tensors); | |||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive); | |||
| if (lite_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; | |||
| FreeInputTensor(&input_tensors); | |||
| @@ -62,17 +62,17 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c | |||
| } | |||
| auto conv_node = pre_node->cast<CNodePtr>(); | |||
| auto node_type = GetCNodeType(conv_node); | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0)); | |||
| MS_ASSERT(primitiveT_value); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0)); | |||
| MS_ASSERT(primitive_c); | |||
| if (node_type == schema::PrimitiveType_Conv2D) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| primc->SetActivationType(activation_type); | |||
| return pre_node; | |||
| } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| primc->SetActivationType(activation_type); | |||
| return pre_node; | |||
| @@ -160,22 +160,22 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons | |||
| auto conv_node = conv_node_anf->cast<CNodePtr>(); | |||
| CheckIfCNodeIsNull(conv_node); | |||
| GenConvNewBias(func_graph, conv_node, add_node); | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0)); | |||
| MS_ASSERT(primitiveT_value != nullptr); | |||
| auto type = primitiveT_value->Type(); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0)); | |||
| MS_ASSERT(primitive_c != nullptr); | |||
| auto type = primitive_c->Type(); | |||
| if (type == schema::PrimitiveType_Conv2D) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| primc->SetHasBias(true); | |||
| } else if (type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| primc->SetHasBias(true); | |||
| } else if (type == schema::PrimitiveType_DeConv2D) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| primc->SetHasBias(true); | |||
| } else { | |||
| @@ -115,14 +115,14 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern | |||
| AnfNodePtr bn_scale_node = nullptr; | |||
| AnfNodePtr bn_bias_node = nullptr; | |||
| float eps = 0; | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(bn_node->input(0)); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(bn_node->input(0)); | |||
| if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) { | |||
| bn_mean_node = bn_node->input(kCaffeBNMeanIndex); | |||
| bn_variance_node = bn_node->input(kCaffeBNVarIndex); | |||
| CheckIfNodeIsParam(bn_mean_node); | |||
| CheckIfNodeIsParam(bn_variance_node); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| eps = primc->GetEpsilon(); | |||
| } else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) { | |||
| @@ -130,8 +130,8 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern | |||
| bn_bias_node = bn_node->input(kTFBNBiasIndex); | |||
| bn_mean_node = bn_node->input(kTFBNMeanIndex); | |||
| bn_variance_node = bn_node->input(kTFBNVarIndex); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| eps = primc->GetEpsilon(); | |||
| } else { | |||
| @@ -97,17 +97,17 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||
| GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); | |||
| delete[] trans_bias; | |||
| delete[] trans_scale; | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0)); | |||
| MS_ASSERT(primitiveT_value != nullptr); | |||
| auto type = primitiveT_value->Type(); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0)); | |||
| MS_ASSERT(primitive_c != nullptr); | |||
| auto type = primitive_c->Type(); | |||
| if (type == schema::PrimitiveType_Conv2D) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| primc->SetHasBias(true); | |||
| } else if (type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value); | |||
| MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c)); | |||
| auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c); | |||
| MS_ASSERT(primc != nullptr); | |||
| primc->SetHasBias(true); | |||
| } else { | |||