Browse Source

!6327 [MSLITE][Develop] fix concat fp16 when tensor is int32

Merge pull request !6327 from sunsuodong/fix_concat_fp16
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e9def9d276
3 changed files with 36 additions and 11 deletions
  1. +20
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc
  2. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h
  3. +14
    -11
      mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc

+ 20
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc View File

@@ -43,4 +43,24 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::Context *ctx) {
}
return fp16_data;
}

bool IsExistFp16Tensor(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs) {
bool result = false;
for (auto &input : inputs) {
if (input->data_type() == kNumberTypeFloat16) {
result = true;
break;
}
}
if (result) {
return true;
}
for (auto &output : outputs) {
if (output->data_type() == kNumberTypeFloat16) {
result = true;
break;
}
}
return result;
}
} // namespace mindspore::kernel

+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h View File

@@ -16,6 +16,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_

#include <vector>
#include "src/lite_kernel.h"

namespace mindspore::kernel {
@@ -23,6 +24,7 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::Context *ctx)

float16_t *MallocOutputFp16(lite::Tensor *output, const lite::Context *ctx);

bool IsExistFp16Tensor(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs);
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_

+ 14
- 11
mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc View File

@@ -13,12 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include "nnacl/fp16/concat_fp16.h"
#include "src/runtime/kernel/arm/fp16/concat_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "src/runtime/kernel/arm/fp32/concat.h"
#include "nnacl/fp16/concat_fp16.h"
#include "src/kernel_registry.h"
#include "schema/model_generated.h"
#include "include/errorcode.h"
#include "nnacl/fp16/cast_fp16.h"

@@ -142,24 +141,28 @@ int ConcatFp16CPUKernel::Run() {
}

kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const Context *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
if (parameter == nullptr) {
MS_LOG(ERROR) << "Input parameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *kernel = new (std::nothrow) ConcatFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel::LiteKernel *kernel = nullptr;
if (IsExistFp16Tensor(inputs, outputs)) {
kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;


Loading…
Cancel
Save