Browse Source

Fix matmul asm bugs

tags/v0.7.0-beta
zhanyuan 5 years ago
parent
commit
b99d8590a1
11 changed files with 187 additions and 90 deletions
  1. +1
    -1
      mindspore/lite/src/ops/matmul.cc
  2. +10
    -14
      mindspore/lite/src/ops/power.cc
  3. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc
  4. +7
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc
  5. +8
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/power.cc
  6. +2
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/power.h
  7. +54
    -48
      mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s
  8. +5
    -2
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc
  9. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h
  10. +3
    -3
      mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h
  11. +94
    -8
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc

+ 1
- 1
mindspore/lite/src/ops/matmul.cc View File

@@ -35,7 +35,7 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor


std::vector<int> a_shape = input0->shape(); std::vector<int> a_shape = input0->shape();
std::vector<int> b_shape = input1->shape(); std::vector<int> b_shape = input1->shape();
if (a_shape.size() < 3 || b_shape.size() < 3) {
if (a_shape.size() < 2 || b_shape.size() < 2) {
MS_LOG(ERROR) << "inputs shape is invalid"; MS_LOG(ERROR) << "inputs shape is invalid";
return RET_INPUT_TENSOR_ERROR; return RET_INPUT_TENSOR_ERROR;
} }


+ 10
- 14
mindspore/lite/src/ops/power.cc View File

@@ -24,24 +24,20 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
MS_ASSERT(this->primitive != nullptr); MS_ASSERT(this->primitive != nullptr);
auto x_tensor = inputs[0]; auto x_tensor = inputs[0];
MS_ASSERT(x_tensor != nullptr); MS_ASSERT(x_tensor != nullptr);
auto exp_tensor = inputs[1];
MS_ASSERT(exp_tensor != nullptr);
tensor::Tensor *exp_tensor = nullptr;
if (inputs.size() == 2) {
exp_tensor = inputs[1];
MS_ASSERT(exp_tensor != nullptr);
}
auto output_tensor = outputs[0]; auto output_tensor = outputs[0];
MS_ASSERT(output_tensor != nullptr); MS_ASSERT(output_tensor != nullptr);
if (inputs.size() < 2) {
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
return RET_INPUT_TENSOR_ERROR;
}
if (exp_tensor->shape() != x_tensor->shape() && exp_tensor->shape().size() != 1) {
MS_LOG(ERROR) << "Power inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
if (exp_tensor) {
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
} }


int exp_size = std::accumulate(exp_tensor->shape().begin(), exp_tensor->shape().end(), 1, std::multiplies<int>());
if (x_tensor->data_type() != exp_tensor->data_type() && exp_size != 1) {
MS_LOG(ERROR) << "Exponent tensor's shape is wrong";
return RET_INPUT_TENSOR_ERROR;
}
output_tensor->SetFormat(x_tensor->GetFormat()); output_tensor->SetFormat(x_tensor->GetFormat());
output_tensor->set_shape(x_tensor->shape()); output_tensor->set_shape(x_tensor->shape());
output_tensor->set_data_type(x_tensor->data_type()); output_tensor->set_data_type(x_tensor->data_type());


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc View File

@@ -69,4 +69,5 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso
} }


REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulKernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatmulKernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

+ 7
- 7
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc View File

@@ -34,15 +34,15 @@ int MatmulCPUKernel::ReSize() { return RET_OK; }


