Browse Source

support python func print and != for list with none

tags/v0.2.0-alpha
buxue 5 years ago
parent
commit
7c233a57fa
7 changed files with 43 additions and 10 deletions
  1. +1
    -0
      mindspore/_extends/parse/resources.py
  2. +2
    -2
      mindspore/_extends/parse/trope.py
  3. +34
    -3
      mindspore/ops/composite/multitype_ops/not_equal_impl.py
  4. +1
    -1
      mindspore/ops/functional.py
  5. +1
    -0
      mindspore/ops/operations/_grad_ops.py
  6. +4
    -2
      tests/ut/python/pipeline/parse/test_operator.py
  7. +0
    -2
      tests/vm_impl/nn_ops_vm_impl.py

+ 1
- 0
mindspore/_extends/parse/resources.py View File

@@ -114,6 +114,7 @@ convert_object_map = {
T.map: C.HyperMap(),
T.partial: F.partial,
T.zip: C.zip_operation,
T.print: F.print_,

# custom define operation
T.iter: M.ms_iter,


+ 2
- 2
mindspore/_extends/parse/trope.py View File

@@ -27,7 +27,7 @@ from operator import ( # noqa

# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
)

# support functools
@@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial',
'partial', 'print',
'exp', 'log', 'sin', 'cos', 'tan']




+ 34
- 3
mindspore/ops/composite/multitype_ops/not_equal_impl.py View File

@@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y):


@not_equal.register("Tuple", "Tuple")
def _euqal_tuple(x, y):
def _not_euqal_tuple(x, y):
"""
Determine if two tuples are not equal by element.

@@ -147,7 +147,7 @@ def _euqal_tuple(x, y):


@not_equal.register("List", "List")
def _euqal_list(x, y):
def _not_euqal_list(x, y):
"""
Determine if two lists are not equal by element.

@@ -162,7 +162,7 @@ def _euqal_list(x, y):


@not_equal.register("Tuple", "None")
def _tuple_euqal_none(x, y):
def _tuple_not_euqal_none(x, y):
"""
Determine if tuple element not equals none element.

@@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y):
"""
return True


@not_equal.register("Tensor", "Number")
@not_equal.register("Number", "Tensor")
@not_equal.register("Tensor", "Tensor")
@@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y):
bool, return True.
"""
return True


@not_equal.register("List", "None")
def _list_not_equal_none(x, y):
"""
Determine if list not equal none.

Args:
x (list): The first input which is a list.
y (none): The second input which is none.

Returns:
bool, return true.
"""
return True


@not_equal.register("None", "List")
def _none_not_equal_list(x, y):
"""
Determine if none not equal list.

Args:
x (none): The first input which is none.
y (list): The second input which is a list.

Returns:
bool, return true.
"""
return True

+ 1
- 1
mindspore/ops/functional.py View File

@@ -66,7 +66,7 @@ scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast()
print_ = P.Print()

tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')


+ 1
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -108,6 +108,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
return x_type


class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient."""



+ 4
- 2
tests/ut/python/pipeline/parse/test_operator.py View File

@@ -160,8 +160,10 @@ def test_ops():
ret_floor = p // q + q // p
ret = ret_pow + ret_mod + ret_floor
if self.int > self.float:
if self.str_a + self.str_b == "helloworld":
return ret
if [1, 2, 3] != None:
if self.str_a + self.str_b == "helloworld":
print("hello world")
return ret
return x

net = OpsNet(9, 2)


+ 0
- 2
tests/vm_impl/nn_ops_vm_impl.py View File

@@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self):
"""Generate vm_impl function for MaxPoolGradWithArgmax"""

def vm_impl(x, dout, argmax):
print("buxue")
print(argmax)
x = x.asnumpy()
dout = dout.asnumpy()
arg_max = argmax.asnumpy()


Loading…
Cancel
Save