Merge pull request !4662 from 张学同/to_mergetags/v0.7.0-beta
| @@ -82,6 +82,11 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | |||
| return nullptr; | |||
| } | |||
| int index = GetCreatorFuncIndex(desc); | |||
| if (index >= array_size_) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | |||
| << desc.type; | |||
| return nullptr; | |||
| } | |||
| auto it = creator_arrays_[index]; | |||
| if (it != nullptr) { | |||
| return it; | |||
| @@ -91,9 +96,9 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { | |||
| int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { | |||
| int index; | |||
| int device_index = static_cast<int>(desc.arch); | |||
| int dType_index = static_cast<int>(desc.data_type); | |||
| int op_index = static_cast<int>(desc.type); | |||
| int device_index = static_cast<int>(desc.arch) - kKernelArch_MIN; | |||
| int dType_index = static_cast<int>(desc.data_type) - kNumberTypeBegin; | |||
| int op_index = static_cast<int>(desc.type) - PrimitiveType_MIN; | |||
| index = device_index * data_type_length_ * op_type_length_ + dType_index * op_type_length_ + op_index; | |||
| return index; | |||
| } | |||
| @@ -115,6 +120,11 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c | |||
| } | |||
| KernelKey desc = {arch, data_type, op_type}; | |||
| int index = GetCreatorFuncIndex(desc); | |||
| if (index >= array_size_) { | |||
| MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " | |||
| << desc.type; | |||
| return; | |||
| } | |||
| creator_arrays_[index] = creator; | |||
| } | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/reshape_fp16.h" | |||
| #include <vector> | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| #include "nnacl/reshape.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Reshape; | |||
| namespace mindspore::kernel { | |||
| int ReshapeCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << ret; | |||
| return ret; | |||
| } | |||
| auto in_tensor = in_tensors_.at(kInputIndex); | |||
| auto out_tensor = out_tensors_.at(kOutputIndex); | |||
| auto input_ptr = in_tensor->Data(); | |||
| auto output_ptr = out_tensor->Data(); | |||
| size_t data_size = out_tensor->Size(); | |||
| auto in_datatype = in_tensor->data_type(); | |||
| auto out_datatype = out_tensor->data_type(); | |||
| if (in_datatype != out_datatype) { | |||
| if (in_datatype == kNumberTypeFloat32 && out_datatype == kNumberTypeFloat16) { | |||
| input_ptr = context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float16_t)); | |||
| if (input_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "malloc in tensor fail!"; | |||
| return mindspore::lite::RET_MEMORY_FAILED; | |||
| } | |||
| Float32ToFloat16(reinterpret_cast<float *>(in_tensor->Data()), reinterpret_cast<float16_t *>(input_ptr), | |||
| in_tensor->ElementsNum()); | |||
| } else if ((in_datatype == kNumberTypeFloat16 && out_datatype == kNumberTypeFloat32)) { | |||
| input_ptr = context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float)); | |||
| if (input_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "malloc in tensor fail!"; | |||
| return mindspore::lite::RET_MEMORY_FAILED; | |||
| } | |||
| Float16ToFloat32(reinterpret_cast<float16_t *>(in_tensor->Data()), reinterpret_cast<float *>(input_ptr), | |||
| in_tensor->ElementsNum()); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported data type, in_datatype: " << in_datatype << ",out_datatype: " << out_datatype; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| Reshape(input_ptr, output_ptr, data_size); | |||
| if (in_datatype != out_datatype) { | |||
| context_->allocator->Free(input_ptr); | |||
| } | |||
| return RET_OK; | |||
| } // namespace mindspore::kernel | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "include/context.h" | |||
| #include "src/runtime/kernel/arm/fp32/reshape.h" | |||
| using mindspore::lite::Context; | |||
| namespace mindspore::kernel { | |||
| class ReshapeFp16CPUKernel : public ReshapeCPUKernel { | |||
| public: | |||
| ReshapeFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : ReshapeCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~ReshapeFp16CPUKernel() = default; | |||
| int Run() override; | |||
| private: | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_ | |||