|
|
|
@@ -13,12 +13,11 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include "nnacl/fp16/concat_fp16.h" |
|
|
|
#include "src/runtime/kernel/arm/fp16/concat_fp16.h" |
|
|
|
#include "src/runtime/kernel/arm/fp16/common_fp16.h" |
|
|
|
#include "src/runtime/kernel/arm/fp32/concat.h" |
|
|
|
#include "nnacl/fp16/concat_fp16.h" |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
#include "schema/model_generated.h" |
|
|
|
#include "include/errorcode.h" |
|
|
|
#include "nnacl/fp16/cast_fp16.h" |
|
|
|
|
|
|
|
@@ -142,24 +141,28 @@ int ConcatFp16CPUKernel::Run() { |
|
|
|
} |
|
|
|
|
|
|
|
kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector<lite::Tensor *> &inputs, |
|
|
|
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, |
|
|
|
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, |
|
|
|
const Context *ctx, const kernel::KernelKey &desc, |
|
|
|
const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
if (opParameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Input opParameter is nullptr!"; |
|
|
|
if (parameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Input parameter is nullptr!"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
MS_ASSERT(desc.type == schema::PrimitiveType_Concat); |
|
|
|
auto *kernel = new (std::nothrow) ConcatFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
kernel::LiteKernel *kernel = nullptr; |
|
|
|
if (IsExistFp16Tensor(inputs, outputs)) { |
|
|
|
kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive); |
|
|
|
} else { |
|
|
|
kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive); |
|
|
|
} |
|
|
|
if (kernel == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto ret = kernel->Init(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ |
|
|
|
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_)); |
|
|
|
delete kernel; |
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " |
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return kernel; |
|
|
|
|