Browse Source

stack support one input

tags/v1.0.0
chenjianping 5 years ago
parent
commit
7a8cc39b5c
4 changed files with 12 additions and 1 deletions
  1. +4
    -0
      mindspore/lite/nnacl/fp32/stack.c
  2. +1
    -0
      mindspore/lite/nnacl/fp32/stack.h
  3. +1
    -1
      mindspore/lite/src/ops/stack.cc
  4. +6
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc

+ 4
- 0
mindspore/lite/nnacl/fp32/stack.c View File

@@ -67,3 +67,7 @@ void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape,
in_offset += copy_num;
}
}

void DoStackOneInput(const int8_t *input, int8_t *output, size_t data_size) {
memcpy(output, input, data_size);
}

+ 1
- 0
mindspore/lite/nnacl/fp32/stack.h View File

@@ -29,6 +29,7 @@ extern "C" {
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output);
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
int32_t *output);
void DoStackOneInput(const int8_t *input, int8_t *output, size_t data_size);
#ifdef __cplusplus
}
#endif


+ 1
- 1
mindspore/lite/src/ops/stack.cc View File

@@ -58,7 +58,7 @@ int Stack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::

namespace {
constexpr int kStackOutputNum = 1;
constexpr int kStackMinInputNum = 2;
constexpr int kStackMinInputNum = 1;
} // namespace
int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);


+ 6
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc View File

@@ -48,6 +48,12 @@ int StackCPUKernel::Run() {
return ret;
}
size_t inputs_num = in_tensors_.size();
auto input0 = in_tensors_[0];
if (inputs_num == 1) {
auto *output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
DoStackOneInput(reinterpret_cast<const int8_t *>(input0->Data()), output_data, input0->Size());
return RET_OK;
}
auto input0_shape = in_tensors_[0]->shape();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());


Loading…
Cancel
Save