diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 010b637f39..fb9d92f355 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -232,6 +232,8 @@ class BiasAddGrad(PrimitiveWithInfer): self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") + if self.format == "NCDHW": + self.format = "NCHW" self.add_prim_attr('data_format', self.format) def infer_shape(self, d_output):