|
|
|
@@ -37,6 +37,8 @@ constexpr auto kAttrPadList = "pad_list"; |
|
|
|
constexpr auto kAttrMode = "mode"; |
|
|
|
constexpr auto kAttrChannelMultiplier = "channel_multiplier"; |
|
|
|
constexpr auto kAttrPerm = "perm"; |
|
|
|
constexpr auto kAttrInputSizes = "input_sizes"; |
|
|
|
constexpr auto kAttrInputSize = "input_size"; |
|
|
|
|
|
|
|
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) { |
|
|
|
MS_EXCEPTION_IF_NULL(conv2d); |
|
|
|
@@ -144,14 +146,22 @@ CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNo |
|
|
|
const CNodePtr &transpose) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(conv2d_backin); |
|
|
|
if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " |
|
|
|
<< conv2d_backin->inputs().size() - 1; |
|
|
|
|
|
|
|
CNodePtr depth_conv_backin = nullptr; |
|
|
|
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) { |
|
|
|
std::vector<AnfNodePtr> depth_conv_backin_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3), |
|
|
|
transpose, conv2d_backin->input(1)}; |
|
|
|
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); |
|
|
|
} else { |
|
|
|
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr |
|
|
|
// in pynative mode. |
|
|
|
std::vector<AnfNodePtr> depth_conv_backin_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose, |
|
|
|
conv2d_backin->input(1)}; |
|
|
|
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); |
|
|
|
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin); |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> depth_conv_backin_inputs = { |
|
|
|
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3), |
|
|
|
transpose, conv2d_backin->input(1)}; |
|
|
|
auto depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(depth_conv_backin); |
|
|
|
depth_conv_backin->set_abstract(conv2d_backin->abstract()); |
|
|
|
depth_conv_backin->set_scope(conv2d_backin->scope()); |
|
|
|
@@ -265,10 +275,8 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const { |
|
|
|
VarPtr dout = std::make_shared<Var>(); |
|
|
|
VarPtr weight = std::make_shared<Var>(); |
|
|
|
VarPtr input_size = std::make_shared<Var>(); |
|
|
|
VectorRef pattern({prim::kPrimConv2DBackpropInput, dout, weight, input_size}); |
|
|
|
VarPtr Xs = std::make_shared<SeqVar>(); |
|
|
|
VectorRef pattern({prim::kPrimConv2DBackpropInput, Xs}); |
|
|
|
return pattern; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -285,9 +293,11 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { |
|
|
|
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " |
|
|
|
<< conv2d_backin->inputs().size() - 1; |
|
|
|
auto input_size = conv2d_backin->inputs().size(); |
|
|
|
// In pynative mode, input_sizes input will be convert to attr if Conv2DBackpropInput is a forward op. |
|
|
|
if (input_size != kConv2DBackpropInputNum && input_size != kConv2DBackpropInputNum - 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << " or " |
|
|
|
<< kConv2DBackpropInputNum - 2 << ", but got " << input_size - 1; |
|
|
|
} |
|
|
|
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(2), true); |
|
|
|
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose); |
|
|
|
|