| @@ -8,8 +8,10 @@ file(GLOB_RECURSE C_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) | |||||
| file(GLOB KERNEL_SRC | file(GLOB KERNEL_SRC | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/*.c | ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/*.c | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32/*.c | ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32/*.c | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32_grad/*.c | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/int8/*.c | ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/int8/*.c | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/quantization/*.c | ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/quantization/*.c | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/fp32/*.cc | |||||
| ) | ) | ||||
| list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c) | list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c) | ||||
| @@ -84,7 +84,7 @@ typedef struct LiteSession { | |||||
| /// \param[in] inputs Define the new inputs shape. | /// \param[in] inputs Define the new inputs shape. | ||||
| /// | /// | ||||
| /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. | /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. | ||||
| int Resize(const TensorPtrVector &inputs); | |||||
| int Resize(const TensorPtrVector &inputs, Int32VectorVector dims); | |||||
| } LiteSession; | } LiteSession; | ||||
| #endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H | #endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H | ||||
| @@ -27,5 +27,6 @@ using String = std::string; | |||||
| using StringVector = std::vector<std::string>; | using StringVector = std::vector<std::string>; | ||||
| using ShapeVector = std::vector<int>; | using ShapeVector = std::vector<int>; | ||||
| using NodePtrVector = std::vector<struct Node *>; | using NodePtrVector = std::vector<struct Node *>; | ||||
| using Int32Vector = std::vector<int32_t>; | |||||
| using Int32VectorVector = std::vector<Int32Vector>; | |||||
| #endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_ | #endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_ | ||||
| @@ -27,6 +27,183 @@ enum NodeType { | |||||
| NodeType_MAX = NodeType_CNode | NodeType_MAX = NodeType_CNode | ||||
| }; | }; | ||||
| enum KernelType { | |||||
| Concat, | |||||
| SoftMax, | |||||
| Activation, | |||||
| Conv2D, | |||||
| FusedBatchNorm, | |||||
| BatchNorm, | |||||
| BiasAdd, | |||||
| Pooling, | |||||
| ROIPooling, | |||||
| DepthwiseConv2D, | |||||
| DeDepthwiseConv2D, | |||||
| Resize, | |||||
| DetectionPostProcess, | |||||
| FullConnection, | |||||
| Mean, | |||||
| DeConv2D, | |||||
| Scale, | |||||
| Reshape, | |||||
| Eltwise, | |||||
| NetOutput, | |||||
| Add, | |||||
| Sub, | |||||
| MatMul, | |||||
| StridedSlice, | |||||
| Power, | |||||
| Slice, | |||||
| Stack, | |||||
| Mul, | |||||
| RealDiv, | |||||
| Pad, | |||||
| Maximum, | |||||
| Minimum, | |||||
| PReLU, | |||||
| LeakyReLU, | |||||
| ArgMax, | |||||
| ArgMin, | |||||
| Exp, | |||||
| Crop, | |||||
| Range, | |||||
| Rsqrt, | |||||
| ExpandDims, | |||||
| Tile, | |||||
| Cast, | |||||
| Shape, | |||||
| Nchw2Nhwc, | |||||
| Nhwc2Nchw, | |||||
| QuantDTypeCast, | |||||
| Split, | |||||
| Permute, | |||||
| FakeQuantWithMinMaxVars, | |||||
| Equal, | |||||
| Less, | |||||
| Greater, | |||||
| NotEqual, | |||||
| LessEqual, | |||||
| GreaterEqual, | |||||
| Min, | |||||
| Floor, | |||||
| Abs, | |||||
| Neg, | |||||
| Cos, | |||||
| Sin, | |||||
| Sqrt, | |||||
| Square, | |||||
| Constant, | |||||
| Log, | |||||
| Tan, | |||||
| Atan, | |||||
| Asin, | |||||
| Clip, | |||||
| Transpose, | |||||
| Squeeze, | |||||
| Unsqueeze, | |||||
| Upsample, | |||||
| Dropout, | |||||
| Broadcast, | |||||
| BroadcastTo, | |||||
| Lrn, | |||||
| ZerosLike, | |||||
| TopK, | |||||
| SpaceToDepth, | |||||
| SpaceToBatch, | |||||
| SparseToDense, | |||||
| ReverseSequence, | |||||
| Rank, | |||||
| Gather, | |||||
| GatherNd, | |||||
| Fill, | |||||
| Elu, | |||||
| DepthToSpace, | |||||
| BatchToSpace, | |||||
| AddN, | |||||
| Ceil, | |||||
| EmbeddingLookup, | |||||
| EmbeddingLookupSparse, | |||||
| FloorDiv, | |||||
| FloorMod, | |||||
| L2Norm, | |||||
| LocalResponseNormalization, | |||||
| MatrixDiag, | |||||
| Reduce, | |||||
| Reverse, | |||||
| Round, | |||||
| Select, | |||||
| Scatter, | |||||
| ScatterND, | |||||
| ConstantOfShape, | |||||
| Unique, | |||||
| Unstack, | |||||
| LogicalAnd, | |||||
| LogicalOr, | |||||
| LogicalXor, | |||||
| LogicalNot, | |||||
| OnnxInt8Quantize, | |||||
| OnnxInt8Dequantize, | |||||
| FakeQuantWithMinMax, | |||||
| FakeQuantWithMinMaxPerChannel, | |||||
| BatchNormFold, | |||||
| MulFold, | |||||
| AddFold, | |||||
| SquaredDifference, | |||||
| Flatten, | |||||
| FlattenGrad, | |||||
| TupleGetItem, | |||||
| Div, | |||||
| Where, | |||||
| OneHot, | |||||
| Lstm, | |||||
| Conv2DGradFilter, | |||||
| Conv2DGradInput, | |||||
| PoolingGrad, | |||||
| BNGrad, | |||||
| BNGradInput, | |||||
| ApplyMomentum, | |||||
| BiasGrad, | |||||
| SoftmaxCrossEntropy, | |||||
| AddGrad, | |||||
| SubGrad, | |||||
| MulGrad, | |||||
| DivGrad, | |||||
| PowerGrad, | |||||
| ActivationGrad, | |||||
| PriorBox, | |||||
| SpaceToBatchND, | |||||
| Depend, | |||||
| Return, | |||||
| MakeTuple, | |||||
| ToFormat, | |||||
| Proposal, | |||||
| Custom, | |||||
| BlackBox, | |||||
| NegGrad, | |||||
| LogGrad, | |||||
| BatchToSpaceND, | |||||
| }; | |||||
| enum ActivationType { | |||||
| NO_ACTIVATION = 0, | |||||
| RELU = 1, | |||||
| SIGMOID = 2, | |||||
| RELU6 = 3, | |||||
| ELU = 4, | |||||
| LEAKY_RELU = 5, | |||||
| ABS = 6, | |||||
| RELU1 = 7, | |||||
| SOFTSIGN = 8, | |||||
| SOFTPLUS = 9, | |||||
| TANH = 10, | |||||
| SELU = 11, | |||||
| HSWISH = 12, | |||||
| HSIGMOID = 13, | |||||
| THRESHOLDRELU = 14, | |||||
| LINEAR = 15, | |||||
| UNKNOW = 16 | |||||
| }; | |||||
| typedef struct Node { | typedef struct Node { | ||||
| String name_; | String name_; | ||||
| NodeType node_type_; | NodeType node_type_; | ||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "internal/src/kernel/fp32/activation.h" | |||||
| #include "internal/include/errorcode.h" | |||||
| #include "internal/include/ms_tensor.h" | |||||
| #include "nnacl/fp32/activation.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "nnacl/errorcode.h" | |||||
| int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator) { | |||||
| ActivationParameter *param = (ActivationParameter *)node->primitive_; | |||||
| int ret = RET_OK; | |||||
| size_t length = in_tensors[0]->ElementsNum(); | |||||
| float *input_addr = (float *)in_tensors[0]->data_; | |||||
| float *output_addr = (float *)out_tensors[0]->data_; | |||||
| if (param->type_ == ActivationType::RELU) { | |||||
| ret = Fp32Relu(input_addr, length, output_addr); | |||||
| } else if (param->type_ == ActivationType::SIGMOID) { | |||||
| ret = Sigmoid(input_addr, length, output_addr); | |||||
| } else if (param->type_ == ActivationType::RELU6) { | |||||
| ret = Fp32Relu6(input_addr, length, output_addr); | |||||
| } else if (param->type_ == ActivationType::LEAKY_RELU) { | |||||
| float alpha = param->alpha_; | |||||
| ret = LRelu(input_addr, length, output_addr, alpha); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupport activation type " << param->type_; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (ret != NNACL_OK) { | |||||
| MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ | |||||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ | |||||
| #include "internal/include/model.h" | |||||
| #include "src/runtime/allocator.h" | |||||
| int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator); | |||||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "internal/src/kernel/fp32/arithmetic_self.h" | |||||
| #include "internal/include/errorcode.h" | |||||
| #include "internal/include/ms_tensor.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "nnacl/fp32/arithmetic_self.h" | |||||
| int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator) { | |||||
| size_t data_size = in_tensors[0]->ElementsNum(); | |||||
| OpParameter *param = node->primitive_; | |||||
| int ret; | |||||
| if (param->type_ == KernelType::Log) { | |||||
| ret = ElementLog((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); | |||||
| } else if (param->type_ == KernelType::Neg) { | |||||
| ret = ElementNegative((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (ret != NNACL_OK) { | |||||
| MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ | |||||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ | |||||
| #include "internal/include/model.h" | |||||
| #include "src/runtime/allocator.h" | |||||
| int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator); | |||||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ | |||||
| @@ -0,0 +1,145 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "internal/src/kernel/fp32/matmul.h" | |||||
| #include "nnacl/fp32/matmul.h" | |||||
| #include "internal/include/errorcode.h" | |||||
| #include "internal/include/ms_tensor.h" | |||||
| #include "utils/log_adapter.h" | |||||
| typedef struct MatMulCPUKernelData { | |||||
| float *a_c12_ptr_; | |||||
| float *b_r8_ptr_; | |||||
| float *bias_ptr_; | |||||
| } MatMulCPUKernelData; | |||||
| void MatMulInitMatrixA(float *src_ptr, float *dst_ptr, MatMulParameter *params) { | |||||
| for (int i = 0; i < params->batch; i++) { | |||||
| float *src = src_ptr + i * params->deep_ * params->row_; | |||||
| float *dst = dst_ptr + i * params->deep_ * params->row_12_; | |||||
| if (params->a_transpose_) { | |||||
| RowMajor2Row12Major(src, dst, params->deep_, params->row_); | |||||
| } else { | |||||
| RowMajor2Col12Major(src, dst, params->row_, params->deep_); | |||||
| } | |||||
| } | |||||
| } | |||||
| void MatMulInitMatrixB(float *src_ptr, float *dst_ptr, MatMulParameter *params) { | |||||
| for (int i = 0; i < params->batch; i++) { | |||||
| float *src = src_ptr + i * params->deep_ * params->col_; | |||||
| float *dst = dst_ptr + i * params->deep_ * params->col_8_; | |||||
| if (params->b_transpose_) { | |||||
| RowMajor2Col8Major(src, dst, params->col_, params->deep_); | |||||
| } else { | |||||
| RowMajor2Row8Major(src, dst, params->deep_, params->col_); | |||||
| } | |||||
| } | |||||
| } | |||||
| void FreeMatMulKernelData(MatMulCPUKernelData *kernel_data, mindspore::lite::Allocator *allocator) { | |||||
| if (kernel_data == NULL) { | |||||
| return; | |||||
| } | |||||
| if (kernel_data->a_c12_ptr_ != NULL) { | |||||
| allocator->Free(kernel_data->a_c12_ptr_); | |||||
| kernel_data->a_c12_ptr_ = NULL; | |||||
| } | |||||
| if (kernel_data->b_r8_ptr_ != NULL) { | |||||
| allocator->Free(kernel_data->b_r8_ptr_); | |||||
| kernel_data->b_r8_ptr_ = NULL; | |||||
| } | |||||
| if (kernel_data->bias_ptr_ != NULL) { | |||||
| allocator->Free(kernel_data->bias_ptr_); | |||||
| kernel_data->bias_ptr_ = NULL; | |||||
| } | |||||
| free(kernel_data); | |||||
| } | |||||
| int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator) { | |||||
| if (in_tensors[0]->data_ == NULL || in_tensors[1]->data_ ==NULL) { | |||||
| MS_LOG(ERROR) << "input data is NULL!"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (allocator == NULL) { | |||||
| MS_LOG(ERROR) << "input allocator is NULL!"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| int batch = 1; | |||||
| std::vector<int> a_shape = in_tensors[0]->shape_; | |||||
| std::vector<int> c_shape = out_tensors[0]->shape_; | |||||
| if (in_tensors.size() == 3) { | |||||
| std::vector<int> bias_shape = in_tensors[2]->shape_; | |||||
| if (bias_shape[bias_shape.size() - 1] != c_shape[c_shape.size() - 1]) { | |||||
| MS_LOG(ERROR) << "The bias' dimension is not equal with column"; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < a_shape.size() - 2; ++i) { | |||||
| batch *= a_shape[i]; | |||||
| } | |||||
| MatMulParameter *params = (MatMulParameter *)node->primitive_; | |||||
| params->batch = batch; | |||||
| params->row_ = c_shape[c_shape.size() - 2]; | |||||
| params->col_ = c_shape[c_shape.size() - 1]; | |||||
| params->deep_ = params->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | |||||
| params->row_12_ = UP_ROUND(params->row_, C12NUM); | |||||
| params->col_8_ = UP_ROUND(params->col_, 8); | |||||
| MatMulCPUKernelData *kernel_data = (MatMulCPUKernelData *)malloc(sizeof(MatMulCPUKernelData)); | |||||
| kernel_data->a_c12_ptr_ | |||||
| = reinterpret_cast<float *>(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float))); | |||||
| if (kernel_data->a_c12_ptr_ == NULL) { | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float)); | |||||
| kernel_data->b_r8_ptr_ | |||||
| = reinterpret_cast<float *>(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float))); | |||||
| if (kernel_data->b_r8_ptr_ == NULL) { | |||||
| FreeMatMulKernelData(kernel_data, allocator); | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| memset(kernel_data->b_r8_ptr_, 0, params->col_8_ * params->deep_ * sizeof(float)); | |||||
| MatMulInitMatrixA((float *)in_tensors[0]->data_, kernel_data->a_c12_ptr_, params); | |||||
| MatMulInitMatrixB((float *)in_tensors[1]->data_, kernel_data->b_r8_ptr_, params); | |||||
| kernel_data->bias_ptr_ = (float *)(allocator->Malloc(params->col_8_ * sizeof(float))); | |||||
| if (kernel_data->bias_ptr_ == NULL) { | |||||
| FreeMatMulKernelData(kernel_data, allocator); | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| memset(kernel_data->bias_ptr_, 0, params->col_8_ * sizeof(float)); | |||||
| if (in_tensors.size() == 3) { | |||||
| memcpy(kernel_data->bias_ptr_, in_tensors[2]->data_, params->col_ * sizeof(float)); | |||||
| } | |||||
| auto c_src = (float *)out_tensors[0]->data_; | |||||
| for (int i = 0; i < params->batch; ++i) { | |||||
| float *a_ptr = kernel_data->a_c12_ptr_ + i * params->row_12_ * params->deep_; | |||||
| float *b_ptr = kernel_data->b_r8_ptr_ + i * params->deep_ * params->col_8_; | |||||
| float *c_ptr = c_src + i * params->row_ * params->col_; | |||||
| MatMulOpt(a_ptr, b_ptr, c_ptr, kernel_data->bias_ptr_, ActType_No, params->deep_, params->row_, params->col_, | |||||
| params->col_, OutType_Nhwc); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ | |||||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ | |||||
| #include "internal/include/model.h" | |||||
| #include "src/runtime/allocator.h" | |||||
| int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator); | |||||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "internal/src/kernel/fp32_grad/activation_grad.h" | |||||
| #include "internal/include/errorcode.h" | |||||
| #include "internal/include/ms_tensor.h" | |||||
| #include "nnacl/fp32_grad/activation_grad.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "nnacl/errorcode.h" | |||||
| int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator) { | |||||
| ActivationGradParameter *param = (ActivationGradParameter *)node->primitive_; | |||||
| int ret = RET_OK; | |||||
| size_t length = in_tensors[0]->ElementsNum(); | |||||
| float *dy_data = (float *)in_tensors[0]->data_; | |||||
| float *x_data = (float *)in_tensors[1]->data_; | |||||
| float *dx_data = (float *)(float *)out_tensors[0]->data_; | |||||
| if (param->type_ == ActivationType::RELU) { | |||||
| ret = ReluGrad(dy_data, x_data, length, dx_data); | |||||
| } else if (param->type_ == ActivationType::SIGMOID) { | |||||
| ret = SigmoidGrad(dy_data, x_data, length, dx_data); | |||||
| } else if (param->type_ == ActivationType::RELU6) { | |||||
| ret = Relu6Grad(dy_data, x_data, length, dx_data); | |||||
| } else if (param->type_ == ActivationType::LEAKY_RELU) { | |||||
| float alpha = param->alpha_; | |||||
| ret = LReluGrad(dy_data, x_data, length, dx_data, alpha); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupport activation type " << param->type_; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (ret != NNACL_OK) { | |||||
| MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| #include "internal/include/model.h" | |||||
| #include "src/runtime/allocator.h" | |||||
| int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator); | |||||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h" | |||||
| #include "internal/include/errorcode.h" | |||||
| #include "internal/include/ms_tensor.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "nnacl/fp32/arithmetic_self.h" | |||||
| #include "nnacl/fp32/arithmetic.h" | |||||
| int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator) { | |||||
| size_t data_size = in_tensors[0]->ElementsNum(); | |||||
| OpParameter *param = node->primitive_; | |||||
| float *dy_data = (float *)in_tensors[0]->data_; | |||||
| float *x_data = (float *)in_tensors[1]->data_; | |||||
| float *dx_data = (float *)(float *)out_tensors[0]->data_; | |||||
| int ret; | |||||
| if (param->type_ == KernelType::LogGrad) { | |||||
| ret = ElementDiv(dy_data, x_data, dx_data, data_size); | |||||
| } else if (param->type_ == KernelType::NegGrad) { | |||||
| ret = ElementNegative(dy_data, dx_data, data_size); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (ret != NNACL_OK) { | |||||
| MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ | |||||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ | |||||
| #include "internal/include/model.h" | |||||
| #include "src/runtime/allocator.h" | |||||
| int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||||
| mindspore::lite::Allocator *allocator); | |||||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ | |||||
| @@ -17,6 +17,13 @@ | |||||
| #include "internal/include/model.h" | #include "internal/include/model.h" | ||||
| #include "internal/include/ms_tensor.h" | #include "internal/include/ms_tensor.h" | ||||
| #include "src/runtime/allocator.h" | #include "src/runtime/allocator.h" | ||||
| #include "internal/include/errorcode.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "internal/src/kernel/fp32/activation.h" | |||||
| #include "internal/src/kernel/fp32/arithmetic_self.h" | |||||
| #include "internal/src/kernel/fp32/matmul.h" | |||||
| #include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h" | |||||
| #include "internal/src/kernel/fp32_grad/activation_grad.h" | |||||
| static Context *g_Ctx; | static Context *g_Ctx; | ||||
| static Model *g_Model; | static Model *g_Model; | ||||
| @@ -58,11 +65,56 @@ TensorPtrVector LiteSession::GetOutputs() const { | |||||
| int LiteSession::RunGraph() { | int LiteSession::RunGraph() { | ||||
| // invoke nnacl kernel | // invoke nnacl kernel | ||||
| return 0; | |||||
| NodePtrVector nodes = g_Model->nodes_; | |||||
| size_t nodes_size = nodes.size(); | |||||
| for (size_t i = 0; i < nodes_size; ++i) { | |||||
| auto node = nodes[i]; | |||||
| if (node->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "node's primitive is NULL!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| TensorPtrVector in_tensors; | |||||
| for (size_t j = 0; j < node->input_indices_.size(); ++j) { | |||||
| in_tensors.push_back(g_Model->all_tensors_[node->input_indices_[j]]); | |||||
| } | |||||
| TensorPtrVector out_tensors; | |||||
| for (size_t j = 0; j < node->output_indices_.size(); ++j) { | |||||
| out_tensors.push_back(g_Model->all_tensors_[node->output_indices_[j]]); | |||||
| } | |||||
| int type = node->primitive_->type_; | |||||
| int ret = RET_ERROR; | |||||
| switch (type) { | |||||
| case KernelType::MatMul: | |||||
| ret = DoMatMul(in_tensors, out_tensors, node, &allocator); | |||||
| break; | |||||
| case KernelType::Activation: | |||||
| ret = DoActivation(in_tensors, out_tensors, node, &allocator); | |||||
| break; | |||||
| case KernelType::Log: | |||||
| case KernelType::Neg: | |||||
| ret = DoArithmeticSelf(in_tensors, out_tensors, node, &allocator); | |||||
| break; | |||||
| case KernelType::LogGrad: | |||||
| case KernelType::NegGrad: | |||||
| ret = DoArithmeticGradSelf(in_tensors, out_tensors, node, &allocator); | |||||
| break; | |||||
| case KernelType::ActivationGrad: | |||||
| ret = DoActivationGrad(in_tensors, out_tensors, node, &allocator); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupport kernel type: " << type; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "run kernel fail!ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | } | ||||
| StringVector LiteSession::GetOutputTensorNames() const { return StringVector(); } | StringVector LiteSession::GetOutputTensorNames() const { return StringVector(); } | ||||
| MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { return NULL; } | MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { return NULL; } | ||||
| int LiteSession::Resize(const TensorPtrVector &inputs) { return 0; } | |||||
| int LiteSession::Resize(const TensorPtrVector &inputs, Int32VectorVector dims) { return 0; } | |||||