diff --git a/mindspore/lite/src/ops/populate/resize_populate.cc b/mindspore/lite/src/ops/populate/resize_populate.cc index 4c9362826f..af6be62d77 100644 --- a/mindspore/lite/src/ops/populate/resize_populate.cc +++ b/mindspore/lite/src/ops/populate/resize_populate.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,8 +32,8 @@ OpParameter *PopulateResizeParameter(const mindspore::lite::PrimitiveC *primitiv resize_param->op_parameter_.type_ = primitive->Type(); auto param = reinterpret_cast(const_cast(primitive)); resize_param->method_ = static_cast(param->GetMethod()); - resize_param->new_height_ = param->GetNewHeight(); - resize_param->new_width_ = param->GetNewWidth(); + resize_param->new_height_ = param->new_height(); + resize_param->new_width_ = param->new_width(); resize_param->coordinate_transform_mode_ = param->GetCoordinateTransformMode(); resize_param->preserve_aspect_ratio_ = param->GetPreserveAspectRatio(); return reinterpret_cast(resize_param); diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index cf7748e0ba..daff37aef1 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -122,6 +122,8 @@ Registry ResizeRegistry(schema::PrimitiveType_Resize, ResizeCreator); namespace { constexpr int kInputRank = 4; } // namespace +int64_t Resize::new_height() const { return new_height_; } +int64_t Resize::new_width() const { return new_width_; } int Resize::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); @@ -145,15 +147,27 @@ int Resize::InferShape(std::vector inputs_, std::vector output_shape; output_shape.push_back(input->Batch()); - if (inputs_.size() == kDoubleNum) { - auto shape_tensor = inputs_.at(1); + auto ret = CalculateNewHeightAndWidth(inputs_); + if (ret == RET_OK) { + output_shape.push_back(new_height_); + output_shape.push_back(new_width_); + output_shape.push_back(input->Channel()); + output->set_shape(output_shape); + } + return ret; +} + +int Resize::CalculateNewHeightAndWidth(const std::vector &inputs) { + auto input = inputs.front(); + if (inputs.size() == kDoubleNum) { + auto shape_tensor = inputs.at(1); if (shape_tensor->data_c() == nullptr) { MS_LOG(INFO) << "Do infer shape in runtime."; return RET_INFER_INVALID; } size_t shape_size = shape_tensor->ElementsNum(); switch (shape_size) { - case kInputRank: { + case kQuadrupleNum: { if (shape_tensor->data_type() == kNumberTypeInt32) { auto data = reinterpret_cast(shape_tensor->data_c()); if (data == nullptr) { @@ -162,12 +176,12 @@ int Resize::InferShape(std::vector inputs_, std::vectorformat()) { case schema::Format_NCHW: - output_shape.push_back(data[2]); - output_shape.push_back(data[3]); + new_height_ = data[2]; + new_width_ = data[3]; break; case schema::Format_NHWC: - output_shape.push_back(data[1]); - output_shape.push_back(data[2]); + new_height_ = data[1]; + new_width_ = data[2]; break; default: MS_LOG(INFO) << "Resize don't support tensor format."; @@ -181,12 +195,12 @@ int Resize::InferShape(std::vector inputs_, std::vectorformat()) { case schema::Format_NCHW: - output_shape.push_back(data[2] * input->Height()); - output_shape.push_back(data[3] * input->Width()); + new_height_ = data[2] * input->Height(); + new_width_ = data[3] * input->Width(); break; case schema::Format_NHWC: - output_shape.push_back(data[1] * input->Height()); - output_shape.push_back(data[2] * input->Width()); + new_height_ = data[1] * input->Height(); + new_width_ = data[2] * input->Width(); break; default: MS_LOG(INFO) << "Resize don't support tensor format."; @@ -195,36 +209,52 @@ int Resize::InferShape(std::vector inputs_, std::vector(shape_tensor->data_c()); if (data == nullptr) { MS_LOG(INFO) << "Resize op size can't cast float."; return RET_INFER_INVALID; } - for (size_t i = 0; i < shape_size; i++) { - output_shape.push_back(data[i]); + new_height_ = data[0]; + new_width_ = data[1]; + break; + } + case kSingleNum: { + // caffe zoom_factor + int scale; + if (shape_tensor->data_type() == kNumberTypeInt32) { + auto data = reinterpret_cast(shape_tensor->data_c()); + if (data == nullptr) { + MS_LOG(INFO) << "Resize op size can't cast int."; + return RET_INFER_INVALID; + } + scale = data[0]; + } else { + MS_LOG(ERROR) << "Unsupported data type:" << shape_tensor->data_type(); + return RET_INFER_ERR; } + new_height_ = input->Height() + (input->Height() - 1) * (scale - 1); + new_width_ = input->Width() + (input->Width() - 1) * (scale - 1); break; } + default: { + MS_LOG(ERROR) << "Unsupported shape size:" << shape_size; + return RET_INFER_ERR; + } } - } else if (inputs_.size() == kSingleNum) { - auto new_height = GetNewHeight(); - auto new_width = GetNewWidth(); - output_shape.push_back(new_height); - output_shape.push_back(new_width); - } else if (inputs_.size() == kQuadrupleNum) { - if (inputs_[3]->data_c() == nullptr) { + } else if (inputs.size() == kSingleNum) { + new_height_ = GetNewHeight(); + new_width_ = GetNewWidth(); + } else if (inputs.size() == kQuadrupleNum) { + if (inputs[3]->data_c() == nullptr) { return RET_INFER_INVALID; } - output_shape.push_back(static_cast(inputs_.at(3)->data_c())[0]); - output_shape.push_back(static_cast(inputs_.at(3)->data_c())[1]); + new_height_ = static_cast(inputs.at(3)->data_c())[0]; + new_height_ = static_cast(inputs.at(3)->data_c())[1]; } else { MS_LOG(ERROR) << "inputs tensor size invalid."; return RET_INFER_ERR; } - output_shape.push_back(input->Channel()); - output->set_shape(output_shape); - return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/resize.h b/mindspore/lite/src/ops/resize.h index ec6846797f..275ac07399 100644 --- a/mindspore/lite/src/ops/resize.h +++ b/mindspore/lite/src/ops/resize.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,14 @@ class Resize : public PrimitiveC { int64_t GetNewWidth() const; bool GetPreserveAspectRatio() const; int GetCoordinateTransformMode() const; + + int64_t new_height() const; + int64_t new_width() const; + + private: + int CalculateNewHeightAndWidth(const std::vector &inputs); + int64_t new_height_; + int64_t new_width_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc b/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc index f8f4641b6d..542a772028 100644 --- a/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc +++ b/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc @@ -78,7 +78,8 @@ domi::ModelBufferData *SubGraphNpuKernel::BuildIRModel() { } int SubGraphNpuKernel::Run() { - return reinterpret_cast(this->executor_)->Run(in_tensors_, out_tensors_, out_nodes_, nodes_); + return reinterpret_cast(this->executor_) + ->Run(in_tensors_, out_tensor_sorted_, out_nodes_, nodes_); } int SubGraphNpuKernel::BuildNPUInputOp() { @@ -156,6 +157,14 @@ std::vector SubGraphNpuKernel::GetNPUNodes(const vectorout_tensors()) { + if (std::find(out_tensors_.begin(), out_tensors_.end(), tensor) != out_tensors_.end()) + this->out_tensor_sorted_[i++] = tensor; + } + } if (subgraph_output_op_.empty()) { MS_LOG(ERROR) << "NPU subgraph output op is empty."; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.h b/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.h index 513a2b4147..5aa5f5adac 100644 --- a/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.h +++ b/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.h @@ -74,6 +74,8 @@ class SubGraphNpuKernel : public SubGraphKernel { std::vector subgraph_input_op_; std::vector subgraph_output_op_; + + std::vector out_tensor_sorted_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_AGENT_SUBGRAPH_NPU_KERNEL_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc index 73c3c1f22f..a43f0145b4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc @@ -62,9 +62,9 @@ int ResizeBaseCPUKernel::CheckParameters() { MS_LOG(INFO) << "Out shape is not assigned"; const_shape_ = false; } else { - auto ret = CalculateLinearNewHeightWidth(); - if (ret != RET_OK) { - return ret; + if (InferShapeDone()) { + new_height_ = out_tensors_.at(0)->shape().at(1); + new_width_ = out_tensors_.at(0)->shape().at(2); } const_shape_ = true; } @@ -78,52 +78,6 @@ int ResizeBaseCPUKernel::CheckParameters() { return RET_OK; } -int ResizeBaseCPUKernel::CalculateLinearNewHeightWidth() { - if (method_ != static_cast(schema::ResizeMethod_LINEAR)) { - return RET_OK; - } - if (in_tensors_.size() != 2) { - return RET_ERROR; - } - auto input_tensor = in_tensors_.at(0); - auto shape_scale_tensor = in_tensors_.at(1); - if (shape_scale_tensor->data_type() == kNumberTypeFloat32) { - // float type means scale - float *shape_scale = reinterpret_cast(shape_scale_tensor->data_c()); - if (shape_scale == nullptr) { - return RET_ERROR; - } - if (shape_scale_tensor->format() == schema::Format_NHWC) { - new_height_ = input_tensor->Height() * shape_scale[1]; - new_width_ = input_tensor->Width() * shape_scale[2]; - } else if (shape_scale_tensor->format() == schema::Format_NCHW) { - new_height_ = input_tensor->Height() * shape_scale[2]; - new_width_ = input_tensor->Width() * shape_scale[3]; - } else { - MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format(); - return RET_ERROR; - } - } else if (shape_scale_tensor->data_type() == kNumberTypeInt32) { - // int32 type means real shape - int32_t *shape_data = reinterpret_cast(shape_scale_tensor->data_c()); - if (shape_data == nullptr) { - return RET_ERROR; - } - if (shape_scale_tensor->format() == schema::Format_NHWC) { - new_height_ = shape_data[1]; - new_width_ = shape_data[2]; - } else if (shape_scale_tensor->format() == schema::Format_NCHW) { - new_height_ = shape_data[2]; - new_width_ = shape_data[3]; - } else { - MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format(); - return RET_ERROR; - } - } - - return RET_OK; -} - int ResizeBaseCPUKernel::CheckInputsOuputs() { if (in_tensors_.size() <= lite::kQuadrupleNum) { for (size_t i = 0; i < in_tensors_.size(); i++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h index 27a6c03bb3..4ec58c94c7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,6 @@ class ResizeBaseCPUKernel : public LiteKernel { private: int CheckParameters(); int CheckInputsOuputs(); - int CalculateLinearNewHeightWidth(); }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/npu/scale_npu.cc b/mindspore/lite/src/runtime/kernel/npu/scale_npu.cc index b31b9034e8..6ac66e4b43 100644 --- a/mindspore/lite/src/runtime/kernel/npu/scale_npu.cc +++ b/mindspore/lite/src/runtime/kernel/npu/scale_npu.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,9 @@ int ScaleNPUKernel::SetNPUInputs(const std::vector &inputs, cons op_->set_attr_axis(scale_parameter_->axis_); op_->set_input_x(*npu_inputs[0]); op_->set_input_scale(*npu_inputs[1]); - op_->set_input_bias(*npu_inputs[2]); + if (npu_inputs[2] != nullptr) { + op_->set_input_bias(*npu_inputs[2]); + } return RET_OK; } diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 9cf5ba5cfe..c0fa2d8f7e 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -218,6 +218,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/if_pass.cc ${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc ${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc + ${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc ) endif() ### train diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 6e0c50920d..01b880ce2b 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -69,6 +69,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/mindir_inputs_adjust_pass.cc ../optimizer/graph/functionalize_control_op_pass.cc ../optimizer/graph/functionalize_while.cc + ../optimizer/graph/inputs_adjust_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 62547ce35e..e01af3355c 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ #include "tools/optimizer/graph/while_pass.h" #include "tools/optimizer/graph/if_pass.h" #include "tools/optimizer/graph/functionalize_control_op_pass.h" +#include "tools/optimizer/graph/inputs_adjust_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/huffman_encode.h" @@ -124,6 +125,7 @@ int AnfTransform::AddGraphPass(const std::shared_ptr &optim auto slice_prepose_pass = std::make_shared(); slice_prepose_pass->SetFmkType(config->fmk); graph_pm->AddPass(slice_prepose_pass); + graph_pm->AddPass(std::make_shared()); optimizer->AddPassManager(graph_pm); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc index c3ddb94aa6..afc3420606 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,9 +27,9 @@ PrimitiveC *CaffeInterpParser::ParseLitePrimitive(const caffe::LayerParameter &p return nullptr; } - const caffe::InterpParameter &interpParam = proto.interp_param(); - if (interpParam.has_height()) { - int64_t height = interpParam.height(); + const caffe::InterpParameter &interp_param = proto.interp_param(); + if (interp_param.has_height()) { + int64_t height = interp_param.height(); if (height < 0) { MS_LOG(ERROR) << "Interp height must be > 0"; return nullptr; @@ -37,8 +37,8 @@ PrimitiveC *CaffeInterpParser::ParseLitePrimitive(const caffe::LayerParameter &p attr->newHeight = height; } - if (interpParam.has_width()) { - int64_t width = interpParam.width(); + if (interp_param.has_width()) { + int64_t width = interp_param.width(); if (width < 0) { MS_LOG(ERROR) << "Interp width must be > 0"; return nullptr; @@ -50,7 +50,11 @@ PrimitiveC *CaffeInterpParser::ParseLitePrimitive(const caffe::LayerParameter &p auto primitive = std::make_unique(); primitive->value.type = schema::PrimitiveType_Resize; primitive->value.value = attr.release(); - return PrimitiveC::Create(primitive.release()); + auto primitive_c = PrimitiveC::Create(primitive.release()); + if (interp_param.has_zoom_factor()) { + primitive_c->AddAttr("zoom_factor", MakeValue(interp_param.zoom_factor())); + } + return primitive_c; } CaffeNodeRegistrar g_caffeInterpParser("Interp", new CaffeInterpParser()); diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc index 082f621e29..0726aa010a 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,14 +34,15 @@ PrimitiveC *CaffeScaleParser::ParseLitePrimitive(const caffe::LayerParameter &pr } const caffe::ScaleParameter &scaleParam = weight.scale_param(); + attr->axis = 1; if (scaleParam.has_axis()) { uint32_t axis_index = 1; if (GetAxisIndex(scaleParam.axis(), &axis_index)) { MS_LOG(ERROR) << "scale get axis failed for layer " << weight.name().c_str(); return nullptr; } + attr->axis = axis_index; } - attr->axis = 1; auto primitive = std::make_unique(); primitive->value.type = schema::PrimitiveType_Scale; primitive->value.value = attr.release(); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 31221f6068..c1d18337a6 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,11 @@ */ #include "tools/optimizer/common/gllo_utils.h" #include -#include #include +#include +#include +#include +#include #include "src/ops/primitive_c.h" #include "src/common/common.h" #include "frontend/operator/ops.h" @@ -120,6 +123,19 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive } } // namespace +bool CheckInputs(const CNodePtr &cnode) { + if (cnode == nullptr) { + MS_LOG(ERROR) << "cnode is nullptr."; + return false; + } + if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(), + [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) { + MS_LOG(ERROR) << "input is nullptr."; + return false; + } + return true; +} + bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { if (node == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); @@ -136,6 +152,55 @@ bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_ty return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); } +bool AnfEqualPrimitive(AnfNodePtr a_node, AnfNodePtr b_node) { + auto a_value_node = a_node->cast(); + auto b_value_node = b_node->cast(); + if (a_value_node == nullptr || b_value_node == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + + auto a_value = a_value_node->value(); + auto b_value = b_value_node->value(); + if (a_value == nullptr || b_value == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + + auto a_prim = a_value->cast(); + auto b_prim = b_value->cast(); + if (a_prim == nullptr || b_prim == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + return a_prim->cast()->Type() == b_prim->cast()->Type(); +} + +bool AnfEqualValueNode(AnfNodePtr a_node, AnfNodePtr b_node) { + auto a_value_node_ptr = a_node->cast(); + auto b_value_node_ptr = b_node->cast(); + if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) { + MS_LOG(ERROR) << "cast value node ptr fail"; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + auto a_value_ptr = a_value_node_ptr->value(); + auto b_value_ptr = b_value_node_ptr->value(); + if (a_value_ptr == nullptr || b_value_ptr == nullptr) { + MS_LOG(ERROR) << "value ptr is nullptr"; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + + if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { + auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get()); + auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get()); + return (*a_obj) == (*b_obj); + } else { + return (*a_value_ptr) == (*b_value_ptr); + } +} + bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (utils::isa(a) && utils::isa(b)) { auto a_node = utils::cast(a); @@ -145,49 +210,10 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { return false; } if (IsValueNode(a_node) && IsValueNode(b_node)) { - auto a_value_node = a_node->cast(); - auto b_value_node = b_node->cast(); - if (a_value_node == nullptr || b_value_node == nullptr) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - - auto a_value = a_value_node->value(); - auto b_value = b_value_node->value(); - if (a_value == nullptr || b_value == nullptr) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - - auto a_prim = a_value->cast(); - auto b_prim = b_value->cast(); - if (a_prim == nullptr || b_prim == nullptr) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - return a_prim->cast()->Type() == b_prim->cast()->Type(); - } else if (a_node->isa() && b_node->isa()) { - auto a_value_node_ptr = a_node->cast(); - auto b_value_node_ptr = b_node->cast(); - if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) { - MS_LOG(ERROR) << "cast value node ptr fail"; - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - auto a_value_ptr = a_value_node_ptr->value(); - auto b_value_ptr = b_value_node_ptr->value(); - if (a_value_ptr == nullptr || b_value_ptr == nullptr) { - MS_LOG(ERROR) << "value ptr is nullptr"; - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { - auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get()); - auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get()); - return (*a_obj) == (*b_obj); - } else { - return (*a_value_ptr) == (*b_value_ptr); - } + return AnfEqualPrimitive(a_node, b_node); + } + if (a_node->isa() && b_node->isa()) { + return AnfEqualValueNode(a_node, b_node); } } if (a.m_ptr->isa() && b.m_ptr->isa()) { @@ -639,70 +665,266 @@ std::shared_ptr>> GetRealNodeUsedListByOu STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, int32_t *filterH, int32_t *filterW) { MS_ASSERT(oriDims.size() == 4); - if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { - *filterK = oriDims.at(lite::KCHW_K); - *filterC = oriDims.at(lite::KCHW_C); - *filterH = oriDims.at(lite::KCHW_H); - *filterW = oriDims.at(lite::KCHW_W); - } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { - *filterC = oriDims.at(lite::CKHW_C); - *filterK = oriDims.at(lite::CKHW_K); - *filterH = oriDims.at(lite::CKHW_H); - *filterW = oriDims.at(lite::CKHW_W); - } else if (type == kHWCK2KCHW || type == kHWCK2CKHW || type == kHWCK2KHWC) { - *filterH = oriDims.at(lite::HWCK_H); - *filterW = oriDims.at(lite::HWCK_W); - *filterC = oriDims.at(lite::HWCK_C); - *filterK = oriDims.at(lite::HWCK_K); - } else if (type == kHWKC2KCHW || type == kHWKC2CKHW || type == kHWKC2KHWC) { - *filterH = oriDims.at(lite::HWKC_H); - *filterW = oriDims.at(lite::HWKC_W); - *filterK = oriDims.at(lite::HWKC_K); - *filterC = oriDims.at(lite::HWKC_C); - } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { - *filterK = oriDims.at(lite::NHWC_N); - *filterH = oriDims.at(lite::NHWC_H); - *filterW = oriDims.at(lite::NHWC_W); - *filterC = oriDims.at(lite::NHWC_C); - } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { - *filterC = oriDims.at(lite::CHWK_C); - *filterH = oriDims.at(lite::CHWK_H); - *filterW = oriDims.at(lite::CHWK_W); - *filterK = oriDims.at(lite::CHWK_K); - } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { - *filterK = oriDims.at(lite::KHWC_K); - *filterH = oriDims.at(lite::KHWC_H); - *filterW = oriDims.at(lite::KHWC_W); - *filterC = oriDims.at(lite::KHWC_C); - } else { + std::unordered_map maps = { + {kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2}, + {kCKHW2HWKC, 2}, {kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3}, + {kHWKC2KCHW, 4}, {kHWKC2CKHW, 4}, {kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5}, + {kNHWC2CKHW, 5}, {kCHWK2HWCK, 6}, {kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7}, + }; + if (maps.find(type) == maps.end()) { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } + switch (maps.find(type)->second) { + case 1: + *filterK = oriDims.at(lite::KCHW_K); + *filterC = oriDims.at(lite::KCHW_C); + *filterH = oriDims.at(lite::KCHW_H); + *filterW = oriDims.at(lite::KCHW_W); + break; + case 2: + *filterC = oriDims.at(lite::CKHW_C); + *filterK = oriDims.at(lite::CKHW_K); + *filterH = oriDims.at(lite::CKHW_H); + *filterW = oriDims.at(lite::CKHW_W); + break; + case 3: + *filterH = oriDims.at(lite::HWCK_H); + *filterW = oriDims.at(lite::HWCK_W); + *filterC = oriDims.at(lite::HWCK_C); + *filterK = oriDims.at(lite::HWCK_K); + break; + case 4: + *filterH = oriDims.at(lite::HWKC_H); + *filterW = oriDims.at(lite::HWKC_W); + *filterK = oriDims.at(lite::HWKC_K); + *filterC = oriDims.at(lite::HWKC_C); + break; + case 5: + *filterK = oriDims.at(lite::NHWC_N); + *filterH = oriDims.at(lite::NHWC_H); + *filterW = oriDims.at(lite::NHWC_W); + *filterC = oriDims.at(lite::NHWC_C); + break; + case 6: + *filterC = oriDims.at(lite::CHWK_C); + *filterH = oriDims.at(lite::CHWK_H); + *filterW = oriDims.at(lite::CHWK_W); + *filterK = oriDims.at(lite::CHWK_K); + break; + case 7: + *filterK = oriDims.at(lite::KHWC_K); + *filterH = oriDims.at(lite::KHWC_H); + *filterW = oriDims.at(lite::KHWC_W); + *filterC = oriDims.at(lite::KHWC_C); + break; + default: + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } return RET_OK; } STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); - if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) { - tensor->set_tensor_shape({filterH, filterW, filterC, filterK}); - } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) { - tensor->set_tensor_shape({filterH, filterW, filterK, filterC}); - } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { - tensor->set_tensor_shape({filterK, filterC, filterH, filterW}); - } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) { - tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); - } else if (type == kKHWC2CHWK) { - tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); - } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC || type == kHWCK2KHWC || - type == kHWKC2KHWC) { - tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); - } else { + std::unordered_map maps = { + {kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1}, + {kKCHW2HWKC, 2}, {kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3}, + {kHWCK2CKHW, 4}, {kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5}, + {kKCHW2KHWC, 6}, {kCKHW2KHWC, 6}, {kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6}, + }; + if (maps.find(type) == maps.end()) { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } + + switch (maps.find(type)->second) { + case 1: + tensor->set_tensor_shape({filterH, filterW, filterC, filterK}); + break; + case 2: + tensor->set_tensor_shape({filterH, filterW, filterK, filterC}); + break; + case 3: + tensor->set_tensor_shape({filterK, filterC, filterH, filterW}); + break; + case 4: + tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); + break; + case 5: + tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); + break; + case 6: + tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); + break; + default: + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } return RET_OK; } + +template +void TransFilterDataCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); + if (type == kCHWK2HWCK) { + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kCHWK2KHWC) { + p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } +} +template +void TransFilterDataKHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } +} + +template +void TransFilterDataKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int k = 0; k < filterK; ++k) { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + if (type == kKCHW2HWCK) { + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kKCHW2KHWC) { + p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else if (type == kKCHW2CKHW) { + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +void TransFilterDataCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + if (type == kCKHW2HWCK) { + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kCKHW2KHWC) { + p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else { + p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } +} +template +void TransFilterDataHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + if (type == kHWCK2KCHW) { + p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else if (type == kHWCK2CKHW) { + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +void TransFilterDataHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kHWKC2KCHW) { + p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +void TransFilterDataNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kNHWC2HWCK) { + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kNHWC2CKHW) { + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +void TransFilterDataKHWC2CHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr &buf) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } +} + template static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { @@ -727,173 +949,41 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType } T *p1Buff = nullptr; T *p2Buff = nullptr; - switch (type) { - case kCHWK2HWCK: - case kCHWK2KHWC: { - for (int c = 0; c < filterC; ++c) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int k = 0; k < filterK; ++k) { - p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); - if (type == kCHWK2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kCHWK2KHWC) { - p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - } - *p2Buff = *p1Buff; - } - } - } - } + + std::unordered_map maps = { + {kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3}, + {kKCHW2KHWC, 3}, {kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4}, + {kHWCK2KCHW, 5}, {kHWCK2CKHW, 5}, {kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6}, + {kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8}, + }; + if (maps.find(type) == maps.end()) { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + switch (maps.find(type)->second) { + case 1: { + TransFilterDataCHWK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; - case kKHWC2HWCK: { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - *p2Buff = *p1Buff; - } - } - } - } + case 2: { + TransFilterDataKHWC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; - case kKCHW2HWCK: - case kKCHW2CKHW: - case kKCHW2KHWC: - case kKCHW2HWKC: { - for (int k = 0; k < filterK; ++k) { - for (int c = 0; c < filterC; ++c) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - if (type == kKCHW2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kKCHW2KHWC) { - p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - } else if (type == kKCHW2CKHW) { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); - } - *p2Buff = *p1Buff; - } - } - } - } + case 3: { + TransFilterDataKCHW(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; - case kCKHW2HWCK: - case kCKHW2KHWC: - case kCKHW2HWKC: { - for (int c = 0; c < filterC; ++c) { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - if (type == kCKHW2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kCKHW2KHWC) { - p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - } else { - p2Buff = - buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); - } - *p2Buff = *p1Buff; - } - } - } - } + case 4: { + TransFilterDataCKHW(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; - case kHWCK2KCHW: - case kHWCK2CKHW: - case kHWCK2KHWC: { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - for (int k = 0; k < filterK; ++k) { - p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - if (type == kHWCK2KCHW) { - p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - } else if (type == kHWCK2CKHW) { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - } - *p2Buff = *p1Buff; - } - } - } - } + case 5: { + TransFilterDataHWCK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; - case kHWKC2KCHW: - case kHWKC2KHWC: - case kHWKC2CKHW: { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - for (int k = 0; k < filterK; ++k) { - p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); - if (type == kHWKC2KCHW) { - p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } - *p2Buff = *p1Buff; - } - } - } - } + case 6: { + TransFilterDataHWKC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; - case kNHWC2HWCK: - case kNHWC2KCHW: - case kNHWC2CKHW: { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); - if (type == kNHWC2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kNHWC2CKHW) { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - } - *p2Buff = *p1Buff; - } - } - } - } + case 7: { + TransFilterDataNHWC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; - case kKHWC2CHWK: { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); - *p2Buff = *p1Buff; - } - } - } - } + case 8: { + TransFilterDataKHWC2CHWK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf); } break; default: { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; @@ -941,6 +1031,22 @@ static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterTyp return lite::RET_OK; } +STATUS TransFilterFormatWithType(const ParamValueLitePtr &tensor, TypeId data_type, + kTransFilterType trans_filter_type) { + if (data_type == kNumberTypeFloat32) { + return TransFilterFormat(tensor, trans_filter_type); + } else if (data_type == kNumberTypeUInt8) { + return TransFilterFormat(tensor, trans_filter_type); + } else if (data_type == kNumberTypeInt8) { + return TransFilterFormat(tensor, trans_filter_type); + } else if (data_type == kNumberTypeFloat16) { + return TransFilterFormat(tensor, trans_filter_type); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } +} + STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) { if (tensor == nullptr) { return lite::RET_NULL_PTR; @@ -953,302 +1059,78 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for auto src_format = tensor->format(); auto data_type = tensor->tensor_type(); lite::STATUS status; + std::unordered_map khwc_trans_maps = { + {schema::Format::Format_KCHW, kKCHW2KHWC}, {schema::Format::Format_CKHW, kCKHW2KHWC}, + {schema::Format::Format_CHWK, kCHWK2KHWC}, {schema::Format::Format_HWCK, kHWCK2KHWC}, + {schema::Format::Format_HWKC, kHWKC2KHWC}, + }; + std::unordered_map hwck_trans_maps = { + {schema::Format::Format_KCHW, kKCHW2HWCK}, + {schema::Format::Format_KHWC, kKHWC2HWCK}, + {schema::Format::Format_CKHW, kCKHW2HWCK}, + {schema::Format::Format_CHWK, kCHWK2HWCK}, + }; + std::unordered_map kchw_trans_maps = { + {schema::Format::Format_HWCK, kHWCK2KCHW}, {schema::Format::Format_HWKC, kHWKC2KCHW}, + {schema::Format::Format_KHWC, kKHWC2KCHW}, {schema::Format::Format_CKHW, kCKHW2KCHW}, + {schema::Format::Format_CHWK, kCHWK2KCHW}, + }; + std::unordered_map ckhw_trans_maps = {{schema::Format::Format_HWCK, kHWCK2CKHW}, + {schema::Format::Format_HWKC, kHWKC2CKHW}, + {schema::Format::Format_KCHW, kKCHW2CKHW}}; + std::unordered_map chwk_trans_maps = {{schema::Format::Format_KHWC, kKHWC2CHWK}}; + if (src_format == dst_format) { + return RET_OK; + } switch (dst_format) { case schema::Format::Format_KHWC: { - switch (src_format) { - case schema::Format::Format_KCHW: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKCHW2KHWC); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKCHW2KHWC); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKCHW2KHWC); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kKCHW2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCKHW2KHWC); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCKHW2KHWC); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCKHW2KHWC); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kCKHW2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CHWK: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCHWK2KHWC); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCHWK2KHWC); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCHWK2KHWC); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kCHWK2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_HWCK: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWCK2KHWC); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWCK2KHWC); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWCK2KHWC); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kHWCK2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_HWKC: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWKC2KHWC); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWKC2KHWC); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWKC2KHWC); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kHWKC2KHWC); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_KHWC: - return RET_OK; - default: - MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); - return RET_ERROR; + if (khwc_trans_maps.find(static_cast(src_format)) == khwc_trans_maps.end()) { + MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) + << " to " << EnumNameFormat(dst_format); + return RET_ERROR; + } else { + status = TransFilterFormatWithType(tensor, data_type, + khwc_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_HWCK: { - switch (src_format) { - case schema::Format::Format_KCHW: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKCHW2HWCK); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKCHW2HWCK); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKCHW2HWCK); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kKCHW2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_KHWC: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKHWC2HWCK); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKHWC2HWCK); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKHWC2HWCK); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kKHWC2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCKHW2HWCK); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCKHW2HWCK); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCKHW2HWCK); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kCKHW2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CHWK: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCHWK2HWCK); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCHWK2HWCK); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCHWK2HWCK); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kCHWK2HWCK); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return lite::RET_ERROR; - } - break; - case schema::Format::Format_HWCK: - return RET_OK; - default: - MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); - return RET_ERROR; + if (hwck_trans_maps.find(static_cast(src_format)) == hwck_trans_maps.end()) { + MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) + << " to " << EnumNameFormat(dst_format); + return RET_ERROR; + } else { + status = TransFilterFormatWithType(tensor, data_type, + hwck_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_KCHW: { - switch (src_format) { - case schema::Format::Format_KCHW: - return RET_OK; - case schema::Format::Format_HWCK: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_HWKC: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWKC2KCHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWKC2KCHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWKC2KCHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kHWCK2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_KHWC: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKHWC2KCHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKHWC2KCHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKHWC2KCHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kKHWC2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CHWK: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kCHWK2KCHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kCHWK2KCHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kCHWK2KCHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kCKHW2KCHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - default: - MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); - return RET_ERROR; + if (kchw_trans_maps.find(static_cast(src_format)) == kchw_trans_maps.end()) { + MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) + << " to " << EnumNameFormat(dst_format); + return RET_ERROR; + } else { + status = TransFilterFormatWithType(tensor, data_type, + kchw_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_CKHW: { - switch (src_format) { - case schema::Format::Format_HWCK: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWCK2CKHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWCK2CKHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWCK2CKHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kHWCK2CKHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_HWKC: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kHWKC2CKHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kHWKC2CKHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kHWKC2CKHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kHWKC2CKHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_KCHW: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKCHW2CKHW); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKCHW2CKHW); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKCHW2CKHW); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kKCHW2CKHW); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CKHW: - return RET_OK; - default: - MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); - return RET_ERROR; + if (ckhw_trans_maps.find(static_cast(src_format)) == ckhw_trans_maps.end()) { + MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) + << " to " << EnumNameFormat(dst_format); + return RET_ERROR; + } else { + status = TransFilterFormatWithType(tensor, data_type, + ckhw_trans_maps.find(static_cast(src_format))->second); } } break; case schema::Format::Format_CHWK: { - switch (src_format) { - case schema::Format::Format_KHWC: - if (data_type == kNumberTypeFloat32) { - status = TransFilterFormat(tensor, kKHWC2CHWK); - } else if (data_type == kNumberTypeUInt8) { - status = TransFilterFormat(tensor, kKHWC2CHWK); - } else if (data_type == kNumberTypeInt8) { - status = TransFilterFormat(tensor, kKHWC2CHWK); - } else if (data_type == kNumberTypeFloat16) { - status = TransFilterFormat(tensor, kKHWC2CHWK); - } else { - MS_LOG(ERROR) << "Unsupported data_type: " << data_type; - return RET_ERROR; - } - break; - case schema::Format::Format_CHWK: - return RET_OK; - default: - MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format); - return RET_ERROR; + if (chwk_trans_maps.find(static_cast(src_format)) == chwk_trans_maps.end()) { + MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast(src_format)) + << " to " << EnumNameFormat(dst_format); + return RET_ERROR; + } else { + status = TransFilterFormatWithType(tensor, data_type, + chwk_trans_maps.find(static_cast(src_format))->second); } } break; default: @@ -1261,5 +1143,125 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for } return RET_OK; } + +ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data, + const std::string &node_name) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(data.size() != 0); + auto param_node = func_graph->add_parameter(); + + auto type_ptr = TypeIdToType(kNumberTypeInt32); + auto abstract_tensor = std::make_shared(type_ptr); + param_node->set_abstract(abstract_tensor); + param_node->set_name(node_name); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_shape({1}); + param_value->set_tensor_type(kNumberTypeInt32); + + char *default_data = new (std::nothrow) char[sizeof(int32_t)]; + *(reinterpret_cast(default_data)) = data; + param_value->SetTensorData(default_data, sizeof(int32_t)); + param_node->set_default_param(param_value); + return param_node; +} + +ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector &data, + const std::string &node_name) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(data.size() != 0); + auto param_node = func_graph->add_parameter(); + + auto type_ptr = TypeIdToType(kNumberTypeInt32); + std::vector shape_vector{static_cast(data.size())}; + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + param_node->set_abstract(abstract_tensor); + param_node->set_name(node_name); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + std::vector shape{static_cast(data.size())}; + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(kNumberTypeInt32); + char *default_data = new (std::nothrow) char[data.size() * sizeof(int32_t)]; + if (memcpy_s(default_data, data.size() * sizeof(int32_t), data.data(), data.size() * sizeof(int32_t)) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] default_data; + return nullptr; + } + param_value->SetTensorData(default_data, data.size() * sizeof(int32_t)); + param_node->set_default_param(param_value); + return param_node; +} + +ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector> &data, + const std::string &node_name) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(data.size() != 0); + auto param_node = func_graph->add_parameter(); + + auto type_ptr = TypeIdToType(kNumberTypeInt32); + std::vector shape_vector; + shape_vector.push_back(data.size()); + shape_vector.push_back(2); + + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + param_node->set_abstract(abstract_tensor); + param_node->set_name(node_name); + + ParamValueLitePtr param_value = std::make_shared(); + + MS_ASSERT(param_value != nullptr); + std::vector shape; + shape.push_back(data.size()); + shape.push_back(2); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(kNumberTypeInt32); + + std::vector data_1d; + for (auto pair : data) { + data_1d.insert(data_1d.end(), pair.begin(), pair.end()); + } + + auto size = data_1d.size() * sizeof(int32_t); + char *default_data = new (std::nothrow) char[size]; + if (memcpy_s(default_data, size, data_1d.data(), size) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] default_data; + return nullptr; + } + param_value->SetTensorData(default_data, size); + param_node->set_default_param(param_value); + return param_node; +} + +ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, + const std::string &node_name) { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(data.size() != 0); + auto param_node = func_graph->add_parameter(); + + auto type_ptr = TypeIdToType(kNumberTypeFloat32); + std::vector shape_vector = {1}; + auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + param_node->set_abstract(abstract_tensor); + param_node->set_name(node_name); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_shape({1}); + param_value->set_tensor_type(kNumberTypeFloat32); + + char *default_data = new (std::nothrow) char[sizeof(float)]; + if (memcpy_s(default_data, sizeof(float), &data, sizeof(float)) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + delete[] default_data; + return nullptr; + } + param_value->SetTensorData(default_data, sizeof(float)); + param_node->set_default_param(param_value); + return param_node; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 5ba4e6bf68..bb32ac402a 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include +#include #include "src/ops/primitive_c.h" #include "ir/anf.h" #include "ir/func_graph.h" @@ -40,6 +41,8 @@ bool IsRealCNodeKernel(const AnfNodePtr &node); bool IsGraphKernel(const AnfNodePtr &node); +bool CheckInputs(const CNodePtr &cnode); + int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph); int CheckIfAnfNodeIsNull(const AnfNodePtr &node); @@ -121,6 +124,19 @@ template static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); + +ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data, + const std::string &node_name); + +ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector &data, + const std::string &node_name); + +ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector> &data, + const std::string &node_name); + +ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data, + const std::string &node_name); + } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.cc new file mode 100644 index 0000000000..0e7a867948 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/inputs_adjust_pass.h" +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore::opt { +STATUS InputAdjustPass::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int input_num, + const std::string &attr_name, int flag) { + MS_ASSERT(cnode != nullptr); + if (!CheckInputs(cnode)) { + MS_LOG(ERROR) << "input is invalid."; + return lite::RET_INPUT_TENSOR_ERROR; + } + auto primitive_c = GetValueNode(cnode->input(0)); + auto value_ptr = primitive_c->GetAttr(attr_name); + if (value_ptr == nullptr) { + MS_LOG(DEBUG) << "there is no attr :" << attr_name; + return lite::RET_NO_CHANGE; + } + auto inputs = cnode->inputs(); + if (static_cast(inputs.size()) > input_num) { + primitive_c->EraseAttr(attr_name); + MS_LOG(DEBUG) << "input num has been meet, which is " << inputs.size(); + return lite::RET_OK; + } else if (static_cast(inputs.size()) < input_num) { + MS_LOG(ERROR) << "input num is invalid."; + return lite::RET_ERROR; + } + switch (flag) { + case 1: { + auto value_data = GetValue(value_ptr); + auto param_node = + BuildIntValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); + inputs.push_back(param_node); + break; + } + case 2: { + auto value_data = GetValue>(value_ptr); + auto param_node = + BuildIntVecParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); + inputs.push_back(param_node); + break; + } + case 3: { + auto value_data = GetValue>>(value_ptr); + auto param_node = + BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); + inputs.push_back(param_node); + break; + } + case 4: { + auto value_data = GetValue(value_ptr); + auto param_node = + BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name); + inputs.push_back(param_node); + break; + } + default: { + MS_LOG(ERROR) << "Error attr flag"; + return lite::RET_ERROR; + } + } + cnode->set_inputs(inputs); + + return lite::RET_OK; +} + +bool InputAdjustPass::Run(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto manager = Manage(func_graph, true); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr."; + return lite::RET_NULL_PTR; + } + auto node_list = TopoSort(func_graph->get_return()); + STATUS status = lite::RET_OK; + for (auto &node : node_list) { + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + + if (GetCNodeType(node) == schema::PrimitiveType_Resize) { + status = AddAttrToInput(func_graph, cnode, 2, "zoom_factor", 1); + } + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + MS_LOG(ERROR) << "adjust input pass is failed."; + return false; + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.h new file mode 100644 index 0000000000..368174c94e --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INPUTS_ADJUST_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INPUTS_ADJUST_PASS_H_ + +#include +#include +#include "tools/optimizer/common/gllo_utils.h" +#include "backend/optimizer/common/pass.h" +#include "src/param_value_lite.h" +#include "mindspore/lite/include/errorcode.h" + +using mindspore::lite::STATUS; +namespace mindspore::opt { +class InputAdjustPass : public Pass { + public: + InputAdjustPass() : Pass("input_adjust") {} + ~InputAdjustPass() override = default; + + static STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int input_num, + const std::string &attr_name, int flag); + bool Run(const FuncGraphPtr &func_graph) override; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INPUTS_ADJUST_PASS_H_