| @@ -219,6 +219,24 @@ def get_bprop_embedding_lookup(self): | |||||
| return bprop_sparse | return bprop_sparse | ||||
| @bprop_getters.register(P.EmbeddingLookup) | |||||
| def get_bprop_embedding_look_up(self): | |||||
| """Generate bprop for EmbeddingLookup""" | |||||
| sub_op = P.Sub() | |||||
| reshape_op = P.Reshape() | |||||
| def bprop(x, indices, offset, out, dout): | |||||
| x_shp = shape_op(x) | |||||
| new_indices = sub_op(indices, offset) | |||||
| # Reshape the 'new_indices' | |||||
| new_indices_shape_changed = (size_op(new_indices),) | |||||
| new_indices = reshape_op(new_indices, new_indices_shape_changed) | |||||
| x_shp_tail = x_shp[1:] | |||||
| actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail | |||||
| actual_dout = reshape_op(dout, actual_dout_shape_changed) | |||||
| return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) | |||||
| return bprop | |||||
| @bprop_getters.register(P.Transpose) | @bprop_getters.register(P.Transpose) | ||||
| def get_bprop_transpose(self): | def get_bprop_transpose(self): | ||||
| """Generate bprop for Transpose""" | """Generate bprop for Transpose""" | ||||