Browse Source

!1195 add validate for geswitch and merge

Merge pull request !1195 from jiangjinsheng/issue_geswitch
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
c169ac6a18
3 changed files with 16 additions and 2 deletions
  1. +14
    -0
      mindspore/ops/operations/control_ops.py
  2. +1
    -1
      mindspore/ops/operations/nn_ops.py
  3. +1
    -1
      tests/ut/python/ops/test_control_ops.py

+ 14
- 0
mindspore/ops/operations/control_ops.py View File

@@ -128,6 +128,7 @@ class GeSwitch(PrimitiveWithInfer):
return (data, data)

def infer_dtype(self, data_type, pred_type):
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
return (data_type, data_type)

@@ -153,7 +154,20 @@ class Merge(PrimitiveWithInfer):
raise NotImplementedError

def infer_shape(self, inputs):
validator.check_integer('inputs len', len(inputs), 0, Rel.GT, self.name)
input_0 = inputs[0]

for i in range(1, len(inputs)):
if inputs[i] != input_0:
raise ValueError(f"For \'{self.name}\', the shape of {i}th input should be same as "
f"first input {input_0}, but got {inputs[i]}.")

return (inputs[0], [1])

def infer_dtype(self, inputs):
args = {}
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item

validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)

+ 1
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -2084,7 +2084,7 @@ class GetNext(PrimitiveWithInfer):
Note:
GetNext op needs to be associated with network and also depends on the init_dataset interface,
it can't be used directly as a single op.
For details, please refer to `nn.cell_wrapper.DataWrapper` source code.
For details, please refer to `nn.DataWrapper` source code.

Args:
types (list[:class:`mindspore.dtype`]): The type of the outputs.


+ 1
- 1
tests/ut/python/ops/test_control_ops.py View File

@@ -33,7 +33,7 @@ def cond_data_test(x_init, y_init):
super(Net, self).__init__()
self.square = P.Square()
self.add = P.TensorAdd()
self.value = Tensor(np.full((1), 3, dtype=np.float32))
self.value = Tensor(3, dtype=ms.float32)
self.switch = P.GeSwitch()
self.merge = P.Merge()
self.less = P.Less()


Loading…
Cancel
Save