| @@ -132,7 +132,17 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP | |||
| // Inputs: index, branch | |||
| const std::string op_name = primitive->name(); | |||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | |||
| (void)CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto index = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto &input_shape = index->shape()->shape(); | |||
| if (input_shape.size() != 0) { | |||
| MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size() | |||
| << " dimension tensor"; | |||
| } | |||
| auto dtype = index->element()->BuildType(); | |||
| if (dtype->type_id() != kInt32->type_id()) { | |||
| MS_EXCEPTION(ValueError) << op_name << " index must be a int32, but got " << dtype->ToString(); | |||
| } | |||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| AbstractBasePtrList branches = branches_abs->elements(); | |||
| const size_t maximum_layer_num = 1000; | |||
| @@ -145,9 +145,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| using mindspore::parse::PyObjectWrapper; | |||
| std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", | |||
| "env_getitem"}; | |||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | |||
| if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch && prim_ != prim::kPrimEnvSetItem && | |||
| prim_ != prim::kPrimEnvGetItem) { | |||
| if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { | |||
| auto ret_abstract = AbstractEval(args); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||
| @@ -167,17 +169,23 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); | |||
| auto ret_abstract = AbstractEval(args_spec_list); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>(); | |||
| auto &func = do_signature->function(); | |||
| if (func->isa<Primitive>()) { | |||
| auto sig_prim = func->cast<PrimitivePtr>(); | |||
| if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) { | |||
| auto ret_abstract = AbstractEval(args_spec_list); | |||
| if (ret_abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined"; | |||
| return ret_abstract; | |||
| } | |||
| } | |||
| } | |||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||
| } | |||
| auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim_); | |||
| auto out_node = dyn_cast<CNode>(out_conf->node()); | |||
| const auto &out_node_inputs = out_node->inputs(); | |||
| if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { | |||
| @@ -447,6 +455,11 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = abs_base->BuildType(); | |||
| dic["value"] = py::none(); | |||
| } else if (abs_base->isa<AbstractUndetermined>()) { | |||
| auto arg = dyn_cast<AbstractUndetermined>(abs_base); | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = arg->BuildType(); | |||
| dic["value"] = py::none(); | |||
| } else { | |||
| auto value = abs_base->BuildValue(); | |||
| if ((*value == *kAnyValue)) { | |||
| @@ -32,6 +32,7 @@ | |||
| #include "ir/tensor.h" | |||
| #include "ir/param_value.h" | |||
| #include "utils/base_ref_extends.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| py::object BuiltinsToPyData(const Any &value); | |||
| @@ -404,6 +405,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py | |||
| auto abstract_none = std::make_shared<abstract::AbstractNone>(); | |||
| return abstract_none; | |||
| } else { | |||
| // When sparse enabled, the undetermined might be raised and eliminated in opt passes | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| if (enable_sparse) { | |||
| return std::make_shared<abstract::AbstractUndetermined>(); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj); | |||
| } | |||
| } | |||
| @@ -101,6 +101,7 @@ list_type = typing.List | |||
| tuple_type = typing.Tuple | |||
| index_slices = typing.IndexedSlicesType() | |||
| sparse_tensor = typing.SparseTensorType() | |||
| undetermined = typing.UndeterminedType() | |||
| number_type = (int8, | |||
| int16, | |||
| @@ -290,7 +290,19 @@ class IndexedSlices: | |||
| """ | |||
| def __init__(self, indices, values, dense_shape): | |||
| raise NotImplementedError | |||
| "Init IndexedSlices" | |||
| self.__indices = indices | |||
| self.__values = values | |||
| self.__dense_shape = dense_shape | |||
| def indices(self): | |||
| return self.__indices | |||
| def values(self): | |||
| return self.__values | |||
| def dense_shape(self): | |||
| return self.__dense_shape | |||
| class SparseTensor: | |||
| @@ -331,4 +343,16 @@ class SparseTensor: | |||
| """ | |||
| def __init__(self, indices, values, dense_shape): | |||
| raise NotImplementedError | |||
| "Init SparseTensor" | |||
| self.__indices = indices | |||
| self.__values = values | |||
| self.__dense_shape = dense_shape | |||
| def indices(self): | |||
| return self.__indices | |||
| def values(self): | |||
| return self.__values | |||
| def dense_shape(self): | |||
| return self.__dense_shape | |||
| @@ -814,9 +814,13 @@ class AddN(PrimitiveWithInfer): | |||
| validator.check_value_type("inputs", inputs, [tuple, list], cls_name) | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||
| args = {} | |||
| contains_undetermined = False | |||
| for i, dtype in enumerate(inputs): | |||
| args[f"inputs[{i}]"] = dtype | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||
| if dtype == mstype.undetermined: | |||
| contains_undetermined = True | |||
| if not contains_undetermined: | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||
| return inputs[0] | |||
| def infer_value(self, inputs): | |||
| @@ -398,7 +398,7 @@ def test_switch_layer(): | |||
| ret = F.switch_layer(index, self.layers)(x) * self.z3 | |||
| return ret | |||
| index = Tensor(0) | |||
| index = Tensor(0, dtype=mstype.int32) | |||
| net = SwitchLayerCell() | |||
| net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, | |||
| @@ -436,7 +436,7 @@ def test_index_to_switch_layer(): | |||
| ret = self.layers[index](x) * self.z3 | |||
| return ret | |||
| index = Tensor(0) | |||
| index = Tensor(0, dtype=mstype.int32) | |||
| net = SwitchLayerCell() | |||
| net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| @File : test_sparse_pynative.py | |||
| @Author: | |||
| @Date : 2020-08-04 | |||
| @Desc : test mindspore sparse pynative | |||
| """ | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context, Tensor, IndexedSlices, SparseTensor | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) | |||
| grad_all = C.GradOperation('get_all', get_all=True) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, *args): | |||
| grad = grad_all(self.network)(*args) | |||
| return grad | |||
| def test_indexed_slices_attr(): | |||
| class IndexedSlicesGetAttr(nn.Cell): | |||
| def __init__(self, dense_shape): | |||
| super(IndexedSlicesGetAttr, self).__init__() | |||
| self.dense_shape = dense_shape | |||
| def construct(self, indices, values): | |||
| x = IndexedSlices(indices, values, self.dense_shape) | |||
| return x.values(), x.indices(), x.dense_shape() | |||
| indices = Tensor([0]) | |||
| values = Tensor([[1, 2]], dtype=ms.float32) | |||
| IndexedSlicesGetAttr((3, 2))(indices, values) | |||
| GradWrap(IndexedSlicesGetAttr((3, 2)))(indices, values) | |||
| def test_sparse_tensor_attr(): | |||
| class SparseTensorGetAttr(nn.Cell): | |||
| def __init__(self): | |||
| super(SparseTensorGetAttr, self).__init__() | |||
| self.dense_shape = (3, 4) | |||
| def construct(self, indices, values): | |||
| x = SparseTensor(indices, values, self.dense_shape) | |||
| return x.values(), x.indices(), x.dense_shape() | |||
| indices = Tensor([[0, 1], [1, 2]]) | |||
| values = Tensor([1, 2], dtype=ms.float32) | |||
| SparseTensorGetAttr()(indices, values) | |||
| GradWrap(SparseTensorGetAttr())(indices, values) | |||