From: @lianliguang Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -366,7 +366,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows") | |||||
| elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | ||||
| target_link_libraries(mindspore mindspore::pybind11_module) | target_link_libraries(mindspore mindspore::pybind11_module) | ||||
| target_link_libraries(mindspore mindspore_gvar) | target_link_libraries(mindspore mindspore_gvar) | ||||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load) | |||||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) | |||||
| else() | else() | ||||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | ||||
| target_link_libraries(mindspore proto_input mindspore::protobuf | target_link_libraries(mindspore proto_input mindspore::protobuf | ||||
| @@ -376,7 +376,8 @@ else() | |||||
| target_link_libraries(mindspore ibverbs rdmacm) | target_link_libraries(mindspore ibverbs rdmacm) | ||||
| endif() | endif() | ||||
| endif() | endif() | ||||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive) | |||||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core | |||||
| proto_input -Wl,--no-whole-archive) | |||||
| target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) | target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) | ||||
| target_link_libraries(_c_expression PRIVATE mindspore_gvar) | target_link_libraries(_c_expression PRIVATE mindspore_gvar) | ||||
| if(ENABLE_D) | if(ENABLE_D) | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||||
| #include <utility> | #include <utility> | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| @@ -75,4 +75,4 @@ struct ConstInputToAttrInfoReceiver { | |||||
| ::mindspore::opt::ConstInputToAttrInfoRegister(op_name) | ::mindspore::opt::ConstInputToAttrInfoRegister(op_name) | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_ | |||||
| @@ -31,6 +31,8 @@ | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "runtime/device/kernel_info.h" | #include "runtime/device/kernel_info.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -700,6 +702,92 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive | |||||
| } | } | ||||
| return CreateCNodeWithGraph(input_nodes, graph); | return CreateCNodeWithGraph(input_nodes, graph); | ||||
| } | } | ||||
| // rectify absttract if the input has been converted to the attr | |||||
| AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &input_abstract) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| opt::ConstInputToAttrInfoRegister reg; | |||||
| if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) { | |||||
| return input_abstract; | |||||
| } | |||||
| if (AnfAlgo::HasDynamicShapeFlag(primitive) || | |||||
| DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) { | |||||
| return input_abstract; | |||||
| } | |||||
| auto convert_input_list = reg.GetConstInputAttrInfo(); | |||||
| auto input_names = primitive->GetAttr(kAttrInputNames); | |||||
| if (input_names == nullptr) { | |||||
| return input_abstract; | |||||
| } | |||||
| auto input_names_vec = GetValue<std::vector<std::string>>(input_names); | |||||
| AbstractBasePtrList rectify_abs_list; | |||||
| size_t ori_index = 0; | |||||
| rectify_abs_list.resize(input_names_vec.size()); | |||||
| for (size_t index = 0; index < rectify_abs_list.size(); ++index) { | |||||
| // if convert input list find the index it means the input has been converted to the attr | |||||
| if (convert_input_list.find(index) != convert_input_list.end()) { | |||||
| AbstractBasePtr rectify_abs = nullptr; | |||||
| auto input_name = input_names_vec[index]; | |||||
| auto attr = primitive->GetAttr(input_name); | |||||
| if (attr != nullptr) { | |||||
| rectify_abs = attr->ToAbstract(); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index | |||||
| << " input name :" << input_name << "has not been converted to the attr"; | |||||
| rectify_abs = input_abstract[ori_index++]; | |||||
| } | |||||
| rectify_abs_list[index] = rectify_abs; | |||||
| continue; | |||||
| } | |||||
| if (ori_index > input_abstract.size()) { | |||||
| MS_LOG(EXCEPTION) << "index is out of range input abstract size " << input_abstract.size() | |||||
| << " get index :" << ori_index; | |||||
| } | |||||
| rectify_abs_list[index] = input_abstract[ori_index++]; | |||||
| } | |||||
| return rectify_abs_list; | |||||
| } | |||||
| AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &input_abstract) { | |||||
| auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes); | |||||
| if (dynamic_inputs_list == nullptr) { | |||||
| return input_abstract; | |||||
| } | |||||
| AbstractBasePtrList rectifyed_abs_list; | |||||
| const int kNotDynamicFlag = -1; | |||||
| auto dynamic_inputs_index = GetValue<std::vector<int64_t>>(dynamic_inputs_list); | |||||
| size_t input_index = 0; | |||||
| for (auto item : dynamic_inputs_index) { | |||||
| if (item == kNotDynamicFlag) { | |||||
| if (input_index >= input_abstract.size()) { | |||||
| MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size(); | |||||
| } | |||||
| rectifyed_abs_list.emplace_back(input_abstract[input_index++]); | |||||
| } else { | |||||
| if (item < 0) { | |||||
| MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got " | |||||
| << item; | |||||
| } | |||||
| AbstractBasePtrList dynamic_inputs_abs; | |||||
| for (auto index = item; index > 0; --index) { | |||||
| if (input_index >= input_abstract.size()) { | |||||
| MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " | |||||
| << input_abstract.size(); | |||||
| } | |||||
| dynamic_inputs_abs.emplace_back(input_abstract[input_index++]); | |||||
| } | |||||
| rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs)); | |||||
| } | |||||
| } | |||||
| return rectifyed_abs_list; | |||||
| } | |||||
| AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) { | |||||
| auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract); | |||||
| return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list); | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { | AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { | ||||
| @@ -835,5 +923,24 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); | |||||
| auto ret = prim_eval_implement_map.find(prim); | |||||
| if (ret != prim_eval_implement_map.end()) { | |||||
| // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr | |||||
| auto infer_spec_list = RectifyAbstract(prim, args_spec_list); | |||||
| return ret->second.impl_(nullptr, prim, infer_spec_list); | |||||
| } else { | |||||
| // if the infer function has been not founded in the front infer map find it in the backend infer map instead | |||||
| auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); | |||||
| auto ret_backend = prim_backend_eval_impl_map.find(prim); | |||||
| if (ret_backend != prim_backend_eval_impl_map.end()) { | |||||
| return ret_backend->second.impl_(nullptr, prim, args_spec_list); | |||||
| } | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | |||||
| << " primitive type:" << prim->type_name(); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -212,6 +212,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | |||||
| // Transfer depend or control_depend to the new node | // Transfer depend or control_depend to the new node | ||||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | ||||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ | #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" | #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" | ||||
| #include "backend/kernel_compiler/kernel.h" | #include "backend/kernel_compiler/kernel.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "pipeline/jit/parse/python_adapter.h" | #include "pipeline/jit/parse/python_adapter.h" | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| @@ -1534,6 +1534,18 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri | |||||
| return AnfAlgo::GetNodeAttr<bool>(node, attr); | return AnfAlgo::GetNodeAttr<bool>(node, attr); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) { | |||||
| auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| if (!primitive->HasAttr(attr_name)) { | |||||
| return false; | |||||
| } | |||||
| return GetValue<bool>(primitive->GetAttr(attr_name)); | |||||
| }; | |||||
| return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) || | |||||
| get_bool_attr(prim, kAttrIsDynamicShape); | |||||
| } | |||||
| bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { | bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { | ||||
| return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) || | return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) || | ||||
| GetBooleanAttr(node, kAttrIsDynamicShape); | GetBooleanAttr(node, kAttrIsDynamicShape); | ||||
| @@ -1805,7 +1817,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { | |||||
| args_spec_list.emplace_back(real_input->abstract()); | args_spec_list.emplace_back(real_input->abstract()); | ||||
| } | } | ||||
| } | } | ||||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | |||||
| auto eval_result = opt::CppInferShape(primitive, args_spec_list); | |||||
| node->set_abstract(eval_result); | node->set_abstract(eval_result); | ||||
| } | } | ||||
| } // namespace session | } // namespace session | ||||
| @@ -230,6 +230,7 @@ class AnfRuntimeAlgorithm { | |||||
| // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | ||||
| static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); | static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); | ||||
| static bool IsDynamicShape(const AnfNodePtr &node); | static bool IsDynamicShape(const AnfNodePtr &node); | ||||
| static bool HasDynamicShapeFlag(const PrimitivePtr &prim); | |||||
| static bool IsCondControlKernel(const CNodePtr &node); | static bool IsCondControlKernel(const CNodePtr &node); | ||||
| static bool IsIndependentNode(const CNodePtr &node); | static bool IsIndependentNode(const CNodePtr &node); | ||||
| static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); | static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); | ||||
| @@ -1311,15 +1311,6 @@ bool IsInWhiteList(const PrimitivePtr &primitive) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto iter = GetPrimitiveToEvalImplMap().find(primitive); | |||||
| if (iter == GetPrimitiveToEvalImplMap().end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return iter->second.impl_; | |||||
| } | |||||
| PrimEvaluatorMap &GetPrimEvaluatorConstructors() { | PrimEvaluatorMap &GetPrimEvaluatorConstructors() { | ||||
| PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; | PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; | ||||
| if (!constructor.empty()) { | if (!constructor.empty()) { | ||||
| @@ -112,7 +112,6 @@ class MixedPrecisionCastEvaluator : public Evaluator { | |||||
| }; | }; | ||||
| bool IsInWhiteList(const PrimitivePtr &primitive); | bool IsInWhiteList(const PrimitivePtr &primitive); | ||||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||||
| using ValuePtrList = std::vector<ValuePtr>; | using ValuePtrList = std::vector<ValuePtr>; | ||||
| using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); | using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); | ||||
| @@ -357,6 +357,13 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||||
| return std::make_shared<MixedPrecisionCastEvaluator>(prim); | return std::make_shared<MixedPrecisionCastEvaluator>(prim); | ||||
| } | } | ||||
| // find prim infer function in the prim function map return a standard evaluator | |||||
| StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); | |||||
| if (eval_impl != nullptr) { | |||||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | |||||
| } | |||||
| // use python infer function if the infer function not founded in the map return a python evaluator | |||||
| EvaluatorPtr evaluator = nullptr; | EvaluatorPtr evaluator = nullptr; | ||||
| if (prim->HasPyEvaluator()) { | if (prim->HasPyEvaluator()) { | ||||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | auto prim_py = dyn_cast<PrimitivePy>(prim); | ||||
| @@ -376,17 +383,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||||
| MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; | MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; | ||||
| } | } | ||||
| if (prim->isa<PrimitivePy>() || prim->HasAttr()) { | |||||
| if (engine == nullptr) { | |||||
| (void)GetPrimEvaluatorConstructors(); | |||||
| } | |||||
| // If a primitive may have attr, try to create a new evaluator. | |||||
| StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); | |||||
| if (eval_impl != nullptr) { | |||||
| return std::make_shared<StandardPrimEvaluator>(prim, eval_impl); | |||||
| } | |||||
| } | |||||
| // return a default evaluator | |||||
| if (engine == nullptr) { | if (engine == nullptr) { | ||||
| // If engine is nullptr, get constructor from default. | // If engine is nullptr, get constructor from default. | ||||
| const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); | const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); | ||||
| @@ -778,16 +775,5 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi | |||||
| auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); | auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); | ||||
| return eval_result; | return eval_result; | ||||
| } | } | ||||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap(); | |||||
| auto ret = prim_eval_implement_map.find(prim); | |||||
| if (ret == prim_eval_implement_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() | |||||
| << " primitive type:" << prim->type_name(); | |||||
| } | |||||
| return ret->second.impl_(nullptr, prim, args_spec_list); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -331,8 +331,6 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { | |||||
| } | } | ||||
| EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); | ||||
| AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,7 +44,7 @@ | |||||
| #include "pipeline/jit/static_analysis/prim.h" | #include "pipeline/jit/static_analysis/prim.h" | ||||
| #include "pipeline/jit/static_analysis/auto_monad.h" | #include "pipeline/jit/static_analysis/auto_monad.h" | ||||
| #include "backend/session/session_factory.h" | #include "backend/session/session_factory.h" | ||||
| #include "backend/optimizer/pass/const_input_to_attr_registry.h" | |||||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| #include "pipeline/jit/action.h" | #include "pipeline/jit/action.h" | ||||
| @@ -807,21 +807,13 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, | |||||
| } | } | ||||
| } | } | ||||
| // get output dynamic shape info | // get output dynamic shape info | ||||
| auto py_abstract = op_exec_info->abstract; | |||||
| MS_EXCEPTION_IF_NULL(py_abstract); | |||||
| auto py_shape = py_abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(py_shape); | |||||
| auto py_shape_info = py_shape->ToString(); | |||||
| if (py_shape_info.find("-1") != string::npos) { | |||||
| auto c_abstract = abstract::CppInferShape(prim, args_spec_list); | |||||
| MS_EXCEPTION_IF_NULL(c_abstract); | |||||
| auto c_shape = c_abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(c_shape); | |||||
| auto c_shape_info = c_shape->ToString(); | |||||
| MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info; | |||||
| if (c_shape_info.find("-1") != string::npos) { | |||||
| op_exec_info->is_dynamic_shape = true; | |||||
| } | |||||
| auto abstract = op_exec_info->abstract; | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| auto shape = abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| auto shape_info = shape->ToString(); | |||||
| if (shape_info.find("-1") != string::npos) { | |||||
| op_exec_info->is_dynamic_shape = true; | |||||
| } | } | ||||
| } | } | ||||
| @@ -123,7 +123,7 @@ void DynamicKernel::InferShape() { | |||||
| } | } | ||||
| } | } | ||||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | |||||
| auto eval_result = opt::CppInferShape(primitive, args_spec_list); | |||||
| cnode_ptr_->set_abstract(eval_result); | cnode_ptr_->set_abstract(eval_result); | ||||
| } | } | ||||
| @@ -1041,6 +1041,9 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string &op_name = primitive->name(); | const std::string &op_name = primitive->name(); | ||||
| if (args_spec_list.size() == 1) { | |||||
| return args_spec_list[0]->Broaden(); | |||||
| } | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | CheckArgsSize(op_name, args_spec_list, 3); | ||||
| AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | ||||
| @@ -292,24 +292,6 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi | |||||
| return std::make_shared<AbstractTuple>(rets); | return std::make_shared<AbstractTuple>(rets); | ||||
| } | } | ||||
| AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[1]); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[3]); | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 5); | |||||
| auto dx = args_spec_list[1]->Broaden(); | |||||
| auto dscale = args_spec_list[2]->Broaden(); | |||||
| auto dbias = args_spec_list[3]->Broaden(); | |||||
| auto reserve_1 = args_spec_list[4]->Broaden(); | |||||
| auto reserve_2 = args_spec_list[5]->Broaden(); | |||||
| AbstractBasePtrList rets = {dx, dscale, dbias, reserve_1, reserve_2}; | |||||
| return std::make_shared<AbstractTuple>(rets); | |||||
| } | |||||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: two tensors(y_backprop, x). | // Inputs: two tensors(y_backprop, x). | ||||
| @@ -468,20 +450,6 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| return std::make_shared<AbstractTensor>(x_type, output_shape_ptr); | return std::make_shared<AbstractTensor>(x_type, output_shape_ptr); | ||||
| } | } | ||||
| AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: three tensors(doutput, input, filters). | |||||
| CheckRequiredArgsSize(primitive->name(), args_spec_list, 3); | |||||
| return args_spec_list[1]->Broaden(); | |||||
| } | |||||
| AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: three tensors(inputs, filter, doutput). | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 3); | |||||
| return args_spec_list[2]->Broaden(); | |||||
| } | |||||
| AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| @@ -17,6 +17,11 @@ | |||||
| */ | */ | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "abstract/abstract_function.h" | #include "abstract/abstract_function.h" | ||||
| #include "abstract/infer_functions.h" | #include "abstract/infer_functions.h" | ||||
| @@ -59,40 +64,21 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | ||||
| {prim::kPrimIsConsant, {InferImplIsConstant, true}}, | {prim::kPrimIsConsant, {InferImplIsConstant, true}}, | ||||
| // Maths | // Maths | ||||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||||
| {prim::kPrimMul, {InferImplMul, true}}, | |||||
| {prim::kPrimAdd, {InferImplAdd, true}}, | |||||
| {prim::kPrimSquare, {InferImplSquare, true}}, | {prim::kPrimSquare, {InferImplSquare, true}}, | ||||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||||
| {prim::kPrimSub, {InferImplSub, true}}, | |||||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||||
| {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | |||||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||||
| {prim::kPrimAddN, {InferImplAddN, true}}, | |||||
| {prim::kPrimMatMul, {InferImplMatMul, true}}, | {prim::kPrimMatMul, {InferImplMatMul, true}}, | ||||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, | {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, | ||||
| {prim::kPrimLess, {InferImplLess, true}}, | |||||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||||
| // Array | // Array | ||||
| {prim::kPrimRange, {InferImplRange, true}}, | |||||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | ||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | ||||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | ||||
| {prim::kPrimStack, {InferImplStack, true}}, | |||||
| {prim::kPrimPad, {InferImplPad, true}}, | |||||
| {prim::kPrimUnique, {InferImplUnique, true}}, | {prim::kPrimUnique, {InferImplUnique, true}}, | ||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | ||||
| {prim::kPrimGather, {InferImplGatherV2, true}}, | {prim::kPrimGather, {InferImplGatherV2, true}}, | ||||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | ||||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | ||||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | ||||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | ||||
| @@ -104,18 +90,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | ||||
| {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, | {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, | ||||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, | {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, | ||||
| {prim::kPrimDiv, {InferImplDiv, true}}, | |||||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||||
| {prim::kPrimShape, {InferImplShape, false}}, | |||||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | ||||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | ||||
| {prim::kPrimSplit, {InferImplSplit, true}}, | {prim::kPrimSplit, {InferImplSplit, true}}, | ||||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | ||||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||||
| {prim::kPrimRange, {InferImplRange, true}}, | |||||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, | |||||
| // Structure | // Structure | ||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | ||||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | {prim::kPrimMakeList, {InferImplMakeList, true}}, | ||||
| @@ -139,14 +117,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimPooling, {InferImplPooling, true}}, | {prim::kPrimPooling, {InferImplPooling, true}}, | ||||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | ||||
| {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | ||||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | |||||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | ||||
| {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, | {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, | ||||
| {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, | |||||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | ||||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | {prim::kPrimConv2D, {InferImplConv2D, true}}, | ||||
| {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, | |||||
| {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, | |||||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, | {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, | ||||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | ||||
| {prim::kPrimRelu, {InferImplRelu, true}}, | {prim::kPrimRelu, {InferImplRelu, true}}, | ||||
| @@ -192,18 +166,60 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | ||||
| {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, | {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, | ||||
| // Comm Ops | // Comm Ops | ||||
| {prim::kPrimAllReduce, {InferImplAllReduce, true}}, | |||||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | |||||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||||
| {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | ||||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | ||||
| }; | |||||
| return prim_eval_implement_map; | |||||
| } | |||||
| PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { | |||||
| static PrimitiveEvalImplMap prim_backend_eval_implement_map = { | |||||
| {prim::kPrimMul, {InferImplMul, true}}, | |||||
| {prim::kPrimAdd, {InferImplAdd, true}}, | |||||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||||
| {prim::kPrimSub, {InferImplSub, true}}, | |||||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||||
| {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||||
| {prim::kPrimCast, {InferImplCast, true}}, | {prim::kPrimCast, {InferImplCast, true}}, | ||||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | ||||
| {prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}}, | |||||
| {prim::kPrimDType, {InferImplDType, true}}, | |||||
| {prim::kPrimAllReduce, {InferImplAllReduce, true}}, | |||||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | |||||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | |||||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||||
| {prim::kPrimAddN, {InferImplAddN, true}}, | |||||
| {prim::kPrimLess, {InferImplLess, true}}, | |||||
| {prim::kPrimStack, {InferImplStack, true}}, | |||||
| {prim::kPrimPad, {InferImplPad, true}}, | |||||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||||
| {prim::kPrimDiv, {InferImplDiv, true}}, | |||||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||||
| {prim::kPrimShape, {InferImplShape, false}}, | |||||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, | |||||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | |||||
| }; | }; | ||||
| return prim_eval_implement_map; | |||||
| return prim_backend_eval_implement_map; | |||||
| } | |||||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto iter = GetPrimitiveToEvalImplMap().find(primitive); | |||||
| if (iter == GetPrimitiveToEvalImplMap().end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return iter->second.impl_; | |||||
| } | } | ||||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { | void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { | ||||
| @@ -18,6 +18,7 @@ | |||||
| #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | ||||
| #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | |||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| @@ -37,6 +38,10 @@ using PrimitiveEvalImplMap = | |||||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); | PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); | ||||
| PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap(); | |||||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | |||||
| std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode); | std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode); | ||||
| void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); | void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); | ||||
| @@ -104,6 +104,5 @@ AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolFusion, prim::kPrimMaxPool, MaxPoolFusionInfer); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,8 +31,6 @@ AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| auto element = tensor_type->element(); | auto element = tensor_type->element(); | ||||
| return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape); | return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGrad, prim::kPrimAvgPoolGrad, AvgPoolGradInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); | REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -58,8 +58,6 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| return std::make_shared<abstract::AbstractTensor>(intype, inshape); | return std::make_shared<abstract::AbstractTensor>(intype, inshape); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); | REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,8 +31,6 @@ AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| auto element = tensor_type->element(); | auto element = tensor_type->element(); | ||||
| return std::make_shared<abstract::AbstractTensor>(element, x1_shape); | return std::make_shared<abstract::AbstractTensor>(element, x1_shape); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGrad, prim::kPrimMaxPoolGrad, MaxPoolGradInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); | REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -102,7 +102,6 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameLRN, LRN); | REGISTER_PRIMITIVE_C(kNameLRN, LRN); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| #include "backend/optimizer/common/const_input_to_attr_registry.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "common/common_test.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| constexpr auto kAttrConvertTestName = "attr_convert_test"; | |||||
| constexpr auto kDynamicInputTestName = "dynamic_input_test"; | |||||
| inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAttrConvertTestName); | |||||
| inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test"); | |||||
| AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| EXPECT_EQ(args_spec_list.size(), 3); | |||||
| EXPECT_NE(args_spec_list[1], nullptr); | |||||
| EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); | |||||
| return args_spec_list[0]; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest); | |||||
| AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| EXPECT_EQ(args_spec_list.size(), 3); | |||||
| EXPECT_NE(args_spec_list[1], nullptr); | |||||
| EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); | |||||
| auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>(); | |||||
| return args_spec_list[0]; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest); | |||||
| class TestAttrAndDynamicBackendInfer : public UT::Common { | |||||
| public: | |||||
| TestAttrAndDynamicBackendInfer() {} | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(TestAttrAndDynamicBackendInfer, test_attr_and_dynamic_input_infer) { | |||||
| // Register Attr for ut | |||||
| ConstInputToAttrInfoRegistry ® = ConstInputToAttrInfoRegistry::Instance(); | |||||
| reg.Register(kAttrConvertTestName, {1}); | |||||
| // construct primitive | |||||
| PrimitivePtr prim_attr_test = std::make_shared<Primitive>(kAttrConvertTestName); | |||||
| PrimitivePtr prim_dynamic_input_test = std::make_shared<Primitive>(kDynamicInputTestName); | |||||
| // set primtive attr | |||||
| auto input_names = std::vector<std::string>{"a", "b", "c"}; | |||||
| auto attr_name = "b"; | |||||
| auto attr = MakeValue(std::vector<int>{1, 2, 3}); | |||||
| prim_attr_test->AddAttr(kAttrInputNames, MakeValue(input_names)); | |||||
| prim_attr_test->AddAttr(attr_name, attr); | |||||
| // set dynameic input list for primtive | |||||
| std::vector<int64_t> dynamic_input_list = {-1, 2, -1}; | |||||
| prim_dynamic_input_test->AddAttr(kAttrDynInputSizes, MakeValue(dynamic_input_list)); | |||||
| // construct Abstract list | |||||
| auto abs_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||||
| auto abs_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||||
| auto attr_infer_result = CppInferShape(prim_attr_test, {abs_a, abs_c}); | |||||
| auto abs_dynamic_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||||
| auto abs_dynamic_b = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||||
| auto abs_dynamic_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||||
| auto abs_dynamic_d = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2}); | |||||
| auto dynamic_infer_result = | |||||
| CppInferShape(prim_dynamic_input_test, {abs_dynamic_a, abs_dynamic_b, abs_dynamic_c, abs_dynamic_d}); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||