int MatmulCPUKernel::Init() { int MatmulCPUKernel::Init() {
int batch = 1; int batch = 1;
auto x_shape = inputs_[0]->shape();
auto o_shape = outputs_[0]->shape();
for (int i = 0; i < x_shape.size() - 2; ++i) {
batch *= x_shape[i];
auto a_shape = inputs_[0]->shape();
auto c_shape = outputs_[0]->shape();
for (int i = 0; i < a_shape.size() - 2; ++i) {
batch *= a_shape[i];
} }
params_->batch = batch; params_->batch = batch;
params_->row_ = o_shape[o_shape.size() - 2];
params_->col_ = o_shape[o_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1];
params_->row_ = c_shape[c_shape.size() - 2];
params_->col_ = c_shape[c_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_8_ = UP_ROUND(params_->row_, 8); params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->col_8_ = UP_ROUND(params_->col_, 8); params_->col_8_ = UP_ROUND(params_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));


+ 8
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/power.cc View File

@@ -51,15 +51,19 @@ int PowerCPUKernel::Run() {


int PowerCPUKernel::RunImpl(int task_id) { int PowerCPUKernel::RunImpl(int task_id) {
auto x_addr = reinterpret_cast<float *>(inputs_[0]->Data()); auto x_addr = reinterpret_cast<float *>(inputs_[0]->Data());
auto exp_addr = reinterpret_cast<float *>(inputs_[1]->Data());
auto output_addr = reinterpret_cast<float *>(outputs_[0]->Data()); auto output_addr = reinterpret_cast<float *>(outputs_[0]->Data());
auto size = inputs_[0]->ElementsNum(); auto size = inputs_[0]->ElementsNum();
int stride = UP_DIV(size, thread_count_); int stride = UP_DIV(size, thread_count_);
int len = MSMIN(stride, size - stride * task_id); int len = MSMIN(stride, size - stride * task_id);
bool broadcast = (inputs_[1]->ElementsNum() == 1) ? true : false;
float *exp_addr = nullptr;
bool broadcast = true;
if (inputs_.size() == 2) {
exp_addr = reinterpret_cast<float *>(inputs_[1]->Data());
broadcast = false;
}
float *cur_exp; float *cur_exp;
if (broadcast) { if (broadcast) {
cur_exp = exp_addr;
cur_exp = &power_;
} else { } else {
cur_exp = exp_addr + stride * task_id; cur_exp = exp_addr + stride * task_id;
} }
@@ -73,8 +77,7 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te
const kernel::KernelKey &desc) { const kernel::KernelKey &desc) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Power); MS_ASSERT(desc.type == schema::PrimitiveType_Power);
auto *kernel =
new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx);
auto *kernel = new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new PowerCPUKernel fail!"; MS_LOG(ERROR) << "new PowerCPUKernel fail!";
return nullptr; return nullptr;


+ 2
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/power.h View File

@@ -30,6 +30,7 @@ class PowerCPUKernel : public LiteKernel {
: LiteKernel(param, inputs, outputs), : LiteKernel(param, inputs, outputs),
ctx_(ctx), ctx_(ctx),
thread_count_(ctx->thread_num_), thread_count_(ctx->thread_num_),
power_(reinterpret_cast<PowerParameter *>(opParameter)->power_),
scale_(reinterpret_cast<PowerParameter *>(opParameter)->scale_), scale_(reinterpret_cast<PowerParameter *>(opParameter)->scale_),
shift_(reinterpret_cast<PowerParameter *>(opParameter)->shift_) {} shift_(reinterpret_cast<PowerParameter *>(opParameter)->shift_) {}
~PowerCPUKernel() override = default; ~PowerCPUKernel() override = default;
@@ -42,6 +43,7 @@ class PowerCPUKernel : public LiteKernel {
private: private:
const lite::Context *ctx_; const lite::Context *ctx_;
int thread_count_; int thread_count_;
float power_;
float scale_; float scale_;
float shift_; float shift_;
}; };


+ 54
- 48
mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s View File

@@ -1,9 +1,9 @@
#ifdef __aarch64__ #ifdef __aarch64__
.text .text
.align 5 .align 5
.global MatMulFloatNeon64
.global MatmulFloatNeon64
#ifndef __APPLE__ #ifndef __APPLE__
.type MatMulFloatNeon64, %function
.type MatmulFloatNeon64, %function
#endif #endif


// A: LM [row_8 * depth] col_8_major // A: LM [row_8 * depth] col_8_major
@@ -46,41 +46,39 @@
// accumulators 8x8 block // accumulators 8x8 block
///////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////
// //
// void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, int col)
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col)
// x0: a // x0: a
// x1: b // x1: b
// x2: c // x2: c
// x3: bias // x3: bias
// v0.s[0]: maxf
// v1.s[0]: minf
// w4: depth
// w5: row
// w6: col
// w4: act_type
// w5: depth
// w6: row
// w7: col


MatMulFloatNeon64:
MatmulFloatNeon64:
sub sp, sp, #128 sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64


mov w7, v0.s[0]
mov w8, v1.s[0]
mov w9, 0 // rm col offset
mov w10, 0 // lm row offset
mov w9, #0 // rm col offset
mov w10, #0 // lm row offset
mov w18, #32 // sizeof(float)*8 mov w18, #32 // sizeof(float)*8
mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth
mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth
mov x11, x3 // bias flag
L1: L1:
cmp w9, w6
cmp w9, w7
beq End1 beq End1


mov w10, 0 // reset lm row offset
mov w10, #0 // reset lm row offset
mov x12, x0 // reload lm ptr mov x12, x0 // reload lm ptr
mov x14, x3 // reload bias ptr
L2: L2:
cmp w10, w6 cmp w10, w6
beq End2 beq End2


mov w13, w4 // reload depth
mov x16, x1 // reload rm ptr
mov w13, w5 // reload depth
mov x14, x3 // reload bias ptr
dup v16.4s, wzr dup v16.4s, wzr
dup v17.4s, wzr dup v17.4s, wzr
dup v18.4s, wzr dup v18.4s, wzr
@@ -103,7 +101,7 @@ OptLoopMul4:
blt CommLoopMul blt CommLoopMul


ld1 {v0.4s, v1.4s}, [x12], #32 ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v8.4s, v9.4s}, [x1], #32
ld1 {v8.4s, v9.4s}, [x16], #32
fmla v16.4s, v8.4s, v0.s[0] fmla v16.4s, v8.4s, v0.s[0]
fmla v17.4s, v9.4s, v0.s[0] fmla v17.4s, v9.4s, v0.s[0]
fmla v18.4s, v8.4s, v0.s[1] fmla v18.4s, v8.4s, v0.s[1]
@@ -112,7 +110,7 @@ OptLoopMul4:
fmla v21.4s, v9.4s, v0.s[2] fmla v21.4s, v9.4s, v0.s[2]
fmla v22.4s, v8.4s, v0.s[3] fmla v22.4s, v8.4s, v0.s[3]
fmla v23.4s, v9.4s, v0.s[3] fmla v23.4s, v9.4s, v0.s[3]
ld1 {v10.4s, v11.4s}, [x1], #32
ld1 {v10.4s, v11.4s}, [x16], #32
fmla v24.4s, v8.4s, v1.s[0] fmla v24.4s, v8.4s, v1.s[0]
fmla v25.4s, v9.4s, v1.s[0] fmla v25.4s, v9.4s, v1.s[0]
fmla v26.4s, v8.4s, v1.s[1] fmla v26.4s, v8.4s, v1.s[1]
@@ -130,7 +128,7 @@ OptLoopMul4:
fmla v21.4s, v11.4s, v2.s[2] fmla v21.4s, v11.4s, v2.s[2]
fmla v22.4s, v10.4s, v2.s[3] fmla v22.4s, v10.4s, v2.s[3]
fmla v23.4s, v11.4s, v2.s[3] fmla v23.4s, v11.4s, v2.s[3]
ld1 {v12.4s, v13.4s}, [x1], #32
ld1 {v12.4s, v13.4s}, [x16], #32
fmla v24.4s, v10.4s, v3.s[0] fmla v24.4s, v10.4s, v3.s[0]
fmla v25.4s, v11.4s, v3.s[0] fmla v25.4s, v11.4s, v3.s[0]
fmla v26.4s, v10.4s, v3.s[1] fmla v26.4s, v10.4s, v3.s[1]
@@ -153,7 +151,7 @@ OptLoopMul4:
fmla v25.4s, v13.4s, v5.s[0] fmla v25.4s, v13.4s, v5.s[0]
fmla v26.4s, v12.4s, v5.s[1] fmla v26.4s, v12.4s, v5.s[1]
fmla v27.4s, v13.4s, v5.s[1] fmla v27.4s, v13.4s, v5.s[1]
ld1 {v14.4s, v15.4s}, [x1], #32
ld1 {v14.4s, v15.4s}, [x16], #32
fmla v28.4s, v12.4s, v5.s[2] fmla v28.4s, v12.4s, v5.s[2]
fmla v29.4s, v13.4s, v5.s[2] fmla v29.4s, v13.4s, v5.s[2]
fmla v30.4s, v12.4s, v5.s[3] fmla v30.4s, v12.4s, v5.s[3]
@@ -182,7 +180,7 @@ CommLoopMul:
blt Bias blt Bias


ld1 {v0.4s, v1.4s}, [x12], #32 ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v2.4s, v3.4s}, [x1], #32
ld1 {v2.4s, v3.4s}, [x16], #32
fmla v16.4s, v2.4s, v0.s[0] fmla v16.4s, v2.4s, v0.s[0]
fmla v17.4s, v3.4s, v0.s[0] fmla v17.4s, v3.4s, v0.s[0]
fmla v18.4s, v2.4s, v0.s[1] fmla v18.4s, v2.4s, v0.s[1]
@@ -203,8 +201,7 @@ CommLoopMul:
b CommLoopMul b CommLoopMul


Bias: Bias:
cmp x3, #0
beq Relu
cbz x11, Activation
ld1 {v0.4s}, [x14], #16 ld1 {v0.4s}, [x14], #16
ld1 {v1.4s}, [x14], #16 ld1 {v1.4s}, [x14], #16
fadd v16.4s, v16.4s, v0.4s fadd v16.4s, v16.4s, v0.4s
@@ -224,9 +221,34 @@ Bias:
fadd v30.4s, v30.4s, v0.4s fadd v30.4s, v30.4s, v0.4s
fadd v31.4s, v31.4s, v1.4s fadd v31.4s, v31.4s, v1.4s


Activation:
cmp w4, #2
beq Relu6
cmp w4, #1
beq Relu
b TransToOut
Relu6:
mov w8, #6
dup v15.4s, w8
scvtf v15.4s, v15.4s
fmin v16.4s, v16.4s, v15.4s
fmin v17.4s, v17.4s, v15.4s
fmin v18.4s, v18.4s, v15.4s
fmin v19.4s, v19.4s, v15.4s
fmin v20.4s, v20.4s, v15.4s
fmin v21.4s, v21.4s, v15.4s
fmin v22.4s, v22.4s, v15.4s
fmin v23.4s, v23.4s, v15.4s
fmin v24.4s, v24.4s, v15.4s
fmin v25.4s, v25.4s, v15.4s
fmin v26.4s, v26.4s, v15.4s
fmin v27.4s, v27.4s, v15.4s
fmin v28.4s, v28.4s, v15.4s
fmin v29.4s, v29.4s, v15.4s
fmin v30.4s, v30.4s, v15.4s
fmin v31.4s, v31.4s, v15.4s
Relu: Relu:
dup v15.4s, w7
dup v14.4s, w8
dup v14.4s, wzr
fmax v16.4s, v16.4s, v14.4s fmax v16.4s, v16.4s, v14.4s
fmax v17.4s, v17.4s, v14.4s fmax v17.4s, v17.4s, v14.4s
fmax v18.4s, v18.4s, v14.4s fmax v18.4s, v18.4s, v14.4s
@@ -244,24 +266,6 @@ Relu:
fmax v30.4s, v30.4s, v14.4s fmax v30.4s, v30.4s, v14.4s
fmax v31.4s, v31.4s, v14.4s fmax v31.4s, v31.4s, v14.4s


fmin v16.4s, v16.4s, v15.4s
fmin v17.4s, v17.4s, v15.4s
fmin v18.4s, v18.4s, v15.4s
fmin v19.4s, v19.4s, v15.4s
fmin v20.4s, v20.4s, v15.4s
fmin v20.4s, v20.4s, v15.4s
fmin v21.4s, v21.4s, v15.4s
fmin v22.4s, v22.4s, v15.4s
fmin v23.4s, v23.4s, v15.4s
fmin v24.4s, v24.4s, v15.4s
fmin v25.4s, v25.4s, v15.4s
fmin v26.4s, v26.4s, v15.4s
fmin v27.4s, v27.4s, v15.4s
fmin v28.4s, v28.4s, v15.4s
fmin v29.4s, v29.4s, v15.4s
fmin v30.4s, v30.4s, v15.4s
fmin v31.4s, v31.4s, v15.4s

TransToOut: TransToOut:
st1 {v16.4s}, [x2], #16 st1 {v16.4s}, [x2], #16
st1 {v17.4s}, [x2], #16 st1 {v17.4s}, [x2], #16
@@ -280,11 +284,13 @@ TransToOut:
st1 {v30.4s}, [x2], #16 st1 {v30.4s}, [x2], #16
st1 {v31.4s}, [x2], #16 st1 {v31.4s}, [x2], #16


add w10, w10, #8 // lhs row offset + 8
add w10, w10, #8 // lm row offset + 8
b L2 b L2


End2: End2:
add w9, w9, #8 // rhs col offset + 8
add w9, w9, #8 // rm col offset + 8
add x1, x1, x15 // rm ptr + stride
add x3, x3, x18 // bias ptr + stride
b L1 b L1


End1: End1:


+ 5
- 2
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc View File

@@ -42,7 +42,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
float *dst_c = dst_r + ci * C8NUM; float *dst_c = dst_r + ci * C8NUM;


/* 8x4 row-major to col-major */ /* 8x4 row-major to col-major */
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
size_t stride = col * 4; size_t stride = col * 4;
asm volatile( asm volatile(
"mov x10, %[src_c]\n" "mov x10, %[src_c]\n"
@@ -156,6 +156,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActT


void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) { int col_8_) {
#ifdef __aarch64__
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_);
#else
MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_); MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_);
return;
#endif
} }

