Browse Source

do not merge tensor move to one in cse pass

tags/v1.1.0
LianLiguang 5 years ago
parent
commit
414f38df8d
2 changed files with 7 additions and 0 deletions
  1. +6
    -0
      mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc
  2. +1
    -0
      mindspore/core/base/core_ops.h

+ 6
- 0
mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc View File

@@ -37,6 +37,12 @@ bool HasSideEffectAttr(const AnfNodePtr &node) {
bool BackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
if (main->isa<CNode>()) {
auto main_name = AnfAlgo::GetCNodeName(main);
if (main_name == prim::kPrimTensorMove->name() || main_name == prim::kPrimMemCpyAsync->name()) {
return false;
}
}
auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
if (main_kernel_info == nullptr && node_kernel_info == nullptr) {


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -200,6 +200,7 @@ inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAda
inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay");
inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum");
inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove");

// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");


Loading…
Cancel
Save