|
|
|
@@ -66,6 +66,8 @@ class _BatchNorm(Cell): |
|
|
|
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) |
|
|
|
else: |
|
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) |
|
|
|
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": |
|
|
|
raise ValueError("NCDHW format only support in Ascend target.") |
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC": |
|
|
|
raise ValueError("NHWC format only support in GPU target.") |
|
|
|
self.use_batch_statistics = use_batch_statistics |
|
|
|
|