Merge pull request !4995 from chenjianping/lite_dev2tags/v0.7.0-beta
| @@ -17,7 +17,7 @@ | |||
| #include "nnacl/fp32/stack.h" | |||
| #include "nnacl/arithmetic_common.h" | |||
| void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) { | |||
| size_t GetStackCopyNum(int axis, int *in_shape, size_t shape_size) { | |||
| size_t one_input_size = 1; | |||
| for (size_t i = 0; i < shape_size; ++i) { | |||
| one_input_size *= in_shape[i]; | |||
| @@ -26,11 +26,37 @@ void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t | |||
| ComputeStrides(in_shape, in_strides, shape_size); | |||
| size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size; | |||
| size_t copy_size = copy_num * sizeof(float); | |||
| return copy_num; | |||
| } | |||
| size_t GetStackPreAxisCount(const int *in_shape, int axis) { | |||
| size_t pre_axis_count = 1; | |||
| for (size_t i = 0; i < axis; ++i) { | |||
| pre_axis_count *= in_shape[i]; | |||
| } | |||
| return pre_axis_count; | |||
| } | |||
| void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) { | |||
| size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size); | |||
| size_t copy_size = copy_num * sizeof(float); | |||
| size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis); | |||
| size_t in_offset = 0; | |||
| size_t out_offset = 0; | |||
| for (size_t i = 0; i < pre_axis_count; ++i) { | |||
| for (size_t j = 0; j < input_num; ++j) { | |||
| memcpy(output + out_offset, inputs[j] + in_offset, copy_size); | |||
| out_offset += copy_num; | |||
| } | |||
| in_offset += copy_num; | |||
| } | |||
| } | |||
| void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, | |||
| int32_t *output) { | |||
| size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size); | |||
| size_t copy_size = copy_num * sizeof(int32_t); | |||
| size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis); | |||
| size_t in_offset = 0; | |||
| size_t out_offset = 0; | |||
| for (size_t i = 0; i < pre_axis_count; ++i) { | |||
| @@ -27,6 +27,8 @@ typedef struct StackParameter { | |||
| extern "C" { | |||
| #endif | |||
| void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output); | |||
| void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, | |||
| int32_t *output); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -56,7 +56,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto input = inputs.at(0); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| auto input0_data_type = input->data_type(); | |||
| outputs[0]->set_data_type(input0_data_type); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| @@ -69,12 +70,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| MS_LOG(ERROR) << "Invalid axis " << GetAxis(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| schema::Format input0_format = input->GetFormat(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (inputs[i]->GetFormat() != input0_format) { | |||
| MS_LOG(ERROR) << "All inputs should have the same format!"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto input_shape_tmp = inputs[i]->shape(); | |||
| if (input_shape_tmp.size() != input_shape.size()) { | |||
| MS_LOG(ERROR) << "All input shape size should be the same!"; | |||
| @@ -86,6 +83,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| } | |||
| if (inputs[i]->data_type() != input0_data_type) { | |||
| MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i << "] data type = " | |||
| << inputs[i]->data_type(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| } | |||
| output_shape.insert(output_shape.begin() + axis, inputs.size()); | |||
| outputs[0]->set_shape(output_shape); | |||
| @@ -49,12 +49,21 @@ int StackCPUKernel::Run() { | |||
| } | |||
| size_t inputs_num = in_tensors_.size(); | |||
| auto input0_shape = in_tensors_[0]->shape(); | |||
| auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data()); | |||
| float *inputs[inputs_num]; | |||
| for (size_t i = 0; i < inputs_num; ++i) { | |||
| inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data()); | |||
| if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { | |||
| auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data()); | |||
| float *inputs[inputs_num]; | |||
| for (size_t i = 0; i < inputs_num; ++i) { | |||
| inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data()); | |||
| } | |||
| DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); | |||
| } else { | |||
| auto *output_data = reinterpret_cast<int32_t *>(out_tensors_[0]->Data()); | |||
| int32_t *inputs[inputs_num]; | |||
| for (size_t i = 0; i < inputs_num; ++i) { | |||
| inputs[i] = reinterpret_cast<int32_t *>(in_tensors_[i]->Data()); | |||
| } | |||
| DoStackInt32(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); | |||
| } | |||
| DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); | |||
| return RET_OK; | |||
| } | |||
| @@ -85,4 +94,5 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector<lite::tensor::Te | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Stack, CpuStackFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||