Browse Source

!14580 [auto-monad] Enforce order of exection for Loads user nodes in frontend

From: @hwhewei
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
pull/14580/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
3bd51c88f4
27 changed files with 569 additions and 82 deletions
  1. +1
    -1
      mindspore/_extends/builtin_operations.py
  2. +11
    -5
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  3. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h
  4. +3
    -0
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc
  5. +3
    -0
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc
  6. +5
    -3
      mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc
  7. +14
    -1
      mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc
  8. +12
    -5
      mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc
  9. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
  10. +20
    -0
      mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc
  11. +1
    -0
      mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h
  12. +10
    -10
      mindspore/ccsrc/backend/optimizer/common/helper.cc
  13. +7
    -6
      mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc
  14. +43
    -47
      mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc
  15. +71
    -0
      mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.cc
  16. +33
    -0
      mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.h
  17. +10
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  18. +1
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  19. +1
    -0
      mindspore/ccsrc/backend/session/ascend_session.cc
  20. +1
    -0
      mindspore/ccsrc/backend/session/cpu_session.cc
  21. +2
    -0
      mindspore/ccsrc/backend/session/gpu_session.cc
  22. +7
    -0
      mindspore/ccsrc/backend/session/session_basic.cc
  23. +1
    -0
      mindspore/ccsrc/backend/session/session_basic.h
  24. +17
    -0
      mindspore/ccsrc/pipeline/jit/action.cc
  25. +258
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc
  26. +27
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.h
  27. +8
    -2
      tests/st/auto_monad/test_auto_monad.py

+ 1
- 1
mindspore/_extends/builtin_operations.py View File

@@ -132,7 +132,7 @@ def Depend(value, expr):
return value


def UpdateState(monad, expr):
def UpdateState(monad, *exprs):
"""Implement `UpdateState`."""
return monad



+ 11
- 5
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -90,7 +90,7 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
input_node = InsertTransOpForOutput(func_graph, input_node, input_node, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
AnfAlgo::SetNodeInput(node, input_node, index);
}
@@ -120,10 +120,16 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
return node;
}

AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node,
const AnfNodePtr &node, const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_node);
for (auto &update_state : update_states) {
manager->SetEdge(update_state.first, update_state.second, node);
}
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
size_t out_num = AnfAlgo::GetOutputTensorNum(node);
@@ -282,7 +288,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
return cast;
}

AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select) {
size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
if (outputs_num == 0) {
@@ -298,7 +304,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
return new_node;
}
// Multiple output
return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);
return InsertTransOpForMultipleOutput(func_graph, orig_node, node, kernel_select);
}

AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,


+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h View File

@@ -103,7 +103,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select);

AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select);

CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);


+ 3
- 0
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc View File

@@ -66,6 +66,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
continue;
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2);


+ 3
- 0
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc View File

@@ -43,6 +43,9 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
continue;
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2);


+ 5
- 3
mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc View File

@@ -297,9 +297,11 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
} else {
int64_t prev_idx = 0;
std::vector<AnfNodePtr> tuple_getitem_nodes;
std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(),
std::back_inserter(tuple_getitem_nodes),
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; });
for (auto &user : manager->node_users()[node]) {
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimTupleGetItem)) {
tuple_getitem_nodes.emplace_back(user.first);
}
}
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
for (auto &getitem : tuple_getitem_nodes) {
MS_EXCEPTION_IF_NULL(getitem);


+ 14
- 1
mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.cc View File

@@ -163,7 +163,20 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get
return func_graph->NewCNode(depend_nodes);
}
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto cnode = orig_cnode;
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
if (!update_states.empty()) {
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
cnode = kernel_graph->NewCNode(orig_cnode);
cnode->set_inputs(orig_cnode->inputs());
for (auto &update_state : update_states) {
manager->SetEdge(update_state.first, update_state.second, cnode);
}
}
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
std::vector<AnfNodePtr> make_tuple_inputs;


+ 12
- 5
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc View File

