Browse Source

!29911 Add chinese document for primitive, ...

Merge pull request !29911 from huangbingjian/code_docs_primitive
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
bd8867d69b
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 448 additions and 6 deletions
  1. +31
    -0
      docs/api/api_python/mindspore.ops.rst
  2. +173
    -0
      docs/api/api_python/ops/mindspore.ops.DataType.rst
  3. +100
    -0
      docs/api/api_python/ops/mindspore.ops.Primitive.rst
  4. +41
    -0
      docs/api/api_python/ops/mindspore.ops.PrimitiveWithCheck.rst
  5. +53
    -0
      docs/api/api_python/ops/mindspore.ops.PrimitiveWithInfer.rst
  6. +12
    -0
      docs/api/api_python/ops/mindspore.ops.constexpr.rst
  7. +17
    -0
      docs/api/api_python/ops/mindspore.ops.get_vm_impl_fn.rst
  8. +16
    -0
      docs/api/api_python/ops/mindspore.ops.prim_attr_register.rst
  9. +3
    -4
      mindspore/python/mindspore/ops/op_info_register.py
  10. +2
    -2
      mindspore/python/mindspore/ops/primitive.py

+ 31
- 0
docs/api/api_python/mindspore.ops.rst View File

@@ -75,3 +75,34 @@ Random类型算子

mindspore.ops.Gamma
mindspore.ops.UniformReal


原语
----

.. cnmsplatformautosummary::
:toctree: ops

mindspore.ops.constexpr
mindspore.ops.prim_attr_register
mindspore.ops.Primitive
mindspore.ops.PrimitiveWithCheck
mindspore.ops.PrimitiveWithInfer


函数实现注册
--------------

.. cnmsplatformautosummary::
:toctree: ops

mindspore.ops.get_vm_impl_fn


算子信息注册
--------------

.. cnmsplatformautosummary::
:toctree: ops

mindspore.ops.DataType

+ 173
- 0
docs/api/api_python/ops/mindspore.ops.DataType.rst View File

@@ -0,0 +1,173 @@
mindspore.ops.DataType
======================

.. py:class:: mindspore.ops.DataType:

Ascend算子的dtype和format的多种组合。

当前支持:

.. code-block::

None_None = ("", "")
None_Default = ("", "DefaultFormat")
BOOL_None = ("bool", "")
BOOL_Default = ("bool", "DefaultFormat")
BOOL_5HD = ("bool", "NC1HWC0")
BOOL_FracZ = ("bool", "FracZ")
BOOL_FracNZ = ("bool", "FRACTAL_NZ")
BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
BOOL_NCHW = ("bool", "NCHW")
BOOL_NHWC = ("bool", "NHWC")
BOOL_HWCN = ("bool", "HWCN")
BOOL_NDHWC = ("bool", "NDHWC")
BOOL_ChannelLast = ("bool", "ChannelLast")

I8_None = ("int8", "")
I8_Default = ("int8", "DefaultFormat")
I8_5HD = ("int8", "NC1HWC0")
I8_FracZ = ("int8", "FracZ")
I8_FracNZ = ("int8", "FRACTAL_NZ")
I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
I8_NCHW = ("int8", "NCHW")
I8_NHWC = ("int8", "NHWC")
I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
I8_ChannelLast = ("int8", "ChannelLast")

U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
U8_5HD = ("uint8", "NC1HWC0")
U8_FracZ = ("uint8", "FracZ")
U8_FracNZ = ("uint8", "FRACTAL_NZ")
U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
U8_NCHW = ("uint8", "NCHW")
U8_NHWC = ("uint8", "NHWC")
U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
U8_ChannelLast = ("uint8", "ChannelLast")

I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")
I16_5HD = ("int16", "NC1HWC0")
I16_FracZ = ("int16", "FracZ")
I16_FracNZ = ("int16", "FRACTAL_NZ")
I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
I16_NCHW = ("int16", "NCHW")
I16_NHWC = ("int16", "NHWC")
I16_HWCN = ("int16", "HWCN")
I16_NDHWC = ("int16", "NDHWC")
I16_ChannelLast = ("int16", "ChannelLast")

