Browse Source

fix tensor data type wrong

tags/v1.2.0-rc1
zhaozhenlong 4 years ago
parent
commit
674b7c4cbd
2 changed files with 10 additions and 0 deletions
  1. +5
    -0
      mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc
  2. +5
    -0
      mindspore/lite/src/scheduler.cc

+ 5
- 0
mindspore/lite/src/runtime/agent/npu/optimizer/npu_fusion_pass.cc View File

@@ -204,6 +204,11 @@ int NPUFusionPass::FormatFusion(kernel::LiteKernel *kernel) {
if (trans_kernel->out_kernels().empty()) {
// kernel is a trans kernel, it's input kernel num and input tensor num must be 1
kernel->in_kernels()[0]->set_out_tensors({trans_kernel->out_tensors()[0]});
// in fp16 mode, tensor data type fp16 need to be changed back.
auto tensor = kernel->in_kernels()[0]->out_tensors()[0];
if (tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
for (const auto &post_kernel : trans_kernel->out_kernels()) {
// update tensor


+ 5
- 0
mindspore/lite/src/scheduler.cc View File

@@ -212,6 +212,11 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
if (desc.data_type == kNumberTypeFloat16) {
desc.data_type = kNumberTypeFloat32;
}
for (auto tensor : in_tensors) {
if (tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type};
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc);
if (kernel != nullptr) {


Loading…
Cancel
Save