Merge pull request !29911 from huangbingjian/code_docs_primitivefeature/build-system-rewrite
| @@ -75,3 +75,34 @@ Random类型算子 | |||
| mindspore.ops.Gamma | |||
| mindspore.ops.UniformReal | |||
| 原语 | |||
| ---- | |||
| .. cnmsplatformautosummary:: | |||
| :toctree: ops | |||
| mindspore.ops.constexpr | |||
| mindspore.ops.prim_attr_register | |||
| mindspore.ops.Primitive | |||
| mindspore.ops.PrimitiveWithCheck | |||
| mindspore.ops.PrimitiveWithInfer | |||
| 函数实现注册 | |||
| -------------- | |||
| .. cnmsplatformautosummary:: | |||
| :toctree: ops | |||
| mindspore.ops.get_vm_impl_fn | |||
| 算子信息注册 | |||
| -------------- | |||
| .. cnmsplatformautosummary:: | |||
| :toctree: ops | |||
| mindspore.ops.DataType | |||
| @@ -0,0 +1,173 @@ | |||
| mindspore.ops.DataType | |||
| ====================== | |||
| .. py:class:: mindspore.ops.DataType: | |||
| Ascend算子的dtype和format的多种组合。 | |||
| 当前支持: | |||
| .. code-block:: | |||
| None_None = ("", "") | |||
| None_Default = ("", "DefaultFormat") | |||
| BOOL_None = ("bool", "") | |||
| BOOL_Default = ("bool", "DefaultFormat") | |||
| BOOL_5HD = ("bool", "NC1HWC0") | |||
| BOOL_FracZ = ("bool", "FracZ") | |||
| BOOL_FracNZ = ("bool", "FRACTAL_NZ") | |||
| BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0") | |||
| BOOL_NCHW = ("bool", "NCHW") | |||
| BOOL_NHWC = ("bool", "NHWC") | |||
| BOOL_HWCN = ("bool", "HWCN") | |||
| BOOL_NDHWC = ("bool", "NDHWC") | |||
| BOOL_ChannelLast = ("bool", "ChannelLast") | |||
| I8_None = ("int8", "") | |||
| I8_Default = ("int8", "DefaultFormat") | |||
| I8_5HD = ("int8", "NC1HWC0") | |||
| I8_FracZ = ("int8", "FracZ") | |||
| I8_FracNZ = ("int8", "FRACTAL_NZ") | |||
| I8_C1HWNCoC0 = ("int8", "C1HWNCoC0") | |||
| I8_NCHW = ("int8", "NCHW") | |||
| I8_NHWC = ("int8", "NHWC") | |||
| I8_HWCN = ("int8", "HWCN") | |||
| I8_NDHWC = ("int8", "NDHWC") | |||
| I8_ChannelLast = ("int8", "ChannelLast") | |||
| U8_None = ("uint8", "") | |||
| U8_Default = ("uint8", "DefaultFormat") | |||
| U8_5HD = ("uint8", "NC1HWC0") | |||
| U8_FracZ = ("uint8", "FracZ") | |||
| U8_FracNZ = ("uint8", "FRACTAL_NZ") | |||
| U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0") | |||
| U8_NCHW = ("uint8", "NCHW") | |||
| U8_NHWC = ("uint8", "NHWC") | |||
| U8_HWCN = ("uint8", "HWCN") | |||
| U8_NDHWC = ("uint8", "NDHWC") | |||
| U8_ChannelLast = ("uint8", "ChannelLast") | |||
| I16_None = ("int16", "") | |||
| I16_Default = ("int16", "DefaultFormat") | |||
| I16_5HD = ("int16", "NC1HWC0") | |||
| I16_FracZ = ("int16", "FracZ") | |||
| I16_FracNZ = ("int16", "FRACTAL_NZ") | |||
| I16_C1HWNCoC0 = ("int16", "C1HWNCoC0") | |||
| I16_NCHW = ("int16", "NCHW") | |||
| I16_NHWC = ("int16", "NHWC") | |||
| I16_HWCN = ("int16", "HWCN") | |||
| I16_NDHWC = ("int16", "NDHWC") | |||
| I16_ChannelLast = ("int16", "ChannelLast") | |||
| U16_None = ("uint16", "") | |||
| U16_Default = ("uint16", "DefaultFormat") | |||
| U16_5HD = ("uint16", "NC1HWC0") | |||
| U16_FracZ = ("uint16", "FracZ") | |||
| U16_FracNZ = ("uint16", "FRACTAL_NZ") | |||
| U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0") | |||
| U16_NCHW = ("uint16", "NCHW") | |||
| U16_NHWC = ("uint16", "NHWC") | |||
| U16_HWCN = ("uint16", "HWCN") | |||
| U16_NDHWC = ("uint16", "NDHWC") | |||
| U16_ChannelLast = ("uint16", "ChannelLast") | |||
| I32_None = ("int32", "") | |||
| I32_Default = ("int32", "DefaultFormat") | |||
| I32_5HD = ("int32", "NC1HWC0") | |||
| I32_FracZ = ("int32", "FracZ") | |||
| I32_FracNZ = ("int32", "FRACTAL_NZ") | |||
| I32_C1HWNCoC0 = ("int32", "C1HWNCoC0") | |||
| I32_NCHW = ("int32", "NCHW") | |||
| I32_NHWC = ("int32", "NHWC") | |||
| I32_HWCN = ("int32", "HWCN") | |||
| I32_NDHWC = ("int32", "NDHWC") | |||
| I32_ChannelLast = ("int32", "ChannelLast") | |||
| U32_None = ("uint32", "") | |||
| U32_Default = ("uint32", "DefaultFormat") | |||
| U32_5HD = ("uint32", "NC1HWC0") | |||
| U32_FracZ = ("uint32", "FracZ") | |||
| U32_FracNZ = ("uint32", "FRACTAL_NZ") | |||
| U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0") | |||
| U32_NCHW = ("uint32", "NCHW") | |||
| U32_NHWC = ("uint32", "NHWC") | |||
| U32_HWCN = ("uint32", "HWCN") | |||
| U32_NDHWC = ("uint32", "NDHWC") | |||
| U32_ChannelLast = ("uint32", "ChannelLast") | |||
| I64_None = ("int64", "") | |||
| I64_Default = ("int64", "DefaultFormat") | |||
| I64_5HD = ("int64", "NC1HWC0") | |||
| I64_FracZ = ("int64", "FracZ") | |||
| I64_FracNZ = ("int64", "FRACTAL_NZ") | |||
| I64_C1HWNCoC0 = ("int64", "C1HWNCoC0") | |||
| I64_NCHW = ("int64", "NCHW") | |||
| I64_NHWC = ("int64", "NHWC") | |||
| I64_HWCN = ("int64", "HWCN") | |||
| I64_NDHWC = ("int64", "NDHWC") | |||
| I64_ChannelLast = ("int64", "ChannelLast") | |||
| U64_None = ("uint64", "") | |||
| U64_Default = ("uint64", "DefaultFormat") | |||
| U64_5HD = ("uint64", "NC1HWC0") | |||
| U64_FracZ = ("uint64", "FracZ") | |||
| U64_FracNZ = ("uint64", "FRACTAL_NZ") | |||
| U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0") | |||
| U64_NCHW = ("uint64", "NCHW") | |||
| U64_NHWC = ("uint64", "NHWC") | |||
| U64_HWCN = ("uint64", "HWCN") | |||
| U64_NDHWC = ("uint64", "NDHWC") | |||
| U64_ChannelLast = ("uint64", "ChannelLast") | |||
| F16_None = ("float16", "") | |||
| F16_Default = ("float16", "DefaultFormat") | |||
| F16_5HD = ("float16", "NC1HWC0") | |||
| F16_FracZ = ("float16", "FracZ") | |||
| F16_FracNZ = ("float16", "FRACTAL_NZ") | |||
| F16_C1HWNCoC0 = ("float16", "C1HWNCoC0") | |||
| F16_NCHW = ("float16", "NCHW") | |||
| F16_NHWC = ("float16", "NHWC") | |||
| F16_HWCN = ("float16", "HWCN") | |||
| F16_NDHWC = ("float16", "NDHWC") | |||
| F16_NCDHW = ("float16", "NCDHW") | |||
| F16_DHWCN = ("float16", "DHWCN") | |||
| F16_NDC1HWC0 = ("float16", "NDC1HWC0") | |||
| F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") | |||
| F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") | |||
| F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN") | |||
| F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS") | |||
| F16_ChannelLast = ("float16", "ChannelLast") | |||
| F32_None = ("float32", "") | |||
| F32_Default = ("float32", "DefaultFormat") | |||
| F32_5HD = ("float32", "NC1HWC0") | |||
| F32_FracZ = ("float32", "FracZ") | |||
| F32_FracNZ = ("float32", "FRACTAL_NZ") | |||
| F32_C1HWNCoC0 = ("float32", "C1HWNCoC0") | |||
| F32_NCHW = ("float32", "NCHW") | |||
| F32_NHWC = ("float32", "NHWC") | |||
| F32_HWCN = ("float32", "HWCN") | |||
| F32_NDHWC = ("float32", "NDHWC") | |||
| F32_NCDHW = ("float32", "NCDHW") | |||
| F32_DHWCN = ("float32", "DHWCN") | |||
| F32_NDC1HWC0 = ("float32", "NDC1HWC0") | |||
| F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") | |||
| F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") | |||
| F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN") | |||
| F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS") | |||
| F32_ChannelLast = ("float32", "ChannelLast") | |||
| F64_None = ("float64", "") | |||
| F64_Default = ("float64", "DefaultFormat") | |||
| F64_5HD = ("float64", "NC1HWC0") | |||
| F64_FracZ = ("float64", "FracZ") | |||
| F64_FracNZ = ("float64", "FRACTAL_NZ") | |||
| F64_C1HWNCoC0 = ("float64", "C1HWNCoC0") | |||
| F64_NCHW = ("float64", "NCHW") | |||
| F64_NHWC = ("float64", "NHWC") | |||
| F64_HWCN = ("float64", "HWCN") | |||
| F64_NDHWC = ("float64", "NDHWC") | |||
| F64_ChannelLast = ("float64", "ChannelLast") | |||
| C64_Default = ("complex64", "DefaultFormat") | |||
| C128_Default = ("complex128", "DefaultFormat") | |||
| @@ -0,0 +1,100 @@ | |||
| mindspore.ops.Primitive | |||
| ======================= | |||
| .. py:class:: mindspore.Primitive(name) | |||
| Primitive是Python中算子原语的基类。 | |||
| **参数:** | |||
| - **name** (str) - 当前Primitive的名称。 | |||
| .. py:method:: add_prim_attr(name, value) | |||
| 添加Primitive的属性。 | |||
| **参数:** | |||
| - **name** (str) - 属性名称。 | |||
| - **value** (Any) - 属性值。 | |||
| .. py:method:: del_prim_attr(name) | |||
| 删除Primitive的属性。 | |||
| **参数:** | |||
| - **name** (str) - 属性名称。 | |||
| .. py:method:: check_elim(*args) | |||
| 检查是否可以消除此Primitive。有需要的子类可以重写该方法。 | |||
| **参数:** | |||
| - **args** (Primitive参数的类型) - 与当前Primitive的参数相同。 | |||
| **返回:** | |||
| 由两个元素组成的元组。第一个元素是指是否能在编译阶段计算Primitive,第二个元素是计算结果。 | |||
| .. py:method:: init_prim_io_names(inputs, outputs) | |||
| 初始化Tensor或属性的输入输出的名称。 | |||
| **参数:** | |||
| - **inputs** (list[str]) - 输入名称的列表。 | |||
| - **outputs** (list[str]) - 输出名称的列表。 | |||
| .. py:method:: recompute(mode=True) | |||
| 设置Primitive的重计算属性。 | |||
| 如果有一个被设置了重计算属性的Primitive,并且其结果在计算导数的时候被使用,那么不会保存该Primitive在前向网络中的中间计算结果,而是在自动微分的时候重新进行计算。 | |||
| .. note:: | |||
| - 如果计算涉及随机化或全局变量,则暂无法保证等效性。 | |||
| - 在PyNative模式下不支持。 | |||
| **参数:** | |||
| - **mode** (bool) - Primitive是否设置了重计算。默认值:True。 | |||
| .. py:method:: set_prim_instance_name(instance_name) | |||
| 设置Primitive算子的实例的名称。 | |||
| .. note:: | |||
| 当用户定义Primitive算子时,默认调用它。 | |||
| **参数:** | |||
| - **instance_name** (str) - 用户设置的Primitive算子的实例的名称。 | |||
| .. py:method:: set_stage(stage) | |||
| 将stage的ID添加到Primitive属性中。 | |||
| .. note:: | |||
| 仅在半自动并行模式下有效。在其他并行模式下,请将其设置为0。 | |||
| **参数:** | |||
| - **stage** (int) - 当前stage的ID。 | |||
| .. py:method:: shard(in_strategy, out_strategy) | |||
| 将切分策略添加到Primitive属性中。 | |||
| .. note:: | |||
| 仅在半自动并行或自动并行模式下有效。在其他并行模式中,将忽略此处设置的策略。 | |||
| **参数:** | |||
| - **in_strategy** (tuple) - 描述算子输入的切分策略。默认值:None。 | |||
| - **out_strategy** (tuple) - 描述算子输出的切分策略,仅针对某些算子,如MatMul。默认值:None。 | |||
| .. py:method:: update_parameter() | |||
| 判断此Primitive是否会更新参数的值。 | |||
| @@ -0,0 +1,41 @@ | |||
| mindspore.ops.PrimitiveWithCheck | |||
| ================================ | |||
| .. py:class:: mindspore.PrimitiveWithCheck(name) | |||
| PrimitiveWithCheck是Python中原语的基类,定义了检查算子输入参数的函数,但是使用了C++源码中注册的推理方法。 | |||
| 可以重写三个方法来定义Primitive的检查逻辑: __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__(),则__check__()的优先级最高。 | |||
| 如果未定义__check__(),则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法(如PrimitiveWithInfer),用于常量传播。 | |||
| **参数:** | |||
| - **name** (str) - 当前Primitive的名称。 | |||
| .. py:method:: check_dtype(*args) | |||
| 检查输入参数的数据类型。 | |||
| **参数:** | |||
| - **args** (:class:`mindspore.dtype`) - 输入的数据类型。 | |||
| **返回:** | |||
| None。 | |||
| .. py:method:: check_shape(*args) | |||
| 检查输入参数的shape。 | |||
| .. note:: | |||
| Scalar的shape是一个空元组。 | |||
| **参数:** | |||
| - **args** (tuple(int)) - 输入tensor的shape。 | |||
| **返回:** | |||
| None。 | |||
| @@ -0,0 +1,53 @@ | |||
| mindspore.ops.PrimitiveWithInfer | |||
| ================================ | |||
| .. py:class:: mindspore.PrimitiveWithInfer(name) | |||
| PrimitiveWithInfer是Python中的原语基类,在python中定义了跟踪推理的函数。 | |||
| 可以重写四个方法来定义Primitive的推断逻辑:__infer__()、infer_shape()、infer_dtype()和infer_value()。如果在Primitive中定义了__infer__(),则__infer__()的优先级最高。 | |||
| 如果未定义__infer__(),则可以定义infer_shape()和infer_dtype()来描述shape和类型的推断逻辑。infer_value()用于常量传播。 | |||
| **参数:** | |||
| - **name** (str) - 当前Primitive的名称。 | |||
| .. py:method:: infer_dtype(*args) | |||
| 根据输入类型推断输出类型。 | |||
| **参数:** | |||
| - **args** (:class:`mindspore.dtype`) - 输入的数据类型。 | |||
| **返回:** | |||
| :class:`mindspore.dtype`,输出的数据类型。 | |||
| .. py:method:: infer_shape(*args) | |||
| 根据输入形状推断输出形状。 | |||
| .. note:: | |||
| Scalar的shape是一个空元组。 | |||
| **参数:** | |||
| - **args** (tuple(int)) - 输入tensor的shape。 | |||
| **返回:** | |||
| `tuple(int)`,输出tensor的shape。 | |||
| .. py:method:: infer_value(*args) | |||
| 根据编译时的输入值推断输出值。 | |||
| **参数:** | |||
| - **args** (Any) - 输入的值。 | |||
| **返回:** | |||
| 输出的值。如果编译时无法推断该值,返回`None`。 | |||
| @@ -0,0 +1,12 @@ | |||
| mindspore.ops.constexpr | |||
| ======================= | |||
| .. py:function:: mindspore.ops.constexpr(fn=None, get_instance=True, name=None): | |||
| 创建PrimiveWithInfer算子,用于在编译时推断值。可以用它定义函数,从而使用构造函数中的常量计算出常量值。 | |||
| **参数:** | |||
| - **fn** (function) - `fn` 用作输出算子的infer_value。默认值:None。 | |||
| - **get_instance** (bool) - 如果为True,返回算子的实例,否则返回算子的类。默认值:True。 | |||
| - **name** (str) - 定义算子的名称。如果 `name` 为None,则使用函数名称作为算子名称。默认值:None。 | |||
| @@ -0,0 +1,17 @@ | |||
| mindspore.ops.get_vm_impl_fn | |||
| ============================ | |||
| .. py:function:: mindspore.ops.get_vm_impl_fn(prim): | |||
| 通过Primitive对象或Primitive名称,获取虚拟实现函数。 | |||
| **参数:** | |||
| - **prim** (Union[Primitive, str]) - 算子注册的Primitive对象或名称。 | |||
| .. note:: | |||
| 该机制目前适用于调试。 | |||
| **返回:** | |||
| 函数,虚拟实现函数。 | |||
| @@ -0,0 +1,16 @@ | |||
| mindspore.ops.prim_attr_register | |||
| ================================ | |||
| .. py:function:: mindspore.ops.prim_attr_register(fn): | |||
| Primitive属性的注册器。 | |||
| 注册装饰器,其中装饰器用于内置算子的Primitive的'__init__'函数。该函数将添加'__init__'的所有参数作为算子属性,并且初始化Primitive的名称。 | |||
| **参数:** | |||
| - **fn** (function) - Primitive的__init__函数。 | |||
| **返回:** | |||
| 函数,原始函数。 | |||
| @@ -764,10 +764,6 @@ class DataType: | |||
| r""" | |||
| Various combinations of dtype and format of Ascend ops. | |||
| The current list below may be incomplete. | |||
| Please add it if necessary. | |||
| current support: | |||
| .. code-block:: | |||
| @@ -933,6 +929,9 @@ class DataType: | |||
| F64_HWCN = ("float64", "HWCN") | |||
| F64_NDHWC = ("float64", "NDHWC") | |||
| F64_ChannelLast = ("float64", "ChannelLast") | |||
| C64_Default = ("complex64", "DefaultFormat") | |||
| C128_Default = ("complex128", "DefaultFormat") | |||
| """ | |||
| None_None = ("", "") | |||
| @@ -381,8 +381,8 @@ class Primitive(Primitive_): | |||
| class PrimitiveWithCheck(Primitive): | |||
| """ | |||
| PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator | |||
| input arguments but used the infer method registered in c++ source codes. | |||
| PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments | |||
| of operators, but uses the infer method registered in c++ source codes. | |||
| There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(), | |||
| check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called. | |||