|
|
|
@@ -45,7 +45,7 @@ int GatherFp16CPUKernel::Init() { |
|
|
|
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); |
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); |
|
|
|
} |
|
|
|
|
|
|
|
(reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *static_cast<int *>(in_tensors_.at(2)->data_c()); |
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|