@@ -30,9 +30,16 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode,
const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto update_states = AnfAlgo::GetUpdateStateUsers(manager, orig_cnode);
for (auto &update_state : update_states) {
manager->SetEdge(update_state.first, update_state.second, cnode);
}
std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list;
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
@@ -69,9 +76,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
MS_EXCEPTION_IF_NULL(make_tuple);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple;
} // namespace
}

AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
@@ -99,7 +106,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return replace_node;
}
// Multiple output
return InsertCastForMultipleOutput(func_graph, cnode);
return InsertCastForMultipleOutput(func_graph, orig_cnode, cnode);
}
} // namespace

@@ -124,7 +131,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
kernel_graph->ReplaceInternalOutput(node, new_node);
}
// process output
return InsertCastForOutput(func_graph, new_node);
return InsertCastForOutput(func_graph, cnode, new_node);
}
} // namespace opt
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc View File

@@ -43,7 +43,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
kernel_graph->ReplaceInternalOutput(node, new_node);
}
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
return InsertTransOpForOutput(func_graph, node, new_node, kernel_select_);
}
} // namespace opt
} // namespace mindspore

+ 20
- 0
mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc View File

@@ -25,6 +25,7 @@
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
#include "backend/optimizer/pass/add_training_attr.h"
#include "backend/optimizer/pass/optimize_updatestate.h"
#include "utils/ms_context.h"
#include "debug/anf_ir_dump.h"

@@ -58,5 +59,24 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
DumpIR(file_name, kernel_graph);
}
}

void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
// Run optimizer passes.
auto optimizer = std::make_shared<GraphOptimizer>();
auto pm = std::make_shared<PassManager>("final_opt");
pm->AddPass(std::make_shared<OptimizeUpdateState>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
// Dump IR if save_graphs is set.
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
const bool save_graphs = context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs) {
std::string filename = "hwopt_common_final_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(filename, kernel_graph);
}
}
} // namespace opt
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h View File

@@ -20,6 +20,7 @@
namespace mindspore {
namespace opt {
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
} // namespace opt
} // namespace mindspore