U16_None = ("uint16", "")
U16_Default = ("uint16", "DefaultFormat")
U16_5HD = ("uint16", "NC1HWC0")
U16_FracZ = ("uint16", "FracZ")
U16_FracNZ = ("uint16", "FRACTAL_NZ")
U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
U16_NCHW = ("uint16", "NCHW")
U16_NHWC = ("uint16", "NHWC")
U16_HWCN = ("uint16", "HWCN")
U16_NDHWC = ("uint16", "NDHWC")
U16_ChannelLast = ("uint16", "ChannelLast")

I32_None = ("int32", "")
I32_Default = ("int32", "DefaultFormat")
I32_5HD = ("int32", "NC1HWC0")
I32_FracZ = ("int32", "FracZ")
I32_FracNZ = ("int32", "FRACTAL_NZ")
I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
I32_NCHW = ("int32", "NCHW")
I32_NHWC = ("int32", "NHWC")
I32_HWCN = ("int32", "HWCN")
I32_NDHWC = ("int32", "NDHWC")
I32_ChannelLast = ("int32", "ChannelLast")

U32_None = ("uint32", "")
U32_Default = ("uint32", "DefaultFormat")
U32_5HD = ("uint32", "NC1HWC0")
U32_FracZ = ("uint32", "FracZ")
U32_FracNZ = ("uint32", "FRACTAL_NZ")
U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
U32_NCHW = ("uint32", "NCHW")
U32_NHWC = ("uint32", "NHWC")
U32_HWCN = ("uint32", "HWCN")
U32_NDHWC = ("uint32", "NDHWC")
U32_ChannelLast = ("uint32", "ChannelLast")

I64_None = ("int64", "")
I64_Default = ("int64", "DefaultFormat")
I64_5HD = ("int64", "NC1HWC0")
I64_FracZ = ("int64", "FracZ")
I64_FracNZ = ("int64", "FRACTAL_NZ")
I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
I64_NCHW = ("int64", "NCHW")
I64_NHWC = ("int64", "NHWC")
I64_HWCN = ("int64", "HWCN")
I64_NDHWC = ("int64", "NDHWC")
I64_ChannelLast = ("int64", "ChannelLast")

U64_None = ("uint64", "")
U64_Default = ("uint64", "DefaultFormat")
U64_5HD = ("uint64", "NC1HWC0")
U64_FracZ = ("uint64", "FracZ")
U64_FracNZ = ("uint64", "FRACTAL_NZ")
U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
U64_NCHW = ("uint64", "NCHW")
U64_NHWC = ("uint64", "NHWC")
U64_HWCN = ("uint64", "HWCN")
U64_NDHWC = ("uint64", "NDHWC")
U64_ChannelLast = ("uint64", "ChannelLast")

F16_None = ("float16", "")
F16_Default = ("float16", "DefaultFormat")
F16_5HD = ("float16", "NC1HWC0")
F16_FracZ = ("float16", "FracZ")
F16_FracNZ = ("float16", "FRACTAL_NZ")
F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
F16_NCHW = ("float16", "NCHW")
F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN")
F16_NDHWC = ("float16", "NDHWC")
F16_NCDHW = ("float16", "NCDHW")
F16_DHWCN = ("float16", "DHWCN")
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
F16_ChannelLast = ("float16", "ChannelLast")

F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat")
F32_5HD = ("float32", "NC1HWC0")
F32_FracZ = ("float32", "FracZ")
F32_FracNZ = ("float32", "FRACTAL_NZ")
F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")
F32_NDHWC = ("float32", "NDHWC")
F32_NCDHW = ("float32", "NCDHW")
F32_DHWCN = ("float32", "DHWCN")
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
F32_ChannelLast = ("float32", "ChannelLast")

F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
F64_5HD = ("float64", "NC1HWC0")
F64_FracZ = ("float64", "FracZ")
F64_FracNZ = ("float64", "FRACTAL_NZ")
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC")
F64_ChannelLast = ("float64", "ChannelLast")

C64_Default = ("complex64", "DefaultFormat")
C128_Default = ("complex128", "DefaultFormat")

