|
|
|
@@ -37,6 +37,7 @@ constexpr auto kAttrPadList = "pad_list"; |
|
|
|
constexpr auto kAttrPads = "pads"; |
|
|
|
constexpr auto kAttrMode = "mode"; |
|
|
|
constexpr auto kAttrChannelMultiplier = "channel_multiplier"; |
|
|
|
constexpr auto kAttrPerm = "perm"; |
|
|
|
|
|
|
|
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) { |
|
|
|
MS_EXCEPTION_IF_NULL(conv2d); |
|
|
|
@@ -86,9 +87,16 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(conv2d); |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
auto perm = std::vector<int64_t>{1, 0, 2, 3}; |
|
|
|
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node, |
|
|
|
CreatePermValueNode(graph, perm)}; |
|
|
|
std::vector<AnfNodePtr> transpose_inputs; |
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { |
|
|
|
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node}; |
|
|
|
} else { |
|
|
|
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node, |
|
|
|
CreatePermValueNode(graph, perm)}; |
|
|
|
} |
|
|
|
auto transpose = graph->NewCNode(transpose_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(transpose); |
|
|
|
transpose->set_scope(conv2d->scope()); |
|
|
|
@@ -111,6 +119,9 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons |
|
|
|
auto output_names = std::vector<std::string>{"output"}; |
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose); |
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), transpose); |
|
|
|
} |
|
|
|
return transpose; |
|
|
|
} |
|
|
|
|
|
|
|
|