Browse Source

eltwise_int8

tags/v1.1.0
sunsuodong 5 years ago
parent
commit
102571e29f
9 changed files with 121 additions and 270 deletions
  1. +5
    -5
      mindspore/lite/src/ops/arithmetic.h
  2. +4
    -4
      mindspore/lite/src/ops/eltwise.h
  3. +0
    -46
      mindspore/lite/src/ops/populate/add_populate.cc
  4. +98
    -16
      mindspore/lite/src/ops/populate/arithmetic_populate.cc
  5. +0
    -47
      mindspore/lite/src/ops/populate/div_populate.cc
  6. +0
    -52
      mindspore/lite/src/ops/populate/eltwise_populate.cc
  7. +0
    -48
      mindspore/lite/src/ops/populate/mul_populate.cc
  8. +0
    -47
      mindspore/lite/src/ops/populate/sub_populate.cc
  9. +14
    -5
      mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc

+ 5
- 5
mindspore/lite/src/ops/arithmetic.h View File

@@ -39,11 +39,11 @@ class Arithmetic : public PrimitiveC {
}
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; }
int NDims() { return this->ndim_; }
std::vector<int> InShape0() { return this->in_shape0_; }
std::vector<int> InShape1() { return this->in_shape1_; }
std::vector<int> OutputShape() { return this->out_shape_; }
bool Broadcasting() const { return this->broadcasting_; }
int NDims() const { return this->ndim_; }
std::vector<int> InShape0() const { return this->in_shape0_; }
std::vector<int> InShape1() const { return this->in_shape1_; }
std::vector<int> OutputShape() const { return this->out_shape_; }

protected:
bool broadcasting_ = false;


+ 4
- 4
mindspore/lite/src/ops/eltwise.h View File

@@ -21,20 +21,20 @@
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
#include "src/ops/arithmetic.h"

namespace mindspore {
namespace lite {
class Eltwise : public PrimitiveC {
class Eltwise : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Eltwise, PrimitiveC);
MS_DECLARE_PARENT(Eltwise, Arithmetic);
Eltwise() = default;
explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
explicit Eltwise(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
void SetMode(int mode);

#else
Eltwise() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetMode() const;


+ 0
- 46
mindspore/lite/src/ops/populate/add_populate.cc View File

@@ -1,46 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/add.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Add *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter);
} // namespace lite
} // namespace mindspore

+ 98
- 16
mindspore/lite/src/ops/populate/arithmetic_populate.cc View File

@@ -13,8 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/arithmetic.h"
#include "src/ops/add.h"
#include "src/ops/sub.h"
#include "src/ops/mul.h"
#include "src/ops/div.h"
#include "src/ops/eltwise.h"
#include "src/ops/greater_equal.h"
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#include "src/ops/primitive_c.h"
@@ -22,27 +27,98 @@

namespace mindspore {
namespace lite {
ArithmeticParameter *PopulateArithmeticCommonPara(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ArithmeticParameter));
param->op_parameter_.type_ = primitive->Type();
param->broadcasting_ = reinterpret_cast<const lite::Arithmetic *>(primitive)->Broadcasting();
param->ndim_ = reinterpret_cast<const lite::Arithmetic *>(primitive)->NDims();
param->activation_type_ = 0;

auto tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->InShape0();
memcpy(param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->InShape1();
memcpy(param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->OutputShape();
memcpy(param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return param;
}

OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
return reinterpret_cast<OpParameter *>(param);
}

OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
param->activation_type_ = reinterpret_cast<const mindspore::lite::Add *>(primitive)->GetActivationType();
return reinterpret_cast<OpParameter *>(param);
}

OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
param->activation_type_ = reinterpret_cast<const mindspore::lite::Sub *>(primitive)->GetActivationType();
return reinterpret_cast<OpParameter *>(param);
}

arithmetic_param->activation_type_ = 0;
OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
param->activation_type_ = reinterpret_cast<const mindspore::lite::Mul *>(primitive)->GetActivationType();
return reinterpret_cast<OpParameter *>(param);
}

auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
param->activation_type_ = reinterpret_cast<const mindspore::lite::Div *>(primitive)->GetActivationType();
return reinterpret_cast<OpParameter *>(param);
}

OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
auto eltwise = reinterpret_cast<const mindspore::lite::Eltwise *>(primitive);
switch (eltwise->GetMode()) {
case schema::EltwiseMode_PROD:
param->op_parameter_.type_ = schema::PrimitiveType_Mul;
break;
case schema::EltwiseMode_SUM:
param->op_parameter_.type_ = schema::PrimitiveType_Add;
break;
case schema::EltwiseMode_MAXIMUM:
param->op_parameter_.type_ = schema::PrimitiveType_Maximum;
break;
default:
free(param);
return nullptr;
}
return reinterpret_cast<OpParameter *>(param);
}

Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithmetic);
@@ -51,6 +127,7 @@ Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic);
Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic);
Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic);
Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic);
Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic);
@@ -58,5 +135,10 @@ Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithme
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);
Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic);
Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic);
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter);
Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter);
Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter);
Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter);
Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter);
} // namespace lite
} // namespace mindspore

+ 0
- 47
mindspore/lite/src/ops/populate/div_populate.cc View File

@@ -1,47 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/div.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {

OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Div *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter);

} // namespace lite
} // namespace mindspore

+ 0
- 52
mindspore/lite/src/ops/populate/eltwise_populate.cc View File

@@ -1,52 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/eltwise.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"

namespace mindspore {
namespace lite {

OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
auto eltwise = reinterpret_cast<mindspore::lite::Eltwise *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
switch (eltwise->GetMode()) {
case schema::EltwiseMode_PROD:
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul;
break;
case schema::EltwiseMode_SUM:
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add;
break;
case schema::EltwiseMode_MAXIMUM:
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum;
break;
default:
free(arithmetic_param);
return nullptr;
}
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter);

} // namespace lite
} // namespace mindspore

+ 0
- 48
mindspore/lite/src/ops/populate/mul_populate.cc View File

@@ -1,48 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/mul.h"
#include "nnacl/arithmetic_common.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {

OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Mul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter);

} // namespace lite
} // namespace mindspore

+ 0
- 47
mindspore/lite/src/ops/populate/sub_populate.cc View File

@@ -1,47 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/sub.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"

namespace mindspore {
namespace lite {

OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Sub *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter);
} // namespace lite
} // namespace mindspore

+ 14
- 5
mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc View File

@@ -15,6 +15,8 @@
*/

#include "src/runtime/kernel/arm/int8/arithmetic_int8.h"
#include "src/runtime/kernel/arm/int8/add_int8.h"
#include "src/runtime/kernel/arm/int8/mul_int8.h"
#include "nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@@ -27,11 +29,14 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::RET_PARAM_INVALID;

using mindspore::schema::PrimitiveType_Add;
using mindspore::schema::PrimitiveType_Eltwise;
using mindspore::schema::PrimitiveType_Equal;
using mindspore::schema::PrimitiveType_Greater;
using mindspore::schema::PrimitiveType_GreaterEqual;
using mindspore::schema::PrimitiveType_Less;
using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_Mul;
using mindspore::schema::PrimitiveType_NotEqual;

namespace mindspore::kernel {
@@ -159,11 +164,15 @@ kernel::LiteKernel *CpuArithmeticInt8KernelCreator(const std::vector<lite::Tenso
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "Input parameter is null!";
return nullptr;
kernel::LiteKernel *kernel = nullptr;
if (desc.type == PrimitiveType_Eltwise && static_cast<schema::PrimitiveType>(parameter->type_) == PrimitiveType_Add) {
kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx, primitive);
} else if (desc.type == PrimitiveType_Eltwise &&
static_cast<schema::PrimitiveType>(parameter->type_) == PrimitiveType_Mul) {
kernel = new (std::nothrow) MulInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
}
auto kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create ArithmeticInt8CPUKernel failed, name: " << parameter->name_;
free(parameter);
@@ -185,5 +194,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Less, CpuArithmeticInt8KernelCre
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LessEqual, CpuArithmeticInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Greater, CpuArithmeticInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_GreaterEqual, CpuArithmeticInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Eltwise, CpuArithmeticInt8KernelCreator)
} // namespace mindspore::kernel

Loading…
Cancel
Save