Browse Source

support python func print and != for list with none

tags/v0.2.0-alpha
buxue 5 years ago
parent
commit
7c233a57fa
7 changed files with 43 additions and 10 deletions
  1. +1
    -0
      mindspore/_extends/parse/resources.py
  2. +2
    -2
      mindspore/_extends/parse/trope.py
  3. +34
    -3
      mindspore/ops/composite/multitype_ops/not_equal_impl.py
  4. +1
    -1
      mindspore/ops/functional.py
  5. +1
    -0
      mindspore/ops/operations/_grad_ops.py
  6. +4
    -2
      tests/ut/python/pipeline/parse/test_operator.py
  7. +0
    -2
      tests/vm_impl/nn_ops_vm_impl.py

+ 1
- 0
mindspore/_extends/parse/resources.py View File

@@ -114,6 +114,7 @@ convert_object_map = {
T.map: C.HyperMap(), T.map: C.HyperMap(),
T.partial: F.partial, T.partial: F.partial,
T.zip: C.zip_operation, T.zip: C.zip_operation,
T.print: F.print_,


# custom define operation # custom define operation
T.iter: M.ms_iter, T.iter: M.ms_iter,


+ 2
- 2
mindspore/_extends/parse/trope.py View File

@@ -27,7 +27,7 @@ from operator import ( # noqa


# support system function call # support system function call
from builtins import ( # noqa from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip
bool, getattr, setattr, len, iter, next, pow, range, map, zip, print
) )


# support functools # support functools
@@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem', 'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial',
'partial', 'print',
'exp', 'log', 'sin', 'cos', 'tan'] 'exp', 'log', 'sin', 'cos', 'tan']






+ 34
- 3
mindspore/ops/composite/multitype_ops/not_equal_impl.py View File

@@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y):




@not_equal.register("Tuple", "Tuple") @not_equal.register("Tuple", "Tuple")
def _euqal_tuple(x, y):
def _not_euqal_tuple(x, y):
""" """
Determine if two tuples are not equal by element. Determine if two tuples are not equal by element.


@@ -147,7 +147,7 @@ def _euqal_tuple(x, y):




@not_equal.register("List", "List") @not_equal.register("List", "List")
def _euqal_list(x, y):
def _not_euqal_list(x, y):
""" """
Determine if two lists are not equal by element. Determine if two lists are not equal by element.


@@ -162,7 +162,7 @@ def _euqal_list(x, y):




@not_equal.register("Tuple", "None") @not_equal.register("Tuple", "None")
def _tuple_euqal_none(x, y):
def _tuple_not_euqal_none(x, y):
""" """
Determine if tuple element not equals none element. Determine if tuple element not equals none element.


@@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y):
""" """
return True return True



@not_equal.register("Tensor", "Number") @not_equal.register("Tensor", "Number")
@not_equal.register("Number", "Tensor") @not_equal.register("Number", "Tensor")
@not_equal.register("Tensor", "Tensor") @not_equal.register("Tensor", "Tensor")
@@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y):
bool, return True. bool, return True.
""" """
return True return True


@not_equal.register("List", "None")
def _list_not_equal_none(x, y):
"""
Determine if list not equal none.

Args:
x (list): The first input which is a list.
y (none): The second input which is none.

Returns:
bool, return true.
"""
return True


@not_equal.register("None", "List")
def _none_not_equal_list(x, y):
"""
Determine if none not equal list.

Args:
x (none): The first input which is none.
y (list): The second input which is a list.

Returns:
bool, return true.
"""
return True

+ 1
- 1
mindspore/ops/functional.py View File

@@ -66,7 +66,7 @@ scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor() scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray() tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast() scalar_cast = P.ScalarCast()
print_ = P.Print()


tuple_setitem = Primitive('tuple_setitem') tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')


+ 1
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -108,6 +108,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
return x_type return x_type



class ConcatOffset(PrimitiveWithInfer): class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient.""" """primitive for computing Concat's gradient."""




+ 4
- 2
tests/ut/python/pipeline/parse/test_operator.py View File

@@ -160,8 +160,10 @@ def test_ops():
ret_floor = p // q + q // p ret_floor = p // q + q // p
ret = ret_pow + ret_mod + ret_floor ret = ret_pow + ret_mod + ret_floor
if self.int > self.float: if self.int > self.float:
if self.str_a + self.str_b == "helloworld":
return ret
if [1, 2, 3] != None:
if self.str_a + self.str_b == "helloworld":
print("hello world")
return ret
return x return x


net = OpsNet(9, 2) net = OpsNet(9, 2)


+ 0
- 2
tests/vm_impl/nn_ops_vm_impl.py View File

@@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self):
"""Generate vm_impl function for MaxPoolGradWithArgmax""" """Generate vm_impl function for MaxPoolGradWithArgmax"""


def vm_impl(x, dout, argmax): def vm_impl(x, dout, argmax):
print("buxue")
print(argmax)
x = x.asnumpy() x = x.asnumpy()
dout = dout.asnumpy() dout = dout.asnumpy()
arg_max = argmax.asnumpy() arg_max = argmax.asnumpy()


Loading…
Cancel
Save