Browse Source

fix embeddinglookupgrad when param shape is one dim

tags/v0.6.0-beta
wuxuejian 5 years ago
parent
commit
e64a53bf1b
2 changed files with 4 additions and 2 deletions
  1. +3
    -2
      mindspore/ops/_grad/grad_array_ops.py
  2. +1
    -0
      tests/st/ops/ascend/test_embedding_lookup.py

+ 3
- 2
mindspore/ops/_grad/grad_array_ops.py View File

@@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self):
# Reshape the 'new_indices' # Reshape the 'new_indices'
new_indices_shape_changed = (size_op(new_indices),) new_indices_shape_changed = (size_op(new_indices),)
new_indices = reshape_op(new_indices, new_indices_shape_changed) new_indices = reshape_op(new_indices, new_indices_shape_changed)
x_shp_tail = x_shp[1:]
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
actual_dout_shape_changed = new_indices_shape_changed
if len(x_shp) > 1:
actual_dout_shape_changed += x_shp[1:]
actual_dout = reshape_op(dout, actual_dout_shape_changed) actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
return bprop return bprop


+ 1
- 0
tests/st/ops/ascend/test_embedding_lookup.py View File

@@ -15,6 +15,7 @@
import numpy as np import numpy as np


import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P


Loading…
Cancel
Save