You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stack.c 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp32/stack.h"
  17. #include "nnacl/arithmetic_common.h"
  18. size_t GetStackCopyNum(int axis, int *in_shape, size_t shape_size) {
  19. size_t one_input_size = 1;
  20. for (size_t i = 0; i < shape_size; ++i) {
  21. one_input_size *= in_shape[i];
  22. }
  23. int in_strides[4];
  24. ComputeStrides(in_shape, in_strides, shape_size);
  25. size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size;
  26. return copy_num;
  27. }
  28. size_t GetStackPreAxisCount(const int *in_shape, int axis) {
  29. size_t pre_axis_count = 1;
  30. for (size_t i = 0; i < axis; ++i) {
  31. pre_axis_count *= in_shape[i];
  32. }
  33. return pre_axis_count;
  34. }
  35. void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
  36. size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
  37. size_t copy_size = copy_num * sizeof(float);
  38. size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
  39. size_t in_offset = 0;
  40. size_t out_offset = 0;
  41. for (size_t i = 0; i < pre_axis_count; ++i) {
  42. for (size_t j = 0; j < input_num; ++j) {
  43. memcpy(output + out_offset, inputs[j] + in_offset, copy_size);
  44. out_offset += copy_num;
  45. }
  46. in_offset += copy_num;
  47. }
  48. }
  49. void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
  50. int32_t *output) {
  51. size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
  52. size_t copy_size = copy_num * sizeof(int32_t);
  53. size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
  54. size_t in_offset = 0;
  55. size_t out_offset = 0;
  56. for (size_t i = 0; i < pre_axis_count; ++i) {
  57. for (size_t j = 0; j < input_num; ++j) {
  58. memcpy(output + out_offset, inputs[j] + in_offset, copy_size);
  59. out_offset += copy_num;
  60. }
  61. in_offset += copy_num;
  62. }
  63. }
  64. void DoStackOneInput(const int8_t *input, int8_t *output, size_t data_size) {
  65. memcpy(output, input, data_size);
  66. }