+ 10
- 10
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -401,11 +401,9 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) {
continue;
}
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) {
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
(cnode_name == prim::kPrimUpdateState->name())) {
continue;
}
output_node_list->push_back(output_info);
@@ -426,12 +424,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) {
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
(cnode_name == prim::kPrimUpdateState->name())) {
continue;
}
size_t used_output_index;
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) {
if (cnode_name == prim::kPrimTupleGetItem->name()) {
used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
} else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
used_output_index = output_index;
@@ -906,12 +905,13 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// find BatchNorm's output which is a Depend
// Find BatchNorm's output which is a Depend or UpdateState.
for (const auto &node_index : manager->node_users()[old_node]) {
AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
auto depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend);
depend->set_input(index, new_node);


+ 7
- 6
mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc View File

@@ -66,13 +66,14 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr
auto output_num = output->cast<CNodePtr>()->size() - 1;
getitem_list->clear();
getitem_list->resize(output_num, nullptr);
const auto &users = mng->node_users()[node];
auto users = mng->node_users()[node];
bool changed = false;
AnfNodePtrList user_nodes;
std::transform(users.begin(), users.end(), std::back_inserter(user_nodes),
[](const std::pair<AnfNodePtr, int> &user) { return user.first; });
for (const auto &getitem : user_nodes) {
MS_EXCEPTION_IF_NULL(getitem);
for (const auto &user : users) {
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
// Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState.
continue;
}
auto &getitem = user.first;
auto idx = GetIndex(getitem);
if (idx >= output_num) {
MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString();


+ 43
- 47
mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc View File

@@ -35,19 +35,17 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno
const std::vector<AnfNodePtr> &new_depend_inputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend = nullptr;
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
if (kernel_graph == nullptr) {
new_depend = func_graph->NewCNode(new_depend_inputs);
auto new_depend = func_graph->NewCNode(new_depend_inputs);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_abstract(cnode->abstract());
new_depend->set_scope(cnode->scope());
} else {
new_depend = kernel_graph->NewCNode(cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
return new_depend;
}
func_graph->manager()->Replace(cnode, new_depend);
auto new_depend = kernel_graph->NewCNode(cnode);
MS_EXCEPTION_IF_NULL(new_depend);
new_depend->set_inputs(new_depend_inputs);
return new_depend;
}

@@ -77,9 +75,9 @@ AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, con
auto replace_node = eliminate_node->input(kSingleInputIndex);
std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs();
new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node;
auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
auto new_node = new_cnode->cast<AnfNodePtr>();
return new_node;
auto new_depend = CreateNewDependNode(func_graph, cnode, new_depend_inputs);
func_graph->manager()->Replace(cnode, new_depend);
return new_depend;
}

AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
@@ -157,55 +155,53 @@ const BaseRef OptimizeDependence::DefinePattern() const {
return VectorRef({X, Xs});
}

std::pair<AnfNodePtr, size_t> SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) {
if (node == nullptr || !node->isa<CNode>()) {
return std::pair<AnfNodePtr, size_t>(nullptr, 0);
}
// get real input of depend and update state.
size_t replace_input_index = 0;
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend;
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput;
} else {
return std::pair<AnfNodePtr, size_t>(nullptr, 0);
}
// check whether real input is cast or trans data
auto real_input = node->cast<CNodePtr>()->input(replace_input_index);
if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) ||
AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) ||
AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) {
return std::pair<AnfNodePtr, size_t>(node, replace_input_index);
}
return SearchTransDataAndCast(real_input, false);
std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) {
// Search Depend and UpdateState only.
if (!cnode->IsApply(prim::kPrimDepend) && !cnode->IsApply(prim::kPrimUpdateState)) {
return {};
}
// Find inputs which is Cast or TransData.
std::vector<size_t> result;
for (size_t i = 1; i < cnode->size(); ++i) {
auto &input = cnode->input(i);
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) ||
AnfAlgo::CheckPrimitiveType(input, prim::KPrimTransData) ||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
result.emplace_back(i);
}
}
return result;
}

const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return nullptr;
}
// Get the cnode with repalce input index
auto cnode_with_input_index = SearchTransDataAndCast(node, true);
if (cnode_with_input_index.first == nullptr) {
// Search inputs to be replaced.
auto candidate_inputs = SearchTransDataAndCast(cnode);
if (candidate_inputs.empty()) {
return nullptr;
}
size_t replace_index = cnode_with_input_index.second;
auto depend_cnode = cnode_with_input_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
// Get new node which will act as new input of depend or UpdateState.
std::vector<AnfNodePtr> new_depend_inputs = depend_cnode->inputs();
auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index);
if (replace_node == nullptr) {
return nullptr;
// Get new nodes which will act as new inputs of Depend or UpdateState.
std::vector<AnfNodePtr> new_inputs = cnode->inputs();
bool inputs_changed = false;
for (auto index : candidate_inputs) {
auto replace_node = GetConvertNode(func_graph, cnode, index);
if (replace_node != nullptr) {
new_inputs[index] = replace_node;
inputs_changed = true;
}
}
new_depend_inputs[replace_index] = replace_node;
auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs);
if (new_depend == nullptr) {
if (!inputs_changed) {
return nullptr;
}
// Create a new Depend node to replace the old one if inputs changed.
auto new_depend = CreateNewDependNode(func_graph, cnode, new_inputs);
func_graph->manager()->Replace(cnode, new_depend);
return nullptr;
}



+ 71
- 0
mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.cc View File

@@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/pass/optimize_updatestate.h"
#include <memory>
#include <vector>
#include <string>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "backend/session/kernel_graph.h"

namespace mindspore {
namespace opt {
constexpr size_t kInputIndex = 1;
constexpr size_t kAttachIndex = 2;
constexpr size_t kAdditionalAttachIndex = 3;

const BaseRef OptimizeUpdateState::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimUpdateState, Xs});
}

