Browse Source

[MSLITE] support mod op

tags/v1.1.0
ling 5 years ago
parent
commit
f54ec95bca
4 changed files with 61 additions and 5 deletions
  1. +42
    -0
      mindspore/lite/nnacl/fp32/arithmetic_fp32.c
  2. +7
    -0
      mindspore/lite/nnacl/fp32/arithmetic_fp32.h
  3. +11
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  4. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h

+ 42
- 0
mindspore/lite/nnacl/fp32/arithmetic_fp32.c View File

@@ -850,6 +850,48 @@ int BroadcastFloorMod(const float *input0, const float *input1, float *tile_inpu
return ElementFloorMod(tile_input0, tile_input1, output, element_size);
}

int ElementMod(const float *input0, const float *input1, float *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = fmod(input0[i], input1[i]);
}
return NNACL_OK;
}

int ElementModInt(const int *input0, const int *input1, int *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = fmod(input0[i], input1[i]);
}
return NNACL_OK;
}

int ElementOptMod(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[0], input1[index]);
}
} else {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[index], input1[0]);
}
}
return NNACL_OK;
}

int ElementOptModInt(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[0], input1[index]);
}
} else {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[index], input1[0]);
}
}
return NNACL_OK;
}

int ElementFloorDiv(const float *input0, const float *input1, float *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = floorf(input0[i] / input1[i]);


+ 7
- 0
mindspore/lite/nnacl/fp32/arithmetic_fp32.h View File

@@ -118,6 +118,13 @@ int ElementFloorModInt(const int *input0, const int *input1, int *output, const
int BroadcastFloorMod(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output,
int element_size, ArithmeticParameter *param);

int ElementMod(const float *input0, const float *input1, float *output, const int element_size);
int ElementModInt(const int *input0, const int *input1, int *output, const int element_size);
int ElementOptMod(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param);
int ElementOptModInt(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param);

int ElementSquaredDifference(const float *input0, const float *input1, float *output, const int element_size);
int BroadcastSquaredDifference(const float *input0, const float *input1, float *tile_input0, float *tile_input1,
float *output, int element_size, ArithmeticParameter *param);


+ 11
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -209,6 +209,10 @@ void ArithmeticCPUKernel::InitRunFunction() {
arithmetic_run_ = ElementFloorMod;
arithmetic_run_int_ = ElementFloorModInt;
break;
case PrimitiveType_Mod:
arithmetic_run_ = ElementMod;
arithmetic_run_int_ = ElementModInt;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference;
break;
@@ -302,6 +306,11 @@ void ArithmeticCPUKernel::InitOptRunFunction() {
break;
}
break;
case PrimitiveType_Mod:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMod;
arithmetic_opt_run_int_ = ElementOptModInt;
break;
default:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
@@ -534,16 +543,15 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector<lite::Tenso
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Mod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator)
@@ -551,8 +559,6 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32Ker
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator)


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h View File

@@ -35,6 +35,7 @@ using mindspore::schema::PrimitiveType_LogicalAnd;
using mindspore::schema::PrimitiveType_LogicalOr;
using mindspore::schema::PrimitiveType_Maximum;
using mindspore::schema::PrimitiveType_Minimum;
using mindspore::schema::PrimitiveType_Mod;
using mindspore::schema::PrimitiveType_Mul;
using mindspore::schema::PrimitiveType_NotEqual;
using mindspore::schema::PrimitiveType_RealDiv;


Loading…
Cancel
Save