diff --git a/mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py b/mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py index badbea3208..a0eabbc4bc 100644 --- a/mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +++ b/mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py @@ -25,6 +25,7 @@ random_choice_with_mask_op_info = AiCPURegOp("RandomChoiceWithMask") \ .attr("seed", "int") \ .attr("seed2", "int") \ .dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \ .get_op_info() @op_info_register(random_choice_with_mask_op_info) diff --git a/mindspore/ops/_op_impl/aicpu/topk.py b/mindspore/ops/_op_impl/aicpu/topk.py index 95bffbdb8c..a68ae3557d 100644 --- a/mindspore/ops/_op_impl/aicpu/topk.py +++ b/mindspore/ops/_op_impl/aicpu/topk.py @@ -24,6 +24,7 @@ top_k_op_info = AiCPURegOp("TopK") \ .output(0, "values", "required") \ .output(1, "indices", "required") \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .get_op_info() @op_info_register(top_k_op_info)