From: @liubuyu Reviewed-by: @zhoufeng54,@kisnwang Signed-off-by: @kisnwangtags/v1.2.0-rc1
| @@ -19,9 +19,11 @@ import subprocess | |||||
| import sys | import sys | ||||
| import os | import os | ||||
| import json | import json | ||||
| from mindspore import log as logger | |||||
| from .common import check_kernel_info, TBEException | from .common import check_kernel_info, TBEException | ||||
| from .helper import _op_select_format, _check_supported | from .helper import _op_select_format, _check_supported | ||||
| def create_tbe_parallel_process(): | def create_tbe_parallel_process(): | ||||
| """ | """ | ||||
| create TBEParallelCompiler object | create TBEParallelCompiler object | ||||
| @@ -31,6 +33,7 @@ def create_tbe_parallel_process(): | |||||
| """ | """ | ||||
| return tbe_process | return tbe_process | ||||
| def op_select_format(op_json: str): | def op_select_format(op_json: str): | ||||
| """ | """ | ||||
| call op's op_select_format to get op supported format | call op's op_select_format to get op supported format | ||||
| @@ -75,6 +78,7 @@ def check_supported(op_json: str): | |||||
| return ret | return ret | ||||
| def run_compiler(op_json): | def run_compiler(op_json): | ||||
| """ | """ | ||||
| run compiler to compile op with subprocess | run compiler to compile op with subprocess | ||||
| @@ -96,6 +100,7 @@ def run_compiler(op_json): | |||||
| except subprocess.CalledProcessError as e: | except subprocess.CalledProcessError as e: | ||||
| return "TBEException", "ERROR:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json | return "TBEException", "ERROR:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json | ||||
| class TbeProcess: | class TbeProcess: | ||||
| """tbe process""" | """tbe process""" | ||||
| @@ -106,9 +111,11 @@ class TbeProcess: | |||||
| process_num = os.getenv("MS_BUILD_PROCESS_NUM") | process_num = os.getenv("MS_BUILD_PROCESS_NUM") | ||||
| if process_num is None: | if process_num is None: | ||||
| self.max_processes_num = 24 | self.max_processes_num = 24 | ||||
| logger.info(f"Using default compile process num {self.max_processes_num}") | |||||
| elif process_num.isdigit(): | elif process_num.isdigit(): | ||||
| if int(process_num) in range(1, 25): | if int(process_num) in range(1, 25): | ||||
| self.max_processes_num = int(process_num) | self.max_processes_num = int(process_num) | ||||
| logger.info(f"Using custom compile process num {self.max_processes_num}") | |||||
| else: | else: | ||||
| raise EnvironmentError( | raise EnvironmentError( | ||||
| f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but: {process_num}") | f"Env ERROR, [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but: {process_num}") | ||||
| @@ -177,4 +184,5 @@ class TbeProcess: | |||||
| if self.__running_tasks: | if self.__running_tasks: | ||||
| self.__running_tasks.clear() | self.__running_tasks.clear() | ||||
| tbe_process = TbeProcess() | tbe_process = TbeProcess() | ||||
| @@ -1636,13 +1636,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { | |||||
| args_spec_list.emplace_back(real_input->abstract()); | args_spec_list.emplace_back(real_input->abstract()); | ||||
| } | } | ||||
| } | } | ||||
| auto prim_name = primitive->name(); | |||||
| if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) { | |||||
| auto attrs = primitive->attrs(); | |||||
| auto new_prim_name = "Dynamic" + prim_name; | |||||
| primitive = std::make_shared<Primitive>(new_prim_name); | |||||
| primitive->SetAttrs(attrs); | |||||
| } | |||||
| auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); | auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); | ||||
| auto ret = prim_eval_implement_map.find(primitive); | auto ret = prim_eval_implement_map.find(primitive); | ||||
| if (ret == prim_eval_implement_map.end()) { | if (ret == prim_eval_implement_map.end()) { | ||||
| @@ -771,12 +771,6 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, | |||||
| MS_EXCEPTION_IF_NULL(py_shape); | MS_EXCEPTION_IF_NULL(py_shape); | ||||
| auto py_shape_info = py_shape->ToString(); | auto py_shape_info = py_shape->ToString(); | ||||
| if (py_shape_info.find("-1") != string::npos) { | if (py_shape_info.find("-1") != string::npos) { | ||||
| if (DynamicShapeConstInputToAttr.find(op_name) != DynamicShapeConstInputToAttr.end()) { | |||||
| auto new_prim_name = "Dynamic" + op_name; | |||||
| auto attrs = prim->attrs(); | |||||
| prim = std::make_shared<PrimitivePy>(new_prim_name, py::object()); | |||||
| prim->SetAttrs(attrs); | |||||
| } | |||||
| auto c_abstract = abstract::CppInferShape(prim, args_spec_list); | auto c_abstract = abstract::CppInferShape(prim, args_spec_list); | ||||
| MS_EXCEPTION_IF_NULL(c_abstract); | MS_EXCEPTION_IF_NULL(c_abstract); | ||||
| auto c_shape = c_abstract->BuildShape(); | auto c_shape = c_abstract->BuildShape(); | ||||
| @@ -124,13 +124,6 @@ void DynamicKernel::InferShape() { | |||||
| args_spec_list.emplace_back(real_input->abstract()); | args_spec_list.emplace_back(real_input->abstract()); | ||||
| } | } | ||||
| } | } | ||||
| auto prim_name = primitive->name(); | |||||
| if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) { | |||||
| auto new_prim_name = "Dynamic" + prim_name; | |||||
| auto attrs = primitive->attrs(); | |||||
| primitive = std::make_shared<Primitive>(new_prim_name); | |||||
| primitive->SetAttrs(attrs); | |||||
| } | |||||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | ||||
| cnode_ptr_->set_abstract(eval_result); | cnode_ptr_->set_abstract(eval_result); | ||||
| @@ -251,30 +251,30 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -658,9 +658,9 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv | |||||
| } | } | ||||
| } | } | ||||
| AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name().substr(kDynamic); | |||||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | CheckArgsSize(op_name, args_spec_list, 2); | ||||
| auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| auto params_shp = params->shape(); | auto params_shp = params->shape(); | ||||
| @@ -754,9 +754,9 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| return std::make_shared<AbstractTensor>(input_x->element(), output_shape); | return std::make_shared<AbstractTensor>(input_x->element(), output_shape); | ||||
| } | } | ||||
| AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string &op_name = primitive->name().substr(kDynamic); | |||||
| AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string &op_name = primitive->name(); | |||||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| auto input_shp = input->shape()->shape(); | auto input_shp = input->shape()->shape(); | ||||
| ValuePtr perm = primitive->GetAttr("perm"); | ValuePtr perm = primitive->GetAttr("perm"); | ||||
| @@ -781,9 +781,9 @@ AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const Primi | |||||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp)); | return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp)); | ||||
| } | } | ||||
| AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name().substr(kDynamic); | |||||
| AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| MS_EXCEPTION_IF_NULL(x); | MS_EXCEPTION_IF_NULL(x); | ||||
| MS_EXCEPTION_IF_NULL(x->shape()); | MS_EXCEPTION_IF_NULL(x->shape()); | ||||
| @@ -121,9 +121,9 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name().substr(kDynamic); | |||||
| AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | CheckArgsSize(op_name, args_spec_list, 1); | ||||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| MS_EXCEPTION_IF_NULL(input_x); | MS_EXCEPTION_IF_NULL(input_x); | ||||
| @@ -479,9 +479,9 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP | |||||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape())); | return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape())); | ||||
| } | } | ||||
| AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name().substr(kDynamic); | |||||
| AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| // GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize. | // GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize. | ||||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| MS_EXCEPTION_IF_NULL(input_x); | MS_EXCEPTION_IF_NULL(input_x); | ||||
| @@ -491,9 +491,9 @@ AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitiveP | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name().substr(kDynamic); | |||||
| AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | CheckArgsSize(op_name, args_spec_list, 1); | ||||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| MS_EXCEPTION_IF_NULL(x); | MS_EXCEPTION_IF_NULL(x); | ||||
| @@ -44,7 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | ||||
| {prim::kPrimSub, {InferImplSub, true}}, | {prim::kPrimSub, {InferImplSub, true}}, | ||||
| {prim::kPrimEqual, {InferImplEqual, true}}, | {prim::kPrimEqual, {InferImplEqual, true}}, | ||||
| {prim::kPrimDynamicReduceSum, {InferImplDynamicReduceSum, true}}, | |||||
| {prim::kPrimReduceSum, {InferImplReduceSum, true}}, | |||||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | {prim::kPrimMinimum, {InferImplMinimum, true}}, | ||||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | ||||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | ||||
| @@ -59,7 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | ||||
| {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | {prim::kPrimGatherV2, {InferImplGatherV2, true}}, | ||||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | ||||
| {prim::kPrimDynamicEmbeddingLookup, {InferImplDynamicEmbeddingLookup, true}}, | |||||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | ||||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | ||||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | ||||
| @@ -76,8 +76,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | ||||
| {prim::kPrimShape, {InferImplShape, false}}, | {prim::kPrimShape, {InferImplShape, false}}, | ||||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | ||||
| {prim::kPrimDynamicTranspose, {InferImplDynamicTranspose, true}}, | |||||
| {prim::kPrimDynamicReshape, {InferImplDynamicReshape, true}}, | |||||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | ||||
| {prim::kPrimSplit, {InferImplSplit, true}}, | {prim::kPrimSplit, {InferImplSplit, true}}, | ||||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | ||||
| @@ -157,8 +157,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | ||||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | ||||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | ||||
| {prim::kPrimDynamicCast, {InferImplDynamicCast, true}}, | |||||
| {prim::kPrimDynamicExpandDims, {InferImplDynamicExpandDims, true}}, | |||||
| {prim::kPrimCast, {InferImplCast, true}}, | |||||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | |||||
| }; | }; | ||||
| return prim_eval_implement_map; | return prim_eval_implement_map; | ||||
| } | } | ||||
| @@ -29,8 +29,6 @@ | |||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| // length of string "dynamic" | |||||
| const int kDynamic = 7; | |||||
| namespace abstract { | namespace abstract { | ||||
| ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2); | ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2); | ||||
| TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2); | TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2); | ||||
| @@ -80,12 +80,6 @@ inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("bro | |||||
| inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map"); | ||||
| inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce"); | ||||
| inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | ||||
| inline const PrimitivePtr kPrimDynamicCast = std::make_shared<Primitive>("DynamicCast"); | |||||
| inline const PrimitivePtr kPrimDynamicReshape = std::make_shared<Primitive>("DynamicReshape"); | |||||
| inline const PrimitivePtr kPrimDynamicReduceSum = std::make_shared<Primitive>("DynamicReduceSum"); | |||||
| inline const PrimitivePtr kPrimDynamicTranspose = std::make_shared<Primitive>("DynamicTranspose"); | |||||
| inline const PrimitivePtr kPrimDynamicExpandDims = std::make_shared<Primitive>("DynamicExpandDims"); | |||||
| inline const PrimitivePtr kPrimDynamicEmbeddingLookup = std::make_shared<Primitive>("DynamicEmbeddingLookup"); | |||||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | ||||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | ||||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | ||||