+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h View File

@@ -32,8 +32,8 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa
extern "C" { extern "C" {
#endif #endif
#ifdef __aarch64__ #ifdef __aarch64__
void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth,
int row, int col);
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col);
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
} }


+ 3
- 3
mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h View File

@@ -157,10 +157,10 @@ inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32
// quantize from float to int8 // quantize from float to int8
inline void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) { inline void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) {
for (int i = 0; i < length; ++i) { for (int i = 0; i < length; ++i) {
int r = (int)round(input_data[i] / scale + zero_point);
int8_t q = r > CHAR_MAX ? (int8_t)CHAR_MAX : (int8_t)r;
int q = (int)round(input_data[i] / scale + zero_point);
q = q > CHAR_MAX ? CHAR_MAX : q;
q = q < CHAR_MIN ? CHAR_MIN : q; q = q < CHAR_MIN ? CHAR_MIN : q;
output_data[i] = q;
output_data[i] = (int8_t)q;
} }
} }




+ 94
- 8
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc View File

@@ -201,19 +201,108 @@ TEST_F(TestMatMulFp32, simple) {
0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552, 0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552,
-0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459, -0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459,
0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343}; 0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343};
std::vector<int> a_shape = {1, 2, 8};
std::vector<int> b_shape = {1, 8, 3};
std::vector<int> c_shape = {1, 2, 3};
std::vector<int> a_shape = {2, 8};
std::vector<int> b_shape = {8, 3};
std::vector<int> c_shape = {2, 3};
int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context; auto ctx = new lite::Context;
ctx->thread_num_ = 2;
ctx->thread_num_ = 1;
auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx); auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
mm->Init(); mm->Init();
mm->Run(); mm->Run();
float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779, float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779,
-0.3049793541431427, -0.027687929570674896, -0.18109679222106934}; -0.3049793541431427, -0.027687929570674896, -0.18109679222106934};
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete matmul_param;
delete mm;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
}

