Browse Source

!13267 Raise error for BatchNorm3D only support Ascend backend.

From: @liu_xiao_93
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
0a7ca97583
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      mindspore/nn/layer/normalization.py

+ 2
- 0
mindspore/nn/layer/normalization.py View File

@@ -66,6 +66,8 @@ class _BatchNorm(Cell):
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
else:
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW":
raise ValueError("NCDHW format only support in Ascend target.")
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics


Loading…
Cancel
Save