You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stack_fp16.c 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/fp16/stack_fp16.h"
  17. #include "nnacl/arithmetic_common.h"
  18. size_t Fp16GetStackCopyNum(int axis, int *in_shape, size_t shape_size) {
  19. size_t one_input_size = 1;
  20. for (size_t i = 0; i < shape_size; ++i) {
  21. one_input_size *= in_shape[i];
  22. }
  23. int in_strides[4];
  24. ComputeStrides(in_shape, in_strides, shape_size);
  25. size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size;
  26. return copy_num;
  27. }
  28. size_t Fp16GetStackPreAxisCount2(const int *in_shape, int axis) {
  29. size_t pre_axis_count = 1;
  30. for (size_t i = 0; i < axis; ++i) {
  31. pre_axis_count *= in_shape[i];
  32. }
  33. return pre_axis_count;
  34. }
  35. void DoStackFp16(const float16_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
  36. float16_t *output) {
  37. size_t copy_num = Fp16GetStackCopyNum(axis, in_shape, shape_size);
  38. size_t copy_size = copy_num * sizeof(float16_t);
  39. size_t pre_axis_count = Fp16GetStackPreAxisCount2(in_shape, axis);
  40. size_t in_offset = 0;
  41. size_t out_offset = 0;
  42. for (size_t i = 0; i < pre_axis_count; ++i) {
  43. for (size_t j = 0; j < input_num; ++j) {
  44. memcpy(output + out_offset, inputs[j] + in_offset, copy_size);
  45. out_offset += copy_num;
  46. }
  47. in_offset += copy_num;
  48. }
  49. }