Browse Source

!801 fix checking bug of prelu

Merge pull request !801 from fary86/fix_checking_bug_of_prelu
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
eeb8e4d4d3
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      mindspore/ops/operations/nn_ops.py

+ 3
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -2066,8 +2066,9 @@ class PReLU(PrimitiveWithInfer):
return input_x_shape

def infer_dtype(self, input_x_dtype, weight_dtype):
args = {"input_x": input_x_dtype, "weight": weight_dtype}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name)
return input_x_dtype




Loading…
Cancel
Save