| @@ -387,16 +387,10 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got " | |||
| << dense_shape_vec[i]; | |||
| } | |||
| if (i == 0) { | |||
| if (dense_shape_vec[i] < values_shp[i]) { | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape should be greator or equal to the " << i | |||
| << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; | |||
| } | |||
| } else { | |||
| if (dense_shape_vec[i] != values_shp[i]) { | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i | |||
| << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; | |||
| } | |||
| // The 0th mode might be less or exceed dense_shape[0] due to duplicated selection | |||
| if (i != 0 && dense_shape_vec[i] != values_shp[i]) { | |||
| MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i | |||
| << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; | |||
| } | |||
| } | |||
| auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec); | |||
| @@ -213,9 +213,73 @@ class Tensor(Tensor_): | |||
| class IndexedSlices: | |||
| """ | |||
| A sparse representation of a set of tensor slices at given indices. | |||
| An IndexedSlices is typically used to represent a subset of a larger | |||
| tensor dense of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. | |||
| The values in indices are the indices in the first dimension of the slices | |||
| that have been extracted from the larger tensor. | |||
| The dense tensor dense represented by an IndexedSlices slices has | |||
| `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`. | |||
| IndexedSlices can only be used in `Cell`'s contruct method. | |||
| Args: | |||
| indices (Tensor): A 1-D integer Tensor of shape [D0]. | |||
| values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn]. | |||
| dense_shape: (tuple): A integer tuple containing the shape | |||
| of the corresponding dense tensor. | |||
| Returns: | |||
| IndexedSlices, composed of `indices`, `values`, `dense_shape`. | |||
| Examples: | |||
| >>> # Create a IndexedSlices. | |||
| >>> indices = Tensor([1, 2]) | |||
| >>> values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) | |||
| >>> dense_shape = (3, 2) | |||
| >>> indexed_slices = IndexedSlices(indices, values, dense_shape) | |||
| >>> | |||
| >>> # Get atrr. | |||
| >>> indices = indexed_slices.indices() | |||
| >>> values = indexed_slices.values() | |||
| >>> dense_shape = indexed_slices.dense_shape() | |||
| """ | |||
| def __init__(self, indices, values, dense_shape): | |||
| raise NotImplementedError | |||
| class SparseTensor: | |||
| """ | |||
| A sparse representation of a set of nonzero elememts from a tensor at given indices. | |||
| SparseTensor can only be used in `Cell`'s contruct method. | |||
| For a tensor dense, its SparseTensor(indices, values, dense_shape) has | |||
| `dense[indices[i]] = values[i]`. | |||
| Args: | |||
| indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`, | |||
| where N and ndims are the number of values and number of dimensions in | |||
| the SparseTensor, respectively. | |||
| values (Tensor): A 1-D tensor of any type and shape `[N]`, which | |||
| supplies the values for each element in indices. | |||
| dense_shape: (tuple): A integer tuple of size `ndims`, | |||
| which specifies the dense_shape of the sparse tensor. | |||
| Returns: | |||
| SparseTensor, composed of `indices`, `values`, `dense_shape`. | |||
| Examples: | |||
| >>> # Create a SparseTensor. | |||
| >>> indices = Tensor([[0, 1], [1, 2]]) | |||
| >>> values = Tensor([1, 2], dtype=ms.float32) | |||
| >>> dense_shape = (3, 4) | |||
| >>> sparse_tensor = SparseTensor(indices, values, dense_shape) | |||
| >>> | |||
| >>> # Get atrr. | |||
| >>> indices = sparse_tensor.indices() | |||
| >>> values = sparse_tensor.values() | |||
| >>> dense_shape = sparse_tensor.dense_shape() | |||
| """ | |||
| def __init__(self, indices, values, dense_shape): | |||
| raise NotImplementedError | |||
| @@ -13,7 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| @@ -100,7 +99,6 @@ def test_embeddinglookup_reducescatter_true_grad(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_embeddinglookup_semi_auto1(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| shape = [64, 32] | |||
| @@ -115,7 +113,6 @@ def test_embeddinglookup_semi_auto1(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_embeddinglookup_semi_auto2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| shape = [64, 32] | |||
| @@ -61,7 +61,6 @@ class Net(nn.Cell): | |||
| return out | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_gatherv2_semi_auto0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), (1, 1)) | |||
| @@ -134,7 +133,6 @@ def test_gatherv2_semi_auto5(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_gatherv2_semi_auto6(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| @@ -169,7 +167,6 @@ def test_gatherv2_semi_auto8(): | |||
| _executor.compile(net, x, y) | |||
| @pytest.mark.skip(reason="waiting for fix by parallel strategy") | |||
| def test_gatherv2_auto0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") | |||
| net = GradWrap(NetWithLoss(Net(0))) | |||