| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -19,8 +19,14 @@ | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/nnacl_utils.h" | |||
| void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, | |||
| int *outStrides, int *multiple) { | |||
| int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, | |||
| float16_t *out, int size, ArithmeticParameter *param) { | |||
| TileDimensionsFp16(in0, in1, tile_in0, tile_in1, param); | |||
| return ElementAddFp16(tile_in0, tile_in1, out, size); | |||
| } | |||
| void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim, size_t ndim, const int *inShape, | |||
| const int *inStrides, const int *outStrides, const int *multiple) { | |||
| int srcDimSize = inShape[dim]; | |||
| if (dim == ndim - 1) { | |||
| for (int i = 0; i < multiple[dim]; i++) { | |||
| @@ -37,7 +43,7 @@ void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t | |||
| } | |||
| } | |||
| void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, | |||
| void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, | |||
| ArithmeticParameter *param) { | |||
| CalcMultiplesAndStrides(param); | |||
| TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, | |||
| @@ -219,6 +225,12 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int | |||
| float16x8_t vout = vaddq_f16(vin0, vin1); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| for (; index <= element_size - 4; index += C4NUM) { | |||
| float16x4_t vin0 = vld1_f16(input0 + index); | |||
| float16x4_t vin1 = vld1_f16(input1 + index); | |||
| float16x4_t vout = vadd_f16(vin0, vin1); | |||
| vst1_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = input0[index] + input1[index]; | |||
| @@ -270,6 +282,14 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, | |||
| vout = vmaxq_f16(vout, zeros); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| float16x4_t zeros1 = vdup_n_f16(0.0f); | |||
| for (; index <= element_size - 4; index += C4NUM) { | |||
| float16x4_t vin0 = vld1_f16(input0 + index); | |||
| float16x4_t vin1 = vld1_f16(input1 + index); | |||
| float16x4_t vout = vadd_f16(vin0, vin1); | |||
| vout = vmax_f16(vout, zeros1); | |||
| vst1_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| float16_t res = input0[index] + input1[index]; | |||
| @@ -328,6 +348,15 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, | |||
| vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| float16x4_t zeros1 = vdup_n_f16(0.0); | |||
| float16x4_t bounds1 = vdup_n_f16(6.0); | |||
| for (; index <= element_size - 4; index += C4NUM) { | |||
| float16x4_t vin0 = vld1_f16(input0 + index); | |||
| float16x4_t vin1 = vld1_f16(input1 + index); | |||
| float16x4_t vout = vadd_f16(vin0, vin1); | |||
| vout = vmin_f16(vmax_f16(vout, zeros1), bounds1); | |||
| vst1_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| for (; index < element_size; index++) { | |||
| output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -26,6 +26,12 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim, size_t ndim, const int *inShape, | |||
| const int *inStrides, const int *outStrides, const int *multiple); | |||
| void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, | |||
| ArithmeticParameter *param); | |||
| int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| ArithmeticParameter *param); | |||
| int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size, | |||
| @@ -84,6 +90,8 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, | |||
| int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, | |||
| float16_t *out, int size, ArithmeticParameter *param); | |||
| int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); | |||
| @@ -111,8 +119,6 @@ int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, | |||
| int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size); | |||
| void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, | |||
| ArithmeticParameter *param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -125,11 +125,6 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| MS_ASSERT(inputs_.at(1) != nullptr); | |||
| MS_ASSERT(inputs_.at(2) != nullptr); | |||
| auto input0 = reinterpret_cast<TensorList *>(inputs_.at(0)); | |||
| if (input0->tensors_data_type() != GetElementDType()) { | |||
| MS_LOG(ERROR) << "op dtype: " << GetElementDType() | |||
| << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| auto get_index = inputs_.at(1); | |||
| MS_ASSERT(get_index != nullptr); | |||
| if (get_index->ElementsNum() != 1) { | |||
| @@ -184,7 +179,7 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| MS_LOG(ERROR) << "element_shape_ is not fullyDefined!"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data_type(GetElementDType()); | |||
| output->set_data_type(input0->data_type()); | |||
| output->set_shape(element_shape_); | |||
| } | |||
| output->set_format(input0->GetTensor(index_)->format()); | |||
| @@ -0,0 +1,106 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "include/errorcode.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/runtime/kernel/arm/fp16/bias_fp16.h" | |||
| #include "src/kernel_registry.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_NULL_PTR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_BiasAdd; | |||
| namespace mindspore::kernel { | |||
| int BiasCPUFp16Kernel::ReSize() { | |||
| auto dims = in_tensors_.at(0)->shape(); | |||
| bias_param_->ndim_ = dims.size(); | |||
| if (bias_param_->ndim_ < 1 || bias_param_->ndim_ > 5) { | |||
| MS_LOG(ERROR) << "input shape is invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| for (size_t i = 0; i < bias_param_->ndim_; i++) { | |||
| bias_param_->in_shape0_[i] = dims[i]; | |||
| bias_param_->in_shape1_[i] = 1; | |||
| bias_param_->out_shape_[i] = dims[i]; | |||
| } | |||
| bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1]; | |||
| return RET_OK; | |||
| } | |||
| int BiasCPUFp16Kernel::Run() { | |||
| auto in = reinterpret_cast<float16_t *>(in_tensors_.at(0)->MutableData()); | |||
| auto out = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData()); | |||
| size_t data_size = in_tensors_.at(0)->ElementsNum(); | |||
| MS_ASSERT(context_->allocator != nullptr); | |||
| auto *tile_in = reinterpret_cast<float16_t *>(context_->allocator->Malloc(data_size * sizeof(float16_t))); | |||
| auto *tile_bias = reinterpret_cast<float16_t *>(context_->allocator->Malloc(data_size * sizeof(float16_t))); | |||
| if (tile_in == nullptr || tile_bias == nullptr) { | |||
| MS_LOG(ERROR) << "Memory allocation failed"; | |||
| context_->allocator->Free(tile_in); | |||
| context_->allocator->Free(tile_bias); | |||
| return RET_NULL_PTR; | |||
| } | |||
| BroadcastAddFp16(in, bias_data_, tile_in, tile_bias, out, data_size, bias_param_); | |||
| context_->allocator->Free(tile_in); | |||
| context_->allocator->Free(tile_bias); | |||
| return RET_OK; | |||
| } | |||
| BiasCPUFp16Kernel::~BiasCPUFp16Kernel() { | |||
| if ((bias_data_type_ == kNumberTypeFloat || bias_data_type_ == kNumberTypeFloat32) && bias_data_ != nullptr) { | |||
| free(bias_data_); | |||
| bias_data_ = nullptr; | |||
| } | |||
| } | |||
| int BiasCPUFp16Kernel::Init() { | |||
| auto bias_tensor = in_tensors_.at(1); | |||
| MS_ASSERT(bias_tensor != nullptr); | |||
| bias_data_type_ = bias_tensor->data_type(); | |||
| if (bias_data_type_ == kNumberTypeFloat || bias_data_type_ == kNumberTypeFloat32) { | |||
| bias_data_ = reinterpret_cast<float16_t *>(malloc(bias_tensor->ElementsNum() * sizeof(float16_t))); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "bias_data_ is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto *bias = reinterpret_cast<float *>(bias_tensor->MutableData()); | |||
| if (bias != nullptr) { | |||
| MS_LOG(ERROR) << "bias is nullptr!"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| for (int i = 0; i < bias_tensor->ElementsNum(); ++i) { | |||
| bias_data_[i] = (float16_t)(bias[i]); | |||
| } | |||
| } else { | |||
| bias_data_ = reinterpret_cast<float16_t *>(bias_tensor->MutableData()); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "bias_data_ is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| } | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, LiteKernelCreator<BiasCPUFp16Kernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp16/arithmetic_fp16.h" | |||
| namespace mindspore::kernel { | |||
| class BiasCPUFp16Kernel : public LiteKernel { | |||
| public: | |||
| BiasCPUFp16Kernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| bias_param_ = reinterpret_cast<ArithmeticParameter *>(parameter); | |||
| } | |||
| ~BiasCPUFp16Kernel() override; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| ArithmeticParameter *bias_param_ = nullptr; | |||
| float16_t *bias_data_ = nullptr; | |||
| TypeId bias_data_type_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_ | |||
| @@ -53,25 +53,24 @@ int TensorListFromTensorCPUKernel::IsCompatibleShape() { | |||
| } | |||
| int TensorListFromTensorCPUKernel::Init() { | |||
| input0_ = in_tensors_[0]; // row tensor | |||
| input1_ = in_tensors_[1]; // element_shape tensor | |||
| output0_ = out_tensors_[0]; | |||
| return IsCompatibleShape(); | |||
| } | |||
| int TensorListFromTensorCPUKernel::ReSize() { | |||
| auto ret = this->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed!"; | |||
| return ret; | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int TensorListFromTensorCPUKernel::ReSize() { return RET_OK; } | |||
| int TensorListFromTensorCPUKernel::Run() { | |||
| input0_ = in_tensors_[0]; // row tensor | |||
| input1_ = in_tensors_[1]; // element_shape tensor | |||
| output0_ = out_tensors_[0]; | |||
| if (IsCompatibleShape() != RET_OK) { | |||
| MS_LOG(ERROR) << "IsNotCompatibleShape!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (input0_->shape().size() == 0) { | |||
| MS_LOG(ERROR) << "input0_->shape().size():" << input0_->shape().size() << " must be greater than 0"; | |||
| } | |||
| @@ -86,7 +85,9 @@ int TensorListFromTensorCPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| int devision_dim0 = input0_->ElementsNum() / dim0; | |||
| auto in_ptr = reinterpret_cast<float *>(input0_->data_c()); | |||
| auto data_offset = devision_dim0 * lite::DataTypeSize(dtype_); | |||
| auto in_data = reinterpret_cast<char *>(input0_->data_c()); | |||
| MS_ASSERT(in_data != nullptr); | |||
| // copy data from input0(tensor) to output(tensorlist) vector<*tensor> | |||
| for (int i = 0; i < dim0; ++i) { | |||
| auto out_ptr = output0->GetTensor(i); | |||
| @@ -96,37 +97,17 @@ int TensorListFromTensorCPUKernel::Run() { | |||
| << " must be euqal to devision_dim0:" << devision_dim0; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(reinterpret_cast<float *>(out_ptr->MutableData()), in_ptr, devision_dim0 * sizeof(float)); | |||
| in_ptr += devision_dim0; | |||
| auto out_data = out_ptr->MutableData(); | |||
| MS_ASSERT(out_data != nullptr); | |||
| memcpy(out_data, in_data, data_offset); | |||
| in_data += data_offset; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuTensorListFromTensorFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *op_parameter, const lite::InnerContext *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (op_parameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input op_parameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Input context is nullptr!"; | |||
| free(op_parameter); | |||
| return nullptr; | |||
| } | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_TensorListFromTensor); | |||
| op_parameter->thread_num_ = ctx->thread_num_; | |||
| auto *kernel = new (std::nothrow) TensorListFromTensorCPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new TensorListFromTensorCPUKernel fail!"; | |||
| free(op_parameter); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListFromTensor, | |||
| LiteKernelCreator<TensorListFromTensorCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListFromTensor, LiteKernelCreator<TensorListFromTensorCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListFromTensor, | |||
| LiteKernelCreator<TensorListFromTensorCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -21,6 +21,7 @@ | |||
| #include "src/lite_kernel.h" | |||
| #include "src/tensorlist.h" | |||
| #include "schema/model_generated.h" | |||
| #include "nnacl/tensorlist_parameter.h" | |||
| namespace mindspore::kernel { | |||
| class TensorListFromTensorCPUKernel : public LiteKernel { | |||
| @@ -28,7 +29,8 @@ class TensorListFromTensorCPUKernel : public LiteKernel { | |||
| TensorListFromTensorCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), | |||
| dtype_(reinterpret_cast<TensorListParameter *>(parameter)->element_dtype_) {} | |||
| ~TensorListFromTensorCPUKernel() = default; | |||
| int Init() override; | |||
| @@ -41,6 +43,7 @@ class TensorListFromTensorCPUKernel : public LiteKernel { | |||
| lite::Tensor *output0_ = nullptr; | |||
| lite::Tensor *input0_ = nullptr; | |||
| lite::Tensor *input1_ = nullptr; | |||
| TypeId dtype_ = kTypeUnknown; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -31,11 +31,11 @@ namespace mindspore::kernel { | |||
| int TensorListGetItemCPUKernel::Init() { | |||
| MS_ASSERT(in_tensors_.size() >= 2); | |||
| MS_ASSERT(in_tensors_.at(0) != nullptr); | |||
| auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0)); | |||
| if (dtype_ != input0->tensors_data_type()) { | |||
| MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); | |||
| return RET_ERROR; | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| @@ -45,6 +45,10 @@ int TensorListGetItemCPUKernel::Run() { | |||
| MS_ASSERT(in_tensors_.at(1) != nullptr); | |||
| MS_ASSERT(out_tensors_.at(0) != nullptr); | |||
| auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0)); | |||
| if (dtype_ != input0->tensors_data_type()) { | |||
| MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr); | |||
| index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0]; | |||
| int dim0 = input0->ElementsNum() - 1; | |||
| @@ -66,8 +70,7 @@ int TensorListGetItemCPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| // reset 0 and dtype = dtype_ | |||
| // TODO(DT_VARIANT): dtype = DT_VARIANT is not handle | |||
| // reset data buffer is zero | |||
| auto out_data = out_tensors_[0]->data_c(); | |||
| if (out_data == nullptr) { | |||
| MS_LOG(ERROR) << "data of out_tensors_[0] is nullptr"; | |||
| @@ -80,30 +83,7 @@ int TensorListGetItemCPUKernel::Run() { | |||
| int TensorListGetItemCPUKernel::ReSize() { return RET_OK; } | |||
| kernel::LiteKernel *CpuTensorListGetItemFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *op_parameter, const lite::InnerContext *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (op_parameter == nullptr) { | |||
| MS_LOG(ERROR) << "Input op_parameter is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Input context is nullptr!"; | |||
| free(op_parameter); | |||
| return nullptr; | |||
| } | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_TensorListGetItem); | |||
| auto *kernel = new (std::nothrow) TensorListGetItemCPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new TensorListGetItemCPUKernel fail!"; | |||
| free(op_parameter); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListGetItem, CpuTensorListGetItemFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListGetItem, CpuTensorListGetItemFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListGetItem, LiteKernelCreator<TensorListGetItemCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListGetItem, LiteKernelCreator<TensorListGetItemCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListGetItem, LiteKernelCreator<TensorListGetItemCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -27,7 +27,14 @@ using mindspore::schema::PrimitiveType_TensorListReserve; | |||
| namespace mindspore::kernel { | |||
| int TensorListReserveCPUKernel::Init() { return RET_OK; } | |||
| int TensorListReserveCPUKernel::Init() { | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && element_dtype_ == kNumberTypeFloat32) { | |||
| element_dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int TensorListReserveCPUKernel::Run() { | |||
| auto input0 = in_tensors_.at(0); | |||
| @@ -48,5 +55,6 @@ int TensorListReserveCPUKernel::Run() { | |||
| int TensorListReserveCPUKernel::ReSize() { return RET_OK; } | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -28,7 +28,14 @@ using mindspore::schema::PrimitiveType_TensorListSetItem; | |||
| namespace mindspore::kernel { | |||
| int TensorListSetItemCPUKernel::Init() { return RET_OK; } | |||
| int TensorListSetItemCPUKernel::Init() { | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int TensorListSetItemCPUKernel::CheckParam() { | |||
| if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) { | |||
| @@ -143,5 +150,6 @@ int TensorListSetItemCPUKernel::Run() { | |||
| int TensorListSetItemCPUKernel::ReSize() { return RET_OK; } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListSetItem, LiteKernelCreator<TensorListSetItemCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListSetItem, LiteKernelCreator<TensorListSetItemCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListSetItem, LiteKernelCreator<TensorListSetItemCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -60,6 +60,11 @@ int TensorListStackCPUKernel::Init() { | |||
| MS_ASSERT(input0_ != nullptr); | |||
| output0_ = out_tensors_[0]; | |||
| MS_ASSERT(output0_ != nullptr); | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| @@ -159,17 +164,21 @@ int TensorListStackCPUKernel::Run() { | |||
| MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num << "must be equal to in_ele_num:" << in_ele_num; | |||
| return RET_ERROR; | |||
| } | |||
| auto out_ptr = reinterpret_cast<float *>(output0_->MutableData()); | |||
| auto out_data = reinterpret_cast<char *>(output0_->MutableData()); | |||
| auto unknown_type_offset = TypeUnknownSize * lite::DataTypeSize(dtype_); | |||
| MS_ASSERT(out_data != nullptr); | |||
| for (int i = 0; i < num_element_; ++i) { | |||
| auto in_ptr = input0_->GetTensor(i); | |||
| MS_ASSERT(in_ptr != nullptr); | |||
| if (in_ptr->data_type() != kTypeUnknown) { | |||
| int in_size = in_ptr->ElementsNum(); | |||
| memcpy(out_ptr, in_ptr->data_c(), in_size * sizeof(float)); | |||
| out_ptr += in_size; | |||
| int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_); | |||
| auto in_data = in_ptr->data_c(); | |||
| MS_ASSERT(in_data != nullptr); | |||
| memcpy(out_data, in_data, data_size); | |||
| out_data += data_size; | |||
| } else { | |||
| memset(out_ptr, 0, TypeUnknownSize * sizeof(float)); | |||
| out_ptr += TypeUnknownSize; | |||
| memset(out_data, 0, unknown_type_offset); | |||
| out_data += unknown_type_offset; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -178,5 +187,6 @@ int TensorListStackCPUKernel::Run() { | |||
| int TensorListStackCPUKernel::ReSize() { return RET_OK; } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListStack, LiteKernelCreator<TensorListStackCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListStack, LiteKernelCreator<TensorListStackCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListStack, LiteKernelCreator<TensorListStackCPUKernel>) | |||
| } // namespace mindspore::kernel | |||