Browse Source

!6885 Add broadcast for mul int8 ops

Merge pull request !6885 from liuwenhao/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3537871d2e
3 changed files with 120 additions and 11 deletions
  1. +43
    -10
      mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc
  2. +5
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h
  3. +72
    -0
      mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc

+ 43
- 10
mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc View File

@@ -56,10 +56,50 @@ int MulInt8CPUKernel::Init() {
para_.mul_quant_arg_.shift_left_ = right_shift < 0 ? -right_shift : 0;
para_.mul_quant_arg_.shift_right_ = right_shift > 0 ? right_shift : 0;

return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int MulInt8CPUKernel::ReSize() { return RET_OK; }
int MulInt8CPUKernel::ReSize() {
size_t input0_size = in_tensors_.at(0)->shape().size();
size_t input1_size = in_tensors_.at(1)->shape().size();
size_t output_size = out_tensors_.at(0)->shape().size();
tile_para->ndim_ = output_size;
if (input0_size == input1_size) {
for (size_t i = 0; i < output_size; i++) {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
} else if (input0_size < input1_size) {
auto fill_dim_num = input1_size - input0_size;
int j = 0;
for (size_t i = 0; i < output_size; i++) {
if (i < fill_dim_num) {
tile_para->in_shape0_[i] = 1;
} else {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(j++);
}
tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
} else {
auto fill_dim_num = input0_size - input1_size;
int j = 0;
for (size_t i = 0; i < output_size; i++) {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
if (i < fill_dim_num) {
tile_para->in_shape1_[i] = 1;
} else {
tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(j++);
}
tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
}
return RET_OK;
}

int MulInt8CPUKernel::Run() {
auto ret = Prepare();
@@ -80,15 +120,8 @@ int MulInt8CPUKernel::Run() {
MS_LOG(ERROR) << "malloc input0_data_ || input1_data_ failed.";
return RET_ERROR;
}
ArithmeticParameter tile_para;
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
for (size_t i = 0; i < tile_para.ndim_; i++) {
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
TileDimensionsInt8(static_cast<int8_t *>(in_tensors_.at(0)->MutableData()),
static_cast<int8_t *>(in_tensors_.at(1)->MutableData()), input0_data_, input1_data_, &tile_para);
static_cast<int8_t *>(in_tensors_.at(1)->MutableData()), input0_data_, input1_data_, tile_para);
ret = ParallelLaunch(this->context_->thread_pool_, MulInt8Run, this, thread_count_);
ctx_->allocator->Free(input0_data_);
ctx_->allocator->Free(input1_data_);


+ 5
- 1
mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h View File

@@ -19,6 +19,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/mul_parameter.h"
#include "nnacl/arithmetic_common.h"
#include "src/runtime/runtime_api.h"

namespace mindspore::kernel {
@@ -27,7 +28,9 @@ class MulInt8CPUKernel : public LiteKernel {
explicit MulInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {
tile_para = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~MulInt8CPUKernel() override{};

int Init() override;
@@ -37,6 +40,7 @@ class MulInt8CPUKernel : public LiteKernel {

private:
const lite::InnerContext *ctx_;
ArithmeticParameter *tile_para;
MulParameter para_;
int thread_count_;
int64_t elements_num_;


+ 72
- 0
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc View File

@@ -313,4 +313,76 @@ TEST_F(TestMulInt8, Mul_quant1_thread1) {
delete output0_tensor;
delete ctx;
}

TEST_F(TestMulInt8, test) {
std::vector<int8_t> input1 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
std::vector<int> shape1 = {2, 2, 3};
std::vector<int8_t> input2 = {1, 2, 3, 4, 5, 6};
std::vector<int> shape2 = {2, 3};
std::vector<int8_t *> input(2, nullptr);
input[0] = input1.data();
input[1] = input2.data();

int8_t output[12];
std::vector<int> output_shape = {2, 2, 3};

lite::QuantArg input_quant_arg;
input_quant_arg.scale = 1.0;
input_quant_arg.zeroPoint = 0;
lite::QuantArg output_quant_arg;
output_quant_arg.scale = 1.0;
output_quant_arg.zeroPoint = 0;

lite::Tensor *input_tensor1 = new lite::Tensor;
TypeId tid_int8 = kNumberTypeInt8;
input_tensor1->SetData(input1.data());
input_tensor1->set_shape(shape1);
input_tensor1->AddQuantParam(input_quant_arg);
input_tensor1->set_data_type(tid_int8);

lite::Tensor *input_tensor2 = new lite::Tensor;
input_tensor2->SetData(input2.data());
input_tensor2->set_shape(shape2);
input_tensor2->AddQuantParam(input_quant_arg);
input_tensor2->set_data_type(tid_int8);

std::vector<lite::Tensor *> inputs_tensor(2);
inputs_tensor[0] = input_tensor1;
inputs_tensor[1] = input_tensor2;

std::vector<lite::Tensor *> outputs_tensor(1);
lite::Tensor *output0_tensor = new lite::Tensor;
output0_tensor->SetData(output);
output0_tensor->set_shape(output_shape);
output0_tensor->AddQuantParam(output_quant_arg);
output0_tensor->set_data_type(tid_int8);
outputs_tensor[0] = output0_tensor;

MulParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_Mul;
lite::InnerContext *ctx = new lite::InnerContext;
ctx->thread_num_ = 2;
ASSERT_EQ(lite::RET_OK, ctx->Init());
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Mul};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&op_param), ctx, desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto output_tensor_shape = output0_tensor->shape();
ASSERT_EQ(output_tensor_shape, output_shape);
kernel->Run();

std::vector<int8_t> except_result = {1, 4, 9, 16, 25, 36, 7, 16, 27, 40, 55, 72};
PrintData("output data", output, input1.size());
CompareOutputData(output, except_result.data(), input1.size(), 0.000001);
input_tensor1->SetData(nullptr);
input_tensor2->SetData(nullptr);
output0_tensor->SetData(nullptr);
delete input_tensor1;
delete input_tensor2;
delete output0_tensor;
delete ctx;
}

} // namespace mindspore

Loading…
Cancel
Save