Merge pull request !6899 from ghzl/dedepthwiseconv-adaptertags/v1.1.0
| @@ -17,6 +17,13 @@ | |||||
| #include "src/ops/deconv2d.h" | #include "src/ops/deconv2d.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "include/errorcode.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| #include <float.h> | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -58,6 +65,121 @@ void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()- | |||||
| void DeConv2D::SetActivationType(int activation_type) { | void DeConv2D::SetActivationType(int activation_type) { | ||||
| this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type; | this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type; | ||||
| } | } | ||||
| template <typename T> | |||||
| void ConvertConvWeight(const ParameterPtr ¶m_node) { | |||||
| MS_ASSERT(param_node != nullptr); | |||||
| auto param = param_node->default_param(); | |||||
| auto weight = std::dynamic_pointer_cast<ParamValueLite>(param); | |||||
| MS_ASSERT(weight != nullptr); | |||||
| std::unique_ptr<T> buf(new (std::nothrow) T[weight->tensor_shape_size()]); | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "new buf failed"; | |||||
| return; | |||||
| } | |||||
| size_t filter_k = weight->tensor_shape()[0]; | |||||
| size_t filter_c = weight->tensor_shape()[1]; | |||||
| size_t filter_h = weight->tensor_shape()[2]; | |||||
| size_t filter_w = weight->tensor_shape()[3]; | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (size_t k = 0; k < filter_k; ++k) { | |||||
| for (size_t c = 0; c < filter_c; ++c) { | |||||
| for (size_t h = 0; h < filter_h; ++h) { | |||||
| for (size_t w = 0; w < filter_w; ++w) { | |||||
| p1Buff = reinterpret_cast<float *>(weight->tensor_addr()) + | |||||
| ((k * filter_c * filter_h * filter_w) + (c * filter_h * filter_w) + (h * filter_w) + (w)); | |||||
| p2Buff = | |||||
| buf.get() + ((c * filter_k * filter_h * filter_w) + (k * filter_h * filter_w) + (h * filter_w) + (w)); | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| auto ret = ::memcpy_s(weight->tensor_addr(), weight->tensor_shape_size() * sizeof(T), buf.get(), | |||||
| weight->tensor_shape_size() * sizeof(T)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||||
| return; | |||||
| } | |||||
| auto abstract_base = param_node->abstract(); | |||||
| MS_ASSERT(abstract_base != nullptr); | |||||
| if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[0] = filter_c; | |||||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[1] = filter_k; | |||||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[2] = filter_h; | |||||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[3] = filter_w; | |||||
| } | |||||
| return; | |||||
| } | |||||
| void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||||
| const std::vector<AnfNodePtr> &inputs) { | |||||
| auto attr = std::make_unique<schema::DeDepthwiseConv2DT>(); | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||||
| attr->dilateH = dilation[0]; | |||||
| attr->dilateW = dilation[1]; | |||||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||||
| attr->strideH = stride[0]; | |||||
| attr->strideW = stride[1]; | |||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME_UPPER; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| if (prim.GetAttr("activation_name") != nullptr) { | |||||
| std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||||
| attr->activationType = kActivationTypeMap[activate_name]; | |||||
| } else { | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| } | |||||
| int channel_mutiplier = 1; | |||||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||||
| } | |||||
| attr->channelMultiplier = channel_mutiplier; | |||||
| MS_ASSERT(inputs.size() == kAnfPopulaterTwo); | |||||
| auto input_node = inputs[kAnfPopulaterOne]; | |||||
| MS_ASSERT(input_node != nullptr); | |||||
| if (input_node->isa<Parameter>()) { | |||||
| auto param_node = input_node->cast<ParameterPtr>(); | |||||
| ConvertConvWeight<float>(param_node); | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) { | void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) { | ||||
| auto attr = std::make_unique<schema::DeConv2DT>(); | auto attr = std::make_unique<schema::DeConv2DT>(); | ||||
| attr->group = group; | attr->group = group; | ||||
| @@ -125,6 +247,8 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||||
| int group = GetValue<int>(prim.GetAttr("group")); | int group = GetValue<int>(prim.GetAttr("group")); | ||||
| if (group == 1) { | if (group == 1) { | ||||
| PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); | PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); | ||||
| } else if (group > 1) { | |||||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||||
| } | } | ||||
| if (GetQuantType() == schema::QuantType_AwareTraining) { | if (GetQuantType() == schema::QuantType_AwareTraining) { | ||||
| @@ -48,6 +48,8 @@ class DeConv2D : public PrimitiveC { | |||||
| void SetHasBias(bool has_bias); | void SetHasBias(bool has_bias); | ||||
| void SetActivationType(int activation_type); | void SetActivationType(int activation_type); | ||||
| void PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | void PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | ||||
| void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | |||||
| const std::vector<AnfNodePtr> &inputs); | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | ||||
| #else | #else | ||||
| @@ -153,7 +153,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||||
| out_shape.at(1) = output_h; | out_shape.at(1) = output_h; | ||||
| out_shape.at(2) = output_w; | out_shape.at(2) = output_w; | ||||
| if (GetChannelMultiplier() * input_channel != weight->shape()[0]) { | if (GetChannelMultiplier() * input_channel != weight->shape()[0]) { | ||||
| MS_LOG(ERROR) << "Conv depthwise only support group equals output channel."; | |||||
| MS_LOG(ERROR) << "Conv dedepthwise only support group equals output channel."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel | out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel | ||||
| @@ -14,11 +14,45 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "include/errorcode.h" | |||||
| #include "src/ops/maximum.h" | #include "src/ops/maximum.h" | ||||
| #include "src/common/log_adapter.h" | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| #include <float.h> | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int Maximum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Maximum; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Maximum) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::MaximumT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic.h" | #include "src/ops/arithmetic.h" | ||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -31,6 +32,7 @@ class Maximum : public Arithmetic { | |||||
| MS_DECLARE_PARENT(Arithmetic, Arithmetic); | MS_DECLARE_PARENT(Arithmetic, Arithmetic); | ||||
| Maximum() = default; | Maximum() = default; | ||||
| explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | #else | ||||
| Maximum() = default; | Maximum() = default; | ||||
| @@ -423,6 +423,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<StridedSlice>(prim, inputs, quantType); | return NewPrimitiveC<StridedSlice>(prim, inputs, quantType); | ||||
| } else if (op_type == "Cast") { | } else if (op_type == "Cast") { | ||||
| return NewPrimitiveC<Cast>(prim, inputs, quantType); | return NewPrimitiveC<Cast>(prim, inputs, quantType); | ||||
| } else if (op_type == "Maximum") { | |||||
| return NewPrimitiveC<Maximum>(prim, inputs, quantType); | |||||
| } else if (op_type == "Split") { | } else if (op_type == "Split") { | ||||
| return NewPrimitiveC<Split>(prim, inputs, quantType); | return NewPrimitiveC<Split>(prim, inputs, quantType); | ||||
| @@ -18,23 +18,19 @@ | |||||
| #include "tools/optimizer/common/gllo_utils.h" | #include "tools/optimizer/common/gllo_utils.h" | ||||
| using mindspore::lite::converter::FmkType_CAFFE; | using mindspore::lite::converter::FmkType_CAFFE; | ||||
| using mindspore::lite::converter::FmkType_TFLITE; | |||||
| using mindspore::lite::converter::FmkType_ONNX; | |||||
| using mindspore::lite::converter::FmkType_MS; | using mindspore::lite::converter::FmkType_MS; | ||||
| using mindspore::schema::QuantType_WeightQuant; | |||||
| using mindspore::schema::QuantType_QUANT_NONE; | |||||
| using mindspore::lite::converter::FmkType_ONNX; | |||||
| using mindspore::lite::converter::FmkType_TFLITE; | |||||
| using mindspore::schema::QuantType_AwareTraining; | using mindspore::schema::QuantType_AwareTraining; | ||||
| using mindspore::schema::QuantType_PostTraining; | using mindspore::schema::QuantType_PostTraining; | ||||
| using mindspore::schema::QuantType_QUANT_NONE; | |||||
| using mindspore::schema::QuantType_WeightQuant; | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| constexpr size_t kConvWeightIndex = 2; | constexpr size_t kConvWeightIndex = 2; | ||||
| } // namespace | } // namespace | ||||
| void WeightFormatHardCodePass::SetQuantType(QuantType type) { | |||||
| this->quant_type = type; | |||||
| } | |||||
| void WeightFormatHardCodePass::SetFmkType(FmkType type) { | |||||
| this->fmk_type = type; | |||||
| } | |||||
| void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; } | |||||
| void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||||
| lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node, | lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node, | ||||
| const ParamValueLitePtr ¶m_value) const { | const ParamValueLitePtr ¶m_value) const { | ||||
| MS_ASSERT(conv_cnode != nullptr); | MS_ASSERT(conv_cnode != nullptr); | ||||
| @@ -42,11 +38,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node | |||||
| switch (quant_type) { | switch (quant_type) { | ||||
| case schema::QuantType_PostTraining: | case schema::QuantType_PostTraining: | ||||
| case QuantType_WeightQuant: | case QuantType_WeightQuant: | ||||
| case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW); | |||||
| case QuantType_QUANT_NONE: | |||||
| param_value->set_format(schema::Format::Format_KCHW); | |||||
| break; | break; | ||||
| default: { | default: { | ||||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| @@ -68,12 +65,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, | |||||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | } else if (op_type == schema::PrimitiveType_DeConv2D) { | ||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case QuantType_PostTraining: | case QuantType_PostTraining: | ||||
| case QuantType_WeightQuant: | case QuantType_WeightQuant: | ||||
| case QuantType_QUANT_NONE: { | case QuantType_QUANT_NONE: { | ||||
| @@ -81,19 +77,18 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, | |||||
| // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) | // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) | ||||
| // deconv (C x K/group x kH x kW) group = 1 | // deconv (C x K/group x kH x kW) group = 1 | ||||
| // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | ||||
| if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D | |||||
| || op_type == schema::PrimitiveType_DeConv2D) { | |||||
| if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D || | |||||
| op_type == schema::PrimitiveType_DeConv2D) { | |||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| default: { | default: { | ||||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| @@ -114,8 +109,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||||
| } else { | } else { | ||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| case QuantType_PostTraining: | case QuantType_PostTraining: | ||||
| case QuantType_WeightQuant: | case QuantType_WeightQuant: | ||||
| case QuantType_QUANT_NONE: { | case QuantType_QUANT_NONE: { | ||||
| @@ -124,18 +118,19 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| param_value->set_format(schema::Format::Format_CKHW); | param_value->set_format(schema::Format::Format_CKHW); | ||||
| } else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) { | |||||
| param_value->set_format(schema::Format::Format_CKHW); | |||||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | } else if (op_type == schema::PrimitiveType_DeConv2D) { | ||||
| param_value->set_format(schema::Format::Format_KCHW); | param_value->set_format(schema::Format::Format_KCHW); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| default: { | default: { | ||||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| @@ -159,15 +154,14 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_nod | |||||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | } else if (op_type == schema::PrimitiveType_DeConv2D) { | ||||
| param_value->set_format(schema::Format::Format_CHWK); | param_value->set_format(schema::Format::Format_CHWK); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | |||||
| break; | |||||
| } break; | |||||
| default: { | default: { | ||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||||
| << conv_node->fullname_with_scope(); | |||||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||||
| << ", node: " << conv_node->fullname_with_scope(); | |||||
| return lite::RET_ERROR; | return lite::RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| @@ -183,8 +177,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| auto conv_cnode = node->cast<CNodePtr>(); | auto conv_cnode = node->cast<CNodePtr>(); | ||||
| auto type = opt::GetCNodeType(node); | auto type = opt::GetCNodeType(node); | ||||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | |||||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | |||||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); | MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); | ||||
| @@ -197,15 +191,20 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| lite::STATUS status; | lite::STATUS status; | ||||
| switch (fmk_type) { | switch (fmk_type) { | ||||
| case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value); | |||||
| case FmkType_CAFFE: | |||||
| status = HardCodeCAFFE(node, param_value); | |||||
| break; | break; | ||||
| case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value); | |||||
| case FmkType_TFLITE: | |||||
| status = HardCodeTFLITE(node, param_value); | |||||
| break; | break; | ||||
| case FmkType_ONNX:status = HardCodeONNX(node, param_value); | |||||
| case FmkType_ONNX: | |||||
| status = HardCodeONNX(node, param_value); | |||||
| break; | break; | ||||
| case FmkType_MS:status = HardCodeMS(node, param_value); | |||||
| case FmkType_MS: | |||||
| status = HardCodeMS(node, param_value); | |||||
| break; | break; | ||||
| default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (status != lite::RET_OK) { | if (status != lite::RET_OK) { | ||||