+ 100
- 0
docs/api/api_python/ops/mindspore.ops.Primitive.rst View File

@@ -0,0 +1,100 @@
mindspore.ops.Primitive
=======================

.. py:class:: mindspore.Primitive(name)

Primitive是Python中算子原语的基类。

**参数:**

- **name** (str) - 当前Primitive的名称。

.. py:method:: add_prim_attr(name, value)

添加Primitive的属性。

**参数:**

- **name** (str) - 属性名称。
- **value** (Any) - 属性值。

.. py:method:: del_prim_attr(name)

删除Primitive的属性。

**参数:**

- **name** (str) - 属性名称。

.. py:method:: check_elim(*args)

检查是否可以消除此Primitive。有需要的子类可以重写该方法。

**参数:**

- **args** (Primitive参数的类型) - 与当前Primitive的参数相同。

**返回:**

由两个元素组成的元组。第一个元素是指是否能在编译阶段计算Primitive,第二个元素是计算结果。

.. py:method:: init_prim_io_names(inputs, outputs)

初始化Tensor或属性的输入输出的名称。

**参数:**

- **inputs** (list[str]) - 输入名称的列表。
- **outputs** (list[str]) - 输出名称的列表。

.. py:method:: recompute(mode=True)

设置Primitive的重计算属性。

如果有一个被设置了重计算属性的Primitive,并且其结果在计算导数的时候被使用,那么不会保存该Primitive在前向网络中的中间计算结果,而是在自动微分的时候重新进行计算。

.. note::
- 如果计算涉及随机化或全局变量,则暂无法保证等效性。
- 在PyNative模式下不支持。

**参数:**

- **mode** (bool) - Primitive是否设置了重计算。默认值:True。

.. py:method:: set_prim_instance_name(instance_name)

设置Primitive算子的实例的名称。

.. note::
当用户定义Primitive算子时,默认调用它。

**参数:**

- **instance_name** (str) - 用户设置的Primitive算子的实例的名称。

.. py:method:: set_stage(stage)

将stage的ID添加到Primitive属性中。

.. note::
仅在半自动并行模式下有效。在其他并行模式下,请将其设置为0。

**参数:**

- **stage** (int) - 当前stage的ID。

.. py:method:: shard(in_strategy, out_strategy)

将切分策略添加到Primitive属性中。

.. note::
仅在半自动并行或自动并行模式下有效。在其他并行模式中,将忽略此处设置的策略。

**参数:**

- **in_strategy** (tuple) - 描述算子输入的切分策略。默认值:None。
- **out_strategy** (tuple) - 描述算子输出的切分策略,仅针对某些算子,如MatMul。默认值:None。

.. py:method:: update_parameter()

判断此Primitive是否会更新参数的值。

+ 41
- 0
docs/api/api_python/ops/mindspore.ops.PrimitiveWithCheck.rst View File

@@ -0,0 +1,41 @@
mindspore.ops.PrimitiveWithCheck
================================

.. py:class:: mindspore.PrimitiveWithCheck(name)

PrimitiveWithCheck是Python中原语的基类,定义了检查算子输入参数的函数,但是使用了C++源码中注册的推理方法。

可以重写三个方法来定义Primitive的检查逻辑: __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__(),则__check__()的优先级最高。
如果未定义__check__(),则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法(如PrimitiveWithInfer),用于常量传播。

**参数:**

- **name** (str) - 当前Primitive的名称。

.. py:method:: check_dtype(*args)

检查输入参数的数据类型。

**参数:**

- **args** (:class:`mindspore.dtype`) - 输入的数据类型。

**返回:**

None。

.. py:method:: check_shape(*args)

检查输入参数的shape。

.. note::
Scalar的shape是一个空元组。

**参数:**

- **args** (tuple(int)) - 输入tensor的shape。

**返回:**

None。

+ 53
- 0
docs/api/api_python/ops/mindspore.ops.PrimitiveWithInfer.rst View File

@@ -0,0 +1,53 @@
mindspore.ops.PrimitiveWithInfer
================================

