Merge pull request !6899 from ghzl/dedepthwiseconv-adaptertags/v1.1.0
| @@ -17,6 +17,13 @@ | |||
| #include "src/ops/deconv2d.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <float.h> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -58,6 +65,121 @@ void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()- | |||
| void DeConv2D::SetActivationType(int activation_type) { | |||
| this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type; | |||
| } | |||
| template <typename T> | |||
| void ConvertConvWeight(const ParameterPtr ¶m_node) { | |||
| MS_ASSERT(param_node != nullptr); | |||
| auto param = param_node->default_param(); | |||
| auto weight = std::dynamic_pointer_cast<ParamValueLite>(param); | |||
| MS_ASSERT(weight != nullptr); | |||
| std::unique_ptr<T> buf(new (std::nothrow) T[weight->tensor_shape_size()]); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "new buf failed"; | |||
| return; | |||
| } | |||
| size_t filter_k = weight->tensor_shape()[0]; | |||
| size_t filter_c = weight->tensor_shape()[1]; | |||
| size_t filter_h = weight->tensor_shape()[2]; | |||
| size_t filter_w = weight->tensor_shape()[3]; | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (size_t k = 0; k < filter_k; ++k) { | |||
| for (size_t c = 0; c < filter_c; ++c) { | |||
| for (size_t h = 0; h < filter_h; ++h) { | |||
| for (size_t w = 0; w < filter_w; ++w) { | |||
| p1Buff = reinterpret_cast<float *>(weight->tensor_addr()) + | |||
| ((k * filter_c * filter_h * filter_w) + (c * filter_h * filter_w) + (h * filter_w) + (w)); | |||
| p2Buff = | |||
| buf.get() + ((c * filter_k * filter_h * filter_w) + (k * filter_h * filter_w) + (h * filter_w) + (w)); | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| auto ret = ::memcpy_s(weight->tensor_addr(), weight->tensor_shape_size() * sizeof(T), buf.get(), | |||
| weight->tensor_shape_size() * sizeof(T)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return; | |||
| } | |||
| auto abstract_base = param_node->abstract(); | |||
| MS_ASSERT(abstract_base != nullptr); | |||
| if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[0] = filter_c; | |||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[1] = filter_k; | |||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[2] = filter_h; | |||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[3] = filter_w; | |||
| } | |||
| return; | |||
| } | |||
| void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto attr = std::make_unique<schema::DeDepthwiseConv2DT>(); | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format::Format_NCHW; | |||
| } else if (format == "NHWC") { | |||
| attr->format = schema::Format::Format_NHWC; | |||
| } else { | |||
| attr->format = schema::Format::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| attr->strideH = stride[0]; | |||
| attr->strideW = stride[1]; | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (pad_mode == "same") { | |||
| attr->padMode = schema::PadMode_SAME_UPPER; | |||
| } else { | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| if (prim.GetAttr("activation_name") != nullptr) { | |||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||
| attr->activationType = kActivationTypeMap[activate_name]; | |||
| } else { | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| } | |||
| int channel_mutiplier = 1; | |||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||
| } | |||
| attr->channelMultiplier = channel_mutiplier; | |||
| MS_ASSERT(inputs.size() == kAnfPopulaterTwo); | |||
| auto input_node = inputs[kAnfPopulaterOne]; | |||
| MS_ASSERT(input_node != nullptr); | |||
| if (input_node->isa<Parameter>()) { | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| ConvertConvWeight<float>(param_node); | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) { | |||
| auto attr = std::make_unique<schema::DeConv2DT>(); | |||
| attr->group = group; | |||
| @@ -125,6 +247,8 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||
| int group = GetValue<int>(prim.GetAttr("group")); | |||
| if (group == 1) { | |||
| PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); | |||
| } else if (group > 1) { | |||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||
| } | |||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | |||
| @@ -48,6 +48,8 @@ class DeConv2D : public PrimitiveC { | |||
| void SetHasBias(bool has_bias); | |||
| void SetActivationType(int activation_type); | |||
| void PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | |||
| void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||
| const std::vector<AnfNodePtr> &inputs); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| @@ -153,7 +153,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| out_shape.at(1) = output_h; | |||
| out_shape.at(2) = output_w; | |||
| if (GetChannelMultiplier() * input_channel != weight->shape()[0]) { | |||
| MS_LOG(ERROR) << "Conv depthwise only support group equals output channel."; | |||
| MS_LOG(ERROR) << "Conv dedepthwise only support group equals output channel."; | |||
| return RET_ERROR; | |||
| } | |||
| out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel | |||
| @@ -14,11 +14,45 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/errorcode.h" | |||
| #include "src/ops/maximum.h" | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <float.h> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Maximum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Maximum; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Maximum) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::MaximumT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -22,6 +22,7 @@ | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -31,6 +32,7 @@ class Maximum : public Arithmetic { | |||
| MS_DECLARE_PARENT(Arithmetic, Arithmetic); | |||
| Maximum() = default; | |||
| explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Maximum() = default; | |||
| @@ -423,6 +423,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<StridedSlice>(prim, inputs, quantType); | |||
| } else if (op_type == "Cast") { | |||
| return NewPrimitiveC<Cast>(prim, inputs, quantType); | |||
| } else if (op_type == "Maximum") { | |||
| return NewPrimitiveC<Maximum>(prim, inputs, quantType); | |||
| } else if (op_type == "Split") { | |||
| return NewPrimitiveC<Split>(prim, inputs, quantType); | |||
| @@ -18,23 +18,19 @@ | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType_CAFFE; | |||
| using mindspore::lite::converter::FmkType_TFLITE; | |||
| using mindspore::lite::converter::FmkType_ONNX; | |||
| using mindspore::lite::converter::FmkType_MS; | |||
| using mindspore::schema::QuantType_WeightQuant; | |||
| using mindspore::schema::QuantType_QUANT_NONE; | |||
| using mindspore::lite::converter::FmkType_ONNX; | |||
| using mindspore::lite::converter::FmkType_TFLITE; | |||
| using mindspore::schema::QuantType_AwareTraining; | |||
| using mindspore::schema::QuantType_PostTraining; | |||
| using mindspore::schema::QuantType_QUANT_NONE; | |||
| using mindspore::schema::QuantType_WeightQuant; | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| } // namespace | |||
| void WeightFormatHardCodePass::SetQuantType(QuantType type) { | |||
| this->quant_type = type; | |||
| } | |||
| void WeightFormatHardCodePass::SetFmkType(FmkType type) { | |||
| this->fmk_type = type; | |||
| } | |||
| void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; } | |||
| void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||
| lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node, | |||
| const ParamValueLitePtr ¶m_value) const { | |||
| MS_ASSERT(conv_cnode != nullptr); | |||
| @@ -42,11 +38,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node | |||
| switch (quant_type) { | |||
| case schema::QuantType_PostTraining: | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW); | |||
| case QuantType_QUANT_NONE: | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| @@ -68,12 +65,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case QuantType_PostTraining: | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE: { | |||
| @@ -81,19 +77,18 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, | |||
| // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) | |||
| // deconv (C x K/group x kH x kW) group = 1 | |||
| // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | |||
| if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D | |||
| || op_type == schema::PrimitiveType_DeConv2D) { | |||
| if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D || | |||
| op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| @@ -114,8 +109,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||
| } else { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| case QuantType_PostTraining: | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE: { | |||
| @@ -124,18 +118,19 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| } else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| @@ -159,15 +154,14 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_nod | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_CHWK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| @@ -183,8 +177,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||
| } | |||
| auto conv_cnode = node->cast<CNodePtr>(); | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | |||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | |||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); | |||
| @@ -197,15 +191,20 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||
| } | |||
| lite::STATUS status; | |||
| switch (fmk_type) { | |||
| case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value); | |||
| case FmkType_CAFFE: | |||
| status = HardCodeCAFFE(node, param_value); | |||
| break; | |||
| case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value); | |||
| case FmkType_TFLITE: | |||
| status = HardCodeTFLITE(node, param_value); | |||
| break; | |||
| case FmkType_ONNX:status = HardCodeONNX(node, param_value); | |||
| case FmkType_ONNX: | |||
| status = HardCodeONNX(node, param_value); | |||
| break; | |||
| case FmkType_MS:status = HardCodeMS(node, param_value); | |||
| case FmkType_MS: | |||
| status = HardCodeMS(node, param_value); | |||
| break; | |||
| default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| if (status != lite::RET_OK) { | |||