You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reverse_sequence.c 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/reverse_sequence.h"
  17. #include <string.h>
  18. #include "nnacl/arithmetic_common.h"
  19. void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para) {
  20. (void)memcpy(output, input0, para->total_data_size_);
  21. ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_);
  22. ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_);
  23. for (int i = 0; i < para->outer_count_; ++i) {
  24. float *in = input0 + i * para->outer_stride_;
  25. float *out = output + i * para->outer_stride_;
  26. for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) {
  27. float *in_batch = in + batch * para->input_stride_[para->batch_axis_];
  28. float *out_batch = out + batch * para->output_stride_[para->batch_axis_];
  29. int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch);
  30. for (int n = 0; n < seq_length; ++n) {
  31. float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_];
  32. float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_];
  33. for (int j = 0; j < para->inner_count_; ++j) {
  34. (void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_);
  35. }
  36. }
  37. }
  38. }
  39. }