You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.ops.PrimitiveWithInfer.rst 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. mindspore.ops.PrimitiveWithInfer
  2. ================================
  3. .. py:class:: mindspore.ops.PrimitiveWithInfer(name)
  4. PrimitiveWithInfer是Python中的原语基类,在python中定义了跟踪推理的函数。
  5. 可以重写四个方法来定义Primitive的推断逻辑:__infer__()、infer_shape()、infer_dtype()和infer_value()。如果在Primitive中定义了__infer__(),则__infer__()的优先级最高。
  6. 如果未定义__infer__(),则可以定义infer_shape()和infer_dtype()来描述shape和类型的推断逻辑。infer_value()用于常量传播。
  7. **参数:**
  8. - **name** (str) - 当前Primitive的名称。
  9. .. py:method:: infer_dtype(*args)
  10. 根据输入类型推断输出类型。
  11. **参数:**
  12. - **args** (:class:`mindspore.dtype`) - 输入的数据类型。
  13. **返回:**
  14. :class:`mindspore.dtype`,输出的数据类型。
  15. .. py:method:: infer_shape(*args)
  16. 根据输入形状推断输出形状。
  17. .. note::
  18. Scalar的shape是一个空元组。
  19. **参数:**
  20. - **args** (tuple(int)) - 输入tensor的shape。
  21. **返回:**
  22. `tuple(int)`,输出tensor的shape。
  23. .. py:method:: infer_value(*args)
  24. 根据编译时的输入值推断输出值。
  25. **参数:**
  26. - **args** (Any) - 输入的值。
  27. **返回:**
  28. 输出的值。如果编译时无法推断该值,返回 `None` 。