Browse Source

fix output of BiasAddGrad.

pull/15245/head
liuxiao93 4 years ago
parent
commit
bd01a358f8
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      mindspore/ops/operations/_grad_ops.py

+ 2
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -232,6 +232,8 @@ class BiasAddGrad(PrimitiveWithInfer):
self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
if self.format == "NCDHW":
self.format = "NCHW"
self.add_prim_attr('data_format', self.format)

def infer_shape(self, d_output):


Loading…
Cancel
Save