Browse Source

reverse_seqence seq_lengths support int64

tags/v1.0.0
sunsuodong 5 years ago
parent
commit
bcaceb57a1
3 changed files with 10 additions and 6 deletions
  1. +4
    -3
      mindspore/lite/nnacl/reverse_sequence.c
  2. +2
    -1
      mindspore/lite/nnacl/reverse_sequence.h
  3. +4
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc

+ 4
- 3
mindspore/lite/nnacl/reverse_sequence.c View File

@@ -18,7 +18,7 @@
#include <string.h>
#include "nnacl/arithmetic_common.h"

void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para) {
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para) {
(void)memcpy(output, input0, para->total_data_size_);
ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_);
ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_);
@@ -28,8 +28,9 @@ void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceP
for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) {
float *in_batch = in + batch * para->input_stride_[para->batch_axis_];
float *out_batch = out + batch * para->output_stride_[para->batch_axis_];
for (int n = 0; n < input1[batch]; ++n) {
float *in_seq = in_batch + (input1[batch] - 1 - n) * para->input_stride_[para->seq_axis_];
int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch);
for (int n = 0; n < seq_length; ++n) {
float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_];
float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_];
for (int j = 0; j < para->inner_count_; ++j) {
(void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_);


+ 2
- 1
mindspore/lite/nnacl/reverse_sequence.h View File

@@ -34,12 +34,13 @@ typedef struct ReverseSequenceParameter {
int inner_stride_;
int copy_byte_size_;
int total_data_size_;
bool is_seq_length_int32_;
} ReverseSequenceParameter;

#ifdef __cplusplus
extern "C" {
#endif
void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para);
void ReverseSequence(float *input0, void *input1, float *output, ReverseSequenceParameter *para);
#ifdef __cplusplus
}
#endif


+ 4
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc View File

@@ -93,9 +93,11 @@ int ReverseSequenceCPUKernel::Run() {
return ret;
}
float *input0 = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
int *input1 = reinterpret_cast<int *>(in_tensors_.at(1)->MutableData());
void *input1 = in_tensors_.at(1)->MutableData();
float *output = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
ReverseSequence(input0, input1, output, reinterpret_cast<ReverseSequenceParameter *>(op_parameter_));
ReverseSequenceParameter *param = reinterpret_cast<ReverseSequenceParameter *>(op_parameter_);
param->is_seq_length_int32_ = in_tensors_.at(1)->data_type() == kNumberTypeInt32;
ReverseSequence(input0, input1, output, param);
return RET_OK;
}



Loading…
Cancel
Save