TEST_F(TestMatMulFp32, simple2) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
matmul_param->a_transpose_ = false;
matmul_param->b_transpose_ = false;
matmul_param->has_bias_ = false;
float a[25 * 12] = {
1, 4, 10, 2, 3, 10, 4, 6, 5, 6, 9, 5, 4, 2, 5, 7, 5, 8, 0, 5, 1, 0, 10, 3, 0, 4, 2, 3, 2, 9,
8, 9, 5, 4, 4, 9, 7, 4, 2, 6, 10, 2, 1, 7, 2, 10, 5, 10, 1, 2, 2, 9, 8, 8, 2, 5, 6, 3, 2, 8,
3, 3, 7, 3, 0, 4, 10, 9, 0, 5, 2, 6, 1, 10, 7, 6, 9, 6, 0, 3, 8, 0, 8, 3, 10, 4, 7, 7, 0, 5,
6, 5, 4, 6, 5, 5, 3, 7, 1, 9, 3, 2, 8, 3, 0, 0, 6, 7, 6, 3, 6, 5, 1, 0, 4, 2, 6, 0, 7, 7,
7, 4, 9, 8, 6, 1, 10, 10, 7, 3, 0, 6, 9, 4, 1, 4, 4, 3, 1, 6, 7, 3, 8, 6, 4, 10, 9, 8, 10, 5,
2, 3, 8, 10, 0, 8, 2, 9, 5, 3, 3, 0, 1, 8, 1, 1, 2, 0, 1, 5, 5, 0, 1, 10, 9, 9, 3, 6, 7, 1,
2, 3, 7, 5, 0, 8, 2, 8, 7, 8, 9, 10, 4, 2, 5, 3, 10, 1, 5, 0, 6, 2, 3, 5, 5, 1, 5, 5, 5, 1,
8, 2, 6, 9, 10, 4, 9, 1, 10, 9, 8, 2, 5, 2, 4, 2, 3, 7, 7, 2, 9, 10, 10, 10, 5, 1, 8, 8, 10, 3,
2, 10, 2, 6, 5, 9, 10, 6, 10, 0, 5, 5, 4, 0, 9, 4, 4, 9, 4, 6, 4, 2, 5, 2, 10, 5, 9, 8, 1, 4,
7, 9, 6, 5, 0, 3, 6, 4, 3, 10, 6, 4, 10, 5, 8, 8, 9, 4, 5, 6, 8, 9, 2, 2, 4, 4, 8, 0, 4, 5};
float b[12 * 36] = {
6, 6, 7, 2, 1, 10, 3, 7, 7, 5, 5, 5, 6, 6, 9, 8, 4, 10, 9, 5, 5, 7, 2, 1, 7, 9, 10, 0, 3,
10, 4, 2, 7, 4, 3, 10, 5, 3, 1, 3, 3, 1, 9, 6, 7, 6, 6, 6, 7, 6, 10, 8, 2, 8, 5, 2, 1, 7,
5, 9, 10, 9, 0, 8, 10, 2, 3, 4, 0, 7, 5, 5, 0, 9, 6, 1, 6, 7, 4, 1, 0, 3, 0, 7, 3, 0, 10,
7, 6, 4, 10, 7, 6, 5, 10, 2, 10, 9, 10, 6, 9, 10, 8, 8, 5, 3, 9, 10, 8, 3, 3, 4, 6, 2, 6, 0,
4, 0, 3, 4, 1, 0, 3, 10, 5, 4, 0, 2, 3, 2, 4, 3, 10, 5, 4, 10, 8, 2, 0, 4, 0, 5, 8, 0, 1,
10, 0, 3, 1, 1, 9, 4, 0, 3, 0, 1, 6, 3, 10, 0, 10, 3, 3, 0, 6, 7, 3, 2, 3, 5, 10, 6, 1, 5,
7, 3, 3, 1, 1, 10, 5, 4, 0, 8, 4, 0, 9, 6, 2, 3, 6, 10, 10, 0, 2, 2, 1, 2, 7, 10, 9, 7, 10,
2, 8, 5, 3, 7, 0, 4, 3, 4, 8, 3, 8, 0, 5, 5, 6, 9, 10, 0, 1, 5, 6, 6, 4, 7, 7, 6, 7, 9,
5, 5, 6, 0, 4, 1, 2, 6, 8, 4, 10, 4, 10, 9, 8, 8, 1, 7, 1, 8, 1, 0, 10, 8, 8, 1, 8, 0, 10,
3, 1, 7, 0, 10, 5, 0, 2, 8, 4, 1, 8, 1, 6, 7, 1, 8, 3, 4, 3, 4, 7, 0, 9, 1, 1, 4, 8, 10,
0, 3, 3, 2, 7, 9, 3, 3, 10, 10, 9, 4, 4, 0, 7, 1, 1, 3, 5, 1, 4, 8, 5, 7, 3, 9, 10, 1, 5,
9, 7, 4, 10, 10, 3, 4, 3, 5, 1, 10, 5, 2, 3, 3, 0, 3, 1, 2, 8, 7, 4, 2, 0, 8, 7, 6, 6, 6,
5, 7, 5, 5, 3, 0, 4, 10, 1, 7, 8, 9, 6, 7, 0, 1, 9, 3, 1, 6, 8, 4, 9, 0, 3, 2, 4, 0, 2,
7, 2, 2, 8, 0, 4, 1, 3, 2, 6, 8, 5, 5, 2, 3, 9, 0, 1, 7, 6, 9, 1, 10, 4, 10, 5, 10, 0, 9,
5, 1, 6, 2, 9, 9, 8, 8, 10, 8, 1, 6, 5, 8, 8, 6, 4, 8, 10, 3, 0, 6, 2, 8, 4, 2};
std::vector<int> a_shape = {25, 12};
std::vector<int> b_shape = {12, 36};
std::vector<int> c_shape = {25, 36};
int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx);
mm->Init();
mm->Run();
float correct[] = {
263, 386, 184, 309, 338, 244, 359, 294, 252, 254, 273, 353, 320, 183, 412, 273, 271, 307, 329, 314, 391, 261, 400,
280, 416, 399, 355, 427, 373, 302, 288, 349, 336, 241, 349, 393, 226, 285, 134, 209, 264, 163, 281, 212, 219, 171,
221, 228, 227, 131, 289, 196, 204, 270, 238, 205, 303, 196, 280, 156, 311, 284, 282, 335, 243, 245, 181, 188, 280,
142, 229, 256, 270, 310, 184, 377, 323, 187, 345, 295, 255, 262, 259, 332, 310, 222, 357, 275, 253, 301, 296, 254,
316, 221, 323, 322, 370, 353, 281, 386, 363, 240, 245, 301, 270, 263, 275, 292, 278, 388, 199, 324, 252, 336, 385,
300, 257, 274, 215, 243, 272, 230, 485, 335, 343, 366, 293, 272, 337, 313, 310, 305, 385, 421, 377, 398, 343, 262,
249, 309, 258, 280, 286, 411, 268, 337, 127, 307, 244, 185, 368, 263, 178, 205, 223, 281, 288, 154, 339, 255, 295,
250, 241, 236, 289, 240, 296, 261, 361, 333, 282, 399, 315, 202, 203, 272, 231, 229, 300, 273, 199, 253, 246, 315,
307, 213, 257, 202, 243, 230, 163, 288, 220, 212, 361, 314, 219, 296, 300, 217, 274, 196, 285, 264, 351, 339, 312,
289, 338, 282, 256, 274, 214, 243, 228, 302, 276, 394, 110, 224, 274, 163, 395, 296, 231, 223, 289, 311, 331, 177,
405, 236, 294, 293, 264, 213, 314, 258, 330, 270, 403, 381, 305, 450, 382, 250, 248, 287, 278, 211, 324, 374, 306,
350, 246, 298, 309, 305, 315, 289, 292, 256, 264, 341, 295, 218, 427, 382, 272, 359, 335, 286, 333, 263, 327, 275,
448, 423, 380, 369, 397, 330, 260, 329, 285, 284, 333, 397, 259, 258, 146, 261, 281, 156, 248, 234, 236, 219, 220,
207, 233, 173, 326, 316, 223, 301, 237, 145, 202, 181, 209, 236, 357, 279, 265, 332, 352, 230, 165, 219, 154, 233,
189, 237, 246, 316, 147, 197, 247, 221, 212, 256, 201, 208, 239, 220, 231, 153, 322, 263, 237, 278, 254, 178, 215,
164, 217, 211, 326, 295, 284, 306, 354, 247, 178, 244, 216, 199, 229, 308, 298, 409, 306, 359, 359, 273, 388, 291,
301, 281, 239, 395, 323, 290, 505, 398, 370, 381, 365, 235, 344, 268, 340, 351, 473, 481, 445, 415, 481, 373, 354,
365, 284, 309, 338, 469, 285, 336, 166, 244, 245, 247, 305, 304, 273, 233, 281, 260, 276, 218, 364, 241, 255, 330,
257, 213, 296, 221, 252, 251, 325, 355, 301, 341, 319, 246, 206, 243, 295, 210, 249, 357, 328, 481, 196, 345, 276,
338, 493, 349, 236, 299, 265, 388, 383, 224, 573, 425, 411, 354, 353, 340, 363, 385, 414, 387, 541, 528, 412, 515,
486, 298, 320, 438, 254, 361, 454, 494, 120, 156, 151, 140, 176, 99, 231, 113, 197, 132, 113, 190, 134, 171, 264,
169, 137, 219, 165, 92, 172, 145, 188, 186, 225, 260, 166, 216, 225, 161, 173, 134, 147, 130, 152, 218, 226, 273,
205, 314, 331, 157, 311, 242, 289, 228, 238, 346, 285, 223, 344, 235, 194, 282, 274, 238, 358, 207, 333, 270, 345,
345, 302, 339, 309, 273, 284, 291, 297, 219, 261, 338, 319, 396, 200, 356, 349, 311, 377, 330, 280, 280, 308, 351,
311, 204, 421, 319, 294, 348, 328, 346, 387, 261, 403, 335, 434, 428, 333, 467, 422, 270, 254, 370, 345, 285, 381,
378, 200, 347, 110, 195, 189, 184, 252, 242, 134, 191, 179, 205, 256, 140, 349, 219, 287, 216, 225, 155, 223, 203,
203, 196, 295, 281, 321, 291, 292, 235, 219, 255, 177, 186, 213, 349, 286, 389, 180, 262, 306, 275, 269, 284, 257,
239, 256, 262, 270, 189, 410, 306, 302, 297, 244, 226, 335, 213, 276, 257, 371, 351, 398, 376, 378, 289, 265, 355,
258, 252, 286, 446, 274, 419, 214, 263, 277, 296, 317, 276, 202, 240, 214, 287, 292, 174, 454, 366, 352, 328, 342,
247, 300, 273, 300, 232, 440, 401, 436, 374, 394, 351, 269, 317, 247, 255, 312, 416, 384, 533, 202, 336, 369, 322,
449, 373, 291, 282, 343, 409, 416, 198, 526, 383, 405, 363, 355, 355, 478, 348, 435, 296, 544, 490, 519, 540, 449,
390, 345, 444, 378, 307, 454, 542, 356, 394, 179, 370, 364, 152, 424, 370, 316, 291, 358, 420, 419, 267, 429, 323,
311, 348, 320, 232, 344, 260, 344, 369, 472, 424, 339, 479, 470, 297, 298, 350, 300, 302, 340, 389, 211, 314, 186,
248, 277, 184, 294, 217, 204, 184, 203, 311, 262, 154, 324, 221, 233, 249, 283, 241, 331, 210, 318, 191, 341, 330,
331, 323, 278, 289, 255, 259, 294, 174, 280, 323, 295, 348, 303, 319, 321, 286, 365, 266, 310, 251, 240, 406, 302,
265, 457, 396, 297, 366, 350, 270, 343, 271, 347, 314, 469, 476, 396, 375, 428, 351, 315, 341, 291, 296, 361, 428,
383, 442, 232, 360, 387, 279, 391, 349, 348, 288, 334, 374, 360, 262, 485, 391, 362, 379, 296, 262, 406, 270, 346,
346, 486, 451, 451, 490, 475, 339, 319, 409, 315, 324, 367, 493, 286, 348, 185, 240, 287, 214, 312, 265, 237, 218,
261, 316, 279, 186, 377, 319, 279, 304, 281, 207, 261, 209, 287, 270, 415, 378, 312, 388, 423, 273, 230, 294, 239,
243, 319, 346};
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete mm; delete mm;
for (auto t : inputs_) delete t; for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t; for (auto t : outputs_) delete t;
@@ -243,7 +332,6 @@ TEST_F(TestMatMulFp32, simple_transb) {
mm->Run(); mm->Run();
float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031}; float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031};
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete matmul_param;
delete mm; delete mm;
for (auto t : inputs_) delete t; for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t; for (auto t : outputs_) delete t;
@@ -298,9 +386,7 @@ TEST_F(TestMatMulFp32, batch) {
8.869029998779297, 25.034008026123047}; 8.869029998779297, 25.034008026123047};


float *output = reinterpret_cast<float *>(outputs_[0]->Data()); float *output = reinterpret_cast<float *>(outputs_[0]->Data());
for (int i = 0; i < 18; ++i) printf("%f ", output[i]);
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete matmul_param;
delete mm; delete mm;
for (auto t : inputs_) delete t; for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t; for (auto t : outputs_) delete t;


Loading…
Cancel
Save