You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gather.py 1.9 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """generate json desc for gather"""
  16. from ._utils import Expander, ExpanderInfoValidator as VLD
  17. @VLD.check_all_formats_same
  18. @VLD.check_attrs('axis')
  19. class Gather(Expander):
  20. """Expand Gather"""
  21. def _expand(self, graph_builder):
  22. inputs, indices = self.inputs
  23. axis = self.attrs['axis']
  24. if axis < 0:
  25. axis += len(inputs.shape)
  26. if len(indices.shape) == 1:
  27. result = graph_builder.emit('Gather', [inputs, indices], attrs={'axis': axis})
  28. else:
  29. ori_indices_shape = indices.shape
  30. indices_shape_one_dim = 1
  31. for dim in ori_indices_shape:
  32. indices_shape_one_dim *= dim
  33. new_indices_shape = [indices_shape_one_dim]
  34. reshape_indices = graph_builder.emit('Reshape', [indices], attrs={'shape': new_indices_shape})
  35. tmp_result = graph_builder.emit('Gather', [inputs, reshape_indices], attrs={'axis': axis})
  36. output_shape = inputs.shape.copy()
  37. output_shape[axis:axis] = ori_indices_shape
  38. del output_shape[axis + len(ori_indices_shape)]
  39. result = graph_builder.emit('Reshape', [tmp_result], attrs={'shape': output_shape})
  40. return result