const AnfNodePtr OptimizeUpdateState::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
auto update_state = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(update_state);
if (update_state->size() <= kAdditionalAttachIndex) {
// Skip UpdateState nodes with no additional attaches.
return nullptr;
}
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
std::vector<AnfNodePtr> new_inputs;
new_inputs.emplace_back(update_state->input(0));
new_inputs.emplace_back(update_state->input(kInputIndex));
new_inputs.emplace_back(update_state->input(kAttachIndex));
for (size_t i = kAdditionalAttachIndex; i < update_state->size(); ++i) {
auto &attach = update_state->input(i);
auto &users = node_users[attach];
if ((users.size() == 1) && (users.front().first == update_state)) {
// If the only user of attach is the UpdateState node, drop the attach node.
continue;
}
new_inputs.emplace_back(attach);
}
if (new_inputs.size() == update_state->size()) {
// Attaches not changed.
return nullptr;
}
// Attaches changed, make a new UpdateState.
auto new_update_state = func_graph->NewCNode(new_inputs);
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
return new_update_state;
}
} // namespace opt
} // namespace mindspore

+ 33
- 0
mindspore/ccsrc/backend/optimizer/pass/optimize_updatestate.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_

#include "backend/optimizer/common/optimizer.h"

namespace mindspore {
namespace opt {
class OptimizeUpdateState : public PatternProcessPass {
public:
explicit OptimizeUpdateState(bool multigraph = true) : PatternProcessPass("optimize_updatestate", multigraph) {}
~OptimizeUpdateState() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_OPTIMIZE_UPDATESTATE_H_

+ 10
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1931,5 +1931,15 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
root_graph->set_output(make_tuple);
}

AnfNodeIndexSet AnfRuntimeAlgorithm::GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
AnfNodeIndexSet update_states;
for (auto &user : manager->node_users()[node]) {
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimUpdateState)) {
update_states.insert(user);
}
}
return update_states;
}
} // namespace session
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -267,6 +267,7 @@ class AnfRuntimeAlgorithm {
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited);
static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph);
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;


+ 1
- 0
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -936,6 +936,7 @@ void AscendSession::InitRuntimeResource() {
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "HardwareOptimize start!";
opt::AscendBackendOptimization(kernel_graph);
FinalOptimize(kernel_graph);
GraphKernelOptimize(kernel_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->SetExecOrderByDefault();


+ 1
- 0
mindspore/ccsrc/backend/session/cpu_session.cc View File

@@ -104,6 +104,7 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
SetKernelInfo(graph.get());
MS_LOG(INFO) << "Set kernel info end";
Optimize(graph);
FinalOptimize(graph);
MS_LOG(INFO) << "Build kernel";
BuildKernel(graph.get());
// Remove reorder after PS feature finish adapting push/pull in auto_monad.


+ 2
- 0
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -341,6 +341,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
SelectKernel(graph);
// Graph optimization relevant to device data format
HardwareOptimize(graph);
// Run final optimization
FinalOptimize(graph);
// Graph kernel fusion optimization
GraphKernelOptimize(graph);
// Start gpu kernel runtime


+ 7
- 0
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -17,6 +17,7 @@

#include <algorithm>
#include <set>
#include <queue>
#include <unordered_map>
#include <utility>

@@ -2343,6 +2344,12 @@ void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
}
}

void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
opt::CommonFinalOptimization(graph);
MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
}

void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();


+ 1
- 0
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
virtual void UnifyMindIR(const KernelGraphPtr &graph) {}
virtual void FinalOptimize(const KernelGraphPtr &graph) const;
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
virtual void BuildGraphImpl(GraphId) {}


+ 17
- 0
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -32,6 +32,7 @@
#include "pipeline/jit/parse/parse_base.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "pipeline/jit/static_analysis/order_enforce.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/program_specialize.h"
@@ -343,6 +344,18 @@ bool AutoMonadAction(const ResourcePtr &res) {
return true;
}

