Browse Source

!12902 【GraphKernel】Add OpInfer for op Select

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @gaoxiong1
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
2b01887371
1 changed files with 11 additions and 0 deletions
  1. +11
    -0
      mindspore/_extends/graph_kernel/model/op_infer.py

+ 11
- 0
mindspore/_extends/graph_kernel/model/op_infer.py View File

@@ -273,3 +273,14 @@ class Greater(_CompareOp):

class GreaterEqual(_CompareOp):
pass


class Select(_Elemwise):
def _check_type(self):
if self.inputs[0].dtype != "bool":
raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype))
if self.inputs[1].dtype != self.inputs[2].dtype:
raise GKException("Select's input mismatch ({} vs {})".format(self.inputs[1].dtype, self.inputs[2].dtype))

def _infer_type(self):
return self.inputs[1].dtype

Loading…
Cancel
Save