|
|
|
@@ -69,6 +69,8 @@ class ControlDepend(Primitive): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, depend_mode=0): |
|
|
|
"""init""" |
|
|
|
validator.check_int_range( |
|
|
|
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name) |
|
|
|
|
|
|
|
def __call__(self, src, dst): |
|
|
|
return src |
|
|
|
@@ -128,8 +130,10 @@ class GeSwitch(PrimitiveWithInfer): |
|
|
|
return (data, data) |
|
|
|
|
|
|
|
def infer_dtype(self, data_type, pred_type): |
|
|
|
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name) |
|
|
|
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) |
|
|
|
validator.check_subclass( |
|
|
|
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name) |
|
|
|
validator.check_tensor_type_same( |
|
|
|
{"pred": pred_type}, [mstype.bool_], self.name) |
|
|
|
return (data_type, data_type) |
|
|
|
|
|
|
|
|
|
|
|
@@ -167,5 +171,6 @@ class Merge(PrimitiveWithInfer): |
|
|
|
for i, item in enumerate(inputs): |
|
|
|
args['inputs[%d]' % i] = item |
|
|
|
|
|
|
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
|
validator.check_tensor_type_same( |
|
|
|
args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
|
return (inputs[0], mstype.int32) |