Browse Source

!2511 add mod and floordiv opration for tensor

Merge pull request !2511 from wangqiuliang/add-tensor-mod-operation
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
67d664d7dd
5 changed files with 30 additions and 21 deletions
  1. +3
    -1
      mindspore/common/api.py
  2. +15
    -6
      mindspore/common/tensor.py
  3. +2
    -0
      mindspore/ops/functional.py
  4. +9
    -1
      tests/ut/python/ir/test_tensor.py
  5. +1
    -13
      tests/ut/python/pynative_mode/ge/model/test_lenet_model.py

+ 3
- 1
mindspore/common/api.py View File

@@ -158,7 +158,9 @@ class _MindSporeFunction:
# replace key with obj info and object ext info when fn is a method
if self.obj is not None:
self.obj.__parse_method__ = method_name
generate_name = self.obj.__module__ + "." + str(self.obj.create_time)
generate_name = self.obj.__module__ + "."
if self.obj.__class__.__name__ != "ClipByNorm":
generate_name = generate_name + str(self.obj.create_time)
if self.identify_obj is not None:
generate_name = generate_name + str(id(self.identify_obj))



+ 15
- 6
mindspore/common/tensor.py View File

@@ -102,16 +102,14 @@ class Tensor(Tensor_):
return out

def __iadd__(self, other):
out = self.__add__(other)
return out
return self.__add__(other)

def __radd__(self, other):
out = tensor_operator_registry.get('__add__')(self, other)
return out

def __imul__(self, other):
out = self.__mul__(other)
return out
return self.__mul__(other)

def __rmul__(self, other):
out = tensor_operator_registry.get('__mul__')(self, other)
@@ -130,8 +128,7 @@ class Tensor(Tensor_):
return out

def __isub__(self, other):
out = self.__sub__(other)
return out
return self.__sub__(other)

def __rsub__(self, other):
out = tensor_operator_registry.get('__sub__')(other, self)
@@ -168,6 +165,18 @@ class Tensor(Tensor_):
return 1
return out[0]

def __mod__(self, other):
return tensor_operator_registry.get('__mod__')(self, other)

def __imod__(self, other):
return self.__mod__(other)

def __floordiv__(self, other):
return tensor_operator_registry.get('__floordiv__')(self, other)

def __ifloordiv__(self, other):
return self.__floordiv__(other)

def __str__(self):
if self.dtype == mstype.type_none:
return "Unknown Tensor type!"


+ 2
- 0
mindspore/ops/functional.py View File

@@ -157,6 +157,8 @@ tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__truediv__', tensor_div)
tensor_operator_registry.register('__mod__', tensor_mod)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)


+ 9
- 1
tests/ut/python/ir/test_tensor.py View File

@@ -24,13 +24,15 @@ import pytest
import mindspore as ms
import mindspore.common.api as me
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, context
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from ..ut_filter import non_graph_engine

ndarr = np.ones((2, 3))

context.set_context(mode=context.GRAPH_MODE)


def test_tensor_flatten():
with pytest.raises(AttributeError):
@@ -452,5 +454,11 @@ def test_tensor_operation():
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = x % 3
assert np.all(res.asnumpy() == np.ones((3, 3)))
res = x // 3
assert np.all(res.asnumpy() == np.ones((3, 3)))
x %= 3
assert np.all(x.asnumpy() == np.ones((3, 3)))
with pytest.raises(ValueError):
res = x * (2, 3)

+ 1
- 13
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py View File

@@ -18,8 +18,7 @@ import pytest

import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.nn import WithGradCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.nn import WithGradCell
from mindspore.ops import operations as P


@@ -63,17 +62,6 @@ def test_lenet_pynative_train_net():
loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False)
grad_fn = nn.SoftmaxCrossEntropyWithLogits()
grad_net = WithGradCell(net, grad_fn, sens=dout)
gradients = grad_net(data, label)

# update parameters
opt = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
opt(gradients)

# verification
if i == verification_step:
loss_net = WithLossCell(net, loss_fn)
loss_output = loss_net(data, label)
print("The loss of %s-th iteration is %s" % (i, loss_output.asnumpy()))


def test_lenet_pynative_train_model():


Loading…
Cancel
Save