Browse Source

!4210 [MS][LITE] fix bug of arm cpu fp32 conv_depthwise: only support group equals output channel

Merge pull request !4210 from yangruoqi713/lite
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
13a66805b3
5 changed files with 13 additions and 3 deletions
  1. +5
    -0
      mindspore/lite/src/ops/convolution_depthwise.cc
  2. +5
    -0
      mindspore/lite/src/ops/deconvolution_depthwise.cc
  3. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc
  4. +1
    -1
      mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc
  5. +1
    -1
      mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc

+ 5
- 0
mindspore/lite/src/ops/convolution_depthwise.cc View File

@@ -40,6 +40,7 @@ int DepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
auto in_shape = input->shape();
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
int input_channel = in_shape.at(3);
int output_w = 0, output_h = 0;

auto conv_prim = this->primitive->value_as_DepthwiseConv2D();
@@ -69,6 +70,10 @@ int DepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
std::vector<int> out_shape{input->shape()};
out_shape.at(1) = output_h;
out_shape.at(2) = output_w;
if (conv_prim->channelMultiplier() * input_channel != weight->shape()[0]) {
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel.";
return RET_ERROR;
}
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel

output->set_shape(out_shape);


+ 5
- 0
mindspore/lite/src/ops/deconvolution_depthwise.cc View File

@@ -40,6 +40,7 @@ int DeconvDepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std
auto in_shape = input->shape();
int input_h = in_shape.at(1);
int input_w = in_shape.at(2);
int input_channel = in_shape.at(3);
int output_w = 0, output_h = 0;

auto conv_prim = this->primitive->value_as_DeDepthwiseConv2D();
@@ -58,6 +59,10 @@ int DeconvDepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std
std::vector<int> out_shape{input->shape()};
out_shape.at(1) = output_h;
out_shape.at(2) = output_w;
if (conv_prim->channelMultiplier() * input_channel != weight->shape()[0]) {
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel.";
return RET_ERROR;
}
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel

output->set_shape(out_shape);


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv_depthwise.cc View File

@@ -53,7 +53,7 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par
sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop
sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_;
sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H
sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_h_; // stride W
sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_; // stride W
sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H
sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W
sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block;


+ 1
- 1
mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc View File

@@ -20,7 +20,7 @@
namespace mindspore {
namespace lite {
void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) {
if (attr == nullptr || attr->group == 1 || attr->group != attr->channelOut) {
if (attr == nullptr || attr->group == 1) {
return;
}
std::unique_ptr<schema::DepthwiseConv2DT> depthwiseConv2DParam(new schema::DepthwiseConv2DT());


+ 1
- 1
mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc View File

@@ -20,7 +20,7 @@
namespace mindspore {
namespace lite {
void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr) {
if (attr == nullptr || attr->group == 1 || attr->group != attr->channelIn) {
if (attr == nullptr || attr->group == 1) {
return;
}



Loading…
Cancel
Save