|
|
|
@@ -231,7 +231,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> |
|
|
|
const std::vector<tensor::Tensor *> &out_tensors, |
|
|
|
const mindspore::lite::PrimitiveC *primitive, const schema::CNode *cnode) { |
|
|
|
MS_ASSERT(nullptr != primitive); |
|
|
|
auto data_type = in_tensors.front()->data_type(); |
|
|
|
TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); |
|
|
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; |
|
|
|
if (context_->device_ctx_.type == DT_GPU) { |
|
|
|
desc.arch = kernel::KERNEL_ARCH::kGPU; |
|
|
|
@@ -271,6 +271,16 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<tensor::Tensor *> &in_tensors) { |
|
|
|
for (const auto &tensor : in_tensors) { |
|
|
|
auto dtype = tensor->data_type(); |
|
|
|
if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8) { |
|
|
|
return dtype; |
|
|
|
} |
|
|
|
} |
|
|
|
return kNumberTypeFloat32; |
|
|
|
} |
|
|
|
|
|
|
|
void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) { |
|
|
|
if (kernel->desc().arch != kernel::KERNEL_ARCH::kCPU) { |
|
|
|
return; |
|
|
|
|