You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transpose_cpu_kernel.cc 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/cpu/transpose_cpu_kernel.h"
  17. #include "device/cpu/cpu_device_address.h"
  18. namespace mindspore {
  19. namespace kernel {
  20. const size_t kMaxDim = 100;
  21. void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) {
  22. MS_EXCEPTION_IF_NULL(kernel_node);
  23. shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
  24. axis_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "perm");
  25. if (shape_.size() != axis_.size()) {
  26. MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal.";
  27. }
  28. }
  29. bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
  30. const std::vector<kernel::AddressPtr> & /*workspace*/,
  31. const std::vector<kernel::AddressPtr> &outputs) {
  32. auto input = reinterpret_cast<float *>(inputs[0]->addr);
  33. auto output = reinterpret_cast<float *>(outputs[0]->addr);
  34. size_t size = IntToSize(inputs[0]->size / sizeof(float));
  35. size_t shape_size = IntToSize(shape_.size());
  36. if (shape_size > kMaxDim) {
  37. MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs.";
  38. }
  39. size_t pos_array[kMaxDim];
  40. size_t size_offset[kMaxDim];
  41. size_offset[0] = size / shape_[0];
  42. for (size_t i = 1; i < shape_size; i++) {
  43. size_offset[i] = size_offset[SizeToInt(i) - 1] / shape_[i];
  44. }
  45. for (size_t position = 0; position < size; position += 1) {
  46. size_t temp_position = position;
  47. pos_array[0] = temp_position / size_offset[0];
  48. for (size_t i = 1; i < shape_size; i++) {
  49. temp_position -= pos_array[SizeToInt(i) - 1] * size_offset[i - 1];
  50. pos_array[i] = temp_position / size_offset[i];
  51. }
  52. size_t new_position = pos_array[axis_[SizeToInt(shape_size) - 1]];
  53. size_t new_position_size = 1;
  54. for (int j = shape_size - 2; j >= 0; j--) {
  55. new_position_size *= shape_[axis_[j + 1]];
  56. new_position += pos_array[axis_[j]] * new_position_size;
  57. }
  58. output[new_position] = input[position];
  59. }
  60. return true;
  61. }
  62. } // namespace kernel
  63. } // namespace mindspore