|
|
@@ -128,6 +128,7 @@ class GeSwitch(PrimitiveWithInfer): |
|
|
return (data, data) |
|
|
return (data, data) |
|
|
|
|
|
|
|
|
def infer_dtype(self, data_type, pred_type): |
|
|
def infer_dtype(self, data_type, pred_type): |
|
|
|
|
|
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name) |
|
|
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) |
|
|
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) |
|
|
return (data_type, data_type) |
|
|
return (data_type, data_type) |
|
|
|
|
|
|
|
|
@@ -153,7 +154,20 @@ class Merge(PrimitiveWithInfer): |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def infer_shape(self, inputs): |
|
|
def infer_shape(self, inputs): |
|
|
|
|
|
validator.check_integer('inputs len', len(inputs), 0, Rel.GT, self.name) |
|
|
|
|
|
input_0 = inputs[0] |
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, len(inputs)): |
|
|
|
|
|
if inputs[i] != input_0: |
|
|
|
|
|
raise ValueError(f"For \'{self.name}\', the shape of {i}th input should be same as " |
|
|
|
|
|
f"first input {input_0}, but got {inputs[i]}.") |
|
|
|
|
|
|
|
|
return (inputs[0], [1]) |
|
|
return (inputs[0], [1]) |
|
|
|
|
|
|
|
|
def infer_dtype(self, inputs): |
|
|
def infer_dtype(self, inputs): |
|
|
|
|
|
args = {} |
|
|
|
|
|
for i, item in enumerate(inputs): |
|
|
|
|
|
args['inputs[%d]' % i] = item |
|
|
|
|
|
|
|
|
|
|
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
return (inputs[0], mstype.int32) |
|
|
return (inputs[0], mstype.int32) |