Browse Source

!260 fix embeddinglookupgrad when param shape is one dim

Merge pull request !260 from wuxuejian/grad_embeddinglookup_fix
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
45b6b2f4e3
2 changed files with 4 additions and 2 deletions
  1. +3
    -2
      mindspore/ops/_grad/grad_array_ops.py
  2. +1
    -0
      tests/st/ops/ascend/test_embedding_lookup.py

+ 3
- 2
mindspore/ops/_grad/grad_array_ops.py View File

@@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self):
# Reshape the 'new_indices' # Reshape the 'new_indices'
new_indices_shape_changed = (size_op(new_indices),) new_indices_shape_changed = (size_op(new_indices),)
new_indices = reshape_op(new_indices, new_indices_shape_changed) new_indices = reshape_op(new_indices, new_indices_shape_changed)
x_shp_tail = x_shp[1:]
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
actual_dout_shape_changed = new_indices_shape_changed
if len(x_shp) > 1:
actual_dout_shape_changed += x_shp[1:]
actual_dout = reshape_op(dout, actual_dout_shape_changed) actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
return bprop return bprop


+ 1
- 0
tests/st/ops/ascend/test_embedding_lookup.py View File

@@ -15,6 +15,7 @@
import numpy as np import numpy as np


import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P


Loading…
Cancel
Save