From 9927e6eb5cbcea9ed4cd752206927f7c575795e9 Mon Sep 17 00:00:00 2001 From: panyifeng Date: Fri, 31 Jul 2020 11:47:31 +0800 Subject: [PATCH] eager mode sparse --- .../ccsrc/frontend/operator/prim_statement.cc | 12 +++- .../pipeline/jit/static_analysis/prim.cc | 27 ++++++-- mindspore/ccsrc/utils/convert_utils.cc | 8 +++ mindspore/common/dtype.py | 1 + mindspore/common/tensor.py | 28 +++++++- mindspore/ops/operations/math_ops.py | 6 +- tests/ut/python/ops/test_control_ops.py | 4 +- .../pynative_mode/test_sparse_pynative.py | 65 +++++++++++++++++++ 8 files changed, 138 insertions(+), 13 deletions(-) create mode 100644 tests/ut/python/pynative_mode/test_sparse_pynative.py diff --git a/mindspore/ccsrc/frontend/operator/prim_statement.cc b/mindspore/ccsrc/frontend/operator/prim_statement.cc index e193ff1dab..6a7f54007b 100644 --- a/mindspore/ccsrc/frontend/operator/prim_statement.cc +++ b/mindspore/ccsrc/frontend/operator/prim_statement.cc @@ -132,7 +132,17 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP // Inputs: index, branch const std::string op_name = primitive->name(); abstract::CheckArgsSize(op_name, args_spec_list, 2); - (void)CheckArg(op_name, args_spec_list, 0); + auto index = CheckArg(op_name, args_spec_list, 0); + auto &input_shape = index->shape()->shape(); + if (input_shape.size() != 0) { + MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size() + << " dimension tensor"; + } + auto dtype = index->element()->BuildType(); + if (dtype->type_id() != kInt32->type_id()) { + MS_EXCEPTION(ValueError) << op_name << " index must be a int32, but got " << dtype->ToString(); + } + AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); AbstractBasePtrList branches = branches_abs->elements(); const size_t maximum_layer_num = 1000; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index fa8c36e460..4f0840b4fc 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -145,9 +145,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { using mindspore::parse::PyObjectWrapper; +std::unordered_set prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", + "env_getitem"}; + EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { - if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch && prim_ != prim::kPrimEnvSetItem && - prim_ != prim::kPrimEnvGetItem) { + if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; @@ -167,17 +169,23 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - auto ret_abstract = AbstractEval(args_spec_list); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; - return ret_abstract; + auto do_signature = prim_->cast(); + auto &func = do_signature->function(); + if (func->isa()) { + auto sig_prim = func->cast(); + if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) { + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined"; + return ret_abstract; + } + } } if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } - auto do_signature = dyn_cast(prim_); auto out_node = dyn_cast(out_conf->node()); const auto &out_node_inputs = out_node->inputs(); if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { @@ -447,6 +455,11 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic["shape"] = py::none(); dic["dtype"] = abs_base->BuildType(); dic["value"] = py::none(); + } else if (abs_base->isa()) { + auto arg = dyn_cast(abs_base); + dic["shape"] = py::none(); + dic["dtype"] = arg->BuildType(); + dic["value"] = py::none(); } else { auto value = abs_base->BuildValue(); if ((*value == *kAnyValue)) { diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index b427831738..f5a6738c8d 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -32,6 +32,7 @@ #include "ir/tensor.h" #include "ir/param_value.h" #include "utils/base_ref_extends.h" +#include "utils/ms_context.h" namespace mindspore { py::object BuiltinsToPyData(const Any &value); @@ -404,6 +405,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py auto abstract_none = std::make_shared(); return abstract_none; } else { + // When sparse enabled, the undetermined might be raised and eliminated in opt passes + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (enable_sparse) { + return std::make_shared(); + } MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj); } } diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index e5c8933fe2..61c7f3096c 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -101,6 +101,7 @@ list_type = typing.List tuple_type = typing.Tuple index_slices = typing.IndexedSlicesType() sparse_tensor = typing.SparseTensorType() +undetermined = typing.UndeterminedType() number_type = (int8, int16, diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index fde5ad4d12..106d0eccad 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -290,7 +290,19 @@ class IndexedSlices: """ def __init__(self, indices, values, dense_shape): - raise NotImplementedError + "Init IndexedSlices" + self.__indices = indices + self.__values = values + self.__dense_shape = dense_shape + + def indices(self): + return self.__indices + + def values(self): + return self.__values + + def dense_shape(self): + return self.__dense_shape class SparseTensor: @@ -331,4 +343,16 @@ class SparseTensor: """ def __init__(self, indices, values, dense_shape): - raise NotImplementedError + "Init SparseTensor" + self.__indices = indices + self.__values = values + self.__dense_shape = dense_shape + + def indices(self): + return self.__indices + + def values(self): + return self.__values + + def dense_shape(self): + return self.__dense_shape diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index a704a573dc..9bfa078560 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -814,9 +814,13 @@ class AddN(PrimitiveWithInfer): validator.check_value_type("inputs", inputs, [tuple, list], cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) args = {} + contains_undetermined = False for i, dtype in enumerate(inputs): args[f"inputs[{i}]"] = dtype - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) + if dtype == mstype.undetermined: + contains_undetermined = True + if not contains_undetermined: + validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) return inputs[0] def infer_value(self, inputs): diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 13b0aa9ce3..753c4856a3 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -398,7 +398,7 @@ def test_switch_layer(): ret = F.switch_layer(index, self.layers)(x) * self.z3 return ret - index = Tensor(0) + index = Tensor(0, dtype=mstype.int32) net = SwitchLayerCell() net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, @@ -436,7 +436,7 @@ def test_index_to_switch_layer(): ret = self.layers[index](x) * self.z3 return ret - index = Tensor(0) + index = Tensor(0, dtype=mstype.int32) net = SwitchLayerCell() net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, diff --git a/tests/ut/python/pynative_mode/test_sparse_pynative.py b/tests/ut/python/pynative_mode/test_sparse_pynative.py new file mode 100644 index 0000000000..17c908e07f --- /dev/null +++ b/tests/ut/python/pynative_mode/test_sparse_pynative.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +@File : test_sparse_pynative.py +@Author: +@Date : 2020-08-04 +@Desc : test mindspore sparse pynative +""" +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, Tensor, IndexedSlices, SparseTensor +from mindspore.ops import composite as C + +context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) + + +grad_all = C.GradOperation('get_all', get_all=True) +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + def construct(self, *args): + grad = grad_all(self.network)(*args) + return grad + + +def test_indexed_slices_attr(): + class IndexedSlicesGetAttr(nn.Cell): + def __init__(self, dense_shape): + super(IndexedSlicesGetAttr, self).__init__() + self.dense_shape = dense_shape + def construct(self, indices, values): + x = IndexedSlices(indices, values, self.dense_shape) + return x.values(), x.indices(), x.dense_shape() + indices = Tensor([0]) + values = Tensor([[1, 2]], dtype=ms.float32) + IndexedSlicesGetAttr((3, 2))(indices, values) + GradWrap(IndexedSlicesGetAttr((3, 2)))(indices, values) + + +def test_sparse_tensor_attr(): + class SparseTensorGetAttr(nn.Cell): + def __init__(self): + super(SparseTensorGetAttr, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + x = SparseTensor(indices, values, self.dense_shape) + return x.values(), x.indices(), x.dense_shape() + + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + SparseTensorGetAttr()(indices, values) + GradWrap(SparseTensorGetAttr())(indices, values)