diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc index 4d6079575c..ba5f235feb 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc @@ -78,6 +78,12 @@ AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &gr // when input is also a hccl op and just part outputs of it linking with cur_hccl_op if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { auto memcpy_async = CreateMemcpyAsyncOp(graph, input); + if (memcpy_async == nullptr) { + MS_LOG(EXCEPTION) << "Create memcpy_async op failed."; + } + if (AnfAlgo::IsNodeDynamicShape(input)) { + AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async); + } auto kernel_info = std::make_shared(); memcpy_async->set_kernel_info(kernel_info); MS_EXCEPTION_IF_NULL(kernel_select_); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc index bac9f54ace..92186b9fd9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc @@ -43,6 +43,9 @@ AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, co if (new_node == nullptr) { MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; } + if (AnfAlgo::IsNodeDynamicShape(tuple_get_item)) { + AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), new_node); + } AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); make_tuple_inputs.push_back(new_node); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc index 76e301d1f6..9d676bad76 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -158,6 +158,12 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co auto input = hccl_node->input(i); if (NeedInsertMemcpy(graph, input, hccl_node)) { auto memcpy_async = CreateMemcpyAsyncOp(graph, input); + if (memcpy_async == nullptr) { + MS_LOG(EXCEPTION) << "Create memcpy_async op failed."; + } + if (AnfAlgo::IsNodeDynamicShape(input)) { + AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async); + } new_inputs.push_back(memcpy_async); memcpy_async_list.push_back(memcpy_async); } else { diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 51ace39176..773dc33084 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -29,6 +29,7 @@ #include "backend/kernel_compiler/kernel_build_info.h" #include "common/trans.h" #include "abstract/param_validator.h" +#include "pipeline/jit/static_analysis/static_analysis.h" namespace mindspore { namespace session { @@ -1279,7 +1280,8 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri } bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { - return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape); + return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) || + GetBooleanAttr(node, kAttrIsDynamicShape); } void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector &shape, @@ -1358,5 +1360,36 @@ std::vector AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_no MS_LOG(EXCEPTION) << "Invalid Shape Type"; } } + +bool CheckDynamic(const NotNull &shape) { + return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; }); +} + +bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto base_shape = node->Shape(); + if (base_shape == nullptr) { + MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope(); + return false; + } + if (base_shape->isa()) { + if (CheckDynamic(NOT_NULL(base_shape->cast()))) { + return true; + } + } else if (base_shape->isa()) { + auto tuple_shape = base_shape->cast(); + MS_EXCEPTION_IF_NULL(tuple_shape); + for (size_t i = 0; i < tuple_shape->size(); i++) { + auto b_shape = (*tuple_shape)[i]; + if (!b_shape->isa()) { + continue; + } + if (CheckDynamic(NOT_NULL(b_shape->cast()))) { + return true; + } + } + } + return false; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 3ffc822353..ca3cee6263 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -229,6 +229,7 @@ class AnfRuntimeAlgorithm { static std::vector GetInputMinShape(const AnfNodePtr &anf_node, size_t index); static std::vector GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index); static std::vector GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); + static bool IsNodeDynamicShape(const AnfNodePtr &node); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 7783009aef..c37194ab2e 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -221,6 +221,9 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 68bab92e40..12fd113b57 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -430,5 +430,15 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv tmp_shape[0] = IntMulWithOverflowCheck(tmp_shape[0], rank_size); return std::make_shared(x->element(), std::make_shared(tmp_shape)); } + +AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + return std::make_shared(x->element(), std::make_shared(x->shape()->shape())); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index eed75d1b44..acde4cfc71 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -127,6 +127,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimBroadcast, {InferImplBroadcast, true}}, {prim::kPrimAllGather, {InferImplAllGather, true}}, {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, + {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index db8fa1ab6f..cf1363d345 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -181,6 +181,7 @@ inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduc inline const PrimitivePtr kPrimBroadcast = std::make_shared("Broadcast"); inline const PrimitivePtr kPrimAllGather = std::make_shared("AllGather"); inline const PrimitivePtr kPrimReduceScatter = std::make_shared("ReduceScatter"); +inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared("memcpy_async"); // RowTensor inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared("MakeRowTensor");