| @@ -45,7 +45,8 @@ FuncGraph::FuncGraph() | |||
| hyper_param_count_(0), | |||
| is_generated_(false), | |||
| return_(nullptr), | |||
| manager_(std::weak_ptr<FuncGraphManager>()) { | |||
| manager_(std::weak_ptr<FuncGraphManager>()), | |||
| stub_(false) { | |||
| debug_info_ = std::make_shared<GraphDebugInfo>(); | |||
| } | |||
| @@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase { | |||
| void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs); | |||
| bool HasEffect(const CNodePtr &cnode); | |||
| bool stub() const { return stub_; } | |||
| void set_stub(bool stub) { stub_ = stub; } | |||
| private: | |||
| // graph is manipulated by manager and others | |||
| friend FuncGraphManager; | |||
| @@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase { | |||
| // CNode order which relates to origin code order | |||
| std::list<CNodePtr> order_; | |||
| bool stub_; | |||
| }; | |||
| inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) { | |||
| @@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons | |||
| (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count()); | |||
| (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); | |||
| (*target_func_graph)->set_is_generate(func_graph->is_generated()); | |||
| (*target_func_graph)->set_stub(func_graph->stub()); | |||
| TraceManager::EndTrace(); | |||
| } | |||
| @@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP | |||
| new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); | |||
| new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | |||
| new_func_graph->set_is_generate(func_graph->is_generated()); | |||
| new_func_graph->set_stub(func_graph->stub()); | |||
| for (auto &item : func_graph->parameter_default_value()) { | |||
| new_func_graph->set_param_default_value(item.first, cloner[item.second]); | |||
| } | |||
| @@ -30,6 +30,7 @@ | |||
| #include "pipeline/static_analysis/param_validator.h" | |||
| #include "operator/cc_implementations.h" | |||
| #include "optimizer/opt.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "./common.h" | |||
| @@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { | |||
| } | |||
| return item.second; | |||
| } | |||
| // Try best match | |||
| py::function py_fn_subclass; | |||
| size_t subclass_match_cnt = 0; | |||
| for (auto &item : fn_cache_py_) { | |||
| TypePtrList sign = item.first; | |||
| if (sign.size() != types.size()) { | |||
| continue; | |||
| return py::none(); | |||
| } | |||
| FuncGraphPtr GenerateStubFunc(const TypePtrList &types) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| if (!enable_sparse) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> parameters; | |||
| ParameterPtr undetermined_param = nullptr; | |||
| auto stub = std::make_shared<FuncGraph>(); | |||
| for (size_t i = 0; i < types.size(); ++i) { | |||
| auto param = stub->add_parameter(); | |||
| parameters.push_back(param); | |||
| if (types[i]->type_id() == kObjectTypeUndeterminedType) { | |||
| undetermined_param = param; | |||
| } | |||
| auto match = true; | |||
| for (size_t i = 0; i < sign.size(); ++i) { | |||
| if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) && | |||
| !IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) { | |||
| match = false; | |||
| break; | |||
| } | |||
| if (undetermined_param != nullptr) { | |||
| std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)}; | |||
| for (size_t i = 0; i < types.size(); ++i) { | |||
| if (types[i]->type_id() == kObjectTypeFunction) { | |||
| std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param}; | |||
| inputs.push_back(stub->NewCNode(call_prim)); | |||
| } else { | |||
| inputs.push_back(parameters[i]); | |||
| } | |||
| } | |||
| if (!match) { | |||
| continue; | |||
| } | |||
| py_fn_subclass = item.second; | |||
| subclass_match_cnt++; | |||
| } | |||
| if (subclass_match_cnt > 1) { | |||
| MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass"; | |||
| } | |||
| if (subclass_match_cnt == 1) { | |||
| MS_LOG(DEBUG) << "Found one subclass match"; | |||
| return py_fn_subclass; | |||
| auto stub_output = stub->NewCNode(inputs); | |||
| stub->set_output(stub_output); | |||
| stub->set_stub(true); | |||
| return stub; | |||
| } | |||
| return py::none(); | |||
| return nullptr; | |||
| } | |||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||
| @@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||
| MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); | |||
| return func_graph; | |||
| } | |||
| auto stub = GenerateStubFunc(types); | |||
| if (stub != nullptr) { | |||
| MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString(); | |||
| return stub; | |||
| } | |||
| std::ostringstream oss; | |||
| oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ | |||
| << "`, corresponding location info:\n"; | |||
| @@ -23,8 +23,8 @@ | |||
| #include "pipeline/static_analysis/param_validator.h" | |||
| #include "pipeline/static_analysis/prim.h" | |||
| #include "pipeline/static_analysis/utils.h" | |||
| #include "utils/symbolic.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| @@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit | |||
| return AbstractFunction::MakeAbstractFunction(jv); | |||
| } | |||
| class UndeterminedShapeType { | |||
| public: | |||
| explicit UndeterminedShapeType(const std::string &env_str) { | |||
| // param_name indices_shape indices_type values_shape values_type dense_shape | |||
| // export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 | |||
| // 2:Float32:3 1 2" | |||
| std::vector<string> fields; | |||
| string tmp; | |||
| std::stringstream input(env_str); | |||
| while (std::getline(input, tmp, ':')) { | |||
| fields.push_back(tmp); | |||
| } | |||
| if (fields.size() != fields_num) { | |||
| MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size(); | |||
| } | |||
| param_name_ = fields[0]; | |||
| indices_shape_ = GetShape(fields[1]); | |||
| indices_type_ = StringToType(fields[2]); | |||
| values_shape_ = GetShape(fields[3]); | |||
| values_type_ = StringToType(fields[4]); | |||
| auto dense_shape_vec = GetShape(fields[5]); | |||
| AbstractBasePtrList dense_shape_list; | |||
| (void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list), | |||
| [](const auto &elem) { return FromValue(elem, false); }); | |||
| dense_shape_ = dense_shape_list; | |||
| } | |||
| ~UndeterminedShapeType() = default; | |||
| const std::string ¶m_name() { return param_name_; } | |||
| const std::vector<int> &indices_shape() { return indices_shape_; } | |||
| const TypePtr &indices_type() { return indices_type_; } | |||
| const std::vector<int> &values_shape() { return values_shape_; } | |||
| const TypePtr &values_type() { return values_type_; } | |||
| const AbstractBasePtrList &dense_shape() { return dense_shape_; } | |||
| private: | |||
| std::string param_name_; | |||
| std::vector<int> indices_shape_; | |||
| TypePtr indices_type_; | |||
| std::vector<int> values_shape_; | |||
| TypePtr values_type_; | |||
| AbstractBasePtrList dense_shape_; | |||
| static const size_t fields_num; | |||
| std::vector<int> GetShape(const std::string &shape_str); | |||
| }; | |||
| std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) { | |||
| std::vector<int> ret; | |||
| std::istringstream iss(shape_str); | |||
| int elem; | |||
| while (iss.good()) { | |||
| iss >> elem; | |||
| ret.emplace_back(elem); | |||
| } | |||
| return ret; | |||
| } | |||
| const size_t UndeterminedShapeType::fields_num = 6; | |||
| std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs; | |||
| void InitUndeterminedFromEnv(const std::string &sparse_shape_types) { | |||
| std::string tmp; | |||
| std::stringstream input(sparse_shape_types); | |||
| g_undetermined_configs.clear(); | |||
| while (std::getline(input, tmp, ';')) { | |||
| auto config = UndeterminedShapeType(tmp); | |||
| g_undetermined_configs.insert(std::make_pair(config.param_name(), config)); | |||
| MS_LOG(DEBUG) << "Undetermined config from env: " << tmp; | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| @@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt | |||
| MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); | |||
| } | |||
| if (!key->sparse_grad().empty()) { | |||
| // Will be fixed once undetermined type ready | |||
| if (g_undetermined_configs.empty()) { | |||
| auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); | |||
| MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types; | |||
| if (sparse_shape_types.empty()) { | |||
| sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; | |||
| } | |||
| InitUndeterminedFromEnv(sparse_shape_types); | |||
| } | |||
| auto shape_types = g_undetermined_configs.find(key->sparse_grad()); | |||
| if (shape_types == g_undetermined_configs.end()) { | |||
| MS_LOG(EXCEPTION) << "Param " << key->ToString() | |||
| << " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES"; | |||
| } | |||
| MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString(); | |||
| AbstractBasePtrList sparse_list; | |||
| // indices | |||
| auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.indices_type()); | |||
| auto indices = | |||
| std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types->second.indices_shape())); | |||
| sparse_list.emplace_back(indices); | |||
| // values | |||
| auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.values_type()); | |||
| auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types->second.values_shape())); | |||
| sparse_list.emplace_back(dout); | |||
| // dense_shape | |||
| sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types->second.dense_shape())); | |||
| return std::make_shared<AbstractTuple>(sparse_list); | |||
| } | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||
| if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) { | |||
| bool enable_sparse = context->enable_sparse(); | |||
| if (enable_sparse && dflt->isa<AbstractTensor>()) { | |||
| auto dflt_tensor = dflt->cast<AbstractTensorPtr>(); | |||
| return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); | |||
| } | |||
| if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) { | |||
| return dflt; | |||
| } | |||
| @@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| if (type->type_id() != kObjectTypeRefKey) { | |||
| MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | |||
| } | |||
| auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||
| ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); | |||
| ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad()); | |||
| return ret; | |||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||
| } | |||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| @@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor { | |||
| } | |||
| auto fg = GetValueNode<FuncGraphPtr>(node); | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { | |||
| return nullptr; | |||
| } | |||
| @@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor { | |||
| // G | |||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { | |||
| return nullptr; | |||
| } | |||
| // Do not inline GraphKernel to Cell. | |||
| @@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| std::string env = common::GetEnv("SLICE_ENV"); | |||
| if (!env.empty()) { | |||
| MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; | |||
| abstract::InitUndeterminedFromEnv(env); | |||
| } | |||
| } | |||
| @@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| ValuePtr value = param_value->value(); | |||
| constexpr bool broaden = true; | |||
| AbstractBasePtr ptr = abstract::FromValue(value, broaden); | |||
| ptr->set_sparse_grad(param_value->sparse_grad()); | |||
| ptr->set_has_indexed_slices_grad(param_value->has_indexed_slices_grad()); | |||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | |||
| args_spec.push_back(ptr); | |||
| @@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, | |||
| "Set the GraphKernel switch to on or off.") | |||
| .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") | |||
| .def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.") | |||
| .def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse."); | |||
| .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.") | |||
| .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity."); | |||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | |||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | |||
| @@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup}, | |||
| {"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_b", OptPassBGroup}, | |||
| {"cconv", CconvPass}, | |||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | |||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | |||
| {"add_control_depend", AddControlDependPass}}; | |||
| std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| {"opt_b", OptPassBGroup}, | |||
| {"add_control_depend", AddControlDependPass}, | |||
| {"opt_control", ControlGroup}, | |||
| {"opt_prepare", PrepareGroup}, | |||
| {"cconv", CconvPass}}; | |||
| std::vector<PassItem> kGePasses = { | |||
| {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, | |||
| {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, | |||
| {"cconv", CconvPass}}; | |||
| std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; | |||
| } // namespace pipeline | |||
| @@ -146,37 +146,35 @@ MethodMap &GetMethodMap() { | |||
| }}, | |||
| {kObjectTypeTensorType, | |||
| { | |||
| {"__add__", std::string("add")}, // C.add | |||
| {"__sub__", std::string("sub")}, // C.sub | |||
| {"__mul__", std::string("mul")}, // C.mul | |||
| {"__truediv__", std::string("truediv")}, // C.truediv | |||
| {"__floordiv__", std::string("floordiv")}, // C.floordiv | |||
| {"__mod__", std::string("mod")}, // C.mod | |||
| {"__pow__", std::string("pow_")}, // C.pow | |||
| {"__floor__", std::string("array_floor")}, // C.array_floor | |||
| {"__trunc__", std::string("array_trunc")}, // C.array_trunc | |||
| {"__pos__", std::string("array_uadd")}, // C.array_uadd | |||
| {"__neg__", std::string("array_usub")}, // C.array_usub | |||
| {"__eq__", std::string("eq")}, // C.eq | |||
| {"__ne__", std::string("ne")}, // C.ne | |||
| {"__lt__", std::string("lt")}, // C.lt | |||
| {"__gt__", std::string("gt")}, // C.gt | |||
| {"__le__", std::string("le")}, // C.le | |||
| {"__ge__", std::string("ge")}, // C.ge | |||
| {"__matmul__", prim::kPrimDot}, // P.dot, | |||
| {"__len__", prim::kPrimArrayLen}, // P.array_len, | |||
| {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, | |||
| {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, | |||
| {"__ms_iter__", std::string("array_iter")}, // C.array_iter | |||
| {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, | |||
| {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, | |||
| {"transpose", std::string("transpose")}, // P.transpose | |||
| {"__bool__", std::string("tensor_bool")}, // C.tensor_bool | |||
| {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices | |||
| {"__add__", std::string("add")}, // C.add | |||
| {"__sub__", std::string("sub")}, // C.sub | |||
| {"__mul__", std::string("mul")}, // C.mul | |||
| {"__truediv__", std::string("truediv")}, // C.truediv | |||
| {"__floordiv__", std::string("floordiv")}, // C.floordiv | |||
| {"__mod__", std::string("mod")}, // C.mod | |||
| {"__pow__", std::string("pow_")}, // C.pow | |||
| {"__floor__", std::string("array_floor")}, // C.array_floor | |||
| {"__trunc__", std::string("array_trunc")}, // C.array_trunc | |||
| {"__pos__", std::string("array_uadd")}, // C.array_uadd | |||
| {"__neg__", std::string("array_usub")}, // C.array_usub | |||
| {"__eq__", std::string("eq")}, // C.eq | |||
| {"__ne__", std::string("ne")}, // C.ne | |||
| {"__lt__", std::string("lt")}, // C.lt | |||
| {"__gt__", std::string("gt")}, // C.gt | |||
| {"__le__", std::string("le")}, // C.le | |||
| {"__ge__", std::string("ge")}, // C.ge | |||
| {"__matmul__", prim::kPrimDot}, // P.dot, | |||
| {"__len__", prim::kPrimArrayLen}, // P.array_len, | |||
| {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, | |||
| {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, | |||
| {"__ms_iter__", std::string("array_iter")}, // C.array_iter | |||
| {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, | |||
| {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, | |||
| {"transpose", std::string("transpose")}, // P.transpose | |||
| {"__bool__", std::string("tensor_bool")}, // C.tensor_bool | |||
| }}, | |||
| {kObjectTypeIndexedSlicesType, | |||
| { | |||
| {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices | |||
| {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values | |||
| {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices | |||
| {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape | |||
| @@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const { | |||
| AbstractBasePtr AbstractBase::Broaden() const { | |||
| AbstractBasePtr clone = Clone(); | |||
| clone->set_value(kAnyValue); | |||
| clone->set_sparse_grad(sparse_grad_); | |||
| return clone; | |||
| } | |||
| @@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const { | |||
| MS_EXCEPTION_IF_NULL(type_); | |||
| MS_EXCEPTION_IF_NULL(shape_); | |||
| buffer << type_name() << "(" | |||
| << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() | |||
| << " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")"; | |||
| << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")"; | |||
| return buffer.str(); | |||
| } | |||
| @@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden() | |||
| AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | |||
| MS_EXCEPTION_IF_NULL(other); | |||
| if (*this == *other) { | |||
| auto ret = shared_from_base<AbstractBase>(); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||
| return ret; | |||
| return shared_from_base<AbstractBase>(); | |||
| } | |||
| auto value_self = GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(value_self); | |||
| ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | |||
| TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); | |||
| if (res_value == value_self) { | |||
| auto ret = shared_from_base<AbstractBase>(); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||
| return ret; | |||
| return shared_from_base<AbstractBase>(); | |||
| } | |||
| auto ret = std::make_shared<AbstractScalar>(res_value, res_type); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||
| return ret; | |||
| return std::make_shared<AbstractScalar>(res_value, res_type); | |||
| } | |||
| AbstractBasePtr AbstractType::Clone() const { | |||
| @@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||
| MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); | |||
| } | |||
| if (*this == *other) { | |||
| if (sparse_grad() == other->sparse_grad()) { | |||
| return shared_from_base<AbstractBase>(); | |||
| } | |||
| return shared_from_base<AbstractBase>(); | |||
| } | |||
| auto element = element_->Join(other_tensor->element_); | |||
| auto shape = ShapeJoin(this->shape(), other_tensor->shape()); | |||
| auto ret = std::make_shared<AbstractTensor>(element, shape); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||
| return ret; | |||
| return std::make_shared<AbstractTensor>(element, shape); | |||
| } | |||
| bool AbstractTensor::operator==(const AbstractTensor &other) const { | |||
| @@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const { | |||
| ShapePtr shp = shape(); | |||
| clone->set_shape(shp->Clone()); | |||
| clone->set_value(GetValueTrack()); | |||
| clone->set_sparse_grad(sparse_grad()); | |||
| clone->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||
| return clone; | |||
| } | |||
| @@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const { | |||
| auto shp = shape(); | |||
| broaden->set_shape(shp->Clone()); | |||
| broaden->set_value(kAnyValue); | |||
| broaden->set_sparse_grad(sparse_grad()); | |||
| broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||
| return broaden; | |||
| } | |||
| @@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { | |||
| shp->Broaden(); | |||
| broaden->set_shape(shp); | |||
| broaden->set_value(kAnyValue); | |||
| broaden->set_sparse_grad(sparse_grad()); | |||
| broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||
| return broaden; | |||
| } | |||
| @@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const { | |||
| MS_EXCEPTION_IF_NULL(value_track); | |||
| buffer << type_name() << "(" | |||
| << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() | |||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() | |||
| << " has_indexed_slices_grad " << has_indexed_slices_grad() << ")"; | |||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"; | |||
| return buffer.str(); | |||
| } | |||
| @@ -44,7 +44,7 @@ class AbstractBase : public Base { | |||
| public: | |||
| explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, | |||
| const BaseShapePtr &shape = kNoShape) | |||
| : value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {} | |||
| : value_(value), type_(type), shape_(shape) {} | |||
| ~AbstractBase() override = default; | |||
| MS_DECLARE_PARENT(AbstractBase, Base) | |||
| @@ -53,17 +53,11 @@ class AbstractBase : public Base { | |||
| virtual bool operator==(const AbstractBase &other) const; | |||
| void set_value(const ValuePtr &value) { value_ = value; } | |||
| void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } | |||
| void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) { | |||
| has_indexed_slices_grad_ = has_indexed_slices_grad; | |||
| } | |||
| void set_type(const TypePtr &type) { type_ = type; } | |||
| void set_shape(const BaseShapePtr &shape) { shape_ = shape; } | |||
| void set_value_desc(const std::string &desc) { value_desc_ = desc; } | |||
| const std::string &value_desc() const { return value_desc_; } | |||
| ValuePtr GetValueTrack() const { return value_; } | |||
| const std::string &sparse_grad() const { return sparse_grad_; } | |||
| const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; } | |||
| TypePtr GetTypeTrack() const { return type_; } | |||
| BaseShapePtr GetShapeTrack() const { return shape_; } | |||
| @@ -91,8 +85,6 @@ class AbstractBase : public Base { | |||
| TypePtr type_; | |||
| BaseShapePtr shape_; | |||
| std::string value_desc_; // store initial value description for error report | |||
| std::string sparse_grad_; | |||
| bool has_indexed_slices_grad_; | |||
| }; | |||
| class AbstractScalar : public AbstractBase { | |||
| @@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| } | |||
| MS_EXCEPTION_IF_NULL(ret_base); | |||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString(); | |||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() | |||
| << ", is stub: " << fg->stub(); | |||
| if (fg->stub()) { | |||
| return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr); | |||
| } | |||
| return std::make_shared<EvalResult>(ret_base, nullptr); | |||
| } | |||
| @@ -25,6 +25,7 @@ | |||
| #include <vector> | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "utils/context/ms_context.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| @@ -59,6 +60,13 @@ class Evaluator : public Base { | |||
| } | |||
| virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| if (!enable_sparse) { | |||
| return nullptr; | |||
| } | |||
| auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { | |||
| if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||
| return true; | |||
| @@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| using mindspore::parse::PyObjectWrapper; | |||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||
| if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { | |||
| if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||
| @@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c | |||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| AnfNodeConfigPtr out_conf) { | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||
| auto ret_abstract = AbstractEval(args_spec_list); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||
| } | |||
| @@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt | |||
| } | |||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||
| ScopePtr scope = kDefaultScope; | |||
| if (out_conf != nullptr) { | |||
| scope = out_conf->node()->scope(); | |||
| @@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| } // end anonymous namespace | |||
| EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||
| if (enable_sparse_flag) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | |||
| @@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs | |||
| } | |||
| EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||
| if (enable_sparse_flag) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. | |||
| if (nargs_ != args.size()) { | |||
| @@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| auto ret = std::make_shared<AbstractScalar>(type); | |||
| auto ref_value = ref_abs->ref(); | |||
| MS_EXCEPTION_IF_NULL(ref_value); | |||
| ret->set_sparse_grad(ref_value->sparse_grad()); | |||
| ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad()); | |||
| return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | |||
| } | |||
| @@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| x = SensitivityTransform(x); | |||
| std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | |||
| std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | |||
| abs_scalar->set_sparse_grad(x->sparse_grad()); | |||
| abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad()); | |||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | |||
| } | |||
| }; | |||
| @@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { | |||
| MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); | |||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | |||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||
| if (enable_sparse_flag) { | |||
| auto ret_abstract = AbstractEval(args_spec_list); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| auto ret_abstract = AbstractEval(args_spec_list); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| // Inputs: data, item | |||
| if (args_spec_list.size() != 2) { | |||
| @@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv | |||
| AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| void InitUndeterminedFromEnv(const std::string &sparse_shape_types); | |||
| AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| @@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co | |||
| AbstractFunctionPtr func = real_a->GetUnique(); | |||
| SpecializeStatusCode errcode; | |||
| ScopeGuard scope_guard(node->scope()); | |||
| AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode); | |||
| AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode); | |||
| if (repl == nullptr) { | |||
| if (errcode == kSpecializeFindUniqueArgvalDead) { | |||
| const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node); | |||
| @@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co | |||
| return repl; | |||
| } | |||
| AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, | |||
| AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, | |||
| const AbstractFunctionPtr &func, | |||
| const AbstractBasePtrList &args, | |||
| SpecializeStatusCode *errcode) { | |||
| MS_EXCEPTION_IF_NULL(abs); | |||
| @@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr | |||
| AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); | |||
| MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() | |||
| << ", graph: " << context->func_graph()->get_return()->DebugString(); | |||
| if (context->func_graph()->stub()) { | |||
| MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString() | |||
| << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString() | |||
| << ", " << node->ToString(); | |||
| return node; | |||
| } | |||
| FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); | |||
| v->set_flag(kFuncGraphFlagUndetermined, false); | |||
| return BuildValueNode(v, abs); | |||
| } | |||
| @@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct | |||
| *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); | |||
| return kSpecializeSuccess; | |||
| } else if (choices->empty()) { | |||
| MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; | |||
| MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | " | |||
| << func->type_name(); | |||
| return kSpecializeFindUniqueArgvalDead; | |||
| } else { | |||
| if (IsPolyFunc(func, argvals)) { | |||
| @@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia | |||
| // Build a specialized node from given argvals; | |||
| AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, | |||
| const AbstractBasePtrList &argvals); | |||
| AnfNodePtr BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, | |||
| const AbstractBasePtrList &args, SpecializeStatusCode *errcode); | |||
| AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, | |||
| const AbstractFunctionPtr &func, const AbstractBasePtrList &args, | |||
| SpecializeStatusCode *errcode); | |||
| // Find the unique argument values which can be used to specialize a primitive or graph function. | |||
| SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval, | |||
| @@ -89,7 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||
| max_device_memory_ = kDefaultMaxDeviceMemory; | |||
| print_file_path_ = ""; | |||
| enable_graph_kernel_ = false; | |||
| enable_sparse_flag_ = false; | |||
| enable_sparse_ = false; | |||
| } | |||
| std::shared_ptr<MsContext> MsContext::GetInstance() { | |||
| @@ -161,8 +161,8 @@ class MsContext { | |||
| void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } | |||
| bool enable_graph_kernel() const { return enable_graph_kernel_; } | |||
| bool enable_sparse_flag() const { return enable_sparse_flag_; } | |||
| void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; } | |||
| bool enable_sparse() const { return enable_sparse_; } | |||
| void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; } | |||
| private: | |||
| MsContext(const std::string &backend_policy, const std::string &target); | |||
| @@ -207,7 +207,7 @@ class MsContext { | |||
| float max_device_memory_; | |||
| std::string print_file_path_; | |||
| bool enable_graph_kernel_; | |||
| bool enable_sparse_flag_; | |||
| bool enable_sparse_; | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -51,18 +51,13 @@ class Parameter: | |||
| requires_grad (bool): True if the parameter requires gradient. Default: True. | |||
| layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, | |||
| broadcast and gradients communication would not be applied on parameters. Default: False. | |||
| sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty. | |||
| has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false. | |||
| """ | |||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, | |||
| sparse_grad="", has_indexed_slices_grad=False): | |||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): | |||
| self._value = ParamValue() | |||
| self.set_parameter_data(default_input) | |||
| self.name = name | |||
| self.requires_grad = requires_grad | |||
| self.layerwise_parallel = layerwise_parallel | |||
| self.sparse_grad = sparse_grad | |||
| self.has_indexed_slices_grad = has_indexed_slices_grad | |||
| self._is_init = False | |||
| self._sliced = False | |||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||
| @@ -177,28 +172,6 @@ class Parameter: | |||
| raise TypeError("`requires_grad` parameter must be bool type") | |||
| self._value.requires_grad = value | |||
| @property | |||
| def sparse_grad(self): | |||
| """Return whether the parameter's gradient is sparse.""" | |||
| return self._value.sparse_grad | |||
| @sparse_grad.setter | |||
| def sparse_grad(self, value=""): | |||
| if not isinstance(value, str): | |||
| raise TypeError("`sparse_grad` parameter must be str type") | |||
| self._value.sparse_grad = value | |||
| @property | |||
| def has_indexed_slices_grad(self): | |||
| """Return whether the parameter's gradient is indexed_slices.""" | |||
| return self._value.has_indexed_slices_grad | |||
| @has_indexed_slices_grad.setter | |||
| def has_indexed_slices_grad(self, value=False): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`has_indexed_slices_grad` parameter must be bool type") | |||
| self._value.has_indexed_slices_grad = value | |||
| @property | |||
| def data(self): | |||
| return self.default_input | |||
| @@ -367,14 +367,6 @@ class _Context: | |||
| def check_bprop(self, check_bprop_flag): | |||
| self._context_handle.set_check_bprop_flag(check_bprop_flag) | |||
| @property | |||
| def enable_sparse(self): | |||
| return self._context_handle.get_enable_sparse_flag() | |||
| @enable_sparse.setter | |||
| def enable_sparse(self, enable_sparse_flag): | |||
| self._context_handle.set_enable_sparse_flag(enable_sparse_flag) | |||
| @property | |||
| def max_device_memory(self): | |||
| return self._context_handle.get_max_device_memory() | |||
| @@ -408,6 +400,13 @@ class _Context: | |||
| full_file_name = print_file_path | |||
| self._context_handle.set_print_file_path(full_file_name) | |||
| @property | |||
| def enable_sparse(self): | |||
| return self._context_handle.get_enable_sparse() | |||
| @enable_sparse.setter | |||
| def enable_sparse(self, enable_sparse): | |||
| self._context_handle.set_enable_sparse(enable_sparse) | |||
| def check_input_format(x): | |||
| import re | |||
| @@ -601,7 +600,7 @@ def set_context(**kwargs): | |||
| print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to | |||
| a file by default, and turn off printing to the screen. If the file already exists, add a timestamp | |||
| suffix to the file. | |||
| enable_sparse (bool): Whether to enable sparse feature. Default: False. | |||
| enable_sparse (bool): Whether to enable sparsity feature. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not an attribute in context. | |||
| @@ -162,8 +162,8 @@ class Adam(Optimizer): | |||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the | |||
| `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. | |||
| The sparse feature is under continuous development. The sparse | |||
| behavior is currently performed on the CPU. | |||
| Args: | |||
| @@ -72,8 +72,8 @@ class FTRL(Optimizer): | |||
| <https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document. | |||
| Note: | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the | |||
| `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. | |||
| The sparse feature is under continuous development. The sparse | |||
| behavior is currently performed on the CPU. | |||
| Args: | |||
| @@ -91,8 +91,8 @@ class LazyAdam(Optimizer): | |||
| value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be | |||
| applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the | |||
| `sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. | |||
| The sparse behavior, to be notice, is not equivalent to the | |||
| original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under | |||
| continuous development. The sparse behavior is currently performed on the CPU. | |||
| @@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer): | |||
| <http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_. | |||
| Note: | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the | |||
| `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse | |||
| The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. | |||
| The sparse feature is under continuous development. The sparse | |||
| behavior is currently performed on the CPU. | |||
| Args: | |||
| @@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices') | |||
| indexed_slices_get_values = Primitive('IndexedSlicesGetValues') | |||
| indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') | |||
| indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') | |||
| is_indexed_slices = Primitive('IsIndexedSlices') | |||
| tensor_operator_registry.register('__add__', tensor_add) | |||
| @@ -36,6 +36,8 @@ from mindspore._checkparam import Rel | |||
| from mindspore.nn import Optimizer | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) | |||
| reduce_sum = P.ReduceSum() | |||
| unsorted_segment_sum = P.UnsortedSegmentSum() | |||
| transpose = P.Transpose() | |||
| @@ -44,7 +46,6 @@ reshape = P.Reshape() | |||
| size_op = P.Size() | |||
| invert_permutation = P.InvertPermutation() | |||
| logical_and = P.LogicalAnd() | |||
| context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) | |||
| @constexpr | |||
| def _generate_shape_index(out_shape, indices_shape, axis): | |||
| @@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self): | |||
| adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | |||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Undetermined", "Bool") | |||
| def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| if gradient.is_indexed_slices(): | |||
| return gradient.values() | |||
| "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool") | |||
| def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param, | |||
| m, v, gradient, decay_flag): | |||
| return gradient.values() | |||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, | |||
| m, v, gradient, decay_flag): | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| @@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices(): | |||
| self.dense_shape = (3, 4) | |||
| def construct(self, indices, values): | |||
| ret = (IndexedSlices(indices, values, self.dense_shape),) | |||
| return ret[0].is_indexed_slices() | |||
| return ret[0] | |||
| indices = Tensor([[0, 0], [1, 2]]) | |||
| values = Tensor([1, 2], dtype=ms.float32) | |||
| MakeIndexedSlices()(indices, values) | |||
| @@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| grad = grad_all(self.network)(x, y) | |||
| return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices() | |||
| return grad, grad[0], grad[1] | |||
| class SparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(SparseGatherV2, self).__init__() | |||
| @@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): | |||
| weights = self.weights | |||
| grad = grad_by_list(self.network, weights)(x) | |||
| x = grad[0] | |||
| return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape() | |||
| return x, x.values(), x.indices(), x.dense_shape() | |||
| class SparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(SparseGatherV2, self).__init__() | |||
| self.sparse_gatherv2 = MySparseGatherV2() | |||
| self.axis = 0 | |||
| self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), | |||
| name="params", has_indexed_slices_grad=True) | |||
| self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params") | |||
| def construct(self, indices): | |||
| return self.sparse_gatherv2(self.params, indices, self.axis) | |||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| @@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): | |||
| network(indices) | |||
| def test_indexed_slices_is_indexed_slices(): | |||
| class MakeIndexedSlices(nn.Cell): | |||
| def __init__(self): | |||
| super(MakeIndexedSlices, self).__init__() | |||
| self.dense_shape = (3, 4) | |||
| def construct(self, indices, values): | |||
| indexed_slices = IndexedSlices(indices, values, self.dense_shape) | |||
| ret = indexed_slices.is_indexed_slices() | |||
| return ret | |||
| indices = Tensor([[0, 0], [1, 2]]) | |||
| values = Tensor([1, 2], dtype=ms.float32) | |||
| MakeIndexedSlices()(indices, values) | |||
| def test_indexed_slices_env_get(): | |||
| class Loss(nn.Cell): | |||
| def __init__(self): | |||
| @@ -271,7 +262,7 @@ def test_indexed_slices_env_get(): | |||
| class NetWithSparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True) | |||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") | |||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") | |||
| self.gatherv2 = MySparseGatherV2() | |||
| self.axis = 0 | |||
| @@ -17,12 +17,13 @@ import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import Tensor, Parameter, context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR | |||
| from mindspore.ops import operations as P | |||
| context.set_context(enable_sparse=True) | |||
| class Net(nn.Cell): | |||
| """ Net definition """ | |||
| @@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell): | |||
| """ NetWithSparseGatherV2 definition """ | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), | |||
| name="weight1", sparse_grad="sparse_key_w1") | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") | |||
| self.axis = 0 | |||
| self.gather = P.SparseGatherV2() | |||
| @@ -27,6 +27,7 @@ from mindspore.ops import functional as F | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| context.set_context(enable_sparse=True) | |||
| adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | |||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| @@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse(): | |||
| class NetWithSparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1") | |||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") | |||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") | |||
| self.gatherv2 = P.SparseGatherV2() | |||
| self.axis = 0 | |||
| @@ -17,12 +17,13 @@ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import Tensor, Parameter, context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import FTRL | |||
| from mindspore.ops import operations as P | |||
| context.set_context(enable_sparse=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| @@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell): | |||
| """ NetWithSparseGatherV2 definition """ | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), | |||
| name="weight1", sparse_grad="sparse_key_w1") | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") | |||
| self.axis = 0 | |||
| self.gather = P.SparseGatherV2() | |||
| @@ -17,12 +17,13 @@ import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import Tensor, Parameter, context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import LazyAdam | |||
| from mindspore.ops import operations as P | |||
| context.set_context(enable_sparse=True) | |||
| class Net(nn.Cell): | |||
| """ Net definition """ | |||
| @@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell): | |||
| """ NetWithSparseGatherV2 definition """ | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), | |||
| name="weight1", sparse_grad="sparse_key_w1") | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") | |||
| self.axis = 0 | |||
| self.gather = P.SparseGatherV2() | |||
| @@ -17,12 +17,13 @@ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import Tensor, Parameter, context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import ProximalAdagrad | |||
| from mindspore.ops import operations as P | |||
| context.set_context(enable_sparse=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| @@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell): | |||
| """ NetWithSparseGatherV2 definition """ | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", | |||
| sparse_grad="sparse_key_w1") | |||
| self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2") | |||
| self.axis = 0 | |||
| self.gather = P.SparseGatherV2() | |||
| @@ -53,4 +53,4 @@ def test_hypermap_specialize_param(): | |||
| expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) | |||
| ret = hypermap_specialize_param() | |||
| assert ret == (expected_ret, expected_ret) | |||
| assert ret == (expected_ret, list(expected_ret)) | |||