Browse Source

!11075 dynamic op re primitive when infer

From: @liubuyu
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
92a85d1061
12 changed files with 68 additions and 38 deletions
  1. +0
    -2
      mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc
  2. +7
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  3. +6
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  4. +8
    -0
      mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc
  5. +3
    -0
      mindspore/ccsrc/utils/utils.h
  6. +12
    -12
      mindspore/core/abstract/infer_functions.h
  7. +9
    -9
      mindspore/core/abstract/prim_arrays.cc
  8. +3
    -3
      mindspore/core/abstract/prim_maths.cc
  9. +6
    -6
      mindspore/core/abstract/prim_others.cc
  10. +6
    -6
      mindspore/core/abstract/primitive_infer_map.cc
  11. +2
    -0
      mindspore/core/abstract/utils.h
  12. +6
    -0
      mindspore/core/base/core_ops.h

+ 0
- 2
mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc View File

@@ -43,8 +43,6 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
todos.push_back(node);
}

std::set<string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
ConstInputToAttrInfoRegister reg;


+ 7
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1636,6 +1636,13 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
args_spec_list.emplace_back(real_input->abstract());
}
}
auto prim_name = primitive->name();
if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) {
auto attrs = primitive->attrs();
auto new_prim_name = "Dynamic" + prim_name;
primitive = std::make_shared<Primitive>(new_prim_name);
primitive->SetAttrs(attrs);
}
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(primitive);
if (ret == prim_eval_implement_map.end()) {


+ 6
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -777,6 +777,12 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(py_shape);
auto py_shape_info = py_shape->ToString();
if (py_shape_info.find("-1") != string::npos) {
if (DynamicShapeConstInputToAttr.find(op_name) != DynamicShapeConstInputToAttr.end()) {
auto new_prim_name = "Dynamic" + op_name;
auto attrs = prim->attrs();
prim = std::make_shared<PrimitivePy>(new_prim_name, py::object());
prim->SetAttrs(attrs);
}
auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
MS_EXCEPTION_IF_NULL(c_abstract);
auto c_shape = c_abstract->BuildShape();


+ 8
- 0
mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc View File

@@ -23,6 +23,7 @@
#include "common/trans.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "abstract/dshape.h"
#include "utils/utils.h"
#include "abstract/param_validator.h"

namespace mindspore {
@@ -123,6 +124,13 @@ void DynamicKernel::InferShape() {
args_spec_list.emplace_back(real_input->abstract());
}
}
auto prim_name = primitive->name();
if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) {
auto new_prim_name = "Dynamic" + prim_name;
auto attrs = primitive->attrs();
primitive = std::make_shared<Primitive>(new_prim_name);
primitive->SetAttrs(attrs);
}

auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
cnode_ptr_->set_abstract(eval_result);


+ 3
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -491,6 +491,9 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalH

const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};

const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};

static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
try {
if (chmod(file_name.c_str(), mode) != 0) {


+ 12
- 12
mindspore/core/abstract/infer_functions.h View File

@@ -251,30 +251,30 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 9
- 9
mindspore/core/abstract/prim_arrays.cc View File

@@ -658,9 +658,9 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv
}
}

AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name().substr(kDynamic);
CheckArgsSize(op_name, args_spec_list, 2);
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto params_shp = params->shape();
@@ -754,9 +754,9 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<AbstractTensor>(input_x->element(), output_shape);
}

AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name().substr(kDynamic);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto input_shp = input->shape()->shape();
ValuePtr perm = primitive->GetAttr("perm");
@@ -781,9 +781,9 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
}

AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name().substr(kDynamic);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());


+ 3
- 3
mindspore/core/abstract/prim_maths.cc View File

@@ -121,9 +121,9 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
return ret;
}

AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name().substr(kDynamic);
CheckArgsSize(op_name, args_spec_list, 1);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);


+ 6
- 6
mindspore/core/abstract/prim_others.cc View File

@@ -479,9 +479,9 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
}

AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name().substr(kDynamic);
// GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize.
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
@@ -491,9 +491,9 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
return ret;
}

AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name().substr(kDynamic);
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);


+ 6
- 6
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -44,7 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
{prim::kPrimSub, {InferImplSub, true}},
{prim::kPrimEqual, {InferImplEqual, true}},
{prim::kPrimReduceSum, {InferImplReduceSum, true}},
{prim::kPrimDynamicReduceSum, {InferImplDynamicReduceSum, true}},
{prim::kPrimMinimum, {InferImplMinimum, true}},
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
@@ -59,7 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimDynamicEmbeddingLookup, {InferImplDynamicEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
@@ -76,8 +76,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimDynamicTranspose, {InferImplDynamicTranspose, true}},
{prim::kPrimDynamicReshape, {InferImplDynamicReshape, true}},
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
@@ -157,8 +157,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
{prim::kPrimCast, {InferImplCast, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
{prim::kPrimDynamicCast, {InferImplDynamicCast, true}},
{prim::kPrimDynamicExpandDims, {InferImplDynamicExpandDims, true}},
};
return prim_eval_implement_map;
}


+ 2
- 0
mindspore/core/abstract/utils.h View File

@@ -29,6 +29,8 @@
#include "utils/shape_utils.h"

namespace mindspore {
// length of string "dynamic"
const int kDynamic = 7;
namespace abstract {
ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2);
TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2);


+ 6
- 0
mindspore/core/base/core_ops.h View File

@@ -80,6 +80,12 @@ inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("bro
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
inline const PrimitivePtr kPrimDynamicCast = std::make_shared<Primitive>("DynamicCast");
inline const PrimitivePtr kPrimDynamicReshape = std::make_shared<Primitive>("DynamicReshape");
inline const PrimitivePtr kPrimDynamicReduceSum = std::make_shared<Primitive>("DynamicReduceSum");
inline const PrimitivePtr kPrimDynamicTranspose = std::make_shared<Primitive>("DynamicTranspose");
inline const PrimitivePtr kPrimDynamicExpandDims = std::make_shared<Primitive>("DynamicExpandDims");
inline const PrimitivePtr kPrimDynamicEmbeddingLookup = std::make_shared<Primitive>("DynamicEmbeddingLookup");
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");


Loading…
Cancel
Save