Browse Source

add topk and randomchoicewithmask data type for aicpu

tags/v0.5.0-beta
yanzhenxiang2020 5 years ago
parent
commit
587091dbe3
2 changed files with 2 additions and 0 deletions
  1. +1
    -0
      mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py
  2. +1
    -0
      mindspore/ops/_op_impl/aicpu/topk.py

+ 1
- 0
mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py View File

@@ -25,6 +25,7 @@ random_choice_with_mask_op_info = AiCPURegOp("RandomChoiceWithMask") \
.attr("seed", "int") \
.attr("seed2", "int") \
.dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
.get_op_info()

@op_info_register(random_choice_with_mask_op_info)


+ 1
- 0
mindspore/ops/_op_impl/aicpu/topk.py View File

@@ -24,6 +24,7 @@ top_k_op_info = AiCPURegOp("TopK") \
.output(0, "values", "required") \
.output(1, "indices", "required") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()

@op_info_register(top_k_op_info)


Loading…
Cancel
Save