|
|
|
@@ -13,13 +13,9 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "src/ops/populate/arithmetic_populate.h" |
|
|
|
#include "src/ops/arithmetic.h" |
|
|
|
#include "src/ops/add.h" |
|
|
|
#include "src/ops/sub.h" |
|
|
|
#include "src/ops/mul.h" |
|
|
|
#include "src/ops/div.h" |
|
|
|
#include "src/ops/eltwise.h" |
|
|
|
#include "src/ops/greater_equal.h" |
|
|
|
#include "src/common/log_adapter.h" |
|
|
|
#include "src/tensor.h" |
|
|
|
#include "src/ops/primitive_c.h" |
|
|
|
@@ -27,6 +23,7 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
|
|
|
|
ArithmeticParameter *PopulateArithmeticCommonPara(const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
ArithmeticParameter *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); |
|
|
|
if (param == nullptr) { |
|
|
|
@@ -57,70 +54,6 @@ OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
return reinterpret_cast<OpParameter *>(param); |
|
|
|
} |
|
|
|
|
|
|
|
OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); |
|
|
|
if (param == nullptr) { |
|
|
|
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
param->activation_type_ = reinterpret_cast<const mindspore::lite::Add *>(primitive)->GetActivationType(); |
|
|
|
return reinterpret_cast<OpParameter *>(param); |
|
|
|
} |
|
|
|
|
|
|
|
OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); |
|
|
|
if (param == nullptr) { |
|
|
|
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
param->activation_type_ = reinterpret_cast<const mindspore::lite::Sub *>(primitive)->GetActivationType(); |
|
|
|
return reinterpret_cast<OpParameter *>(param); |
|
|
|
} |
|
|
|
|
|
|
|
OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); |
|
|
|
if (param == nullptr) { |
|
|
|
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
param->activation_type_ = reinterpret_cast<const mindspore::lite::Mul *>(primitive)->GetActivationType(); |
|
|
|
return reinterpret_cast<OpParameter *>(param); |
|
|
|
} |
|
|
|
|
|
|
|
OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); |
|
|
|
if (param == nullptr) { |
|
|
|
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
param->activation_type_ = reinterpret_cast<const mindspore::lite::Div *>(primitive)->GetActivationType(); |
|
|
|
return reinterpret_cast<OpParameter *>(param); |
|
|
|
} |
|
|
|
|
|
|
|
OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive); |
|
|
|
if (param == nullptr) { |
|
|
|
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto eltwise = reinterpret_cast<const mindspore::lite::Eltwise *>(primitive); |
|
|
|
switch (eltwise->GetMode()) { |
|
|
|
case schema::EltwiseMode_PROD: |
|
|
|
param->op_parameter_.type_ = schema::PrimitiveType_Mul; |
|
|
|
break; |
|
|
|
case schema::EltwiseMode_SUM: |
|
|
|
param->op_parameter_.type_ = schema::PrimitiveType_Add; |
|
|
|
break; |
|
|
|
case schema::EltwiseMode_MAXIMUM: |
|
|
|
param->op_parameter_.type_ = schema::PrimitiveType_Maximum; |
|
|
|
break; |
|
|
|
default: |
|
|
|
free(param); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return reinterpret_cast<OpParameter *>(param); |
|
|
|
} |
|
|
|
|
|
|
|
Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithmetic); |
|
|
|
Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); |
|
|
|
Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); |
|
|
|
@@ -135,10 +68,5 @@ Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithme |
|
|
|
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); |
|
|
|
Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); |
|
|
|
Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); |
|
|
|
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter); |
|
|
|
Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter); |
|
|
|
Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter); |
|
|
|
Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter); |
|
|
|
Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter); |
|
|
|
} // namespace lite |
|
|
|
} // namespace mindspore |