You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.ops.PrimitiveWithCheck.rst 1.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. mindspore.ops.PrimitiveWithCheck
  2. ================================
  3. .. py:class:: mindspore.ops.PrimitiveWithCheck(name)
  4. PrimitiveWithCheck是Python中原语的基类,定义了检查算子输入参数的函数,但是使用了C++源码中注册的推理方法。
  5. 可以重写三个方法来定义Primitive的检查逻辑: __check__()、check_shape()和check_dtype()。如果在Primitive中定义了__check__(),则__check__()的优先级最高。
  6. 如果未定义__check__(),则可以定义check_shape()和check_dtype()来描述形状和类型的检查逻辑。可以定义infer_value()方法(如PrimitiveWithInfer),用于常量传播。
  7. **参数:**
  8. - **name** (str) - 当前Primitive的名称。
  9. .. py:method:: check_dtype(*args)
  10. 检查输入参数的数据类型。
  11. **参数:**
  12. - **args** (:class:`mindspore.dtype`) - 输入的数据类型。
  13. **返回:**
  14. None。
  15. .. py:method:: check_shape(*args)
  16. 检查输入参数的shape。
  17. .. note::
  18. Scalar的shape是一个空元组。
  19. **参数:**
  20. - **args** (tuple(int)) - 输入tensor的shape。
  21. **返回:**
  22. None。