| @@ -42,6 +42,7 @@ using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher | |||
| const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; | |||
| const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; | |||
| const char FUNC_GRAPH_FLAG_CORE[] = "core"; | |||
| const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; | |||
| // ANF transform class | |||
| // either a primitive or a func_graph | |||
| @@ -23,6 +23,7 @@ | |||
| #include <sstream> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pipeline/static_analysis/abstract_function.h" | |||
| #include "pipeline/static_analysis/dshape.h" | |||
| @@ -334,6 +335,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL | |||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->debug_info()->set_name("hyper_map"); | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| @@ -278,10 +278,12 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { | |||
| // Convert class to Tuple | |||
| // Convert getattr to getitem | |||
| // Convert make_record to make_tuple | |||
| void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||
| bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| bool changed = false; | |||
| // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | |||
| AnfNodeSet all_node = manager->all_nodes(); | |||
| for (auto &node : all_node) { | |||
| @@ -316,7 +318,9 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||
| if (new_node != nullptr) { | |||
| new_node->set_abstract(node->abstract()); | |||
| MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); | |||
| (void)manager->Replace(node, new_node); | |||
| changed = true; | |||
| } | |||
| } | |||
| @@ -324,6 +328,7 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||
| auto ret = Reabs(node->abstract()); | |||
| node->set_abstract(ret); | |||
| } | |||
| return changed; | |||
| } | |||
| // expand tuples in graph parameters | |||
| @@ -31,7 +31,7 @@ namespace mindspore { | |||
| namespace opt { | |||
| // Remove the class type from graphs | |||
| void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); | |||
| bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); | |||
| // Remove most uses of tuples from the graph | |||
| // tuples that are returned will be kept | |||
| @@ -38,13 +38,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod | |||
| // src type check | |||
| auto src_type = src_->Type(); | |||
| if (src_type == nullptr) { | |||
| if (src_type == nullptr || !src_type->isa<TensorType>()) { | |||
| return nullptr; | |||
| } | |||
| if (src_type->isa<TensorType>()) { | |||
| src_type = src_type->cast<TensorTypePtr>()->element(); | |||
| } | |||
| src_type = src_type->cast<TensorTypePtr>()->element(); | |||
| // tgt type check | |||
| auto tgt_type = GetValueNode<TypePtr>(tgt_); | |||
| @@ -52,14 +52,16 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| opt::SimplifyDataStructures(func_graph, res->manager()); | |||
| bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| if (changed) { | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| } | |||
| res->set_args_spec(args_spec); | |||
| return true; | |||
| } | |||
| @@ -177,8 +177,8 @@ std::size_t FuncGraphAbstractClosure::hash() const { | |||
| std::string FuncGraphAbstractClosure::ToString() const { | |||
| std::stringstream ss; | |||
| ss << "FuncGraphAbstractClosure: " << this << "FuncGraph: " << func_graph_.get() << ", " << func_graph_->ToString() | |||
| << "; Context: " << context_.get() << context_->ToString(); | |||
| ss << "FuncGraphAbstractClosure: " | |||
| << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); | |||
| return ss.str(); | |||
| } | |||
| @@ -166,8 +166,9 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||
| public: | |||
| // Represents a partial application. | |||
| // args_spec_list: The first few arguments of that function | |||
| PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list) | |||
| : fn_(fn), args_spec_list_(args_spec_list) {} | |||
| PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, | |||
| const AnfNodePtr &node = nullptr) | |||
| : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} | |||
| ~PartialAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) | |||
| @@ -175,7 +176,11 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||
| AbstractFunctionPtr fn() { return fn_; } | |||
| AbstractBasePtrList args() { return args_spec_list_; } | |||
| AbstractFunctionPtr Copy() const override { return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_); } | |||
| AnfNodePtr node() { return node_.lock(); } | |||
| void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } | |||
| AbstractFunctionPtr Copy() const override { | |||
| return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_, node_.lock()); | |||
| } | |||
| bool operator==(const AbstractFunction &other) const override; | |||
| std::size_t hash() const override; | |||
| @@ -184,6 +189,8 @@ class PartialAbstractClosure : public AbstractFuncAtom { | |||
| private: | |||
| AbstractFuncAtomPtr fn_; | |||
| AbstractBasePtrList args_spec_list_; | |||
| // The CNode which this PartialAbstractClosure evaluated from. | |||
| AnfNodeWeakPtr node_; | |||
| }; | |||
| class JTransformedAbstractClosure : public AbstractFuncAtom { | |||
| @@ -951,8 +951,19 @@ class PartialEvaluator : public Evaluator { | |||
| if (args_conf_list.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "Args size should be greater than 0"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | |||
| auto arg0_value = args_conf_list[0]->GetEvaluatedValue(); | |||
| AbstractBasePtrList args_spec_list{arg0_value}; | |||
| // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. | |||
| if (arg0_value->isa<AbstractError>()) { | |||
| auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node()); | |||
| MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() | |||
| << " as func is: " << arg0_value->ToString(); | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| } | |||
| auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); | |||
| // Sometimes, node[0] in out_conf becomes phi0; | |||
| if (func->isa<PrimitiveAbstractClosure>()) { | |||
| @@ -962,19 +973,26 @@ class PartialEvaluator : public Evaluator { | |||
| return HandleDoSignature(engine, do_signature_prim->function(), out_conf); | |||
| } | |||
| } | |||
| (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||
| (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); }); | |||
| AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); | |||
| AbstractFuncAtomPtrList partialPtrList; | |||
| auto build_partial = [args, &partialPtrList](const AbstractFuncAtomPtr &atom_func) { | |||
| auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args); | |||
| partialPtrList.push_back(new_func); | |||
| auto cnode = out_conf->node()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() != (args_conf_list.size() + 1)) { | |||
| MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() | |||
| << ", args_conf_list: " << mindspore::ToString(args_conf_list); | |||
| } | |||
| AbstractFuncAtomPtrList partial_funcs_list; | |||
| auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { | |||
| auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode); | |||
| partial_funcs_list.push_back(new_func); | |||
| }; | |||
| func->Visit(build_partial); | |||
| auto ret = AbstractFunction::MakeAbstractFunction(partialPtrList); | |||
| auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); | |||
| (*cache_)[args_spec_list] = ret; | |||
| return ret; | |||
| } | |||
| @@ -23,7 +23,9 @@ | |||
| #include "./common.h" | |||
| #include "operator/ops.h" | |||
| #include "operator/composite/do_signature.h" | |||
| #include "pipeline/static_analysis/abstract_function.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/profile.h" | |||
| #include "debug/trace.h" | |||
| @@ -232,6 +234,13 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { | |||
| return; | |||
| } | |||
| new_node->set_abstract(GetEvaluatedValueWrap(conf)); | |||
| if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) { | |||
| auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract()); | |||
| if (partial_abstract->node() == node) { | |||
| partial_abstract->set_node(new_node); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); | |||
| if (node->isa<CNode>()) { | |||
| @@ -383,6 +392,56 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr | |||
| return BuildValueNode(v, abs); | |||
| } | |||
| AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { | |||
| auto new_inputs = new_node->inputs(); | |||
| AnfNodePtr func = new_inputs[0]; | |||
| AbstractBasePtr fnval = new_inputs[0]->abstract(); | |||
| AbstractBasePtrList args; | |||
| auto backed_fnval = fnval; | |||
| if (fnval->isa<PartialAbstractClosure>()) { | |||
| auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval); | |||
| backed_fnval = partial_closure->fn(); | |||
| args = partial_closure->args(); | |||
| } | |||
| std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), | |||
| [](const AnfNodePtr &inp) { return inp->abstract(); }); | |||
| ScopeGuard scope_guard(new_node->scope()); | |||
| auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); | |||
| auto wrapped_node = specialized_node; | |||
| if (fnval->isa<PartialAbstractClosure>()) { | |||
| auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval); | |||
| AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), | |||
| specialized_node}; | |||
| auto anf_node = partial_closure->node(); | |||
| if (!anf_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if (cnode->size() != partial_closure->args().size() + 2) { | |||
| MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() | |||
| << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); | |||
| } | |||
| for (size_t i = 0; i < partial_closure->args().size(); i++) { | |||
| auto old_node = cnode->input(i + 2); | |||
| auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]); | |||
| if (possibile_value_node != nullptr) { | |||
| partial_node_list.push_back(possibile_value_node); | |||
| } else { | |||
| if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) { | |||
| MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); | |||
| } | |||
| partial_node_list.push_back(old_node); | |||
| } | |||
| } | |||
| wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); | |||
| wrapped_node->set_abstract(partial_closure); | |||
| } | |||
| return wrapped_node; | |||
| } | |||
| const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { | |||
| auto cache_iter = evalcaches_.find(eval); | |||
| if (cache_iter == evalcaches_.end()) { | |||
| @@ -465,6 +524,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||
| << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); | |||
| } | |||
| if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) { | |||
| auto wrapped_node = BuildSpecializedParameterNode(new_node); | |||
| new_inputs[0] = wrapped_node; | |||
| } | |||
| if (CanSpecializeNode(func)) { | |||
| new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); | |||
| } | |||
| @@ -474,16 +538,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||
| if (CanSpecializeNode(args[i])) { | |||
| new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{}); | |||
| } | |||
| // support for partial(Multitype) which Multitype should not be inferred to POLY. | |||
| // after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only, | |||
| // so even with partial parameter, it will specialize to that graph. | |||
| // Maybe a better idea should inline graph with partial node first, then it will have full | |||
| // parameter list to infer and specialize. | |||
| MS_EXCEPTION_IF_NULL(new_inputs[next]); | |||
| if (new_inputs[next]->isa<ValueNode>() && (GetValueNode(new_inputs[next]) == kPolyNode) && | |||
| IsPrimitive(func, prim::kPrimPartial)) { | |||
| new_inputs[next] = args[i]; | |||
| } | |||
| i = next; | |||
| } | |||
| @@ -106,6 +106,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia | |||
| // (disconnected). | |||
| AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); | |||
| // Build a value node from parameter if the function graph has special flag to hint it can be done. | |||
| AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); | |||
| // Build a value node if ival is constant and not any-value | |||
| AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival); | |||
| // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a | |||
| @@ -87,11 +87,6 @@ class CumSumNet(nn.Cell): | |||
| raise_set = [ | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('TensorAdd0', { | |||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('TensorAdd1', { | |||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | |||
| @@ -271,11 +266,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Sub0', { | |||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Sub1', { | |||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | |||
| @@ -287,11 +277,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Mul0', { | |||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Mul1', { | |||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | |||
| @@ -352,11 +337,6 @@ raise_set = [ | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Minimum0', { | |||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Minimum1', { | |||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | |||
| @@ -368,11 +348,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Maximum0', { | |||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Maximum1', { | |||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | |||
| @@ -384,11 +359,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('RealDiv0', { | |||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('RealDiv1', { | |||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | |||
| @@ -400,11 +370,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Div0', { | |||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Div1', { | |||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | |||
| @@ -416,11 +381,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('FloorDiv0', { | |||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('FloorDiv1', { | |||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | |||
| @@ -439,11 +399,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('FloorMod0', { | |||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('FloorMod1', { | |||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | |||
| @@ -462,11 +417,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Equal0', { | |||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('Equal1', { | |||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | |||
| @@ -490,11 +440,6 @@ raise_set = [ | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| # input is not tensor | |||
| ('NotEqual0', { | |||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('NotEqual1', { | |||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | |||
| @@ -506,11 +451,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Greater0', { | |||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('Greater1', { | |||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | |||
| @@ -522,11 +462,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('GreaterEqual0', { | |||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('GreaterEqual1', { | |||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | |||
| @@ -538,11 +473,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Less0', { | |||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('Less1', { | |||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | |||
| @@ -554,11 +484,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('LessEqual0', { | |||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('LessEqual1', { | |||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | |||
| @@ -728,11 +653,6 @@ raise_set = [ | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Atan20', { | |||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Atan21', { | |||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_hypermap_partial """ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.api import ms_function | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_hypermap_specialize_param(): | |||
| class Net(nn.Cell): | |||
| """ Net definition """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.mul = P.Mul() | |||
| def construct(self, x, y): | |||
| ret = self.mul(x, y) | |||
| return ret | |||
| factor1 = Tensor(5, dtype=mstype.int32) | |||
| x = Tensor(np.ones([1]).astype(np.int32)) | |||
| y = Tensor(np.ones([2]).astype(np.int32)) | |||
| net = Net() | |||
| hypermap = C.HyperMap() | |||
| @ms_function | |||
| def hypermap_specialize_param(): | |||
| ret1 = hypermap(F.partial(net, factor1), (x, y)) | |||
| # List will be converted to Tuple in SimlifyDataStructurePass. | |||
| ret2 = hypermap(F.partial(net, factor1), [x, y]) | |||
| return ret1, ret2 | |||
| expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) | |||
| ret = hypermap_specialize_param() | |||
| assert(ret == (expected_ret, expected_ret)) | |||