|
|
|
@@ -273,3 +273,14 @@ class Greater(_CompareOp): |
|
|
|
|
|
|
|
class GreaterEqual(_CompareOp): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class Select(_Elemwise): |
|
|
|
def _check_type(self): |
|
|
|
if self.inputs[0].dtype != "bool": |
|
|
|
raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype)) |
|
|
|
if self.inputs[1].dtype != self.inputs[2].dtype: |
|
|
|
raise GKException("Select's input mismatch ({} vs {})".format(self.inputs[1].dtype, self.inputs[2].dtype)) |
|
|
|
|
|
|
|
def _infer_type(self): |
|
|
|
return self.inputs[1].dtype |