From: @lianliguang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -366,7 +366,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| target_link_libraries(mindspore mindspore::pybind11_module) | |||
| target_link_libraries(mindspore mindspore_gvar) | |||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load) | |||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) | |||
| else() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(mindspore proto_input mindspore::protobuf | |||
| @@ -376,7 +376,8 @@ else() | |||
| target_link_libraries(mindspore ibverbs rdmacm) | |||
| endif() | |||
| endif() | |||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive) | |||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core | |||
| proto_input -Wl,--no-whole-archive) | |||
| target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) | |||
| target_link_libraries(_c_expression PRIVATE mindspore_gvar) | |||
| if(ENABLE_D) | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||
| #include <utility> | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| @@ -75,4 +75,4 @@ struct ConstInputToAttrInfoReceiver { | |||
| ::mindspore::opt::ConstInputToAttrInfoRegister(op_name) | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||
| @@ -31,6 +31,8 @@ | |||
| #include "utils/ms_utils.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -700,6 +702,92 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive | |||
| } | |||
| return CreateCNodeWithGraph(input_nodes, graph); | |||
| } | |||
| // rectify absttract if the input has been converted to the attr | |||
| AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &input_abstract) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| opt::ConstInputToAttrInfoRegister reg; | |||
| if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) { | |||
| return input_abstract; | |||
| } | |||
| if (AnfAlgo::HasDynamicShapeFlag(primitive) || | |||
| DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) { | |||
| return input_abstract; | |||
| } | |||
| auto convert_input_list = reg.GetConstInputAttrInfo(); | |||
| auto input_names = primitive->GetAttr(kAttrInputNames); | |||
| if (input_names == nullptr) { | |||
| return input_abstract; | |||
| } | |||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names); | |||
| AbstractBasePtrList rectify_abs_list; | |||
| size_t ori_index = 0; | |||
| rectify_abs_list.resize(input_names_vec.size()); | |||
| for (size_t index = 0; index < rectify_abs_list.size(); ++index) { | |||
| // if convert input list find the index it means the input has been converted to the attr | |||
| if (convert_input_list.find(index) != convert_input_list.end()) { | |||
| AbstractBasePtr rectify_abs = nullptr; | |||
| auto input_name = input_names_vec[index]; | |||
| auto attr = primitive->GetAttr(input_name); | |||
| if (attr != nullptr) { | |||
| rectify_abs = attr->ToAbstract(); | |||
| } else { | |||
| MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index | |||
| << " input name :" << input_name << "has not been converted to the attr"; | |||
| rectify_abs = input_abstract[ori_index++]; | |||
| } | |||
| rectify_abs_list[index] = rectify_abs; | |||
| continue; | |||
| } | |||
| if (ori_index > input_abstract.size()) { | |||
| MS_LOG(EXCEPTION) << "index is out of range input abstract size " << input_abstract.size() | |||
| << " get index :" << ori_index; | |||
| } | |||
| rectify_abs_list[index] = input_abstract[ori_index++]; | |||
| } | |||
| return rectify_abs_list; | |||
| } | |||
| AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &input_abstract) { | |||
| auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes); | |||
| if (dynamic_inputs_list == nullptr) { | |||
| return input_abstract; | |||
| } | |||
| AbstractBasePtrList rectifyed_abs_list; | |||
| const int kNotDynamicFlag = -1; | |||
| auto dynamic_inputs_index = GetValue<std::vector<int64_t>>(dynamic_inputs_list); | |||
| size_t input_index = 0; | |||
| for (auto item : dynamic_inputs_index) { | |||
| if (item == kNotDynamicFlag) { | |||
| if (input_index >= input_abstract.size()) { | |||
| MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size(); | |||
| } | |||
| rectifyed_abs_list.emplace_back(input_abstract[input_index++]); | |||
| } else { | |||
| if (item < 0) { | |||
| MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got " | |||
| << item; | |||
| } | |||
| AbstractBasePtrList dynamic_inputs_abs; | |||
| for (auto index = item; index > 0; --index) { | |||
| if (input_index >= input_abstract.size()) { | |||
| MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " | |||
| << input_abstract.size(); | |||
| } | |||
| dynamic_inputs_abs.emplace_back(input_abstract[input_index++]); | |||
| } | |||
| rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs)); | |||
| } | |||
| } | |||
| return rectifyed_abs_list; | |||
| } | |||
| AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) { | |||
| auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract); | |||
| return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list); | |||
| } | |||
| } // namespace | |||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { | |||
| @@ -835,5 +923,24 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C | |||
| } | |||
| } | |||
| } | |||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); | |||
| auto ret = prim_eval_implement_map.find(prim); | |||
| if (ret != prim_eval_implement_map.end()) { | |||
| // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr | |||
| auto infer_spec_list = RectifyAbstract(prim, args_spec_list); | |||
| return ret->second.impl_(nullptr, prim, infer_spec_list); | |||
| } else { | |||
| // if the infer function has been not founded in the front infer map find it in the backend infer map instead | |||
| auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); | |||
| auto ret_backend = prim_backend_eval_impl_map.find(prim); | |||
| if (ret_backend != prim_backend_eval_impl_map.end()) { | |||
| return ret_backend->second.impl_(nullptr, prim, args_spec_list); | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | |||
| << " primitive type:" << prim->type_name(); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -212,6 +212,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | |||
| // Transfer depend or control_depend to the new node | |||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | |||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ | |||
| @@ -27,7 +27,7 @@ | |||
| #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/func_graph.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -1534,6 +1534,18 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri | |||
| return AnfAlgo::GetNodeAttr<bool>(node, attr); | |||
| } | |||
| bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) { | |||
| auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (!primitive->HasAttr(attr_name)) { | |||
| return false; | |||
| } | |||
| return GetValue<bool>(primitive->GetAttr(attr_name)); | |||
| }; | |||
| return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) || | |||
| get_bool_attr(prim, kAttrIsDynamicShape); | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { | |||
| return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) || | |||
| GetBooleanAttr(node, kAttrIsDynamicShape); | |||
| @@ -1805,7 +1817,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { | |||
| args_spec_list.emplace_back(real_input->abstract()); | |||
| } | |||
| } | |||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | |||
| auto eval_result = opt::CppInferShape(primitive, args_spec_list); | |||
| node->set_abstract(eval_result); | |||
| } | |||
| } // namespace session | |||
| @@ -230,6 +230,7 @@ class AnfRuntimeAlgorithm { | |||
| // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | |||
| static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); | |||
| static bool IsDynamicShape(const AnfNodePtr &node); | |||
| static bool HasDynamicShapeFlag(const PrimitivePtr &prim); | |||
| static bool IsCondControlKernel(const CNodePtr &node); | |||
| static bool IsIndependentNode(const CNodePtr &node); | |||
| static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); | |||
| @@ -1311,15 +1311,6 @@ bool IsInWhiteList(const PrimitivePtr &primitive) { | |||
| return false; | |||
| } | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto iter = GetPrimitiveToEvalImplMap().find(primitive); | |||
| if (iter == GetPrimitiveToEvalImplMap().end()) { | |||
| return nullptr; | |||
| } | |||
| return iter->second.impl_; | |||
| } | |||
| PrimEvaluatorMap &GetPrimEvaluatorConstructors() { | |||
| PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; | |||
| if (!constructor.empty()) { | |||
| @@ -112,7 +112,6 @@ class MixedPrecisionCastEvaluator : public Evaluator { | |||
| }; | |||
| bool IsInWhiteList(const PrimitivePtr &primitive); | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||
| using ValuePtrList = std::vector<ValuePtr>; | |||
| using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); | |||
| @@ -357,6 +357,13 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| return std::make_shared<MixedPrecisionCastEvaluator>(prim); | |||
| } | |||
| // find prim infer function in the prim function map return a standard evaluator | |||
| StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); | |||
| if (eval_impl != nullptr) { | |||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | |||
| } | |||
| // use python infer function if the infer function not founded in the map return a python evaluator | |||
| EvaluatorPtr evaluator = nullptr; | |||
| if (prim->HasPyEvaluator()) { | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | |||
| @@ -376,17 +383,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||
| MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; | |||
| } | |||
| if (prim->isa<PrimitivePy>() || prim->HasAttr()) { | |||
| if (engine == nullptr) { | |||
| (void)GetPrimEvaluatorConstructors(); | |||
| } | |||
| // If a primitive may have attr, try to create a new evaluator. | |||
| StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); | |||
| if (eval_impl != nullptr) { | |||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | |||
| } | |||
| } | |||
| // return a default evaluator | |||
| if (engine == nullptr) { | |||
| // If engine is nullptr, get constructor from default. | |||
| const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); | |||
| @@ -778,16 +775,5 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi | |||
| auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); | |||
| return eval_result; | |||
| } | |||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap(); | |||
| auto ret = prim_eval_implement_map.find(prim); | |||
| if (ret == prim_eval_implement_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | |||
| << " primitive type:" << prim->type_name(); | |||
| } | |||
| return ret->second.impl_(nullptr, prim, args_spec_list); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -331,8 +331,6 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { | |||
| } | |||
| EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | |||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -44,7 +44,7 @@ | |||
| #include "pipeline/jit/static_analysis/prim.h" | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "pipeline/jit/action.h" | |||
| @@ -807,21 +807,13 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, | |||
| } | |||
| } | |||
| // get output dynamic shape info | |||
| auto py_abstract = op_exec_info->abstract; | |||
| MS_EXCEPTION_IF_NULL(py_abstract); | |||
| auto py_shape = py_abstract->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(py_shape); | |||
| auto py_shape_info = py_shape->ToString(); | |||
| if (py_shape_info.find("-1") != string::npos) { | |||
| auto c_abstract = abstract::CppInferShape(prim, args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(c_abstract); | |||
| auto c_shape = c_abstract->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(c_shape); | |||
| auto c_shape_info = c_shape->ToString(); | |||
| MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info; | |||
| if (c_shape_info.find("-1") != string::npos) { | |||
| op_exec_info->is_dynamic_shape = true; | |||
| } | |||
| auto abstract = op_exec_info->abstract; | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto shape = abstract->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| auto shape_info = shape->ToString(); | |||
| if (shape_info.find("-1") != string::npos) { | |||
| op_exec_info->is_dynamic_shape = true; | |||
| } | |||
| } | |||
| @@ -123,7 +123,7 @@ void DynamicKernel::InferShape() { | |||
| } | |||
| } | |||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | |||
| auto eval_result = opt::CppInferShape(primitive, args_spec_list); | |||
| cnode_ptr_->set_abstract(eval_result); | |||
| } | |||
| @@ -1041,6 +1041,9 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| if (args_spec_list.size() == 1) { | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| @@ -292,24 +292,6 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi | |||
| return std::make_shared<AbstractTuple>(rets); | |||
| } | |||
| AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[1]); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[3]); | |||
| CheckArgsSize(primitive->name(), args_spec_list, 5); | |||
| auto dx = args_spec_list[1]->Broaden(); | |||
| auto dscale = args_spec_list[2]->Broaden(); | |||
| auto dbias = args_spec_list[3]->Broaden(); | |||
| auto reserve_1 = args_spec_list[4]->Broaden(); | |||
| auto reserve_2 = args_spec_list[5]->Broaden(); | |||
| AbstractBasePtrList rets = {dx, dscale, dbias, reserve_1, reserve_2}; | |||
| return std::make_shared<AbstractTuple>(rets); | |||
| } | |||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: two tensors(y_backprop, x). | |||
| @@ -468,20 +450,6 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| return std::make_shared<AbstractTensor>(x_type, output_shape_ptr); | |||
| } | |||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three tensors(doutput, input, filters). | |||
| CheckRequiredArgsSize(primitive->name(), args_spec_list, 3); | |||
| return args_spec_list[1]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three tensors(inputs, filter, doutput). | |||
| CheckArgsSize(primitive->name(), args_spec_list, 3); | |||
| return args_spec_list[2]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| @@ -17,6 +17,11 @@ | |||
| */ | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "abstract/abstract_function.h" | |||
| #include "abstract/infer_functions.h" | |||
| @@ -59,40 +64,21 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | |||
| {prim::kPrimIsConsant, {InferImplIsConstant, true}}, | |||
| // Maths | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMul, {InferImplMul, true}}, | |||
| {prim::kPrimAdd, {InferImplAdd, true}}, | |||
| {prim::kPrimSquare, {InferImplSquare, true}}, | |||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||
| {prim::kPrimSub, {InferImplSub, true}}, | |||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||
| {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | |||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||
| {prim::kPrimAddN, {InferImplAddN, true}}, | |||
| {prim::kPrimMatMul, {InferImplMatMul, true}}, | |||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, | |||
| {prim::kPrimLess, {InferImplLess, true}}, | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||
| // Array | |||
| {prim::kPrimRange, {InferImplRange, true}}, | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | |||
| {prim::kPrimStack, {InferImplStack, true}}, | |||
| {prim::kPrimPad, {InferImplPad, true}}, | |||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| {prim::kPrimGather, {InferImplGatherV2, true}}, | |||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | |||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | |||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | |||
| @@ -104,18 +90,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | |||
| {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, | |||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, | |||
| {prim::kPrimDiv, {InferImplDiv, true}}, | |||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||
| {prim::kPrimShape, {InferImplShape, false}}, | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, true}}, | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | |||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||
| {prim::kPrimRange, {InferImplRange, true}}, | |||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| @@ -139,14 +117,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimPooling, {InferImplPooling, true}}, | |||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | |||
| {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | |||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | |||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | |||
| {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, | |||
| {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, | |||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | |||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | |||
| {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, | |||
| {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, | |||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, | |||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | |||
| {prim::kPrimRelu, {InferImplRelu, true}}, | |||
| @@ -192,18 +166,60 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | |||
| {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, | |||
| // Comm Ops | |||
| {prim::kPrimAllReduce, {InferImplAllReduce, true}}, | |||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | |||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||
| {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | |||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { | |||
| static PrimitiveEvalImplMap prim_backend_eval_implement_map = { | |||
| {prim::kPrimMul, {InferImplMul, true}}, | |||
| {prim::kPrimAdd, {InferImplAdd, true}}, | |||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||
| {prim::kPrimSub, {InferImplSub, true}}, | |||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||
| {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||
| {prim::kPrimCast, {InferImplCast, true}}, | |||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | |||
| {prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}}, | |||
| {prim::kPrimDType, {InferImplDType, true}}, | |||
| {prim::kPrimAllReduce, {InferImplAllReduce, true}}, | |||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | |||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | |||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||
| {prim::kPrimAddN, {InferImplAddN, true}}, | |||
| {prim::kPrimLess, {InferImplLess, true}}, | |||
| {prim::kPrimStack, {InferImplStack, true}}, | |||
| {prim::kPrimPad, {InferImplPad, true}}, | |||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||
| {prim::kPrimDiv, {InferImplDiv, true}}, | |||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||
| {prim::kPrimShape, {InferImplShape, false}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, | |||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| return prim_backend_eval_implement_map; | |||
| } | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto iter = GetPrimitiveToEvalImplMap().find(primitive); | |||
| if (iter == GetPrimitiveToEvalImplMap().end()) { | |||
| return nullptr; | |||
| } | |||
| return iter->second.impl_; | |||
| } | |||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { | |||
| @@ -18,6 +18,7 @@ | |||
| #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ir/primitive.h" | |||
| #include "base/core_ops.h" | |||
| #include "abstract/abstract_value.h" | |||
| @@ -37,6 +38,10 @@ using PrimitiveEvalImplMap = | |||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); | |||
| PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap(); | |||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode); | |||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); | |||
| @@ -104,6 +104,5 @@ AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const Pr | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolFusion, prim::kPrimMaxPool, MaxPoolFusionInfer); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -31,8 +31,6 @@ AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||
| auto element = tensor_type->element(); | |||
| return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGrad, prim::kPrimAvgPoolGrad, AvgPoolGradInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -58,8 +58,6 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||
| return std::make_shared<abstract::AbstractTensor>(intype, inshape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer); | |||
| REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -31,8 +31,6 @@ AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||
| auto element = tensor_type->element(); | |||
| return std::make_shared<abstract::AbstractTensor>(element, x1_shape); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGrad, prim::kPrimMaxPoolGrad, MaxPoolGradInfer); | |||
| REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -102,7 +102,6 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer); | |||
| REGISTER_PRIMITIVE_C(kNameLRN, LRN); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "common/common_test.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr auto kAttrConvertTestName = "attr_convert_test"; | |||
| constexpr auto kDynamicInputTestName = "dynamic_input_test"; | |||
| inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAttrConvertTestName); | |||
| inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test"); | |||
| AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| EXPECT_EQ(args_spec_list.size(), 3); | |||
| EXPECT_NE(args_spec_list[1], nullptr); | |||
| EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); | |||
| return args_spec_list[0]; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest); | |||
| AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| EXPECT_EQ(args_spec_list.size(), 3); | |||
| EXPECT_NE(args_spec_list[1], nullptr); | |||
| EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); | |||
| auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>(); | |||
| return args_spec_list[0]; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest); | |||
| class TestAttrAndDynamicBackendInfer : public UT::Common { | |||
| public: | |||
| TestAttrAndDynamicBackendInfer() {} | |||
| void SetUp() override {} | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestAttrAndDynamicBackendInfer, test_attr_and_dynamic_input_infer) { | |||
| // Register Attr for ut | |||
| ConstInputToAttrInfoRegistry ® = ConstInputToAttrInfoRegistry::Instance(); | |||
| reg.Register(kAttrConvertTestName, {1}); | |||
| // construct primitive | |||
| PrimitivePtr prim_attr_test = std::make_shared<Primitive>(kAttrConvertTestName); | |||
| PrimitivePtr prim_dynamic_input_test = std::make_shared<Primitive>(kDynamicInputTestName); | |||
| // set primtive attr | |||
| auto input_names = std::vector<std::string>{"a", "b", "c"}; | |||
| auto attr_name = "b"; | |||
| auto attr = MakeValue(std::vector<int>{1, 2, 3}); | |||
| prim_attr_test->AddAttr(kAttrInputNames, MakeValue(input_names)); | |||
| prim_attr_test->AddAttr(attr_name, attr); | |||
| // set dynameic input list for primtive | |||
| std::vector<int64_t> dynamic_input_list = {-1, 2, -1}; | |||
| prim_dynamic_input_test->AddAttr(kAttrDynInputSizes, MakeValue(dynamic_input_list)); | |||
| // construct Abstract list | |||
| auto abs_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||
| auto abs_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||
| auto attr_infer_result = CppInferShape(prim_attr_test, {abs_a, abs_c}); | |||
| auto abs_dynamic_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||
| auto abs_dynamic_b = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||
| auto abs_dynamic_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||
| auto abs_dynamic_d = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||
| auto dynamic_infer_result = | |||
| CppInferShape(prim_dynamic_input_test, {abs_dynamic_a, abs_dynamic_b, abs_dynamic_c, abs_dynamic_d}); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||