|
|
|
@@ -18,7 +18,7 @@ |
|
|
|
#include <string.h> |
|
|
|
#include "nnacl/arithmetic_common.h" |
|
|
|
|
|
|
|
void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para) { |
|
|
|
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para) { |
|
|
|
(void)memcpy(output, input0, para->total_data_size_); |
|
|
|
ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_); |
|
|
|
ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_); |
|
|
|
@@ -28,8 +28,9 @@ void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceP |
|
|
|
for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) { |
|
|
|
float *in_batch = in + batch * para->input_stride_[para->batch_axis_]; |
|
|
|
float *out_batch = out + batch * para->output_stride_[para->batch_axis_]; |
|
|
|
for (int n = 0; n < input1[batch]; ++n) { |
|
|
|
float *in_seq = in_batch + (input1[batch] - 1 - n) * para->input_stride_[para->seq_axis_]; |
|
|
|
int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch); |
|
|
|
for (int n = 0; n < seq_length; ++n) { |
|
|
|
float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_]; |
|
|
|
float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_]; |
|
|
|
for (int j = 0; j < para->inner_count_; ++j) { |
|
|
|
(void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_); |
|
|
|
|