| @@ -78,6 +78,12 @@ AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &gr | |||||
| // when input is also a hccl op and just part outputs of it linking with cur_hccl_op | // when input is also a hccl op and just part outputs of it linking with cur_hccl_op | ||||
| if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { | if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { | ||||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | ||||
| if (memcpy_async == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Create memcpy_async op failed."; | |||||
| } | |||||
| if (AnfAlgo::IsNodeDynamicShape(input)) { | |||||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async); | |||||
| } | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | auto kernel_info = std::make_shared<device::KernelInfo>(); | ||||
| memcpy_async->set_kernel_info(kernel_info); | memcpy_async->set_kernel_info(kernel_info); | ||||
| MS_EXCEPTION_IF_NULL(kernel_select_); | MS_EXCEPTION_IF_NULL(kernel_select_); | ||||
| @@ -43,6 +43,9 @@ AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, co | |||||
| if (new_node == nullptr) { | if (new_node == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; | MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; | ||||
| } | } | ||||
| if (AnfAlgo::IsNodeDynamicShape(tuple_get_item)) { | |||||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), new_node); | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); | AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); | ||||
| make_tuple_inputs.push_back(new_node); | make_tuple_inputs.push_back(new_node); | ||||
| } | } | ||||
| @@ -158,6 +158,12 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co | |||||
| auto input = hccl_node->input(i); | auto input = hccl_node->input(i); | ||||
| if (NeedInsertMemcpy(graph, input, hccl_node)) { | if (NeedInsertMemcpy(graph, input, hccl_node)) { | ||||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | ||||
| if (memcpy_async == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Create memcpy_async op failed."; | |||||
| } | |||||
| if (AnfAlgo::IsNodeDynamicShape(input)) { | |||||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async); | |||||
| } | |||||
| new_inputs.push_back(memcpy_async); | new_inputs.push_back(memcpy_async); | ||||
| memcpy_async_list.push_back(memcpy_async); | memcpy_async_list.push_back(memcpy_async); | ||||
| } else { | } else { | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "backend/kernel_compiler/kernel_build_info.h" | #include "backend/kernel_compiler/kernel_build_info.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| #include "abstract/param_validator.h" | #include "abstract/param_validator.h" | ||||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -1279,7 +1280,8 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri | |||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { | bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { | ||||
| return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape); | |||||
| return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) || | |||||
| GetBooleanAttr(node, kAttrIsDynamicShape); | |||||
| } | } | ||||
| void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape, | void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape, | ||||
| @@ -1358,5 +1360,36 @@ std::vector<int> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_no | |||||
| MS_LOG(EXCEPTION) << "Invalid Shape Type"; | MS_LOG(EXCEPTION) << "Invalid Shape Type"; | ||||
| } | } | ||||
| } | } | ||||
| bool CheckDynamic(const NotNull<abstract::ShapePtr> &shape) { | |||||
| return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; }); | |||||
| } | |||||
| bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto base_shape = node->Shape(); | |||||
| if (base_shape == nullptr) { | |||||
| MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope(); | |||||
| return false; | |||||
| } | |||||
| if (base_shape->isa<abstract::Shape>()) { | |||||
| if (CheckDynamic(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) { | |||||
| return true; | |||||
| } | |||||
| } else if (base_shape->isa<abstract::TupleShape>()) { | |||||
| auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tuple_shape); | |||||
| for (size_t i = 0; i < tuple_shape->size(); i++) { | |||||
| auto b_shape = (*tuple_shape)[i]; | |||||
| if (!b_shape->isa<abstract::Shape>()) { | |||||
| continue; | |||||
| } | |||||
| if (CheckDynamic(NOT_NULL(b_shape->cast<abstract::ShapePtr>()))) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -229,6 +229,7 @@ class AnfRuntimeAlgorithm { | |||||
| static std::vector<int> GetInputMinShape(const AnfNodePtr &anf_node, size_t index); | static std::vector<int> GetInputMinShape(const AnfNodePtr &anf_node, size_t index); | ||||
| static std::vector<int> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index); | static std::vector<int> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index); | ||||
| static std::vector<int> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); | static std::vector<int> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); | ||||
| static bool IsNodeDynamicShape(const AnfNodePtr &node); | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | using AnfAlgo = session::AnfRuntimeAlgorithm; | ||||
| @@ -221,6 +221,9 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tuple or list or dict. | // Inputs: a tuple or list or dict. | ||||
| @@ -430,5 +430,15 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv | |||||
| tmp_shape[0] = IntMulWithOverflowCheck(tmp_shape[0], rank_size); | tmp_shape[0] = IntMulWithOverflowCheck(tmp_shape[0], rank_size); | ||||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape)); | return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape)); | ||||
| } | } | ||||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(x); | |||||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape())); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -127,6 +127,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | ||||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | {prim::kPrimAllGather, {InferImplAllGather, true}}, | ||||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | ||||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||||
| }; | }; | ||||
| return prim_eval_implement_map; | return prim_eval_implement_map; | ||||
| } | } | ||||
| @@ -181,6 +181,7 @@ inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduc | |||||
| inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | ||||
| inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather"); | inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather"); | ||||
| inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); | inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); | ||||
| inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async"); | |||||
| // RowTensor | // RowTensor | ||||
| inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor"); | inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor"); | ||||