bool OrderEnforceAction(const ResourcePtr &res) {
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null";
}
auto func_graph = res->func_graph();
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null";
}
pipeline::OrderEnforce(func_graph);
return true;
}

bool InferenceOptPrepareAction(const ResourcePtr &res) {
if (res->manager() == nullptr) {
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
@@ -752,6 +765,7 @@ std::vector<ActionItem> GePipeline() {
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
return actions;
}
@@ -765,6 +779,8 @@ std::vector<ActionItem> VmPipeline() {
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));

actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));

actions.emplace_back(std::make_pair("validate", ValidateAction));
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PSContext::instance()->is_worker()) {
@@ -784,6 +800,7 @@ std::vector<ActionItem> VmPipeline() {
std::vector<ActionItem> PServerPipeline() {
auto actions = CommonPipeline();
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
return actions;


+ 258
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc View File

@@ -0,0 +1,258 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "pipeline/jit/static_analysis/order_enforce.h"
#include <algorithm>
#include <map>
#include <queue>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "base/core_ops.h"

namespace mindspore::pipeline {
namespace {

class OrderEnforcer {
public:
explicit OrderEnforcer(const FuncGraphPtr &func_graph) : func_graph_(func_graph), manager_(func_graph->manager()) {
MS_EXCEPTION_IF_NULL(func_graph_);
MS_EXCEPTION_IF_NULL(manager_);
}
~OrderEnforcer() = default;

void Run() {
auto nodes = MakeTopoSortMap();
for (auto &node : nodes) {
HandleNode(node);
}
}

private:
AnfNodePtrList MakeTopoSortMap() {
auto nodes = TopoSort(func_graph_->get_return());
for (size_t i = 0; i < nodes.size(); ++i) {
topo_sort_map_.emplace(nodes[i], i);
}
return nodes;
}

void HandleNode(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
// Skip nodes other than UpdateState.
return;
}
auto update_state = node->cast<CNodePtr>();
if (!HasAbstractUMonad(update_state->input(1))) {
// Skip UpdateStates for IO.
return;
}
auto updated_refs = FindUpdatedRefs(update_state);
if (updated_refs.empty()) {
// Skip UpdateStates that do not have updated refs.
return;
}
auto &attach = update_state->input(2);
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
// Handle UpdateState with Load.
EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs);
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
// Handle UpdateState with MakeTuple.
EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs);
}
}

std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) {
std::unordered_set<AnfNodePtr> updated_refs;
auto &users = manager_->node_users()[update_state];
for (auto &user : users) {
auto cnode = dyn_cast<CNode>(user.first);
if (cnode == nullptr) {
continue;
}
if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) ||
cnode->IsApply(prim::kPrimUpdateState)) {
continue;
}
for (auto &input : cnode->inputs()) {
if (IsRef(input)) {
updated_refs.insert(input);
}
}
}
return updated_refs;
}

bool IsRef(const AnfNodePtr &node) {
auto &abs = node->abstract();
return abs != nullptr && abs->isa<abstract::AbstractRef>();
}

void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load,
const std::unordered_set<AnfNodePtr> &refs) {
if (refs.find(load->input(1)) == refs.end()) {
// Skip if loaded parameter is not updated.
return;
}
// Find load users, ignore processed nodes.
auto load_users = FindLoadUsers(load, update_state);
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, load_users);
}

void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
const std::unordered_set<AnfNodePtr> &refs) {
// The UpdateState should be the only one user of the make_tuple.
// for performance, we only check the number of output edges.
if (manager_->node_users()[make_tuple].size() != 1) {
return;
}
// Find load users from the tuple of Load nodes.
std::unordered_set<AnfNodePtr> all_load_users;
auto &inputs = make_tuple->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto &input = inputs.at(i);
if (!IsPrimitiveCNode(input, prim::kPrimLoad)) {
// Skip non-Load nodes.
continue;
}
auto load = input->cast<CNodePtr>();
if (refs.find(load->input(1)) == refs.end()) {
// Skip if loaded parameter is not updated.
continue;
}
auto load_users = FindLoadUsers(load, make_tuple);
all_load_users.insert(load_users.begin(), load_users.end());
}
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, all_load_users);
}

