diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 4e01e9003f..d7a6eb81f7 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -45,7 +45,8 @@ FuncGraph::FuncGraph() hyper_param_count_(0), is_generated_(false), return_(nullptr), - manager_(std::weak_ptr()) { + manager_(std::weak_ptr()), + stub_(false) { debug_info_ = std::make_shared(); } diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index b1be892a53..70e53f4828 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase { void SetEffectDepends(const std::vector &depend_inputs); bool HasEffect(const CNodePtr &cnode); + bool stub() const { return stub_; } + void set_stub(bool stub) { stub_ = stub; } + private: // graph is manipulated by manager and others friend FuncGraphManager; @@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase { // CNode order which relates to origin code order std::list order_; + bool stub_; }; inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg) { diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index 5b9d57ffa4..f720913b98 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count()); (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); (*target_func_graph)->set_is_generate(func_graph->is_generated()); + (*target_func_graph)->set_stub(func_graph->stub()); TraceManager::EndTrace(); } @@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_is_generate(func_graph->is_generated()); + new_func_graph->set_stub(func_graph->stub()); for (auto &item : func_graph->parameter_default_value()) { new_func_graph->set_param_default_value(item.first, cloner[item.second]); } diff --git a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc index de6526f642..7919ea5f4f 100644 --- a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc @@ -30,6 +30,7 @@ #include "pipeline/static_analysis/param_validator.h" #include "operator/cc_implementations.h" #include "optimizer/opt.h" +#include "utils/context/ms_context.h" #include "utils/symbolic.h" #include "pybind_api/api_register.h" #include "./common.h" @@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { } return item.second; } - // Try best match - py::function py_fn_subclass; - size_t subclass_match_cnt = 0; - for (auto &item : fn_cache_py_) { - TypePtrList sign = item.first; - if (sign.size() != types.size()) { - continue; + return py::none(); +} + +FuncGraphPtr GenerateStubFunc(const TypePtrList &types) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + + std::vector parameters; + ParameterPtr undetermined_param = nullptr; + auto stub = std::make_shared(); + for (size_t i = 0; i < types.size(); ++i) { + auto param = stub->add_parameter(); + parameters.push_back(param); + if (types[i]->type_id() == kObjectTypeUndeterminedType) { + undetermined_param = param; } - auto match = true; - for (size_t i = 0; i < sign.size(); ++i) { - if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) && - !IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) { - match = false; - break; + } + if (undetermined_param != nullptr) { + std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; + for (size_t i = 0; i < types.size(); ++i) { + if (types[i]->type_id() == kObjectTypeFunction) { + std::vector call_prim{parameters[i], undetermined_param}; + inputs.push_back(stub->NewCNode(call_prim)); + } else { + inputs.push_back(parameters[i]); } } - if (!match) { - continue; - } - py_fn_subclass = item.second; - subclass_match_cnt++; - } - if (subclass_match_cnt > 1) { - MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass"; - } - if (subclass_match_cnt == 1) { - MS_LOG(DEBUG) << "Found one subclass match"; - return py_fn_subclass; + auto stub_output = stub->NewCNode(inputs); + stub->set_output(stub_output); + stub->set_stub(true); + return stub; } - return py::none(); + return nullptr; } FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { @@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); return func_graph; } + auto stub = GenerateStubFunc(types); + if (stub != nullptr) { + MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString(); + return stub; + } std::ostringstream oss; oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ << "`, corresponding location info:\n"; diff --git a/mindspore/ccsrc/operator/prim_others.cc b/mindspore/ccsrc/operator/prim_others.cc index ff9ec712bb..c6c693b4d8 100644 --- a/mindspore/ccsrc/operator/prim_others.cc +++ b/mindspore/ccsrc/operator/prim_others.cc @@ -23,8 +23,8 @@ #include "pipeline/static_analysis/param_validator.h" #include "pipeline/static_analysis/prim.h" #include "pipeline/static_analysis/utils.h" -#include "utils/symbolic.h" #include "utils/context/ms_context.h" +#include "utils/symbolic.h" namespace mindspore { namespace abstract { @@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit return AbstractFunction::MakeAbstractFunction(jv); } -class UndeterminedShapeType { - public: - explicit UndeterminedShapeType(const std::string &env_str) { - // param_name indices_shape indices_type values_shape values_type dense_shape - // export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 - // 2:Float32:3 1 2" - std::vector fields; - string tmp; - std::stringstream input(env_str); - while (std::getline(input, tmp, ':')) { - fields.push_back(tmp); - } - if (fields.size() != fields_num) { - MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size(); - } - - param_name_ = fields[0]; - - indices_shape_ = GetShape(fields[1]); - indices_type_ = StringToType(fields[2]); - - values_shape_ = GetShape(fields[3]); - values_type_ = StringToType(fields[4]); - - auto dense_shape_vec = GetShape(fields[5]); - AbstractBasePtrList dense_shape_list; - (void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list), - [](const auto &elem) { return FromValue(elem, false); }); - dense_shape_ = dense_shape_list; - } - ~UndeterminedShapeType() = default; - const std::string ¶m_name() { return param_name_; } - const std::vector &indices_shape() { return indices_shape_; } - const TypePtr &indices_type() { return indices_type_; } - const std::vector &values_shape() { return values_shape_; } - const TypePtr &values_type() { return values_type_; } - const AbstractBasePtrList &dense_shape() { return dense_shape_; } - - private: - std::string param_name_; - std::vector indices_shape_; - TypePtr indices_type_; - std::vector values_shape_; - TypePtr values_type_; - AbstractBasePtrList dense_shape_; - static const size_t fields_num; - - std::vector GetShape(const std::string &shape_str); -}; -std::vector UndeterminedShapeType::GetShape(const std::string &shape_str) { - std::vector ret; - std::istringstream iss(shape_str); - int elem; - while (iss.good()) { - iss >> elem; - ret.emplace_back(elem); - } - return ret; -} -const size_t UndeterminedShapeType::fields_num = 6; - -std::unordered_map g_undetermined_configs; -void InitUndeterminedFromEnv(const std::string &sparse_shape_types) { - std::string tmp; - std::stringstream input(sparse_shape_types); - g_undetermined_configs.clear(); - while (std::getline(input, tmp, ';')) { - auto config = UndeterminedShapeType(tmp); - g_undetermined_configs.insert(std::make_pair(config.param_name(), config)); - MS_LOG(DEBUG) << "Undetermined config from env: " << tmp; - } -} - AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { MS_EXCEPTION_IF_NULL(primitive); @@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); } - if (!key->sparse_grad().empty()) { - // Will be fixed once undetermined type ready - if (g_undetermined_configs.empty()) { - auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); - MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types; - if (sparse_shape_types.empty()) { - sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; - } - InitUndeterminedFromEnv(sparse_shape_types); - } - - auto shape_types = g_undetermined_configs.find(key->sparse_grad()); - if (shape_types == g_undetermined_configs.end()) { - MS_LOG(EXCEPTION) << "Param " << key->ToString() - << " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES"; - } - MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString(); - AbstractBasePtrList sparse_list; - // indices - auto indices_ele = std::make_shared(kAnyValue, shape_types->second.indices_type()); - auto indices = - std::make_shared(indices_ele, std::make_shared(shape_types->second.indices_shape())); - sparse_list.emplace_back(indices); - // values - auto dout_ele = std::make_shared(kAnyValue, shape_types->second.values_type()); - auto dout = std::make_shared(dout_ele, std::make_shared(shape_types->second.values_shape())); - sparse_list.emplace_back(dout); - // dense_shape - sparse_list.emplace_back(std::make_shared(shape_types->second.dense_shape())); - return std::make_shared(sparse_list); - } - auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa()) { + bool enable_sparse = context->enable_sparse(); + if (enable_sparse && dflt->isa()) { auto dflt_tensor = dflt->cast(); return std::make_shared(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); } + if (!key->GetValueTrack()->isa()) { return dflt; } @@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & if (type->type_id() != kObjectTypeRefKey) { MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); } - auto ret = std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); - ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); - ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad()); - return ret; + return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); } AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, diff --git a/mindspore/ccsrc/optimizer/irpass/inline.h b/mindspore/ccsrc/optimizer/irpass/inline.h index 64f192347c..4b48d604d9 100644 --- a/mindspore/ccsrc/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/optimizer/irpass/inline.h @@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor { } auto fg = GetValueNode(node); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { return nullptr; } @@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor { // G auto fg = GetValueNode(inputs[0]); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { return nullptr; } // Do not inline GraphKernel to Cell. diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index cea82bc180..1766e29566 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { std::string env = common::GetEnv("SLICE_ENV"); if (!env.empty()) { MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; - abstract::InitUndeterminedFromEnv(env); } } diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index a645452cc0..425ad28fb5 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { ValuePtr value = param_value->value(); constexpr bool broaden = true; AbstractBasePtr ptr = abstract::FromValue(value, broaden); - ptr->set_sparse_grad(param_value->sparse_grad()); - ptr->set_has_indexed_slices_grad(param_value->has_indexed_slices_grad()); parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); args_spec.push_back(ptr); diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 199e841fc9..305acc67ec 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, "Set the GraphKernel switch to on or off.") .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") - .def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.") - .def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse."); + .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.") + .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity."); (void)py::class_>(m, "MpiConfig") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index f6cfd6362c..abffc37bb2 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { return true; } -std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_a", OptPassAGroup}, +std::vector kVmPasses = {{"opt_a", OptPassAGroup}, + {"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_control_depend", AddControlDependPass}}; -std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_a", OptPassAGroup}, - {"opt_b", OptPassBGroup}, - {"add_control_depend", AddControlDependPass}, - {"opt_control", ControlGroup}, - {"opt_prepare", PrepareGroup}, - {"cconv", CconvPass}}; +std::vector kGePasses = { + {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, + {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, + {"cconv", CconvPass}}; std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; } // namespace pipeline diff --git a/mindspore/ccsrc/pipeline/resource.cc b/mindspore/ccsrc/pipeline/resource.cc index faf1f2015d..cd79b2466a 100644 --- a/mindspore/ccsrc/pipeline/resource.cc +++ b/mindspore/ccsrc/pipeline/resource.cc @@ -146,37 +146,35 @@ MethodMap &GetMethodMap() { }}, {kObjectTypeTensorType, { - {"__add__", std::string("add")}, // C.add - {"__sub__", std::string("sub")}, // C.sub - {"__mul__", std::string("mul")}, // C.mul - {"__truediv__", std::string("truediv")}, // C.truediv - {"__floordiv__", std::string("floordiv")}, // C.floordiv - {"__mod__", std::string("mod")}, // C.mod - {"__pow__", std::string("pow_")}, // C.pow - {"__floor__", std::string("array_floor")}, // C.array_floor - {"__trunc__", std::string("array_trunc")}, // C.array_trunc - {"__pos__", std::string("array_uadd")}, // C.array_uadd - {"__neg__", std::string("array_usub")}, // C.array_usub - {"__eq__", std::string("eq")}, // C.eq - {"__ne__", std::string("ne")}, // C.ne - {"__lt__", std::string("lt")}, // C.lt - {"__gt__", std::string("gt")}, // C.gt - {"__le__", std::string("le")}, // C.le - {"__ge__", std::string("ge")}, // C.ge - {"__matmul__", prim::kPrimDot}, // P.dot, - {"__len__", prim::kPrimArrayLen}, // P.array_len, - {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, - {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, - {"__ms_iter__", std::string("array_iter")}, // C.array_iter - {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, - {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, - {"transpose", std::string("transpose")}, // P.transpose - {"__bool__", std::string("tensor_bool")}, // C.tensor_bool - {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices + {"__add__", std::string("add")}, // C.add + {"__sub__", std::string("sub")}, // C.sub + {"__mul__", std::string("mul")}, // C.mul + {"__truediv__", std::string("truediv")}, // C.truediv + {"__floordiv__", std::string("floordiv")}, // C.floordiv + {"__mod__", std::string("mod")}, // C.mod + {"__pow__", std::string("pow_")}, // C.pow + {"__floor__", std::string("array_floor")}, // C.array_floor + {"__trunc__", std::string("array_trunc")}, // C.array_trunc + {"__pos__", std::string("array_uadd")}, // C.array_uadd + {"__neg__", std::string("array_usub")}, // C.array_usub + {"__eq__", std::string("eq")}, // C.eq + {"__ne__", std::string("ne")}, // C.ne + {"__lt__", std::string("lt")}, // C.lt + {"__gt__", std::string("gt")}, // C.gt + {"__le__", std::string("le")}, // C.le + {"__ge__", std::string("ge")}, // C.ge + {"__matmul__", prim::kPrimDot}, // P.dot, + {"__len__", prim::kPrimArrayLen}, // P.array_len, + {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, + {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, + {"__ms_iter__", std::string("array_iter")}, // C.array_iter + {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, + {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, + {"transpose", std::string("transpose")}, // P.transpose + {"__bool__", std::string("tensor_bool")}, // C.tensor_bool }}, {kObjectTypeIndexedSlicesType, { - {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc index 6c07f92274..a2f97cf3b0 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc @@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const { AbstractBasePtr AbstractBase::Broaden() const { AbstractBasePtr clone = Clone(); clone->set_value(kAnyValue); - clone->set_sparse_grad(sparse_grad_); return clone; } @@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const { MS_EXCEPTION_IF_NULL(type_); MS_EXCEPTION_IF_NULL(shape_); buffer << type_name() << "(" - << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() - << " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")"; + << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")"; return buffer.str(); } @@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden() AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); if (*this == *other) { - auto ret = shared_from_base(); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; + return shared_from_base(); } auto value_self = GetValueTrack(); MS_EXCEPTION_IF_NULL(value_self); ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); if (res_value == value_self) { - auto ret = shared_from_base(); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; + return shared_from_base(); } - auto ret = std::make_shared(res_value, res_type); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; + return std::make_shared(res_value, res_type); } AbstractBasePtr AbstractType::Clone() const { @@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } if (*this == *other) { - if (sparse_grad() == other->sparse_grad()) { - return shared_from_base(); - } + return shared_from_base(); } auto element = element_->Join(other_tensor->element_); auto shape = ShapeJoin(this->shape(), other_tensor->shape()); - auto ret = std::make_shared(element, shape); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; + return std::make_shared(element, shape); } bool AbstractTensor::operator==(const AbstractTensor &other) const { @@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const { ShapePtr shp = shape(); clone->set_shape(shp->Clone()); clone->set_value(GetValueTrack()); - clone->set_sparse_grad(sparse_grad()); - clone->set_has_indexed_slices_grad(has_indexed_slices_grad()); return clone; } @@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const { auto shp = shape(); broaden->set_shape(shp->Clone()); broaden->set_value(kAnyValue); - broaden->set_sparse_grad(sparse_grad()); - broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); return broaden; } @@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { shp->Broaden(); broaden->set_shape(shp); broaden->set_value(kAnyValue); - broaden->set_sparse_grad(sparse_grad()); - broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); return broaden; } @@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const { MS_EXCEPTION_IF_NULL(value_track); buffer << type_name() << "(" << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() - << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() - << " has_indexed_slices_grad " << has_indexed_slices_grad() << ")"; + << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"; return buffer.str(); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h index 3981a6eb23..f165808fa0 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h @@ -44,7 +44,7 @@ class AbstractBase : public Base { public: explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, const BaseShapePtr &shape = kNoShape) - : value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {} + : value_(value), type_(type), shape_(shape) {} ~AbstractBase() override = default; MS_DECLARE_PARENT(AbstractBase, Base) @@ -53,17 +53,11 @@ class AbstractBase : public Base { virtual bool operator==(const AbstractBase &other) const; void set_value(const ValuePtr &value) { value_ = value; } - void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } - void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) { - has_indexed_slices_grad_ = has_indexed_slices_grad; - } void set_type(const TypePtr &type) { type_ = type; } void set_shape(const BaseShapePtr &shape) { shape_ = shape; } void set_value_desc(const std::string &desc) { value_desc_ = desc; } const std::string &value_desc() const { return value_desc_; } ValuePtr GetValueTrack() const { return value_; } - const std::string &sparse_grad() const { return sparse_grad_; } - const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; } TypePtr GetTypeTrack() const { return type_; } BaseShapePtr GetShapeTrack() const { return shape_; } @@ -91,8 +85,6 @@ class AbstractBase : public Base { TypePtr type_; BaseShapePtr shape_; std::string value_desc_; // store initial value description for error report - std::string sparse_grad_; - bool has_indexed_slices_grad_; }; class AbstractScalar : public AbstractBase { diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc index 34ecfc8980..a95f686199 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc @@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr } MS_EXCEPTION_IF_NULL(ret_base); - MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString(); + MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() + << ", is stub: " << fg->stub(); + if (fg->stub()) { + return std::make_shared(std::make_shared(), nullptr); + } return std::make_shared(ret_base, nullptr); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h index f6430eda84..079c1aac61 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h @@ -25,6 +25,7 @@ #include #include "pipeline/static_analysis/static_analysis.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace abstract { @@ -59,6 +60,13 @@ class Evaluator : public Base { } virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { return true; diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 0c9764af93..19aeceb19b 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { using mindspore::parse::PyObjectWrapper; EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { + if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; @@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; + return ret_abstract; + } + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } @@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt } AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - ScopePtr scope = kDefaultScope; if (out_conf != nullptr) { scope = out_conf->node()->scope(); @@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic } // end anonymous namespace EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; - return ret_abstract; - } + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; + return ret_abstract; } MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); @@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs } EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; - return ret_abstract; - } + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; + return ret_abstract; } // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. if (nargs_ != args.size()) { @@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { auto ret = std::make_shared(type); auto ref_value = ref_abs->ref(); MS_EXCEPTION_IF_NULL(ref_value); - ret->set_sparse_grad(ref_value->sparse_grad()); - ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad()); return std::make_shared(ret, std::make_shared()); } @@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { x = SensitivityTransform(x); std::shared_ptr key = std::make_shared(node, x); std::shared_ptr abs_scalar = std::make_shared(key, type); - abs_scalar->set_sparse_grad(x->sparse_grad()); - abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad()); return std::make_shared(abs_scalar, std::make_shared()); } }; @@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag) { - auto ret_abstract = AbstractEval(args_spec_list); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; - return ret_abstract; - } + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; + return ret_abstract; } // Inputs: data, item if (args_spec_list.size() != 2) { diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 1346dba2a2..5a686fbadc 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -void InitUndeterminedFromEnv(const std::string &sparse_shape_types); AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc index e01b98841b..b0ad1c3d67 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc @@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co AbstractFunctionPtr func = real_a->GetUnique(); SpecializeStatusCode errcode; ScopeGuard scope_guard(node->scope()); - AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode); + AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode); if (repl == nullptr) { if (errcode == kSpecializeFindUniqueArgvalDead) { const auto error_dead_node = std::make_shared(kDeadNode, node); @@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co return repl; } -AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, +AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractFunctionPtr &func, const AbstractBasePtrList &args, SpecializeStatusCode *errcode) { MS_EXCEPTION_IF_NULL(abs); @@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString(); + if (context->func_graph()->stub()) { + MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString() + << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString() + << ", " << node->ToString(); + return node; + } FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); + v->set_flag(kFuncGraphFlagUndetermined, false); return BuildValueNode(v, abs); } @@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); return kSpecializeSuccess; } else if (choices->empty()) { - MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; + MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | " + << func->type_name(); return kSpecializeFindUniqueArgvalDead; } else { if (IsPolyFunc(func, argvals)) { diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h index b04978586d..831c404873 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h @@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this MsContext::GetInstance() { diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 3bca16f8ee..19205cccb8 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -161,8 +161,8 @@ class MsContext { void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } bool enable_graph_kernel() const { return enable_graph_kernel_; } - bool enable_sparse_flag() const { return enable_sparse_flag_; } - void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; } + bool enable_sparse() const { return enable_sparse_; } + void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; } private: MsContext(const std::string &backend_policy, const std::string &target); @@ -207,7 +207,7 @@ class MsContext { float max_device_memory_; std::string print_file_path_; bool enable_graph_kernel_; - bool enable_sparse_flag_; + bool enable_sparse_; }; } // namespace mindspore diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 5a8f0b8996..1ce98cb147 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -51,18 +51,13 @@ class Parameter: requires_grad (bool): True if the parameter requires gradient. Default: True. layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, broadcast and gradients communication would not be applied on parameters. Default: False. - sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty. - has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false. """ - def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, - sparse_grad="", has_indexed_slices_grad=False): + def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): self._value = ParamValue() self.set_parameter_data(default_input) self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel - self.sparse_grad = sparse_grad - self.has_indexed_slices_grad = has_indexed_slices_grad self._is_init = False self._sliced = False if context.get_context("mode") == context.PYNATIVE_MODE: @@ -177,28 +172,6 @@ class Parameter: raise TypeError("`requires_grad` parameter must be bool type") self._value.requires_grad = value - @property - def sparse_grad(self): - """Return whether the parameter's gradient is sparse.""" - return self._value.sparse_grad - - @sparse_grad.setter - def sparse_grad(self, value=""): - if not isinstance(value, str): - raise TypeError("`sparse_grad` parameter must be str type") - self._value.sparse_grad = value - - @property - def has_indexed_slices_grad(self): - """Return whether the parameter's gradient is indexed_slices.""" - return self._value.has_indexed_slices_grad - - @has_indexed_slices_grad.setter - def has_indexed_slices_grad(self, value=False): - if not isinstance(value, bool): - raise TypeError("`has_indexed_slices_grad` parameter must be bool type") - self._value.has_indexed_slices_grad = value - @property def data(self): return self.default_input diff --git a/mindspore/context.py b/mindspore/context.py index fe3d95b192..51418d3965 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -367,14 +367,6 @@ class _Context: def check_bprop(self, check_bprop_flag): self._context_handle.set_check_bprop_flag(check_bprop_flag) - @property - def enable_sparse(self): - return self._context_handle.get_enable_sparse_flag() - - @enable_sparse.setter - def enable_sparse(self, enable_sparse_flag): - self._context_handle.set_enable_sparse_flag(enable_sparse_flag) - @property def max_device_memory(self): return self._context_handle.get_max_device_memory() @@ -408,6 +400,13 @@ class _Context: full_file_name = print_file_path self._context_handle.set_print_file_path(full_file_name) + @property + def enable_sparse(self): + return self._context_handle.get_enable_sparse() + + @enable_sparse.setter + def enable_sparse(self, enable_sparse): + self._context_handle.set_enable_sparse(enable_sparse) def check_input_format(x): import re @@ -601,7 +600,7 @@ def set_context(**kwargs): print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to a file by default, and turn off printing to the screen. If the file already exists, add a timestamp suffix to the file. - enable_sparse (bool): Whether to enable sparse feature. Default: False. + enable_sparse (bool): Whether to enable sparsity feature. Default: False. Raises: ValueError: If input key is not an attribute in context. diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index d33adb04ee..c95f22ee61 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -162,8 +162,8 @@ class Adam(Optimizer): To improve parameter groups performance, the customized order of parameters can be supported. - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index b2954430b4..43eba7c8d1 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -72,8 +72,8 @@ class FTRL(Optimizer): `_ for engineering document. Note: - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index 4b97d2eb20..7905398437 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -91,8 +91,8 @@ class LazyAdam(Optimizer): value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse behavior, to be notice, is not equivalent to the original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 3530065127..25cf438034 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer): `_. Note: - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index d23fcd3092..2be011cb77 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices') indexed_slices_get_values = Primitive('IndexedSlicesGetValues') indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') -is_indexed_slices = Primitive('IsIndexedSlices') tensor_operator_registry.register('__add__', tensor_add) diff --git a/tests/ut/python/ir/test_indexed_slices.py b/tests/ut/python/ir/test_indexed_slices.py index 8690183090..36dfe464cb 100644 --- a/tests/ut/python/ir/test_indexed_slices.py +++ b/tests/ut/python/ir/test_indexed_slices.py @@ -36,6 +36,8 @@ from mindspore._checkparam import Rel from mindspore.nn import Optimizer from mindspore.nn import TrainOneStepCell, WithLossCell +context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + reduce_sum = P.ReduceSum() unsorted_segment_sum = P.UnsortedSegmentSum() transpose = P.Transpose() @@ -44,7 +46,6 @@ reshape = P.Reshape() size_op = P.Size() invert_permutation = P.InvertPermutation() logical_and = P.LogicalAnd() -context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) @constexpr def _generate_shape_index(out_shape, indices_shape, axis): @@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self): adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Undetermined", "Bool") -def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): - if gradient.is_indexed_slices(): - return gradient.values() + "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool") +def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param, + m, v, gradient, decay_flag): + return gradient.values() + +@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, + m, v, gradient, decay_flag): op_mul = P.Mul() op_square = P.Square() op_sqrt = P.Sqrt() @@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices(): self.dense_shape = (3, 4) def construct(self, indices, values): ret = (IndexedSlices(indices, values, self.dense_shape),) - return ret[0].is_indexed_slices() + return ret[0] indices = Tensor([[0, 0], [1, 2]]) values = Tensor([1, 2], dtype=ms.float32) MakeIndexedSlices()(indices, values) @@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): self.network = network def construct(self, x, y): grad = grad_all(self.network)(x, y) - return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices() + return grad, grad[0], grad[1] class SparseGatherV2(nn.Cell): def __init__(self): super(SparseGatherV2, self).__init__() @@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): weights = self.weights grad = grad_by_list(self.network, weights)(x) x = grad[0] - return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape() + return x, x.values(), x.indices(), x.dense_shape() class SparseGatherV2(nn.Cell): def __init__(self): super(SparseGatherV2, self).__init__() self.sparse_gatherv2 = MySparseGatherV2() self.axis = 0 - self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), - name="params", has_indexed_slices_grad=True) + self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params") def construct(self, indices): return self.sparse_gatherv2(self.params, indices, self.axis) indices = Tensor(np.array([0, 1]).astype(np.int32)) @@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): network(indices) -def test_indexed_slices_is_indexed_slices(): - class MakeIndexedSlices(nn.Cell): - def __init__(self): - super(MakeIndexedSlices, self).__init__() - self.dense_shape = (3, 4) - def construct(self, indices, values): - indexed_slices = IndexedSlices(indices, values, self.dense_shape) - ret = indexed_slices.is_indexed_slices() - return ret - indices = Tensor([[0, 0], [1, 2]]) - values = Tensor([1, 2], dtype=ms.float32) - MakeIndexedSlices()(indices, values) - - def test_indexed_slices_env_get(): class Loss(nn.Cell): def __init__(self): @@ -271,7 +262,7 @@ def test_indexed_slices_env_get(): class NetWithSparseGatherV2(nn.Cell): def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True) + self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") self.gatherv2 = MySparseGatherV2() self.axis = 0 diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index b435bf65b9..03a73893c5 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -17,12 +17,13 @@ import numpy as np import pytest import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): """ Net definition """ @@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), - name="weight1", sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py index 7f9f341a93..23aad24c47 100644 --- a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py +++ b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py @@ -27,6 +27,7 @@ from mindspore.ops import functional as F from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel +context.set_context(enable_sparse=True) adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse(): class NetWithSparseGatherV2(nn.Cell): def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1") + self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") self.gatherv2 = P.SparseGatherV2() self.axis = 0 diff --git a/tests/ut/python/nn/optim/test_ftrl.py b/tests/ut/python/nn/optim/test_ftrl.py index de59dfdbad..670bebc92d 100644 --- a/tests/ut/python/nn/optim/test_ftrl.py +++ b/tests/ut/python/nn/optim/test_ftrl.py @@ -17,12 +17,13 @@ import numpy as np import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import FTRL from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self): @@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), - name="weight1", sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/nn/optim/test_lazyadam.py b/tests/ut/python/nn/optim/test_lazyadam.py index ce66b404e2..7769597140 100644 --- a/tests/ut/python/nn/optim/test_lazyadam.py +++ b/tests/ut/python/nn/optim/test_lazyadam.py @@ -17,12 +17,13 @@ import numpy as np import pytest import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import LazyAdam from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): """ Net definition """ @@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), - name="weight1", sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/nn/optim/test_proximal_ada_grad.py b/tests/ut/python/nn/optim/test_proximal_ada_grad.py index c7e6d3f88a..3077896fed 100644 --- a/tests/ut/python/nn/optim/test_proximal_ada_grad.py +++ b/tests/ut/python/nn/optim/test_proximal_ada_grad.py @@ -17,12 +17,13 @@ import numpy as np import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import ProximalAdagrad from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self): @@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", - sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/pipeline/infer/test_hypermap_specialize.py b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py index 1f669f7355..c292e3662d 100644 --- a/tests/ut/python/pipeline/infer/test_hypermap_specialize.py +++ b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py @@ -53,4 +53,4 @@ def test_hypermap_specialize_param(): expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) ret = hypermap_specialize_param() - assert ret == (expected_ret, expected_ret) + assert ret == (expected_ret, list(expected_ret))