Browse Source

!2511 add mod and floordiv opration for tensor

Merge pull request !2511 from wangqiuliang/add-tensor-mod-operation
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
67d664d7dd
5 changed files with 30 additions and 21 deletions
  1. +3
    -1
      mindspore/common/api.py
  2. +15
    -6
      mindspore/common/tensor.py
  3. +2
    -0
      mindspore/ops/functional.py
  4. +9
    -1
      tests/ut/python/ir/test_tensor.py
  5. +1
    -13
      tests/ut/python/pynative_mode/ge/model/test_lenet_model.py

+ 3
- 1
mindspore/common/api.py View File

@@ -158,7 +158,9 @@ class _MindSporeFunction:
# replace key with obj info and object ext info when fn is a method # replace key with obj info and object ext info when fn is a method
if self.obj is not None: if self.obj is not None:
self.obj.__parse_method__ = method_name self.obj.__parse_method__ = method_name
generate_name = self.obj.__module__ + "." + str(self.obj.create_time)
generate_name = self.obj.__module__ + "."
if self.obj.__class__.__name__ != "ClipByNorm":
generate_name = generate_name + str(self.obj.create_time)
if self.identify_obj is not None: if self.identify_obj is not None:
generate_name = generate_name + str(id(self.identify_obj)) generate_name = generate_name + str(id(self.identify_obj))




+ 15
- 6
mindspore/common/tensor.py View File

@@ -102,16 +102,14 @@ class Tensor(Tensor_):
return out return out


def __iadd__(self, other): def __iadd__(self, other):
out = self.__add__(other)
return out
return self.__add__(other)


def __radd__(self, other): def __radd__(self, other):
out = tensor_operator_registry.get('__add__')(self, other) out = tensor_operator_registry.get('__add__')(self, other)
return out return out


def __imul__(self, other): def __imul__(self, other):
out = self.__mul__(other)
return out
return self.__mul__(other)


def __rmul__(self, other): def __rmul__(self, other):
out = tensor_operator_registry.get('__mul__')(self, other) out = tensor_operator_registry.get('__mul__')(self, other)
@@ -130,8 +128,7 @@ class Tensor(Tensor_):
return out return out


def __isub__(self, other): def __isub__(self, other):
out = self.__sub__(other)
return out
return self.__sub__(other)


def __rsub__(self, other): def __rsub__(self, other):
out = tensor_operator_registry.get('__sub__')(other, self) out = tensor_operator_registry.get('__sub__')(other, self)
@@ -168,6 +165,18 @@ class Tensor(Tensor_):
return 1 return 1
return out[0] return out[0]


def __mod__(self, other):
return tensor_operator_registry.get('__mod__')(self, other)

def __imod__(self, other):
return self.__mod__(other)

def __floordiv__(self, other):
return tensor_operator_registry.get('__floordiv__')(self, other)

def __ifloordiv__(self, other):
return self.__floordiv__(other)

def __str__(self): def __str__(self):
if self.dtype == mstype.type_none: if self.dtype == mstype.type_none:
return "Unknown Tensor type!" return "Unknown Tensor type!"


+ 2
- 0
mindspore/ops/functional.py View File

@@ -157,6 +157,8 @@ tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__truediv__', tensor_div) tensor_operator_registry.register('__truediv__', tensor_div)
tensor_operator_registry.register('__mod__', tensor_mod)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
#ms cannot support Tensor(True) compare #ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal) tensor_operator_registry.register('__ne__', not_equal)


+ 9
- 1
tests/ut/python/ir/test_tensor.py View File

@@ -24,13 +24,15 @@ import pytest
import mindspore as ms import mindspore as ms
import mindspore.common.api as me import mindspore.common.api as me
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, context
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine


ndarr = np.ones((2, 3)) ndarr = np.ones((2, 3))


context.set_context(mode=context.GRAPH_MODE)



def test_tensor_flatten(): def test_tensor_flatten():
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
@@ -452,5 +454,11 @@ def test_tensor_operation():
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = x % 3
assert np.all(res.asnumpy() == np.ones((3, 3)))
res = x // 3
assert np.all(res.asnumpy() == np.ones((3, 3)))
x %= 3
assert np.all(x.asnumpy() == np.ones((3, 3)))
with pytest.raises(ValueError): with pytest.raises(ValueError):
res = x * (2, 3) res = x * (2, 3)

+ 1
- 13
tests/ut/python/pynative_mode/ge/model/test_lenet_model.py View File

@@ -18,8 +18,7 @@ import pytest


import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn import WithGradCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.nn import WithGradCell
from mindspore.ops import operations as P from mindspore.ops import operations as P




@@ -63,17 +62,6 @@ def test_lenet_pynative_train_net():
loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False) loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False)
grad_fn = nn.SoftmaxCrossEntropyWithLogits() grad_fn = nn.SoftmaxCrossEntropyWithLogits()
grad_net = WithGradCell(net, grad_fn, sens=dout) grad_net = WithGradCell(net, grad_fn, sens=dout)
gradients = grad_net(data, label)

# update parameters
opt = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
opt(gradients)

# verification
if i == verification_step:
loss_net = WithLossCell(net, loss_fn)
loss_output = loss_net(data, label)
print("The loss of %s-th iteration is %s" % (i, loss_output.asnumpy()))




def test_lenet_pynative_train_model(): def test_lenet_pynative_train_model():


Loading…
Cancel
Save