// Add load users as input edges of the update_state node.
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users);
for (auto &load_user : sorted_load_users) {
if (!IsDependOn(load_user, update_state)) {
processed_nodes_.insert(load_user);
manager_->AddEdge(update_state, load_user);
}
}
}

// Sort load users by their topo sort order.
std::vector<AnfNodePtr> SortLoadUsers(const std::unordered_set<AnfNodePtr> &load_users) {
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
return vec;
}

// Check if the load user node depend on the given UpdateState node.
bool IsDependOn(const AnfNodePtr &load_user, const AnfNodePtr &update_state) {
size_t update_state_order = topo_sort_map_[update_state];
if (topo_sort_map_[load_user] < update_state_order) {
return false;
}
auto user_cnode = dyn_cast<CNode>(load_user);
if (user_cnode == nullptr) {
return false;
}
size_t seen = NewSeenGeneration();
std::queue<CNodePtr> q;
user_cnode->seen_ = seen;
q.push(user_cnode);
while (!q.empty()) {
auto cnode = q.front();
q.pop();
for (auto &input : cnode->inputs()) {
if (input == update_state) {
// Dependency found.
return true;
}
if (input->seen_ == seen) {
// Skip visited nodes.
continue;
}
if (topo_sort_map_[input] < update_state_order) {
// Skip input nodes that before the UpdateState node.
continue;
}
auto input_cnode = dyn_cast<CNode>(input);
if (input_cnode != nullptr) {
input_cnode->seen_ = seen;
q.push(input_cnode);
}
}
}
return false;
}

bool IsBefore(const AnfNodePtr &node1, const AnfNodePtr &node2) {
return topo_sort_map_[node1] < topo_sort_map_[node2];
}

// Find Load users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> FindLoadUsers(const CNodePtr &load, const AnfNodePtr &exclude) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(load);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> load_users;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (user_node == exclude) {
// Skip excluded node.
continue;
}
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
// Skip processed nodes.
continue;
}
auto cnode = dyn_cast<CNode>(user_node);
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
const bool has_u_input =
std::any_of(inputs.begin() + 1, inputs.end(), [](const AnfNodePtr &input) { return HasAbstractUMonad(input); });
if (has_u_input) {
// Skip nodes with memory side effects, which use u input.
continue;
}
load_users.insert(cnode);
}
return load_users;
}

private:
const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_;
std::unordered_map<AnfNodePtr, size_t> topo_sort_map_;
std::unordered_set<AnfNodePtr> processed_nodes_;
};

} // namespace

//
// Enforce order of execution for Load users node.
//
void OrderEnforce(const FuncGraphPtr &func_graph) {
OrderEnforcer enforcer(func_graph);
enforcer.Run();
}
} // namespace mindspore::pipeline

+ 27
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.h View File

@@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_

#include "ir/func_graph.h"

namespace mindspore::pipeline {
// Enforce order of execution of the given graph.
void OrderEnforce(const FuncGraphPtr &func_graph);
} // namespace mindspore::pipeline

#endif // MINDSPORE_CCSRC_PIPELINE_JIT_ORDER_ENFORCE_H_

+ 8
- 2
tests/st/auto_monad/test_auto_monad.py View File

@@ -1456,7 +1456,10 @@ def test_while_forward():
assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001)


@pytest.mark.skip(reason="not supported yet")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multi_add_assign():
class Net(Cell):
def __init__(self, i1):
@@ -1493,7 +1496,10 @@ def test_multi_add_assign():
np.testing.assert_array_equal(outputs, expects)


@pytest.mark.skip(reason="not supported yet")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_multi_abs_add_assign():
class Net(Cell):
def __init__(self, para):


Loading…
Cancel
Save