| @@ -43,8 +43,6 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An | |||
| todos.push_back(node); | |||
| } | |||
| std::set<string> DynamicShapeConstInputToAttr = { | |||
| kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName}; | |||
| for (auto &t : todos) { | |||
| CNodePtr cnode = t->cast<CNodePtr>(); | |||
| ConstInputToAttrInfoRegister reg; | |||
| @@ -1636,6 +1636,13 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { | |||
| args_spec_list.emplace_back(real_input->abstract()); | |||
| } | |||
| } | |||
| auto prim_name = primitive->name(); | |||
| if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) { | |||
| auto attrs = primitive->attrs(); | |||
| auto new_prim_name = "Dynamic" + prim_name; | |||
| primitive = std::make_shared<Primitive>(new_prim_name); | |||
| primitive->SetAttrs(attrs); | |||
| } | |||
| auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); | |||
| auto ret = prim_eval_implement_map.find(primitive); | |||
| if (ret == prim_eval_implement_map.end()) { | |||
| @@ -774,6 +774,12 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, | |||
| MS_EXCEPTION_IF_NULL(py_shape); | |||
| auto py_shape_info = py_shape->ToString(); | |||
| if (py_shape_info.find("-1") != string::npos) { | |||
| if (DynamicShapeConstInputToAttr.find(op_name) != DynamicShapeConstInputToAttr.end()) { | |||
| auto new_prim_name = "Dynamic" + op_name; | |||
| auto attrs = prim->attrs(); | |||
| prim = std::make_shared<PrimitivePy>(new_prim_name, py::object()); | |||
| prim->SetAttrs(attrs); | |||
| } | |||
| auto c_abstract = abstract::CppInferShape(prim, args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(c_abstract); | |||
| auto c_shape = c_abstract->BuildShape(); | |||
| @@ -23,6 +23,7 @@ | |||
| #include "common/trans.h" | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| #include "abstract/dshape.h" | |||
| #include "utils/utils.h" | |||
| #include "abstract/param_validator.h" | |||
| namespace mindspore { | |||
| @@ -123,6 +124,13 @@ void DynamicKernel::InferShape() { | |||
| args_spec_list.emplace_back(real_input->abstract()); | |||
| } | |||
| } | |||
| auto prim_name = primitive->name(); | |||
| if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) { | |||
| auto new_prim_name = "Dynamic" + prim_name; | |||
| auto attrs = primitive->attrs(); | |||
| primitive = std::make_shared<Primitive>(new_prim_name); | |||
| primitive->SetAttrs(attrs); | |||
| } | |||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | |||
| cnode_ptr_->set_abstract(eval_result); | |||
| @@ -490,6 +490,9 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalH | |||
| const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; | |||
| const std::set<std::string> DynamicShapeConstInputToAttr = { | |||
| kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName}; | |||
| static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | |||
| try { | |||
| if (chmod(file_name.c_str(), mode) != 0) { | |||
| @@ -249,30 +249,30 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -656,9 +656,9 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name().substr(kDynamic); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto params_shp = params->shape(); | |||
| @@ -752,9 +752,9 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return std::make_shared<AbstractTensor>(input_x->element(), output_shape); | |||
| } | |||
| AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name().substr(kDynamic); | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto input_shp = input->shape()->shape(); | |||
| ValuePtr perm = primitive->GetAttr("perm"); | |||
| @@ -779,9 +779,9 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp)); | |||
| } | |||
| AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name().substr(kDynamic); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| @@ -121,9 +121,9 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name().substr(kDynamic); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| @@ -479,9 +479,9 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape())); | |||
| } | |||
| AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name().substr(kDynamic); | |||
| // GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize. | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| @@ -491,9 +491,9 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name().substr(kDynamic); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| @@ -44,7 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||
| {prim::kPrimSub, {InferImplSub, true}}, | |||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||
| {prim::kPrimReduceSum, {InferImplReduceSum, true}}, | |||
| {prim::kPrimDynamicReduceSum, {InferImplDynamicReduceSum, true}}, | |||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | |||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||
| @@ -59,7 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||
| {prim::kPrimDynamicEmbeddingLookup, {InferImplDynamicEmbeddingLookup, true}}, | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | |||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | |||
| @@ -76,8 +76,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||
| {prim::kPrimShape, {InferImplShape, false}}, | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||
| {prim::kPrimDynamicTranspose, {InferImplDynamicTranspose, true}}, | |||
| {prim::kPrimDynamicReshape, {InferImplDynamicReshape, true}}, | |||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, true}}, | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | |||
| @@ -155,8 +155,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | |||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||
| {prim::kPrimCast, {InferImplCast, true}}, | |||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | |||
| {prim::kPrimDynamicCast, {InferImplDynamicCast, true}}, | |||
| {prim::kPrimDynamicExpandDims, {InferImplDynamicExpandDims, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| @@ -29,6 +29,8 @@ | |||
| #include "utils/shape_utils.h" | |||
| namespace mindspore { | |||
| // length of string "dynamic" | |||
| const int kDynamic = 7; | |||
| namespace abstract { | |||
| ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2); | |||
| TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2); | |||
| @@ -80,6 +80,12 @@ inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("bro | |||
| inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | |||
| inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | |||
| inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||
| inline const PrimitivePtr kPrimDynamicCast = std::make_shared<Primitive>("DynamicCast"); | |||
| inline const PrimitivePtr kPrimDynamicReshape = std::make_shared<Primitive>("DynamicReshape"); | |||
| inline const PrimitivePtr kPrimDynamicReduceSum = std::make_shared<Primitive>("DynamicReduceSum"); | |||
| inline const PrimitivePtr kPrimDynamicTranspose = std::make_shared<Primitive>("DynamicTranspose"); | |||
| inline const PrimitivePtr kPrimDynamicExpandDims = std::make_shared<Primitive>("DynamicExpandDims"); | |||
| inline const PrimitivePtr kPrimDynamicEmbeddingLookup = std::make_shared<Primitive>("DynamicEmbeddingLookup"); | |||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||