| @@ -725,6 +725,22 @@ int ElementSub(const float *input0, const float *input1, float *output, const in | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int ElementSubInt(const int *input0, const int *input1, int *output, const int element_size) { | |||||
| int index = 0; | |||||
| #ifdef ENABLE_NEON | |||||
| for (; index <= element_size - 4; index += C4NUM) { | |||||
| int32x4_t vin0 = vld1q_s32(input0 + index); | |||||
| int32x4_t vin1 = vld1q_s32(input1 + index); | |||||
| int32x4_t vout = vsubq_s32(vin0, vin1); | |||||
| vst1q_s32(output + index, vout); | |||||
| } | |||||
| #endif | |||||
| for (; index < element_size; index++) { | |||||
| output[index] = input0[index] - input1[index]; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size) { | int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size) { | ||||
| int index = 0; | int index = 0; | ||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| @@ -77,6 +77,7 @@ int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t | |||||
| int element_size, ArithmeticParameter *param); | int element_size, ArithmeticParameter *param); | ||||
| int ElementSub(const float *input0, const float *input1, float *output, const int element_size); | int ElementSub(const float *input0, const float *input1, float *output, const int element_size); | ||||
| int ElementSubInt(const int *input0, const int *input1, int *output, const int element_size); | |||||
| int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size); | int ElementSubRelu(const float *input0, const float *input1, float *output, const int element_size); | ||||
| int ElementSubRelu6(const float *input0, const float *input1, float *output, const int element_size); | int ElementSubRelu6(const float *input0, const float *input1, float *output, const int element_size); | ||||
| int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, | int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, | ||||
| @@ -363,6 +363,8 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelC | |||||
| REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) | ||||
| @@ -97,6 +97,7 @@ class ArithmeticCPUKernel : public LiteKernel { | |||||
| break; | break; | ||||
| default: | default: | ||||
| arithmetic_run_ = ElementSub; | arithmetic_run_ = ElementSub; | ||||
| arithmetic_run_int_ = ElementSubInt; | |||||
| break; | break; | ||||
| } | } | ||||
| break; | break; | ||||
| @@ -285,9 +285,9 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe | |||||
| } | } | ||||
| } | } | ||||
| // update nodes indexes | // update nodes indexes | ||||
| for (auto nodeIter = graphT->nodes.begin(); nodeIter != graphT->nodes.end(); nodeIter++) { | |||||
| for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) { | |||||
| // update nodes input indexes | // update nodes input indexes | ||||
| UpdateNodeIndex((*nodeIter).get(), deleteIdx); | |||||
| UpdateNodeIndex((*node_iter).get(), deleteIdx); | |||||
| } | } | ||||
| // update deleteTensorIdx | // update deleteTensorIdx | ||||
| for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) { | for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) { | ||||
| @@ -374,10 +374,10 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPla | |||||
| MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx; | MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx; | ||||
| return graphT->nodes.end(); | return graphT->nodes.end(); | ||||
| } | } | ||||
| auto nodeIter = graphT->nodes.begin() + existNodeIdx; | |||||
| MS_ASSERT(nodeIter != graphT->nodes.begin()); | |||||
| MS_ASSERT((*nodeIter) != nullptr); | |||||
| return InsertNode(graphT, nodeIter, place, inoutIndex, std::move(toAddNode), errorCode); | |||||
| auto node_iter = graphT->nodes.begin() + existNodeIdx; | |||||
| MS_ASSERT(node_iter != graphT->nodes.begin()); | |||||
| MS_ASSERT((*node_iter) != nullptr); | |||||
| return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode); | |||||
| } | } | ||||
| NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, | ||||
| @@ -131,33 +131,33 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| const auto &onnx_conv_weight = onnx_node.input(1); | const auto &onnx_conv_weight = onnx_node.input(1); | ||||
| if (onnx_node.op_type() == "Conv") { | if (onnx_node.op_type() == "Conv") { | ||||
| auto nodeIter = | |||||
| auto node_iter = | |||||
| std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | ||||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | ||||
| if (nodeIter == onnx_graph.initializer().end()) { | |||||
| if (node_iter == onnx_graph.initializer().end()) { | |||||
| MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; | MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; | ||||
| } else { | } else { | ||||
| std::vector<int> weight_shape; | std::vector<int> weight_shape; | ||||
| auto size = (*nodeIter).dims_size(); | |||||
| auto size = (*node_iter).dims_size(); | |||||
| weight_shape.reserve(size); | weight_shape.reserve(size); | ||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| weight_shape.emplace_back((*nodeIter).dims(i)); | |||||
| weight_shape.emplace_back((*node_iter).dims(i)); | |||||
| } | } | ||||
| attr->channelOut = weight_shape[0]; | attr->channelOut = weight_shape[0]; | ||||
| attr->channelIn = weight_shape[1] * attr->group; | attr->channelIn = weight_shape[1] * attr->group; | ||||
| } | } | ||||
| } else { | } else { | ||||
| auto nodeIter = | |||||
| auto node_iter = | |||||
| std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | ||||
| [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); | [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); | ||||
| if (nodeIter == onnx_graph.node().end()) { | |||||
| if (node_iter == onnx_graph.node().end()) { | |||||
| MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; | MS_LOG(ERROR) << "can not find node: " << onnx_conv_weight; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<int> dims; | std::vector<int> dims; | ||||
| auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), | |||||
| auto iter = std::find_if((*node_iter).attribute().begin(), (*node_iter).attribute().end(), | |||||
| [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); | ||||
| if (iter != (*nodeIter).attribute().end()) { | |||||
| if (iter != (*node_iter).attribute().end()) { | |||||
| if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { | if (iter->ints().begin() == nullptr || iter->ints().end() == nullptr) { | ||||
| MS_LOG(ERROR) << "dims insert failed"; | MS_LOG(ERROR) << "dims insert failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -133,18 +133,18 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| } | } | ||||
| const auto &onnx_conv_weight = onnx_node.input(1); | const auto &onnx_conv_weight = onnx_node.input(1); | ||||
| auto nodeIter = | |||||
| auto node_iter = | |||||
| std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | ||||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | ||||
| if (nodeIter == onnx_graph.initializer().end()) { | |||||
| if (node_iter == onnx_graph.initializer().end()) { | |||||
| MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); | MS_LOG(ERROR) << "not find node: " << onnx_conv_weight.c_str(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<int> weight_shape; | std::vector<int> weight_shape; | ||||
| auto size = (*nodeIter).dims_size(); | |||||
| auto size = (*node_iter).dims_size(); | |||||
| weight_shape.reserve(size); | weight_shape.reserve(size); | ||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| weight_shape.emplace_back((*nodeIter).dims(i)); | |||||
| weight_shape.emplace_back((*node_iter).dims(i)); | |||||
| } | } | ||||
| if (weight_shape.size() != 4) { | if (weight_shape.size() != 4) { | ||||
| MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | MS_LOG(ERROR) << "weight_shape.size() should be 4, but is " << weight_shape.size(); | ||||
| @@ -41,14 +41,14 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| std::vector<int> dst_shape; | std::vector<int> dst_shape; | ||||
| const auto &onnx_expand_power = onnx_node.input(1); | const auto &onnx_expand_power = onnx_node.input(1); | ||||
| auto nodeIter = | |||||
| auto node_iter = | |||||
| std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | ||||
| [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); | [onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; }); | ||||
| if (nodeIter == onnx_graph.node().end()) { | |||||
| if (node_iter == onnx_graph.node().end()) { | |||||
| MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | MS_LOG(ERROR) << "can not find node: " << onnx_expand_power; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| for (const auto &attrPower : nodeIter->attribute()) { | |||||
| for (const auto &attrPower : node_iter->attribute()) { | |||||
| if (attrPower.name() == "value") { | if (attrPower.name() == "value") { | ||||
| const auto &t = attrPower.t(); | const auto &t = attrPower.t(); | ||||
| auto *dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data()); | auto *dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data()); | ||||