|
|
|
@@ -84,7 +84,7 @@ class GeSwitch(PrimitiveWithInfer): |
|
|
|
the true branch will be activated, or vise verse. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **data** (Tensor) - The data to be used for switch control. |
|
|
|
- **data** (Union[Tensor, Number]) - The data to be used for switch control. |
|
|
|
- **pred** (Tensor) - It should be a scalar whose type is bool and shape is `()`, It is used as condition for |
|
|
|
switch control. |
|
|
|
Outputs: |
|
|
|
@@ -144,7 +144,7 @@ class Merge(PrimitiveWithInfer): |
|
|
|
One and only one of the inputs should be selected as the output |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **inputs** (Tuple) - The data to be merged. |
|
|
|
- **inputs** (Tuple) - The data to be merged. All tuple elements should have same data type. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element. |
|
|
|
@@ -171,6 +171,5 @@ class Merge(PrimitiveWithInfer): |
|
|
|
for i, item in enumerate(inputs): |
|
|
|
args['inputs[%d]' % i] = item |
|
|
|
|
|
|
|
validator.check_tensor_type_same( |
|
|
|
args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
|
validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) |
|
|
|
return (inputs[0], mstype.int32) |