|
|
|
@@ -35,30 +35,15 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso |
|
|
|
auto input_tensor = inputs.at(kInputIndex); |
|
|
|
auto data_type = input_tensor->data_type(); |
|
|
|
kernel::LiteKernel *kernel = nullptr; |
|
|
|
switch (data_type) { |
|
|
|
case kNumberTypeInt8: |
|
|
|
case kNumberTypeUInt8: { |
|
|
|
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
if (kernel == nullptr) { |
|
|
|
MS_LOG(ERROR) << "kernel is nullptr."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
case kNumberTypeFloat32: { |
|
|
|
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
if (kernel == nullptr) { |
|
|
|
MS_LOG(ERROR) << "kernel is nullptr."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
default: |
|
|
|
break; |
|
|
|
if (data_type == kNumberTypeInt8 || data_type == kNumberTypeUInt8) { |
|
|
|
kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
} else { |
|
|
|
kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
} |
|
|
|
if (kernel == nullptr) { |
|
|
|
MS_LOG(ERROR) << "kernel is nullptr."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto ret = kernel->Init(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
delete kernel; |
|
|
|
|