Browse Source

!1787 optimize transdata for pynative mode

Merge pull request !1787 from chujinjin/optimize_transdata_for_pynative
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
65eacc9593
5 changed files with 35 additions and 0 deletions
  1. +18
    -0
      mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc
  2. +4
    -0
      tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc
  3. +5
    -0
      tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc
  4. +4
    -0
      tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
  5. +4
    -0
      tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc

+ 18
- 0
mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc View File

@@ -16,11 +16,13 @@

#include "pre_activate/ascend/format_type/insert_trans_op.h"
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "pre_activate/ascend/ascend_helper.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"

namespace mindspore {
namespace opt {
@@ -30,6 +32,15 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs});
}

bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
return true;
}

return false;
}

const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
@@ -38,6 +49,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
MS_LOG(DEBUG) << "====process op: " << node->DebugString();
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node;
}
}
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
}
} // namespace opt


+ 4
- 0
tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc View File

@@ -21,6 +21,7 @@
#include "pre_activate/common/pass_manager.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"

#define private public
#define protected public
@@ -103,6 +104,9 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
* return output
*
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0");
// Do insert_trans_op_ pass of hardware opt
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();


+ 5
- 0
tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc View File

@@ -20,6 +20,8 @@
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "debug/anf_ir_dump.h"
#include "utils/context/ms_context.h"

#define private public
#define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h"
@@ -91,6 +93,9 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
* transdata = Transdata(transpose)
* return transdata
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);


+ 4
- 0
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc View File

@@ -19,6 +19,7 @@
#include "device/kernel_info.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "utils/context/ms_context.h"
#define private public
#define protected public
#include "pre_activate/ascend/format_type/insert_trans_op.h"
@@ -76,6 +77,9 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
* transdata = Transdata(transpose)
* return transdata
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);


+ 4
- 0
tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc View File

@@ -30,6 +30,7 @@
#include "utils/context/ms_context.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "utils/context/ms_context.h"

#define private public
#define protected public
@@ -71,6 +72,9 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
* output = make_tuple(res)
* return output
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before");
// Renormalize func_graph to infer and set shape and type information.
std::vector<int> shp{2, 32, 224, 224};


Loading…
Cancel
Save