Browse Source

fix embeddinglookupgrad when param shape is one dim

tags/v0.6.0-beta
wuxuejian 5 years ago
parent
commit
e64a53bf1b
2 changed files with 4 additions and 2 deletions
  1. +3
    -2
      mindspore/ops/_grad/grad_array_ops.py
  2. +1
    -0
      tests/st/ops/ascend/test_embedding_lookup.py

+ 3
- 2
mindspore/ops/_grad/grad_array_ops.py View File

@@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self):
# Reshape the 'new_indices'
new_indices_shape_changed = (size_op(new_indices),)
new_indices = reshape_op(new_indices, new_indices_shape_changed)
x_shp_tail = x_shp[1:]
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
actual_dout_shape_changed = new_indices_shape_changed
if len(x_shp) > 1:
actual_dout_shape_changed += x_shp[1:]
actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
return bprop


+ 1
- 0
tests/st/ops/ascend/test_embedding_lookup.py View File

@@ -15,6 +15,7 @@
import numpy as np

import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P


Loading…
Cancel
Save