.. py:class:: mindspore.PrimitiveWithInfer(name)

PrimitiveWithInfer是Python中的原语基类,在python中定义了跟踪推理的函数。

可以重写四个方法来定义Primitive的推断逻辑:__infer__()、infer_shape()、infer_dtype()和infer_value()。如果在Primitive中定义了__infer__(),则__infer__()的优先级最高。

如果未定义__infer__(),则可以定义infer_shape()和infer_dtype()来描述shape和类型的推断逻辑。infer_value()用于常量传播。

**参数:**

- **name** (str) - 当前Primitive的名称。

.. py:method:: infer_dtype(*args)

根据输入类型推断输出类型。

**参数:**

- **args** (:class:`mindspore.dtype`) - 输入的数据类型。

**返回:**

:class:`mindspore.dtype`,输出的数据类型。

.. py:method:: infer_shape(*args)

根据输入形状推断输出形状。

.. note::
Scalar的shape是一个空元组。

**参数:**

- **args** (tuple(int)) - 输入tensor的shape。

**返回:**

`tuple(int)`,输出tensor的shape。

.. py:method:: infer_value(*args)

根据编译时的输入值推断输出值。
**参数:**

- **args** (Any) - 输入的值。

**返回:**

输出的值。如果编译时无法推断该值,返回`None`。

+ 12
- 0
docs/api/api_python/ops/mindspore.ops.constexpr.rst View File

@@ -0,0 +1,12 @@
mindspore.ops.constexpr
=======================

.. py:function:: mindspore.ops.constexpr(fn=None, get_instance=True, name=None):

创建PrimiveWithInfer算子,用于在编译时推断值。可以用它定义函数,从而使用构造函数中的常量计算出常量值。

**参数:**

- **fn** (function) - `fn` 用作输出算子的infer_value。默认值:None。
- **get_instance** (bool) - 如果为True,返回算子的实例,否则返回算子的类。默认值:True。
- **name** (str) - 定义算子的名称。如果 `name` 为None,则使用函数名称作为算子名称。默认值:None。

+ 17
- 0
docs/api/api_python/ops/mindspore.ops.get_vm_impl_fn.rst View File

@@ -0,0 +1,17 @@
mindspore.ops.get_vm_impl_fn
============================

.. py:function:: mindspore.ops.get_vm_impl_fn(prim):

通过Primitive对象或Primitive名称,获取虚拟实现函数。

**参数:**

- **prim** (Union[Primitive, str]) - 算子注册的Primitive对象或名称。

.. note::
该机制目前适用于调试。

**返回:**

函数,虚拟实现函数。

+ 16
- 0
docs/api/api_python/ops/mindspore.ops.prim_attr_register.rst View File

@@ -0,0 +1,16 @@
mindspore.ops.prim_attr_register
================================

.. py:function:: mindspore.ops.prim_attr_register(fn):

Primitive属性的注册器。

注册装饰器,其中装饰器用于内置算子的Primitive的'__init__'函数。该函数将添加'__init__'的所有参数作为算子属性,并且初始化Primitive的名称。

**参数:**

- **fn** (function) - Primitive的__init__函数。

**返回:**

函数,原始函数。

+ 3
- 4
mindspore/python/mindspore/ops/op_info_register.py View File

@@ -764,10 +764,6 @@ class DataType:
r"""
Various combinations of dtype and format of Ascend ops.

The current list below may be incomplete.

Please add it if necessary.

current support:

.. code-block::
@@ -933,6 +929,9 @@ class DataType:
F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC")
F64_ChannelLast = ("float64", "ChannelLast")

C64_Default = ("complex64", "DefaultFormat")
C128_Default = ("complex128", "DefaultFormat")
"""

None_None = ("", "")


+ 2
- 2
mindspore/python/mindspore/ops/primitive.py View File

@@ -381,8 +381,8 @@ class Primitive(Primitive_):

class PrimitiveWithCheck(Primitive):
"""
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator
input arguments but used the infer method registered in c++ source codes.
PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments
of operators, but uses the infer method registered in c++ source codes.

There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called.


Loading…
Cancel
Save