You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

slice.py 1.4 kB

1234567891011121314151617181920212223242526272829303132333435
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """generate json desc for slice"""
  16. from ._utils import Expander, ExpanderInfoValidator as VLD
  17. @VLD.check_attrs('begin', 'size')
  18. class Slice(Expander):
  19. """Slice expander"""
  20. def _expand(self, graph_builder):
  21. input_x = self.inputs[0]
  22. begin = self.attrs['begin']
  23. size = self.attrs['size']
  24. end = []
  25. strides = []
  26. for i in range(len(begin)):
  27. strides.append(1)
  28. end.append(begin[i] + size[i])
  29. output = graph_builder.tensor(size, input_x.dtype, input_x.data_format)
  30. graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides})
  31. return output