|
|
|
@@ -40,6 +40,7 @@ int DeconvDepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std |
|
|
|
auto in_shape = input->shape(); |
|
|
|
int input_h = in_shape.at(1); |
|
|
|
int input_w = in_shape.at(2); |
|
|
|
int input_channel = in_shape.at(3); |
|
|
|
int output_w = 0, output_h = 0; |
|
|
|
|
|
|
|
auto conv_prim = this->primitive->value_as_DeDepthwiseConv2D(); |
|
|
|
@@ -58,6 +59,10 @@ int DeconvDepthwiseConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std |
|
|
|
std::vector<int> out_shape{input->shape()}; |
|
|
|
out_shape.at(1) = output_h; |
|
|
|
out_shape.at(2) = output_w; |
|
|
|
if (conv_prim->channelMultiplier() * input_channel != weight->shape()[0]) { |
|
|
|
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel |
|
|
|
|
|
|
|
output->set_shape(out_shape); |
|
|
|
|