Browse Source

fix csr mul broadcast error

r1.7
yanglf1121 4 years ago
parent
commit
3aecb22983
4 changed files with 4 additions and 11 deletions
  1. +1
    -1
      akg
  2. +0
    -4
      mindspore/core/abstract/prim_others.cc
  3. +2
    -2
      mindspore/python/mindspore/common/tensor.py
  4. +1
    -4
      mindspore/python/mindspore/ops/functional.py

+ 1
- 1
akg

@@ -1 +1 @@
Subproject commit e3f2411858e34499fce13ec00ea35e1292d441b1
Subproject commit 50d3082fdb2d084fff8509b6fbbdab5bc1e75e5c

+ 0
- 4
mindspore/core/abstract/prim_others.cc View File

@@ -46,10 +46,6 @@ inline void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
if (sparse_shp.size() < 1) {
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
}
if (dense_shp[0] != sparse_shp[0]) {
MS_EXCEPTION(mindspore::ValueError)
<< "Currently, dense tensor and sparse tensor shapes must equal in first dimension.";
}
for (size_t i = 0; i < sparse_shp.size(); i++) {
auto s = sparse_shp[i];
auto d = dense_shp[i];


+ 2
- 2
mindspore/python/mindspore/common/tensor.py View File

@@ -2806,8 +2806,8 @@ class CSRTensor(CSRTensor_):
Examples:
>>> from mindspore import Tensor, CSRTensor
>>> from mindspore import dtype as mstype
>>> indptr = Tensor([0, 1, 2], dtype=ms.int32)
>>> indices = Tensor([0, 1], dtype=ms.int32)
>>> indptr = Tensor([0, 1, 2], dtype=mstype.int32)
>>> indices = Tensor([0, 1], dtype=mstype.int32)
>>> values = Tensor([2, 1], dtype=mstype.float32)
>>> dense_shape = (2, 4)
>>> csr_tensor = CSRTensor(indptr, indices, values, dense_shape)


+ 1
- 4
mindspore/python/mindspore/ops/functional.py View File

@@ -171,8 +171,6 @@ def csr_mul(x, y):
Supported Platforms:
``GPU`` ``CPU``
"""
if x.shape[0] != 1 and y.shape[0] == 1:
y = y.expand_as(x)
return _csr_ops.CSRMul()(x, y)

def csr_div(x, y):
@@ -195,8 +193,6 @@ def csr_div(x, y):
Supported Platforms:
``GPU`` ``CPU``
"""
if x.shape[0] != 1 and y.shape[0] == 1:
y = y.expand_as(x)
return _csr_ops.CSRDiv()(x, y)

csr_mv = _csr_ops.CSRMV()
@@ -974,6 +970,7 @@ coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
def print_info(info):
print(info)


def make_sparse_tensor(indices, values, dense_shape):
"""Call make_coo_tensor in this function."""
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +


Loading…
Cancel
Save