Browse Source

memcpy_async infershape

tags/v1.1.0
liubuyu 5 years ago
parent
commit
0b79a94e22
9 changed files with 65 additions and 1 deletions
  1. +6
    -0
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc
  2. +3
    -0
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc
  3. +6
    -0
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc
  4. +34
    -1
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  5. +1
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  6. +3
    -0
      mindspore/core/abstract/infer_functions.h
  7. +10
    -0
      mindspore/core/abstract/prim_others.cc
  8. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  9. +1
    -0
      mindspore/core/base/core_ops.h

+ 6
- 0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc View File

@@ -78,6 +78,12 @@ AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &gr
// when input is also a hccl op and just part outputs of it linking with cur_hccl_op // when input is also a hccl op and just part outputs of it linking with cur_hccl_op
if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input); auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
if (memcpy_async == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
}
if (AnfAlgo::IsNodeDynamicShape(input)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
}
auto kernel_info = std::make_shared<device::KernelInfo>(); auto kernel_info = std::make_shared<device::KernelInfo>();
memcpy_async->set_kernel_info(kernel_info); memcpy_async->set_kernel_info(kernel_info);
MS_EXCEPTION_IF_NULL(kernel_select_); MS_EXCEPTION_IF_NULL(kernel_select_);


+ 3
- 0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc View File

@@ -43,6 +43,9 @@ AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, co
if (new_node == nullptr) { if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; MS_LOG(EXCEPTION) << "Create memcpy_async op failed!";
} }
if (AnfAlgo::IsNodeDynamicShape(tuple_get_item)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), new_node);
}
AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node);
make_tuple_inputs.push_back(new_node); make_tuple_inputs.push_back(new_node);
} }


+ 6
- 0
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc View File

@@ -158,6 +158,12 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
auto input = hccl_node->input(i); auto input = hccl_node->input(i);
if (NeedInsertMemcpy(graph, input, hccl_node)) { if (NeedInsertMemcpy(graph, input, hccl_node)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input); auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
if (memcpy_async == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
}
if (AnfAlgo::IsNodeDynamicShape(input)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
}
new_inputs.push_back(memcpy_async); new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async); memcpy_async_list.push_back(memcpy_async);
} else { } else {


+ 34
- 1
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -29,6 +29,7 @@
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
#include "common/trans.h" #include "common/trans.h"
#include "abstract/param_validator.h" #include "abstract/param_validator.h"
#include "pipeline/jit/static_analysis/static_analysis.h"


namespace mindspore { namespace mindspore {
namespace session { namespace session {
@@ -1279,7 +1280,8 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri
} }


bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape);
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) ||
GetBooleanAttr(node, kAttrIsDynamicShape);
} }


void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape, void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape,
@@ -1358,5 +1360,36 @@ std::vector<int> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_no
MS_LOG(EXCEPTION) << "Invalid Shape Type"; MS_LOG(EXCEPTION) << "Invalid Shape Type";
} }
} }

bool CheckDynamic(const NotNull<abstract::ShapePtr> &shape) {
return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; });
}

bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();
if (base_shape == nullptr) {
MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
return false;
}
if (base_shape->isa<abstract::Shape>()) {
if (CheckDynamic(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
return true;
}
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
for (size_t i = 0; i < tuple_shape->size(); i++) {
auto b_shape = (*tuple_shape)[i];
if (!b_shape->isa<abstract::Shape>()) {
continue;
}
if (CheckDynamic(NOT_NULL(b_shape->cast<abstract::ShapePtr>()))) {
return true;
}
}
}
return false;
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

+ 1
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -229,6 +229,7 @@ class AnfRuntimeAlgorithm {
static std::vector<int> GetInputMinShape(const AnfNodePtr &anf_node, size_t index); static std::vector<int> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index); static std::vector<int> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); static std::vector<int> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
static bool IsNodeDynamicShape(const AnfNodePtr &node);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;


+ 3
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -221,6 +221,9 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr
AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);


AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

template <typename T> template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict. // Inputs: a tuple or list or dict.


+ 10
- 0
mindspore/core/abstract/prim_others.cc View File

@@ -430,5 +430,15 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv
tmp_shape[0] = IntMulWithOverflowCheck(tmp_shape[0], rank_size); tmp_shape[0] = IntMulWithOverflowCheck(tmp_shape[0], rank_size);
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape)); return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
} }

AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -127,6 +127,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimBroadcast, {InferImplBroadcast, true}}, {prim::kPrimBroadcast, {InferImplBroadcast, true}},
{prim::kPrimAllGather, {InferImplAllGather, true}}, {prim::kPrimAllGather, {InferImplAllGather, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
}; };
return prim_eval_implement_map; return prim_eval_implement_map;
} }


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -181,6 +181,7 @@ inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduc
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather"); inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");
inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async");


// RowTensor // RowTensor
inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor"); inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor");


Loading…
Cancel
Save