| @@ -35,7 +35,7 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| std::vector<int> a_shape = input0->shape(); | |||
| std::vector<int> b_shape = input1->shape(); | |||
| if (a_shape.size() < 3 || b_shape.size() < 3) { | |||
| if (a_shape.size() < 2 || b_shape.size() < 2) { | |||
| MS_LOG(ERROR) << "inputs shape is invalid"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| @@ -24,24 +24,20 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| auto x_tensor = inputs[0]; | |||
| MS_ASSERT(x_tensor != nullptr); | |||
| auto exp_tensor = inputs[1]; | |||
| MS_ASSERT(exp_tensor != nullptr); | |||
| tensor::Tensor *exp_tensor = nullptr; | |||
| if (inputs.size() == 2) { | |||
| exp_tensor = inputs[1]; | |||
| MS_ASSERT(exp_tensor != nullptr); | |||
| } | |||
| auto output_tensor = outputs[0]; | |||
| MS_ASSERT(output_tensor != nullptr); | |||
| if (inputs.size() < 2) { | |||
| MS_LOG(ERROR) << "input size" << inputs.size() << " is error!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| if (exp_tensor->shape() != x_tensor->shape() && exp_tensor->shape().size() != 1) { | |||
| MS_LOG(ERROR) << "Power inputs shape is not equal!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| if (exp_tensor) { | |||
| if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) { | |||
| MS_LOG(ERROR) << "Power inputs shape or type is not equal!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| int exp_size = std::accumulate(exp_tensor->shape().begin(), exp_tensor->shape().end(), 1, std::multiplies<int>()); | |||
| if (x_tensor->data_type() != exp_tensor->data_type() && exp_size != 1) { | |||
| MS_LOG(ERROR) << "Exponent tensor's shape is wrong"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output_tensor->SetFormat(x_tensor->GetFormat()); | |||
| output_tensor->set_shape(x_tensor->shape()); | |||
| output_tensor->set_data_type(x_tensor->data_type()); | |||
| @@ -69,4 +69,5 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MatMul, CpuMatmulKernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatmulKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -34,15 +34,15 @@ int MatmulCPUKernel::ReSize() { return RET_OK; } | |||
| int MatmulCPUKernel::Init() { | |||
| int batch = 1; | |||
| auto x_shape = inputs_[0]->shape(); | |||
| auto o_shape = outputs_[0]->shape(); | |||
| for (int i = 0; i < x_shape.size() - 2; ++i) { | |||
| batch *= x_shape[i]; | |||
| auto a_shape = inputs_[0]->shape(); | |||
| auto c_shape = outputs_[0]->shape(); | |||
| for (int i = 0; i < a_shape.size() - 2; ++i) { | |||
| batch *= a_shape[i]; | |||
| } | |||
| params_->batch = batch; | |||
| params_->row_ = o_shape[o_shape.size() - 2]; | |||
| params_->col_ = o_shape[o_shape.size() - 1]; | |||
| params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; | |||
| params_->row_ = c_shape[c_shape.size() - 2]; | |||
| params_->col_ = c_shape[c_shape.size() - 1]; | |||
| params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | |||
| params_->row_8_ = UP_ROUND(params_->row_, 8); | |||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | |||
| @@ -51,15 +51,19 @@ int PowerCPUKernel::Run() { | |||
| int PowerCPUKernel::RunImpl(int task_id) { | |||
| auto x_addr = reinterpret_cast<float *>(inputs_[0]->Data()); | |||
| auto exp_addr = reinterpret_cast<float *>(inputs_[1]->Data()); | |||
| auto output_addr = reinterpret_cast<float *>(outputs_[0]->Data()); | |||
| auto size = inputs_[0]->ElementsNum(); | |||
| int stride = UP_DIV(size, thread_count_); | |||
| int len = MSMIN(stride, size - stride * task_id); | |||
| bool broadcast = (inputs_[1]->ElementsNum() == 1) ? true : false; | |||
| float *exp_addr = nullptr; | |||
| bool broadcast = true; | |||
| if (inputs_.size() == 2) { | |||
| exp_addr = reinterpret_cast<float *>(inputs_[1]->Data()); | |||
| broadcast = false; | |||
| } | |||
| float *cur_exp; | |||
| if (broadcast) { | |||
| cur_exp = exp_addr; | |||
| cur_exp = &power_; | |||
| } else { | |||
| cur_exp = exp_addr + stride * task_id; | |||
| } | |||
| @@ -73,8 +77,7 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te | |||
| const kernel::KernelKey &desc) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Power); | |||
| auto *kernel = | |||
| new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx); | |||
| auto *kernel = new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new PowerCPUKernel fail!"; | |||
| return nullptr; | |||
| @@ -30,6 +30,7 @@ class PowerCPUKernel : public LiteKernel { | |||
| : LiteKernel(param, inputs, outputs), | |||
| ctx_(ctx), | |||
| thread_count_(ctx->thread_num_), | |||
| power_(reinterpret_cast<PowerParameter *>(opParameter)->power_), | |||
| scale_(reinterpret_cast<PowerParameter *>(opParameter)->scale_), | |||
| shift_(reinterpret_cast<PowerParameter *>(opParameter)->shift_) {} | |||
| ~PowerCPUKernel() override = default; | |||
| @@ -42,6 +43,7 @@ class PowerCPUKernel : public LiteKernel { | |||
| private: | |||
| const lite::Context *ctx_; | |||
| int thread_count_; | |||
| float power_; | |||
| float scale_; | |||
| float shift_; | |||
| }; | |||
| @@ -1,9 +1,9 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global MatMulFloatNeon64 | |||
| .global MatmulFloatNeon64 | |||
| #ifndef __APPLE__ | |||
| .type MatMulFloatNeon64, %function | |||
| .type MatmulFloatNeon64, %function | |||
| #endif | |||
| // A: LM [row_8 * depth] col_8_major | |||
| @@ -46,41 +46,39 @@ | |||
| // accumulators 8x8 block | |||
| ///////////////////////////////////////////////////////////////////////////////// | |||
| // | |||
| // void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, int col) | |||
| // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col) | |||
| // x0: a | |||
| // x1: b | |||
| // x2: c | |||
| // x3: bias | |||
| // v0.s[0]: maxf | |||
| // v1.s[0]: minf | |||
| // w4: depth | |||
| // w5: row | |||
| // w6: col | |||
| // w4: act_type | |||
| // w5: depth | |||
| // w6: row | |||
| // w7: col | |||
| MatMulFloatNeon64: | |||
| MatmulFloatNeon64: | |||
| sub sp, sp, #128 | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| mov w7, v0.s[0] | |||
| mov w8, v1.s[0] | |||
| mov w9, 0 // rm col offset | |||
| mov w10, 0 // lm row offset | |||
| mov w9, #0 // rm col offset | |||
| mov w10, #0 // lm row offset | |||
| mov w18, #32 // sizeof(float)*8 | |||
| mul w15, w4, w18 // the stride of lm/rm: sizeof(float)*8*depth | |||
| mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth | |||
| mov x11, x3 // bias flag | |||
| L1: | |||
| cmp w9, w6 | |||
| cmp w9, w7 | |||
| beq End1 | |||
| mov w10, 0 // reset lm row offset | |||
| mov w10, #0 // reset lm row offset | |||
| mov x12, x0 // reload lm ptr | |||
| mov x14, x3 // reload bias ptr | |||
| L2: | |||
| cmp w10, w6 | |||
| beq End2 | |||
| mov w13, w4 // reload depth | |||
| mov x16, x1 // reload rm ptr | |||
| mov w13, w5 // reload depth | |||
| mov x14, x3 // reload bias ptr | |||
| dup v16.4s, wzr | |||
| dup v17.4s, wzr | |||
| dup v18.4s, wzr | |||
| @@ -103,7 +101,7 @@ OptLoopMul4: | |||
| blt CommLoopMul | |||
| ld1 {v0.4s, v1.4s}, [x12], #32 | |||
| ld1 {v8.4s, v9.4s}, [x1], #32 | |||
| ld1 {v8.4s, v9.4s}, [x16], #32 | |||
| fmla v16.4s, v8.4s, v0.s[0] | |||
| fmla v17.4s, v9.4s, v0.s[0] | |||
| fmla v18.4s, v8.4s, v0.s[1] | |||
| @@ -112,7 +110,7 @@ OptLoopMul4: | |||
| fmla v21.4s, v9.4s, v0.s[2] | |||
| fmla v22.4s, v8.4s, v0.s[3] | |||
| fmla v23.4s, v9.4s, v0.s[3] | |||
| ld1 {v10.4s, v11.4s}, [x1], #32 | |||
| ld1 {v10.4s, v11.4s}, [x16], #32 | |||
| fmla v24.4s, v8.4s, v1.s[0] | |||
| fmla v25.4s, v9.4s, v1.s[0] | |||
| fmla v26.4s, v8.4s, v1.s[1] | |||
| @@ -130,7 +128,7 @@ OptLoopMul4: | |||
| fmla v21.4s, v11.4s, v2.s[2] | |||
| fmla v22.4s, v10.4s, v2.s[3] | |||
| fmla v23.4s, v11.4s, v2.s[3] | |||
| ld1 {v12.4s, v13.4s}, [x1], #32 | |||
| ld1 {v12.4s, v13.4s}, [x16], #32 | |||
| fmla v24.4s, v10.4s, v3.s[0] | |||
| fmla v25.4s, v11.4s, v3.s[0] | |||
| fmla v26.4s, v10.4s, v3.s[1] | |||
| @@ -153,7 +151,7 @@ OptLoopMul4: | |||
| fmla v25.4s, v13.4s, v5.s[0] | |||
| fmla v26.4s, v12.4s, v5.s[1] | |||
| fmla v27.4s, v13.4s, v5.s[1] | |||
| ld1 {v14.4s, v15.4s}, [x1], #32 | |||
| ld1 {v14.4s, v15.4s}, [x16], #32 | |||
| fmla v28.4s, v12.4s, v5.s[2] | |||
| fmla v29.4s, v13.4s, v5.s[2] | |||
| fmla v30.4s, v12.4s, v5.s[3] | |||
| @@ -182,7 +180,7 @@ CommLoopMul: | |||
| blt Bias | |||
| ld1 {v0.4s, v1.4s}, [x12], #32 | |||
| ld1 {v2.4s, v3.4s}, [x1], #32 | |||
| ld1 {v2.4s, v3.4s}, [x16], #32 | |||
| fmla v16.4s, v2.4s, v0.s[0] | |||
| fmla v17.4s, v3.4s, v0.s[0] | |||
| fmla v18.4s, v2.4s, v0.s[1] | |||
| @@ -203,8 +201,7 @@ CommLoopMul: | |||
| b CommLoopMul | |||
| Bias: | |||
| cmp x3, #0 | |||
| beq Relu | |||
| cbz x11, Activation | |||
| ld1 {v0.4s}, [x14], #16 | |||
| ld1 {v1.4s}, [x14], #16 | |||
| fadd v16.4s, v16.4s, v0.4s | |||
| @@ -224,9 +221,34 @@ Bias: | |||
| fadd v30.4s, v30.4s, v0.4s | |||
| fadd v31.4s, v31.4s, v1.4s | |||
| Activation: | |||
| cmp w4, #2 | |||
| beq Relu6 | |||
| cmp w4, #1 | |||
| beq Relu | |||
| b TransToOut | |||
| Relu6: | |||
| mov w8, #6 | |||
| dup v15.4s, w8 | |||
| scvtf v15.4s, v15.4s | |||
| fmin v16.4s, v16.4s, v15.4s | |||
| fmin v17.4s, v17.4s, v15.4s | |||
| fmin v18.4s, v18.4s, v15.4s | |||
| fmin v19.4s, v19.4s, v15.4s | |||
| fmin v20.4s, v20.4s, v15.4s | |||
| fmin v21.4s, v21.4s, v15.4s | |||
| fmin v22.4s, v22.4s, v15.4s | |||
| fmin v23.4s, v23.4s, v15.4s | |||
| fmin v24.4s, v24.4s, v15.4s | |||
| fmin v25.4s, v25.4s, v15.4s | |||
| fmin v26.4s, v26.4s, v15.4s | |||
| fmin v27.4s, v27.4s, v15.4s | |||
| fmin v28.4s, v28.4s, v15.4s | |||
| fmin v29.4s, v29.4s, v15.4s | |||
| fmin v30.4s, v30.4s, v15.4s | |||
| fmin v31.4s, v31.4s, v15.4s | |||
| Relu: | |||
| dup v15.4s, w7 | |||
| dup v14.4s, w8 | |||
| dup v14.4s, wzr | |||
| fmax v16.4s, v16.4s, v14.4s | |||
| fmax v17.4s, v17.4s, v14.4s | |||
| fmax v18.4s, v18.4s, v14.4s | |||
| @@ -244,24 +266,6 @@ Relu: | |||
| fmax v30.4s, v30.4s, v14.4s | |||
| fmax v31.4s, v31.4s, v14.4s | |||
| fmin v16.4s, v16.4s, v15.4s | |||
| fmin v17.4s, v17.4s, v15.4s | |||
| fmin v18.4s, v18.4s, v15.4s | |||
| fmin v19.4s, v19.4s, v15.4s | |||
| fmin v20.4s, v20.4s, v15.4s | |||
| fmin v20.4s, v20.4s, v15.4s | |||
| fmin v21.4s, v21.4s, v15.4s | |||
| fmin v22.4s, v22.4s, v15.4s | |||
| fmin v23.4s, v23.4s, v15.4s | |||
| fmin v24.4s, v24.4s, v15.4s | |||
| fmin v25.4s, v25.4s, v15.4s | |||
| fmin v26.4s, v26.4s, v15.4s | |||
| fmin v27.4s, v27.4s, v15.4s | |||
| fmin v28.4s, v28.4s, v15.4s | |||
| fmin v29.4s, v29.4s, v15.4s | |||
| fmin v30.4s, v30.4s, v15.4s | |||
| fmin v31.4s, v31.4s, v15.4s | |||
| TransToOut: | |||
| st1 {v16.4s}, [x2], #16 | |||
| st1 {v17.4s}, [x2], #16 | |||
| @@ -280,11 +284,13 @@ TransToOut: | |||
| st1 {v30.4s}, [x2], #16 | |||
| st1 {v31.4s}, [x2], #16 | |||
| add w10, w10, #8 // lhs row offset + 8 | |||
| add w10, w10, #8 // lm row offset + 8 | |||
| b L2 | |||
| End2: | |||
| add w9, w9, #8 // rhs col offset + 8 | |||
| add w9, w9, #8 // rm col offset + 8 | |||
| add x1, x1, x15 // rm ptr + stride | |||
| add x3, x3, x18 // bias ptr + stride | |||
| b L1 | |||
| End1: | |||
| @@ -42,7 +42,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||
| float *dst_c = dst_r + ci * C8NUM; | |||
| /* 8x4 row-major to col-major */ | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_ARM64 | |||
| size_t stride = col * 4; | |||
| asm volatile( | |||
| "mov x10, %[src_c]\n" | |||
| @@ -156,6 +156,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActT | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_, | |||
| int col_8_) { | |||
| #ifdef __aarch64__ | |||
| MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_); | |||
| #else | |||
| MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_); | |||
| return; | |||
| #endif | |||
| } | |||
| @@ -32,8 +32,8 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa | |||
| extern "C" { | |||
| #endif | |||
| #ifdef __aarch64__ | |||
| void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, | |||
| int row, int col); | |||
| void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -157,10 +157,10 @@ inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32 | |||
| // quantize from float to int8 | |||
| inline void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) { | |||
| for (int i = 0; i < length; ++i) { | |||
| int r = (int)round(input_data[i] / scale + zero_point); | |||
| int8_t q = r > CHAR_MAX ? (int8_t)CHAR_MAX : (int8_t)r; | |||
| int q = (int)round(input_data[i] / scale + zero_point); | |||
| q = q > CHAR_MAX ? CHAR_MAX : q; | |||
| q = q < CHAR_MIN ? CHAR_MIN : q; | |||
| output_data[i] = q; | |||
| output_data[i] = (int8_t)q; | |||
| } | |||
| } | |||
| @@ -201,19 +201,108 @@ TEST_F(TestMatMulFp32, simple) { | |||
| 0.006050155, 0.008656233, 0.012911413, -0.0028635843, -0.00034080597, -0.0010622552, | |||
| -0.012254699, -0.01312836, 0.0025241964, -0.004706142, 0.002451482, -0.009558459, | |||
| 0.004481974, 0.0033251503, -0.011705584, -0.001720293, -0.0039410214, -0.0073637343}; | |||
| std::vector<int> a_shape = {1, 2, 8}; | |||
| std::vector<int> b_shape = {1, 8, 3}; | |||
| std::vector<int> c_shape = {1, 2, 3}; | |||
| std::vector<int> a_shape = {2, 8}; | |||
| std::vector<int> b_shape = {8, 3}; | |||
| std::vector<int> c_shape = {2, 3}; | |||
| int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); | |||
| auto ctx = new lite::Context; | |||
| ctx->thread_num_ = 2; | |||
| ctx->thread_num_ = 1; | |||
| auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx); | |||
| mm->Init(); | |||
| mm->Run(); | |||
| float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779, | |||
| -0.3049793541431427, -0.027687929570674896, -0.18109679222106934}; | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| delete matmul_param; | |||
| delete mm; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| } | |||
| TEST_F(TestMatMulFp32, simple2) { | |||
| std::vector<lite::tensor::Tensor *> inputs_; | |||
| std::vector<lite::tensor::Tensor *> outputs_; | |||
| auto matmul_param = new MatMulParameter(); | |||
| matmul_param->a_transpose_ = false; | |||
| matmul_param->b_transpose_ = false; | |||
| matmul_param->has_bias_ = false; | |||
| float a[25 * 12] = { | |||
| 1, 4, 10, 2, 3, 10, 4, 6, 5, 6, 9, 5, 4, 2, 5, 7, 5, 8, 0, 5, 1, 0, 10, 3, 0, 4, 2, 3, 2, 9, | |||
| 8, 9, 5, 4, 4, 9, 7, 4, 2, 6, 10, 2, 1, 7, 2, 10, 5, 10, 1, 2, 2, 9, 8, 8, 2, 5, 6, 3, 2, 8, | |||
| 3, 3, 7, 3, 0, 4, 10, 9, 0, 5, 2, 6, 1, 10, 7, 6, 9, 6, 0, 3, 8, 0, 8, 3, 10, 4, 7, 7, 0, 5, | |||
| 6, 5, 4, 6, 5, 5, 3, 7, 1, 9, 3, 2, 8, 3, 0, 0, 6, 7, 6, 3, 6, 5, 1, 0, 4, 2, 6, 0, 7, 7, | |||
| 7, 4, 9, 8, 6, 1, 10, 10, 7, 3, 0, 6, 9, 4, 1, 4, 4, 3, 1, 6, 7, 3, 8, 6, 4, 10, 9, 8, 10, 5, | |||
| 2, 3, 8, 10, 0, 8, 2, 9, 5, 3, 3, 0, 1, 8, 1, 1, 2, 0, 1, 5, 5, 0, 1, 10, 9, 9, 3, 6, 7, 1, | |||
| 2, 3, 7, 5, 0, 8, 2, 8, 7, 8, 9, 10, 4, 2, 5, 3, 10, 1, 5, 0, 6, 2, 3, 5, 5, 1, 5, 5, 5, 1, | |||
| 8, 2, 6, 9, 10, 4, 9, 1, 10, 9, 8, 2, 5, 2, 4, 2, 3, 7, 7, 2, 9, 10, 10, 10, 5, 1, 8, 8, 10, 3, | |||
| 2, 10, 2, 6, 5, 9, 10, 6, 10, 0, 5, 5, 4, 0, 9, 4, 4, 9, 4, 6, 4, 2, 5, 2, 10, 5, 9, 8, 1, 4, | |||
| 7, 9, 6, 5, 0, 3, 6, 4, 3, 10, 6, 4, 10, 5, 8, 8, 9, 4, 5, 6, 8, 9, 2, 2, 4, 4, 8, 0, 4, 5}; | |||
| float b[12 * 36] = { | |||
| 6, 6, 7, 2, 1, 10, 3, 7, 7, 5, 5, 5, 6, 6, 9, 8, 4, 10, 9, 5, 5, 7, 2, 1, 7, 9, 10, 0, 3, | |||
| 10, 4, 2, 7, 4, 3, 10, 5, 3, 1, 3, 3, 1, 9, 6, 7, 6, 6, 6, 7, 6, 10, 8, 2, 8, 5, 2, 1, 7, | |||
| 5, 9, 10, 9, 0, 8, 10, 2, 3, 4, 0, 7, 5, 5, 0, 9, 6, 1, 6, 7, 4, 1, 0, 3, 0, 7, 3, 0, 10, | |||
| 7, 6, 4, 10, 7, 6, 5, 10, 2, 10, 9, 10, 6, 9, 10, 8, 8, 5, 3, 9, 10, 8, 3, 3, 4, 6, 2, 6, 0, | |||
| 4, 0, 3, 4, 1, 0, 3, 10, 5, 4, 0, 2, 3, 2, 4, 3, 10, 5, 4, 10, 8, 2, 0, 4, 0, 5, 8, 0, 1, | |||
| 10, 0, 3, 1, 1, 9, 4, 0, 3, 0, 1, 6, 3, 10, 0, 10, 3, 3, 0, 6, 7, 3, 2, 3, 5, 10, 6, 1, 5, | |||
| 7, 3, 3, 1, 1, 10, 5, 4, 0, 8, 4, 0, 9, 6, 2, 3, 6, 10, 10, 0, 2, 2, 1, 2, 7, 10, 9, 7, 10, | |||
| 2, 8, 5, 3, 7, 0, 4, 3, 4, 8, 3, 8, 0, 5, 5, 6, 9, 10, 0, 1, 5, 6, 6, 4, 7, 7, 6, 7, 9, | |||
| 5, 5, 6, 0, 4, 1, 2, 6, 8, 4, 10, 4, 10, 9, 8, 8, 1, 7, 1, 8, 1, 0, 10, 8, 8, 1, 8, 0, 10, | |||
| 3, 1, 7, 0, 10, 5, 0, 2, 8, 4, 1, 8, 1, 6, 7, 1, 8, 3, 4, 3, 4, 7, 0, 9, 1, 1, 4, 8, 10, | |||
| 0, 3, 3, 2, 7, 9, 3, 3, 10, 10, 9, 4, 4, 0, 7, 1, 1, 3, 5, 1, 4, 8, 5, 7, 3, 9, 10, 1, 5, | |||
| 9, 7, 4, 10, 10, 3, 4, 3, 5, 1, 10, 5, 2, 3, 3, 0, 3, 1, 2, 8, 7, 4, 2, 0, 8, 7, 6, 6, 6, | |||
| 5, 7, 5, 5, 3, 0, 4, 10, 1, 7, 8, 9, 6, 7, 0, 1, 9, 3, 1, 6, 8, 4, 9, 0, 3, 2, 4, 0, 2, | |||
| 7, 2, 2, 8, 0, 4, 1, 3, 2, 6, 8, 5, 5, 2, 3, 9, 0, 1, 7, 6, 9, 1, 10, 4, 10, 5, 10, 0, 9, | |||
| 5, 1, 6, 2, 9, 9, 8, 8, 10, 8, 1, 6, 5, 8, 8, 6, 4, 8, 10, 3, 0, 6, 2, 8, 4, 2}; | |||
| std::vector<int> a_shape = {25, 12}; | |||
| std::vector<int> b_shape = {12, 36}; | |||
| std::vector<int> c_shape = {25, 36}; | |||
| int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); | |||
| auto ctx = new lite::Context; | |||
| ctx->thread_num_ = 2; | |||
| auto mm = new kernel::MatmulCPUKernel(reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx); | |||
| mm->Init(); | |||
| mm->Run(); | |||
| float correct[] = { | |||
| 263, 386, 184, 309, 338, 244, 359, 294, 252, 254, 273, 353, 320, 183, 412, 273, 271, 307, 329, 314, 391, 261, 400, | |||
| 280, 416, 399, 355, 427, 373, 302, 288, 349, 336, 241, 349, 393, 226, 285, 134, 209, 264, 163, 281, 212, 219, 171, | |||
| 221, 228, 227, 131, 289, 196, 204, 270, 238, 205, 303, 196, 280, 156, 311, 284, 282, 335, 243, 245, 181, 188, 280, | |||
| 142, 229, 256, 270, 310, 184, 377, 323, 187, 345, 295, 255, 262, 259, 332, 310, 222, 357, 275, 253, 301, 296, 254, | |||
| 316, 221, 323, 322, 370, 353, 281, 386, 363, 240, 245, 301, 270, 263, 275, 292, 278, 388, 199, 324, 252, 336, 385, | |||
| 300, 257, 274, 215, 243, 272, 230, 485, 335, 343, 366, 293, 272, 337, 313, 310, 305, 385, 421, 377, 398, 343, 262, | |||
| 249, 309, 258, 280, 286, 411, 268, 337, 127, 307, 244, 185, 368, 263, 178, 205, 223, 281, 288, 154, 339, 255, 295, | |||
| 250, 241, 236, 289, 240, 296, 261, 361, 333, 282, 399, 315, 202, 203, 272, 231, 229, 300, 273, 199, 253, 246, 315, | |||
| 307, 213, 257, 202, 243, 230, 163, 288, 220, 212, 361, 314, 219, 296, 300, 217, 274, 196, 285, 264, 351, 339, 312, | |||
| 289, 338, 282, 256, 274, 214, 243, 228, 302, 276, 394, 110, 224, 274, 163, 395, 296, 231, 223, 289, 311, 331, 177, | |||
| 405, 236, 294, 293, 264, 213, 314, 258, 330, 270, 403, 381, 305, 450, 382, 250, 248, 287, 278, 211, 324, 374, 306, | |||
| 350, 246, 298, 309, 305, 315, 289, 292, 256, 264, 341, 295, 218, 427, 382, 272, 359, 335, 286, 333, 263, 327, 275, | |||
| 448, 423, 380, 369, 397, 330, 260, 329, 285, 284, 333, 397, 259, 258, 146, 261, 281, 156, 248, 234, 236, 219, 220, | |||
| 207, 233, 173, 326, 316, 223, 301, 237, 145, 202, 181, 209, 236, 357, 279, 265, 332, 352, 230, 165, 219, 154, 233, | |||
| 189, 237, 246, 316, 147, 197, 247, 221, 212, 256, 201, 208, 239, 220, 231, 153, 322, 263, 237, 278, 254, 178, 215, | |||
| 164, 217, 211, 326, 295, 284, 306, 354, 247, 178, 244, 216, 199, 229, 308, 298, 409, 306, 359, 359, 273, 388, 291, | |||
| 301, 281, 239, 395, 323, 290, 505, 398, 370, 381, 365, 235, 344, 268, 340, 351, 473, 481, 445, 415, 481, 373, 354, | |||
| 365, 284, 309, 338, 469, 285, 336, 166, 244, 245, 247, 305, 304, 273, 233, 281, 260, 276, 218, 364, 241, 255, 330, | |||
| 257, 213, 296, 221, 252, 251, 325, 355, 301, 341, 319, 246, 206, 243, 295, 210, 249, 357, 328, 481, 196, 345, 276, | |||
| 338, 493, 349, 236, 299, 265, 388, 383, 224, 573, 425, 411, 354, 353, 340, 363, 385, 414, 387, 541, 528, 412, 515, | |||
| 486, 298, 320, 438, 254, 361, 454, 494, 120, 156, 151, 140, 176, 99, 231, 113, 197, 132, 113, 190, 134, 171, 264, | |||
| 169, 137, 219, 165, 92, 172, 145, 188, 186, 225, 260, 166, 216, 225, 161, 173, 134, 147, 130, 152, 218, 226, 273, | |||
| 205, 314, 331, 157, 311, 242, 289, 228, 238, 346, 285, 223, 344, 235, 194, 282, 274, 238, 358, 207, 333, 270, 345, | |||
| 345, 302, 339, 309, 273, 284, 291, 297, 219, 261, 338, 319, 396, 200, 356, 349, 311, 377, 330, 280, 280, 308, 351, | |||
| 311, 204, 421, 319, 294, 348, 328, 346, 387, 261, 403, 335, 434, 428, 333, 467, 422, 270, 254, 370, 345, 285, 381, | |||
| 378, 200, 347, 110, 195, 189, 184, 252, 242, 134, 191, 179, 205, 256, 140, 349, 219, 287, 216, 225, 155, 223, 203, | |||
| 203, 196, 295, 281, 321, 291, 292, 235, 219, 255, 177, 186, 213, 349, 286, 389, 180, 262, 306, 275, 269, 284, 257, | |||
| 239, 256, 262, 270, 189, 410, 306, 302, 297, 244, 226, 335, 213, 276, 257, 371, 351, 398, 376, 378, 289, 265, 355, | |||
| 258, 252, 286, 446, 274, 419, 214, 263, 277, 296, 317, 276, 202, 240, 214, 287, 292, 174, 454, 366, 352, 328, 342, | |||
| 247, 300, 273, 300, 232, 440, 401, 436, 374, 394, 351, 269, 317, 247, 255, 312, 416, 384, 533, 202, 336, 369, 322, | |||
| 449, 373, 291, 282, 343, 409, 416, 198, 526, 383, 405, 363, 355, 355, 478, 348, 435, 296, 544, 490, 519, 540, 449, | |||
| 390, 345, 444, 378, 307, 454, 542, 356, 394, 179, 370, 364, 152, 424, 370, 316, 291, 358, 420, 419, 267, 429, 323, | |||
| 311, 348, 320, 232, 344, 260, 344, 369, 472, 424, 339, 479, 470, 297, 298, 350, 300, 302, 340, 389, 211, 314, 186, | |||
| 248, 277, 184, 294, 217, 204, 184, 203, 311, 262, 154, 324, 221, 233, 249, 283, 241, 331, 210, 318, 191, 341, 330, | |||
| 331, 323, 278, 289, 255, 259, 294, 174, 280, 323, 295, 348, 303, 319, 321, 286, 365, 266, 310, 251, 240, 406, 302, | |||
| 265, 457, 396, 297, 366, 350, 270, 343, 271, 347, 314, 469, 476, 396, 375, 428, 351, 315, 341, 291, 296, 361, 428, | |||
| 383, 442, 232, 360, 387, 279, 391, 349, 348, 288, 334, 374, 360, 262, 485, 391, 362, 379, 296, 262, 406, 270, 346, | |||
| 346, 486, 451, 451, 490, 475, 339, 319, 409, 315, 324, 367, 493, 286, 348, 185, 240, 287, 214, 312, 265, 237, 218, | |||
| 261, 316, 279, 186, 377, 319, 279, 304, 281, 207, 261, 209, 287, 270, 415, 378, 312, 388, 423, 273, 230, 294, 239, | |||
| 243, 319, 346}; | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| delete mm; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| @@ -243,7 +332,6 @@ TEST_F(TestMatMulFp32, simple_transb) { | |||
| mm->Run(); | |||
| float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031}; | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| delete matmul_param; | |||
| delete mm; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| @@ -298,9 +386,7 @@ TEST_F(TestMatMulFp32, batch) { | |||
| 8.869029998779297, 25.034008026123047}; | |||
| float *output = reinterpret_cast<float *>(outputs_[0]->Data()); | |||
| for (int i = 0; i < 18; ++i) printf("%f ", output[i]); | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| delete matmul_param; | |||
| delete mm; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||