diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 84fd265c75..613367d613 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -66,6 +66,8 @@ class _BatchNorm(Cell): self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) else: self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) + if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": + raise ValueError("NCDHW format only support in Ascend target.") if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = use_batch_statistics