| @@ -378,10 +378,25 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim | |||
| auto elem = GetValue<int>(e); | |||
| return elem; | |||
| }); | |||
| for (auto dense_shape_elem : dense_shape_vec) { | |||
| if (dense_shape_elem < 0) { | |||
| MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got " | |||
| << dense_shape_value->ToString(); | |||
| if (dense_shape_vec.size() != values_shp.size()) { | |||
| MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values " | |||
| << values_shp.size() << ", but got " << dense_shape_value->size(); | |||
| } | |||
| for (size_t i = 0; i < dense_shape_vec.size(); i++) { | |||
| if (dense_shape_vec[i] < 0) { | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got " | |||
| << dense_shape_vec[i]; | |||
| } | |||
| if (i == 0) { | |||
| if (dense_shape_vec[i] < values_shp[i]) { | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape should be greator or equal to the " << i | |||
| << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; | |||
| } | |||
| } else { | |||
| if (dense_shape_vec[i] != values_shp[i]) { | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i | |||
| << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; | |||
| } | |||
| } | |||
| } | |||
| auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec); | |||
| @@ -386,6 +386,16 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| dic["shape"] = arg_tensor->shape()->shape(); | |||
| dic["dtype"] = arg_tensor->BuildType(); | |||
| dic["value"] = BuildValue(arg_tensor->BuildValue()); | |||
| } else if (abs_base->isa<AbstractIndexedSlices>()) { | |||
| auto arg = dyn_cast<AbstractIndexedSlices>(abs_base); | |||
| dic["shape"] = arg->shape()->shape(); | |||
| dic["dtype"] = arg->BuildType(); | |||
| dic["value"] = BuildValue(arg->BuildValue()); | |||
| } else if (abs_base->isa<AbstractSparseTensor>()) { | |||
| auto arg = dyn_cast<AbstractSparseTensor>(abs_base); | |||
| dic["shape"] = arg->shape()->shape(); | |||
| dic["dtype"] = arg->BuildType(); | |||
| dic["value"] = BuildValue(arg->BuildValue()); | |||
| } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) { | |||
| std::vector<int> shape; | |||
| dic["shape"] = shape; | |||
| @@ -99,6 +99,8 @@ slice_type = typing.Slice | |||
| ellipsis_type = typing.TypeEllipsis | |||
| list_type = typing.List | |||
| tuple_type = typing.Tuple | |||
| index_slices = typing.IndexedSlicesType() | |||
| sparse_tensor = typing.SparseTensorType() | |||
| number_type = (int8, | |||
| int16, | |||
| @@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean | |||
| from ..cell import Cell | |||
| from ...common import Tensor | |||
| from ...common import Tensor, IndexedSlices | |||
| from ...common.parameter import Parameter | |||
| from ...ops import functional as F | |||
| from ...ops import composite as C | |||
| @@ -35,6 +35,12 @@ reciprocal = P.Reciprocal() | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * F.cast(reciprocal(scale), F.dtype(grad)) | |||
| @_grad_scale.register("Tensor", "IndexedSlices") | |||
| def tensor_grad_scale_indexed_slices(scale, grad): | |||
| return IndexedSlices(grad.indices(), | |||
| grad.values() * F.cast(reciprocal(scale), F.dtype(grad.values())), | |||
| grad.dense_shape()) | |||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||
| grad_overflow = P.FloatStatus() | |||
| @@ -15,6 +15,8 @@ | |||
| """array_ops""" | |||
| import mindspore as ms | |||
| from mindspore.ops import composite as C | |||
| from .. import operations as P | |||
| from ..operations import _grad_ops as G | |||
| from ..operations import _inner_ops as inner | |||
| @@ -35,6 +37,7 @@ reshape = P.Reshape() | |||
| size_op = P.Size() | |||
| invert_permutation = P.InvertPermutation() | |||
| logical_and = P.LogicalAnd() | |||
| is_sub_class = P.IsSubClass() | |||
| @bprop_getters.register(P.Fill) | |||
| @@ -57,6 +60,29 @@ def get_bprop_dtype(self): | |||
| return bprop | |||
| dout_cast = C.MultitypeFuncGraph("dout_cast") | |||
| @dout_cast.register("Tensor", "Tensor") | |||
| def dout_cast_tensor(dout, x): | |||
| cast = P.Cast() | |||
| get_dtype = P.DType() | |||
| dx = cast(dout, get_dtype(x)) | |||
| return dx | |||
| @dout_cast.register("Number", "Number") | |||
| def dout_cast_number(dout, x): | |||
| cast = P.Cast() | |||
| get_dtype = P.DType() | |||
| dx = cast(dout, get_dtype(x)) | |||
| return dx | |||
| @dout_cast.register("IndexedSlices", "Tensor") | |||
| def dout_cast_indexed_slices(dout, x): | |||
| cast = P.Cast() | |||
| get_dtype = P.DType() | |||
| values = cast(dout.values(), get_dtype(x)) | |||
| return IndexedSlices(dout.indices(), values, dout.dense_shape()) | |||
| @bprop_getters.register(P.Cast) | |||
| def get_bprop_cast(self): | |||
| """Generate bprop for Cast""" | |||
| @@ -67,6 +93,13 @@ def get_bprop_cast(self): | |||
| dx = cast(dout, get_dtype(x)) | |||
| return dx, zeros_like(t) | |||
| def bprop_sparse(x, t, out, dout): | |||
| dx = dout_cast(dout, x) | |||
| return dx, zeros_like(t) | |||
| if context.get_context('enable_sparse'): | |||
| return bprop_sparse | |||
| return bprop | |||
| @@ -372,6 +405,11 @@ def get_bprop_pack(self): | |||
| def bprop(x, out, dout): | |||
| pack_grad = P.Unpack(axis) | |||
| out = pack_grad(dout) | |||
| if is_sub_class(F.typeof(x), ms.list_): | |||
| ret = [] | |||
| for item in out: | |||
| ret.append(item) | |||
| return (ret,) | |||
| return (out,) | |||
| return bprop | |||
| @@ -18,8 +18,10 @@ import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| import mindspore as ms | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from tests.ut.python.ut_filter import non_graph_engine | |||
| from tests.mindspore_test_framework.mindspore_test import mindspore_test | |||
| @@ -282,3 +284,26 @@ test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) | |||
| def test_exec(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| return test_exec_case | |||
| def test_grad_make_list(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, idx, x): | |||
| return x[idx, :, :] | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_all(self.net)(*inputs) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, x) | |||
| @@ -19,6 +19,7 @@ | |||
| @Desc : test mindspore indexed_slices's operation | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| @@ -222,7 +223,7 @@ def test_indexed_slices_make_indexed_slices(): | |||
| class MakeIndexedSlices(nn.Cell): | |||
| def __init__(self): | |||
| super(MakeIndexedSlices, self).__init__() | |||
| self.dense_shape = (3, 4) | |||
| self.dense_shape = (3, 2) | |||
| def construct(self, indices, values): | |||
| ret = (IndexedSlices(indices, values, self.dense_shape),) | |||
| return ret[0] | |||
| @@ -231,17 +232,19 @@ def test_indexed_slices_make_indexed_slices(): | |||
| MakeIndexedSlices()(indices, values) | |||
| class IndexedSlicesGetAttr(nn.Cell): | |||
| def __init__(self, dense_shape): | |||
| super(IndexedSlicesGetAttr, self).__init__() | |||
| self.dense_shape = dense_shape | |||
| def construct(self, indices, values): | |||
| x = IndexedSlices(indices, values, self.dense_shape) | |||
| return x.values(), x.indices(), x.dense_shape() | |||
| def test_indexed_slices_attr(): | |||
| class IndexedSlicesGetAttr(nn.Cell): | |||
| def __init__(self): | |||
| super(IndexedSlicesGetAttr, self).__init__() | |||
| self.dense_shape = (3, 4) | |||
| def construct(self, indices, values): | |||
| x = IndexedSlices(indices, values, self.dense_shape) | |||
| return x.values(), x.indices(), x.dense_shape() | |||
| indices = Tensor([0]) | |||
| values = Tensor([[1, 2]], dtype=ms.float32) | |||
| IndexedSlicesGetAttr()(indices, values) | |||
| IndexedSlicesGetAttr((3, 2))(indices, values) | |||
| def test_indexed_slices_sparse_gatherv2_grad_all(): | |||
| @@ -342,3 +345,109 @@ def test_indexed_slices_model_train(): | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| model = Model(net, optimizer=optimizer) | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| def test_indexed_slices_values_dim_greater_than_dense_shape_dim(): | |||
| indices = Tensor(np.array([0, 1], dtype=np.int32)) | |||
| values = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) | |||
| dense_shape = (3, 4) | |||
| with pytest.raises(TypeError): | |||
| IndexedSlicesGetAttr(dense_shape)(indices, values) | |||
| def test_indexed_slices_values_dim_less_than_dense_shape_dim(): | |||
| indices = Tensor(np.array([0, 1], dtype=np.int32)) | |||
| values = Tensor(np.random.randn(2, 4).astype(np.float32)) | |||
| dense_shape = (3, 4, 5) | |||
| with pytest.raises(TypeError): | |||
| IndexedSlicesGetAttr(dense_shape)(indices, values) | |||
| def test_indexed_slices_value_and_dense_shape_illegal(): | |||
| indices = Tensor(np.array([0, 1], dtype=np.int32)) | |||
| values = Tensor(np.random.randn(2, 4).astype(np.float32)) | |||
| dense_shape = (3, 5) | |||
| with pytest.raises(TypeError): | |||
| IndexedSlicesGetAttr(dense_shape)(indices, values) | |||
| class IndexedSlicesValuesDouble(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| indices = x.indices() | |||
| values = x.values() * 2 | |||
| dense_shape = x.dense_shape() | |||
| return IndexedSlices(indices, values, dense_shape) | |||
| class IndexedSlicesValuesAdd2(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| indices = x.indices() | |||
| values = x.values() + 2 | |||
| dense_shape = x.dense_shape() | |||
| return IndexedSlices(indices, values, dense_shape) | |||
| class IndexedSlicesWithControlIf(nn.Cell): | |||
| def __init__(self, dense_shape): | |||
| super().__init__() | |||
| self.op1 = IndexedSlicesValuesDouble() | |||
| self.op2 = IndexedSlicesValuesAdd2() | |||
| self.dense_shape = dense_shape | |||
| def construct(self, a, b, indices, values): | |||
| x = IndexedSlices(indices, values, self.dense_shape) | |||
| if a > b: | |||
| x = self.op1(x) | |||
| else: | |||
| x = self.op2(x) | |||
| return x.indices(), x.values() | |||
| def test_indexed_slices_with_control_flow_if(): | |||
| a = Tensor(np.array(0).astype(np.int32)) | |||
| b = Tensor(np.array(2).astype(np.int32)) | |||
| indices = Tensor(np.array([0, 2]).astype(np.int32)) | |||
| values = Tensor(np.ones([2, 2]).astype(np.float32)) | |||
| dense_shape = (5, 2) | |||
| net = IndexedSlicesWithControlIf(dense_shape) | |||
| net(a, b, indices, values) | |||
| class EmbeddingLookUpBnNet(nn.Cell): | |||
| def __init__(self, param_np, target='CPU'): | |||
| super().__init__() | |||
| self.param = Parameter(Tensor(param_np), name="w1") | |||
| self.embedding_lookup = nn.EmbeddingLookup(target=target) | |||
| self.bn = nn.BatchNorm2d(num_features=3) | |||
| self.mul = P.Mul() | |||
| self.reshape = P.Reshape() | |||
| self.relu = nn.PReLU() | |||
| def construct(self, indices): | |||
| x = self.embedding_lookup(self.param, indices) | |||
| x = self.reshape(x, (2, 3, 2, 2)) | |||
| x = self.relu(x) | |||
| x = self.bn(x) | |||
| return x | |||
| def test_embedding_lookup_with_mix_precision(): | |||
| param_np = np.ones([8, 8]).astype(np.float32) | |||
| data = Tensor(np.array([0, 1, 2]).astype(np.int32)) | |||
| label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32)) | |||
| net = EmbeddingLookUpBnNet(param_np, target='CPU') | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') | |||
| optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1) | |||
| optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||
| train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2") | |||
| train_network.set_train() | |||
| for _ in range(2): | |||
| train_network(data, label) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| @@ -99,6 +100,7 @@ def test_embeddinglookup_reducescatter_true_grad(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_embeddinglookup_semi_auto1(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| shape = [64, 32] | |||
| @@ -113,6 +115,7 @@ def test_embeddinglookup_semi_auto1(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_embeddinglookup_semi_auto2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| shape = [64, 32] | |||
| @@ -61,6 +61,7 @@ class Net(nn.Cell): | |||
| return out | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_gatherv2_semi_auto0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), (1, 1)) | |||
| @@ -133,6 +134,7 @@ def test_gatherv2_semi_auto5(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_gatherv2_semi_auto6(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| @@ -167,6 +169,7 @@ def test_gatherv2_semi_auto8(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_gatherv2_auto0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") | |||
| net = GradWrap(NetWithLoss(Net(0))) | |||