You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

split.c 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/split.h"
  17. #include "nnacl/split_parameter.h"
  18. #include <string.h>
  19. #include "nnacl/errorcode.h"
  20. int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset, int num_unit,
  21. SplitParameter *split_param) {
  22. if (in_data == NULL || out_data == NULL) {
  23. return NNACL_ERR;
  24. }
  25. int num_split = split_param->num_split_;
  26. int *split_sizes = split_param->split_sizes_;
  27. int *strides = split_param->strides_;
  28. int split_dim = split_param->split_dim_;
  29. int in_stride = strides[split_dim];
  30. float *src;
  31. int size_float = (int)(sizeof(float));
  32. int in_stride_bytes = in_stride * size_float;
  33. int split_which;
  34. int split_times;
  35. int stride_per_split = in_stride * input_shape[split_dim];
  36. split_which = offset % num_split;
  37. split_times = offset / num_split;
  38. src = in_data + split_times * stride_per_split;
  39. for (int i = 0; i < split_which; i++) {
  40. src += split_sizes[i] * in_stride;
  41. }
  42. for (int i = offset; i < offset + num_unit; i++) {
  43. split_which = i % num_split;
  44. split_times = i / num_split;
  45. int split_size = split_sizes[split_which];
  46. float *dst = out_data[split_which] + split_times * in_stride * split_size;
  47. (void)memcpy(dst, src, split_size * in_stride_bytes);
  48. src += split_size * in_stride;
  49. }
  50. return NNACL_OK;
  51. }