Browse Source

!1304 Add check for controldepend

Merge pull request !1304 from amongo/AddCheckForControlDepend
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
bddd743ca9
2 changed files with 17 additions and 3 deletions
  1. +8
    -3
      mindspore/ops/operations/control_ops.py
  2. +9
    -0
      tests/ut/python/ops/test_control_ops.py

+ 8
- 3
mindspore/ops/operations/control_ops.py View File

@@ -69,6 +69,8 @@ class ControlDepend(Primitive):
@prim_attr_register
def __init__(self, depend_mode=0):
"""init"""
validator.check_int_range(
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)

def __call__(self, src, dst):
return src
@@ -128,8 +130,10 @@ class GeSwitch(PrimitiveWithInfer):
return (data, data)

def infer_dtype(self, data_type, pred_type):
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
validator.check_subclass(
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same(
{"pred": pred_type}, [mstype.bool_], self.name)
return (data_type, data_type)


@@ -167,5 +171,6 @@ class Merge(PrimitiveWithInfer):
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item

validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensor_type_same(
args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)

+ 9
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
""" test control ops """
import pytest
import numpy as np

import mindspore as ms
@@ -434,3 +435,11 @@ def test_index_to_switch_layer():
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))

def test_control_depend_check():
with pytest.raises(TypeError) as e:
depend = P.ControlDepend(0.0)
with pytest.raises(ValueError) as e:
depend = P.ControlDepend(2)
with pytest.raises(TypeError) as e:
depend = P.ControlDepend((2,))

Loading…
Cancel
Save