| @@ -42,6 +42,7 @@ using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher | |||||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | ||||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | ||||
| const char FUNC_GRAPH_FLAG_CORE[] = "core"; | const char FUNC_GRAPH_FLAG_CORE[] = "core"; | ||||
| const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; | |||||
| // ANF transform class | // ANF transform class | ||||
| // either a primitive or a func_graph | // either a primitive or a func_graph | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <sstream> | #include <sstream> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | |||||
| #include "pipeline/static_analysis/abstract_value.h" | #include "pipeline/static_analysis/abstract_value.h" | ||||
| #include "pipeline/static_analysis/abstract_function.h" | #include "pipeline/static_analysis/abstract_function.h" | ||||
| #include "pipeline/static_analysis/dshape.h" | #include "pipeline/static_analysis/dshape.h" | ||||
| @@ -334,6 +335,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL | |||||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | ||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | ||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | ||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->debug_info()->set_name("hyper_map"); | ptrGraph->debug_info()->set_name("hyper_map"); | ||||
| AnfNodePtr ptrFnArg = nullptr; | AnfNodePtr ptrFnArg = nullptr; | ||||
| @@ -278,10 +278,12 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { | |||||
| // Convert class to Tuple | // Convert class to Tuple | ||||
| // Convert getattr to getitem | // Convert getattr to getitem | ||||
| // Convert make_record to make_tuple | // Convert make_record to make_tuple | ||||
| void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||||
| bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| manager->AddFuncGraph(root); | manager->AddFuncGraph(root); | ||||
| bool changed = false; | |||||
| // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | ||||
| AnfNodeSet all_node = manager->all_nodes(); | AnfNodeSet all_node = manager->all_nodes(); | ||||
| for (auto &node : all_node) { | for (auto &node : all_node) { | ||||
| @@ -316,7 +318,9 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||||
| if (new_node != nullptr) { | if (new_node != nullptr) { | ||||
| new_node->set_abstract(node->abstract()); | new_node->set_abstract(node->abstract()); | ||||
| MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); | |||||
| (void)manager->Replace(node, new_node); | (void)manager->Replace(node, new_node); | ||||
| changed = true; | |||||
| } | } | ||||
| } | } | ||||
| @@ -324,6 +328,7 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||||
| auto ret = Reabs(node->abstract()); | auto ret = Reabs(node->abstract()); | ||||
| node->set_abstract(ret); | node->set_abstract(ret); | ||||
| } | } | ||||
| return changed; | |||||
| } | } | ||||
| // expand tuples in graph parameters | // expand tuples in graph parameters | ||||
| @@ -31,7 +31,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| // Remove the class type from graphs | // Remove the class type from graphs | ||||
| void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); | |||||
| bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); | |||||
| // Remove most uses of tuples from the graph | // Remove most uses of tuples from the graph | ||||
| // tuples that are returned will be kept | // tuples that are returned will be kept | ||||
| @@ -38,13 +38,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod | |||||
| // src type check | // src type check | ||||
| auto src_type = src_->Type(); | auto src_type = src_->Type(); | ||||
| if (src_type == nullptr) { | |||||
| if (src_type == nullptr || !src_type->isa<TensorType>()) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (src_type->isa<TensorType>()) { | |||||
| src_type = src_type->cast<TensorTypePtr>()->element(); | |||||
| } | |||||
| src_type = src_type->cast<TensorTypePtr>()->element(); | |||||
| // tgt type check | // tgt type check | ||||
| auto tgt_type = GetValueNode<TypePtr>(tgt_); | auto tgt_type = GetValueNode<TypePtr>(tgt_); | ||||
| @@ -52,14 +52,16 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { | |||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | MS_EXCEPTION_IF_NULL(res->func_graph()); | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| opt::SimplifyDataStructures(func_graph, res->manager()); | |||||
| bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); | |||||
| abstract::AbstractBasePtrList args_spec; | abstract::AbstractBasePtrList args_spec; | ||||
| auto parameters = func_graph->parameters(); | auto parameters = func_graph->parameters(); | ||||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | ||||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | ||||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||||
| res->set_func_graph(new_fg); | |||||
| if (changed) { | |||||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||||
| res->set_func_graph(new_fg); | |||||
| } | |||||
| res->set_args_spec(args_spec); | res->set_args_spec(args_spec); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -177,8 +177,8 @@ std::size_t FuncGraphAbstractClosure::hash() const { | |||||
| std::string FuncGraphAbstractClosure::ToString() const { | std::string FuncGraphAbstractClosure::ToString() const { | ||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << "FuncGraphAbstractClosure: " << this << "FuncGraph: " << func_graph_.get() << ", " << func_graph_->ToString() | |||||
| << "; Context: " << context_.get() << context_->ToString(); | |||||
| ss << "FuncGraphAbstractClosure: " | |||||
| << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); | |||||
| return ss.str(); | return ss.str(); | ||||
| } | } | ||||
| @@ -166,8 +166,9 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||||
| public: | public: | ||||
| // Represents a partial application. | // Represents a partial application. | ||||
| // args_spec_list: The first few arguments of that function | // args_spec_list: The first few arguments of that function | ||||
| PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list) | |||||
| : fn_(fn), args_spec_list_(args_spec_list) {} | |||||
| PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, | |||||
| const AnfNodePtr &node = nullptr) | |||||
| : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} | |||||
| ~PartialAbstractClosure() override = default; | ~PartialAbstractClosure() override = default; | ||||
| MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) | MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) | ||||
| @@ -175,7 +176,11 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||||
| AbstractFunctionPtr fn() { return fn_; } | AbstractFunctionPtr fn() { return fn_; } | ||||
| AbstractBasePtrList args() { return args_spec_list_; } | AbstractBasePtrList args() { return args_spec_list_; } | ||||
| AbstractFunctionPtr Copy() const override { return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_); } | |||||
| AnfNodePtr node() { return node_.lock(); } | |||||
| void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } | |||||
| AbstractFunctionPtr Copy() const override { | |||||
| return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_, node_.lock()); | |||||
| } | |||||
| bool operator==(const AbstractFunction &other) const override; | bool operator==(const AbstractFunction &other) const override; | ||||
| std::size_t hash() const override; | std::size_t hash() const override; | ||||
| @@ -184,6 +189,8 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||||
| private: | private: | ||||
| AbstractFuncAtomPtr fn_; | AbstractFuncAtomPtr fn_; | ||||
| AbstractBasePtrList args_spec_list_; | AbstractBasePtrList args_spec_list_; | ||||
| // The CNode which this PartialAbstractClosure evaluated from. | |||||
| AnfNodeWeakPtr node_; | |||||
| }; | }; | ||||
| class JTransformedAbstractClosure : public AbstractFuncAtom { | class JTransformedAbstractClosure : public AbstractFuncAtom { | ||||
| @@ -951,8 +951,19 @@ class PartialEvaluator : public Evaluator { | |||||
| if (args_conf_list.size() == 0) { | if (args_conf_list.size() == 0) { | ||||
| MS_LOG(EXCEPTION) << "Args size should be greater than 0"; | MS_LOG(EXCEPTION) << "Args size should be greater than 0"; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(out_conf); | |||||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | |||||
| auto arg0_value = args_conf_list[0]->GetEvaluatedValue(); | auto arg0_value = args_conf_list[0]->GetEvaluatedValue(); | ||||
| AbstractBasePtrList args_spec_list{arg0_value}; | AbstractBasePtrList args_spec_list{arg0_value}; | ||||
| // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. | |||||
| if (arg0_value->isa<AbstractError>()) { | |||||
| auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node()); | |||||
| MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() | |||||
| << " as func is: " << arg0_value->ToString(); | |||||
| (*cache_)[args_spec_list] = ret; | |||||
| return ret; | |||||
| } | |||||
| auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); | auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); | ||||
| // Sometimes, node[0] in out_conf becomes phi0; | // Sometimes, node[0] in out_conf becomes phi0; | ||||
| if (func->isa<PrimitiveAbstractClosure>()) { | if (func->isa<PrimitiveAbstractClosure>()) { | ||||
| @@ -962,19 +973,26 @@ class PartialEvaluator : public Evaluator { | |||||
| return HandleDoSignature(engine, do_signature_prim->function(), out_conf); | return HandleDoSignature(engine, do_signature_prim->function(), out_conf); | ||||
| } | } | ||||
| } | } | ||||
| (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||||
| (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||||
| [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); }); | |||||
| AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); | AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); | ||||
| AbstractFuncAtomPtrList partialPtrList; | |||||
| auto build_partial = [args, &partialPtrList](const AbstractFuncAtomPtr &atom_func) { | |||||
| auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args); | |||||
| partialPtrList.push_back(new_func); | |||||
| auto cnode = out_conf->node()->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->size() != (args_conf_list.size() + 1)) { | |||||
| MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() | |||||
| << ", args_conf_list: " << mindspore::ToString(args_conf_list); | |||||
| } | |||||
| AbstractFuncAtomPtrList partial_funcs_list; | |||||
| auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { | |||||
| auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode); | |||||
| partial_funcs_list.push_back(new_func); | |||||
| }; | }; | ||||
| func->Visit(build_partial); | func->Visit(build_partial); | ||||
| auto ret = AbstractFunction::MakeAbstractFunction(partialPtrList); | |||||
| auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); | |||||
| (*cache_)[args_spec_list] = ret; | (*cache_)[args_spec_list] = ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -23,7 +23,9 @@ | |||||
| #include "./common.h" | #include "./common.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "operator/composite/do_signature.h" | #include "operator/composite/do_signature.h" | ||||
| #include "pipeline/static_analysis/abstract_function.h" | |||||
| #include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
| #include "utils/log_adapter.h" | |||||
| #include "utils/profile.h" | #include "utils/profile.h" | ||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| @@ -232,6 +234,13 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||||
| return; | return; | ||||
| } | } | ||||
| new_node->set_abstract(GetEvaluatedValueWrap(conf)); | new_node->set_abstract(GetEvaluatedValueWrap(conf)); | ||||
| if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) { | |||||
| auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract()); | |||||
| if (partial_abstract->node() == node) { | |||||
| partial_abstract->set_node(new_node); | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); | MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| @@ -383,6 +392,56 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr | |||||
| return BuildValueNode(v, abs); | return BuildValueNode(v, abs); | ||||
| } | } | ||||
| AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { | |||||
| auto new_inputs = new_node->inputs(); | |||||
| AnfNodePtr func = new_inputs[0]; | |||||
| AbstractBasePtr fnval = new_inputs[0]->abstract(); | |||||
| AbstractBasePtrList args; | |||||
| auto backed_fnval = fnval; | |||||
| if (fnval->isa<PartialAbstractClosure>()) { | |||||
| auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval); | |||||
| backed_fnval = partial_closure->fn(); | |||||
| args = partial_closure->args(); | |||||
| } | |||||
| std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), | |||||
| [](const AnfNodePtr &inp) { return inp->abstract(); }); | |||||
| ScopeGuard scope_guard(new_node->scope()); | |||||
| auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); | |||||
| auto wrapped_node = specialized_node; | |||||
| if (fnval->isa<PartialAbstractClosure>()) { | |||||
| auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval); | |||||
| AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), | |||||
| specialized_node}; | |||||
| auto anf_node = partial_closure->node(); | |||||
| if (!anf_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); | |||||
| } | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if (cnode->size() != partial_closure->args().size() + 2) { | |||||
| MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() | |||||
| << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); | |||||
| } | |||||
| for (size_t i = 0; i < partial_closure->args().size(); i++) { | |||||
| auto old_node = cnode->input(i + 2); | |||||
| auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]); | |||||
| if (possibile_value_node != nullptr) { | |||||
| partial_node_list.push_back(possibile_value_node); | |||||
| } else { | |||||
| if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) { | |||||
| MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); | |||||
| } | |||||
| partial_node_list.push_back(old_node); | |||||
| } | |||||
| } | |||||
| wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); | |||||
| wrapped_node->set_abstract(partial_closure); | |||||
| } | |||||
| return wrapped_node; | |||||
| } | |||||
| const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { | const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { | ||||
| auto cache_iter = evalcaches_.find(eval); | auto cache_iter = evalcaches_.find(eval); | ||||
| if (cache_iter == evalcaches_.end()) { | if (cache_iter == evalcaches_.end()) { | ||||
| @@ -465,6 +524,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||||
| << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); | << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); | ||||
| } | } | ||||
| if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) { | |||||
| auto wrapped_node = BuildSpecializedParameterNode(new_node); | |||||
| new_inputs[0] = wrapped_node; | |||||
| } | |||||
| if (CanSpecializeNode(func)) { | if (CanSpecializeNode(func)) { | ||||
| new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); | new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); | ||||
| } | } | ||||
| @@ -474,16 +538,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||||
| if (CanSpecializeNode(args[i])) { | if (CanSpecializeNode(args[i])) { | ||||
| new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{}); | new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{}); | ||||
| } | } | ||||
| // support for partial(Multitype) which Multitype should not be inferred to POLY. | |||||
| // after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only, | |||||
| // so even with partial parameter, it will specialize to that graph. | |||||
| // Maybe a better idea should inline graph with partial node first, then it will have full | |||||
| // parameter list to infer and specialize. | |||||
| MS_EXCEPTION_IF_NULL(new_inputs[next]); | |||||
| if (new_inputs[next]->isa<ValueNode>() && (GetValueNode(new_inputs[next]) == kPolyNode) && | |||||
| IsPrimitive(func, prim::kPrimPartial)) { | |||||
| new_inputs[next] = args[i]; | |||||
| } | |||||
| i = next; | i = next; | ||||
| } | } | ||||
| @@ -106,6 +106,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia | |||||
| // (disconnected). | // (disconnected). | ||||
| AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); | AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); | ||||
| // Build a value node from parameter if the function graph has special flag to hint it can be done. | |||||
| AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); | |||||
| // Build a value node if ival is constant and not any-value | // Build a value node if ival is constant and not any-value | ||||
| AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival); | AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival); | ||||
| // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a | // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a | ||||
| @@ -87,11 +87,6 @@ class CumSumNet(nn.Cell): | |||||
| raise_set = [ | raise_set = [ | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('TensorAdd0', { | |||||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('TensorAdd1', { | ('TensorAdd1', { | ||||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | ||||
| @@ -271,11 +266,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Sub0', { | |||||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('Sub1', { | ('Sub1', { | ||||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | ||||
| @@ -287,11 +277,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Mul0', { | |||||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('Mul1', { | ('Mul1', { | ||||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | ||||
| @@ -352,11 +337,6 @@ raise_set = [ | |||||
| 'desc_inputs': [5.0], | 'desc_inputs': [5.0], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Minimum0', { | |||||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('Minimum1', { | ('Minimum1', { | ||||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | ||||
| @@ -368,11 +348,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Maximum0', { | |||||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('Maximum1', { | ('Maximum1', { | ||||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | ||||
| @@ -384,11 +359,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('RealDiv0', { | |||||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('RealDiv1', { | ('RealDiv1', { | ||||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | ||||
| @@ -400,11 +370,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Div0', { | |||||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('Div1', { | ('Div1', { | ||||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | ||||
| @@ -416,11 +381,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('FloorDiv0', { | |||||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('FloorDiv1', { | ('FloorDiv1', { | ||||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | ||||
| @@ -439,11 +399,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], | 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('FloorMod0', { | |||||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('FloorMod1', { | ('FloorMod1', { | ||||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | ||||
| @@ -462,11 +417,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # input is not tensor | |||||
| ('Equal0', { | |||||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | # type of x and y not match | ||||
| ('Equal1', { | ('Equal1', { | ||||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | ||||
| @@ -490,11 +440,6 @@ raise_set = [ | |||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # shape of x and y not match | # shape of x and y not match | ||||
| # input is not tensor | |||||
| ('NotEqual0', { | |||||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | # type of x and y not match | ||||
| ('NotEqual1', { | ('NotEqual1', { | ||||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | ||||
| @@ -506,11 +451,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # input is not tensor | |||||
| ('Greater0', { | |||||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | # type of x and y not match | ||||
| ('Greater1', { | ('Greater1', { | ||||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | ||||
| @@ -522,11 +462,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # input is not tensor | |||||
| ('GreaterEqual0', { | |||||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | # type of x and y not match | ||||
| ('GreaterEqual1', { | ('GreaterEqual1', { | ||||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | ||||
| @@ -538,11 +473,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # input is not tensor | |||||
| ('Less0', { | |||||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | # type of x and y not match | ||||
| ('Less1', { | ('Less1', { | ||||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | ||||
| @@ -554,11 +484,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # input is not tensor | |||||
| ('LessEqual0', { | |||||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | # type of x and y not match | ||||
| ('LessEqual1', { | ('LessEqual1', { | ||||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | ||||
| @@ -728,11 +653,6 @@ raise_set = [ | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Atan20', { | |||||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | # input two tensors, but element types are not same | ||||
| ('Atan21', { | ('Atan21', { | ||||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | ||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test_hypermap_partial """ | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, context | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common.api import ms_function | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def test_hypermap_specialize_param(): | |||||
| class Net(nn.Cell): | |||||
| """ Net definition """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.mul = P.Mul() | |||||
| def construct(self, x, y): | |||||
| ret = self.mul(x, y) | |||||
| return ret | |||||
| factor1 = Tensor(5, dtype=mstype.int32) | |||||
| x = Tensor(np.ones([1]).astype(np.int32)) | |||||
| y = Tensor(np.ones([2]).astype(np.int32)) | |||||
| net = Net() | |||||
| hypermap = C.HyperMap() | |||||
| @ms_function | |||||
| def hypermap_specialize_param(): | |||||
| ret1 = hypermap(F.partial(net, factor1), (x, y)) | |||||
| # List will be converted to Tuple in SimlifyDataStructurePass. | |||||
| ret2 = hypermap(F.partial(net, factor1), [x, y]) | |||||
| return ret1, ret2 | |||||
| expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) | |||||
| ret = hypermap_specialize_param() | |||||
| assert(ret == (expected_ret, expected_ret)) | |||||