Browse Source

update ms hybrid test cases

fix ms hybrid doc

revert changes
r1.7
Zichun Ye 4 years ago
parent
commit
61469ad999
3 changed files with 42 additions and 13 deletions
  1. +9
    -8
      mindspore/python/mindspore/ops/operations/_ms_hybrid.py
  2. +1
    -1
      mindspore/python/mindspore/ops/operations/custom_ops.py
  3. +32
    -4
      tests/st/ops/graph_kernel/custom/test_ms_hybrid.py

+ 9
- 8
mindspore/python/mindspore/ops/operations/_ms_hybrid.py View File

@@ -473,7 +473,7 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
When a function written by the Hybrid DSL is decorated by ms_hybrid,
it can be run as a usual Python function.
Also, this function can be used in the api Custom and to create a Custom op, with func_type
"hybrid" or "py_func". Creating a custom op with mode "hybrid" by the Hybrid DSL function
"hybrid" or "pyfunc". Creating a custom op with mode "hybrid" by the Hybrid DSL function
will enjoy the automatic dtype/shape infer for free.

Args:
@@ -502,12 +502,13 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
... }
>>> # Create the reg info json string.
>>> op_gpu_info = CustomRegOp() \
... .input(0, "a") \
... .input(0, "b") \
... .output(0, "y") \
... .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
... .target("GPU") \
... .get_op_info()
... .input(0, "a") \
... .input(0, "b") \
... .output(0, "y") \
... .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
... .target("GPU") \
... .get_op_info()
>>>
>>> # Create inputs for the custom op.
>>> input_x = np.ones([4, 4]).astype(np.float32)
>>> input_y = np.ones([4, 4]).astype(np.float32)
@@ -534,7 +535,7 @@ def ms_hybrid(fn=None, reg_info=None, compile_attrs=None):
>>> # In this case, we will enjoy the automatic dtype/shape infer for free.
>>> # The inputs should be mindspore tensors.
>>> test_op_hybrid = ops.Custom(outer_product)
>>> output = test_op_akg(Tensor(input_x), Tensor(input_y))
>>> output = test_op_hybrid(Tensor(input_x), Tensor(input_y))
"""
if compile_attrs is None:
compile_attrs = {}


+ 1
- 1
mindspore/python/mindspore/ops/operations/custom_ops.py View File

@@ -445,7 +445,7 @@ class Custom(ops.PrimitiveWithInfer):
"""Update information of func"""
if callable(self.func):
# For the func_type other then hybrid, get the original function if func is decorated
if "__wrapped__" in self.func.__dict__ and not self._is_ms_hybrid:
if "__wrapped__" in self.func.__dict__ and not self.func_type in ["hybrid", "pyfunc"]:
self.func = self.func.__dict__["__wrapped__"]
# func name
self.func_name = self.func.__name__


+ 32
- 4
tests/st/ops/graph_kernel/custom/test_ms_hybrid.py View File

@@ -23,6 +23,9 @@ from mindspore.ops import ms_hybrid

@ms_hybrid
def dtype_and_cast_example(a, b):
"""
test function for dtype and cast in Hybrid DSL
"""
d = allocate(a.shape, "float16")
c = output_tensor(a.shape, "float16")

@@ -36,6 +39,9 @@ def dtype_and_cast_example(a, b):

@ms_hybrid
def allocate_and_math_intrin_example(a, b):
"""
test function for allocate and math function in Hybrid DSL
"""
d = allocate(a.shape, a.dtype)
c = output_tensor(a.shape, a.dtype)

@@ -48,6 +54,9 @@ def allocate_and_math_intrin_example(a, b):

@ms_hybrid
def grid_example(a, b):
"""
test function for grid in Hybrid DSL
"""
c = output_tensor(a.shape, a.dtype)

for arg in grid(a.shape):
@@ -68,6 +77,9 @@ class TestMsHybridDSL(Cell):


def ms_hybrid_cast_with_infer():
"""
test case Custom Op with functions written in Hybrid DSL and infer functions
"""
np.random.seed(10)
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float16)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float16)
@@ -81,6 +93,9 @@ def ms_hybrid_cast_with_infer():


def ms_hybrid_cast_without_infer():
"""
test case Custom Op with functions written in Hybrid DSL and without infer functions
"""
np.random.seed(10)
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float16)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float16)
@@ -94,6 +109,9 @@ def ms_hybrid_cast_without_infer():


def ms_hybrid_cast_pyfunc():
"""
test case Custom Op with functions written in Hybrid DSL and func_type pyfunc
"""
np.random.seed(10)
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float16)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float16)
@@ -107,6 +125,9 @@ def ms_hybrid_cast_pyfunc():


def ms_hybrid_allocate():
"""
test case Custom Op with functions written in Hybrid DSL about math functions and allocate
"""
np.random.seed(10)
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float16)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float16)
@@ -120,6 +141,9 @@ def ms_hybrid_allocate():


def ms_hybrid_grid():
"""
test case Custom Op with functions written in Hybrid DSL about grid
"""
np.random.seed(10)
input_x = np.random.normal(0, 1, [4, 4]).astype(np.float16)
input_y = np.random.normal(0, 1, [4, 4]).astype(np.float16)
@@ -143,10 +167,11 @@ def test_ms_hybrid_ascend_graph_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_cast_pyfunc()
ms_hybrid_allocate()
ms_hybrid_grid()


@ pytest.mark.level0
@@ -160,10 +185,11 @@ def test_ms_hybrid_ascend_pynative_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_cast_pyfunc()
ms_hybrid_allocate()
ms_hybrid_grid()


@ pytest.mark.level0
@@ -176,10 +202,11 @@ def test_ms_hybrid_gpu_graph_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_cast_pyfunc()
ms_hybrid_allocate()
ms_hybrid_grid()


@ pytest.mark.level0
@@ -192,7 +219,8 @@ def test_ms_hybrid_gpu_pynative_mode():
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
ms_hybrid_cast_pyfunc()
ms_hybrid_cast_with_infer()
ms_hybrid_cast_without_infer()
ms_hybrid_cast_pyfunc()
ms_hybrid_allocate()
ms_hybrid_grid()

Loading…
Cancel
Save