| @@ -8,8 +8,10 @@ file(GLOB_RECURSE C_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) | |||
| file(GLOB KERNEL_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/*.c | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32/*.c | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32_grad/*.c | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/int8/*.c | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/quantization/*.c | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/fp32/*.cc | |||
| ) | |||
| list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c) | |||
| @@ -84,7 +84,7 @@ typedef struct LiteSession { | |||
| /// \param[in] inputs Define the new inputs shape. | |||
| /// | |||
| /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. | |||
| int Resize(const TensorPtrVector &inputs); | |||
| int Resize(const TensorPtrVector &inputs, Int32VectorVector dims); | |||
| } LiteSession; | |||
| #endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H | |||
| @@ -27,5 +27,6 @@ using String = std::string; | |||
| using StringVector = std::vector<std::string>; | |||
| using ShapeVector = std::vector<int>; | |||
| using NodePtrVector = std::vector<struct Node *>; | |||
| using Int32Vector = std::vector<int32_t>; | |||
| using Int32VectorVector = std::vector<Int32Vector>; | |||
| #endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_ | |||
| @@ -27,6 +27,183 @@ enum NodeType { | |||
| NodeType_MAX = NodeType_CNode | |||
| }; | |||
| enum KernelType { | |||
| Concat, | |||
| SoftMax, | |||
| Activation, | |||
| Conv2D, | |||
| FusedBatchNorm, | |||
| BatchNorm, | |||
| BiasAdd, | |||
| Pooling, | |||
| ROIPooling, | |||
| DepthwiseConv2D, | |||
| DeDepthwiseConv2D, | |||
| Resize, | |||
| DetectionPostProcess, | |||
| FullConnection, | |||
| Mean, | |||
| DeConv2D, | |||
| Scale, | |||
| Reshape, | |||
| Eltwise, | |||
| NetOutput, | |||
| Add, | |||
| Sub, | |||
| MatMul, | |||
| StridedSlice, | |||
| Power, | |||
| Slice, | |||
| Stack, | |||
| Mul, | |||
| RealDiv, | |||
| Pad, | |||
| Maximum, | |||
| Minimum, | |||
| PReLU, | |||
| LeakyReLU, | |||
| ArgMax, | |||
| ArgMin, | |||
| Exp, | |||
| Crop, | |||
| Range, | |||
| Rsqrt, | |||
| ExpandDims, | |||
| Tile, | |||
| Cast, | |||
| Shape, | |||
| Nchw2Nhwc, | |||
| Nhwc2Nchw, | |||
| QuantDTypeCast, | |||
| Split, | |||
| Permute, | |||
| FakeQuantWithMinMaxVars, | |||
| Equal, | |||
| Less, | |||
| Greater, | |||
| NotEqual, | |||
| LessEqual, | |||
| GreaterEqual, | |||
| Min, | |||
| Floor, | |||
| Abs, | |||
| Neg, | |||
| Cos, | |||
| Sin, | |||
| Sqrt, | |||
| Square, | |||
| Constant, | |||
| Log, | |||
| Tan, | |||
| Atan, | |||
| Asin, | |||
| Clip, | |||
| Transpose, | |||
| Squeeze, | |||
| Unsqueeze, | |||
| Upsample, | |||
| Dropout, | |||
| Broadcast, | |||
| BroadcastTo, | |||
| Lrn, | |||
| ZerosLike, | |||
| TopK, | |||
| SpaceToDepth, | |||
| SpaceToBatch, | |||
| SparseToDense, | |||
| ReverseSequence, | |||
| Rank, | |||
| Gather, | |||
| GatherNd, | |||
| Fill, | |||
| Elu, | |||
| DepthToSpace, | |||
| BatchToSpace, | |||
| AddN, | |||
| Ceil, | |||
| EmbeddingLookup, | |||
| EmbeddingLookupSparse, | |||
| FloorDiv, | |||
| FloorMod, | |||
| L2Norm, | |||
| LocalResponseNormalization, | |||
| MatrixDiag, | |||
| Reduce, | |||
| Reverse, | |||
| Round, | |||
| Select, | |||
| Scatter, | |||
| ScatterND, | |||
| ConstantOfShape, | |||
| Unique, | |||
| Unstack, | |||
| LogicalAnd, | |||
| LogicalOr, | |||
| LogicalXor, | |||
| LogicalNot, | |||
| OnnxInt8Quantize, | |||
| OnnxInt8Dequantize, | |||
| FakeQuantWithMinMax, | |||
| FakeQuantWithMinMaxPerChannel, | |||
| BatchNormFold, | |||
| MulFold, | |||
| AddFold, | |||
| SquaredDifference, | |||
| Flatten, | |||
| FlattenGrad, | |||
| TupleGetItem, | |||
| Div, | |||
| Where, | |||
| OneHot, | |||
| Lstm, | |||
| Conv2DGradFilter, | |||
| Conv2DGradInput, | |||
| PoolingGrad, | |||
| BNGrad, | |||
| BNGradInput, | |||
| ApplyMomentum, | |||
| BiasGrad, | |||
| SoftmaxCrossEntropy, | |||
| AddGrad, | |||
| SubGrad, | |||
| MulGrad, | |||
| DivGrad, | |||
| PowerGrad, | |||
| ActivationGrad, | |||
| PriorBox, | |||
| SpaceToBatchND, | |||
| Depend, | |||
| Return, | |||
| MakeTuple, | |||
| ToFormat, | |||
| Proposal, | |||
| Custom, | |||
| BlackBox, | |||
| NegGrad, | |||
| LogGrad, | |||
| BatchToSpaceND, | |||
| }; | |||
| enum ActivationType { | |||
| NO_ACTIVATION = 0, | |||
| RELU = 1, | |||
| SIGMOID = 2, | |||
| RELU6 = 3, | |||
| ELU = 4, | |||
| LEAKY_RELU = 5, | |||
| ABS = 6, | |||
| RELU1 = 7, | |||
| SOFTSIGN = 8, | |||
| SOFTPLUS = 9, | |||
| TANH = 10, | |||
| SELU = 11, | |||
| HSWISH = 12, | |||
| HSIGMOID = 13, | |||
| THRESHOLDRELU = 14, | |||
| LINEAR = 15, | |||
| UNKNOW = 16 | |||
| }; | |||
| typedef struct Node { | |||
| String name_; | |||
| NodeType node_type_; | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "internal/src/kernel/fp32/activation.h" | |||
| #include "internal/include/errorcode.h" | |||
| #include "internal/include/ms_tensor.h" | |||
| #include "nnacl/fp32/activation.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "nnacl/errorcode.h" | |||
| int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator) { | |||
| ActivationParameter *param = (ActivationParameter *)node->primitive_; | |||
| int ret = RET_OK; | |||
| size_t length = in_tensors[0]->ElementsNum(); | |||
| float *input_addr = (float *)in_tensors[0]->data_; | |||
| float *output_addr = (float *)out_tensors[0]->data_; | |||
| if (param->type_ == ActivationType::RELU) { | |||
| ret = Fp32Relu(input_addr, length, output_addr); | |||
| } else if (param->type_ == ActivationType::SIGMOID) { | |||
| ret = Sigmoid(input_addr, length, output_addr); | |||
| } else if (param->type_ == ActivationType::RELU6) { | |||
| ret = Fp32Relu6(input_addr, length, output_addr); | |||
| } else if (param->type_ == ActivationType::LEAKY_RELU) { | |||
| float alpha = param->alpha_; | |||
| ret = LRelu(input_addr, length, output_addr, alpha); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport activation type " << param->type_; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (ret != NNACL_OK) { | |||
| MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ | |||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ | |||
| #include "internal/include/model.h" | |||
| #include "src/runtime/allocator.h" | |||
| int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator); | |||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "internal/src/kernel/fp32/arithmetic_self.h" | |||
| #include "internal/include/errorcode.h" | |||
| #include "internal/include/ms_tensor.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "nnacl/fp32/arithmetic_self.h" | |||
| int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator) { | |||
| size_t data_size = in_tensors[0]->ElementsNum(); | |||
| OpParameter *param = node->primitive_; | |||
| int ret; | |||
| if (param->type_ == KernelType::Log) { | |||
| ret = ElementLog((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); | |||
| } else if (param->type_ == KernelType::Neg) { | |||
| ret = ElementNegative((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (ret != NNACL_OK) { | |||
| MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ | |||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ | |||
| #include "internal/include/model.h" | |||
| #include "src/runtime/allocator.h" | |||
| int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator); | |||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ | |||
| @@ -0,0 +1,145 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "internal/src/kernel/fp32/matmul.h" | |||
| #include "nnacl/fp32/matmul.h" | |||
| #include "internal/include/errorcode.h" | |||
| #include "internal/include/ms_tensor.h" | |||
| #include "utils/log_adapter.h" | |||
| typedef struct MatMulCPUKernelData { | |||
| float *a_c12_ptr_; | |||
| float *b_r8_ptr_; | |||
| float *bias_ptr_; | |||
| } MatMulCPUKernelData; | |||
| void MatMulInitMatrixA(float *src_ptr, float *dst_ptr, MatMulParameter *params) { | |||
| for (int i = 0; i < params->batch; i++) { | |||
| float *src = src_ptr + i * params->deep_ * params->row_; | |||
| float *dst = dst_ptr + i * params->deep_ * params->row_12_; | |||
| if (params->a_transpose_) { | |||
| RowMajor2Row12Major(src, dst, params->deep_, params->row_); | |||
| } else { | |||
| RowMajor2Col12Major(src, dst, params->row_, params->deep_); | |||
| } | |||
| } | |||
| } | |||
| void MatMulInitMatrixB(float *src_ptr, float *dst_ptr, MatMulParameter *params) { | |||
| for (int i = 0; i < params->batch; i++) { | |||
| float *src = src_ptr + i * params->deep_ * params->col_; | |||
| float *dst = dst_ptr + i * params->deep_ * params->col_8_; | |||
| if (params->b_transpose_) { | |||
| RowMajor2Col8Major(src, dst, params->col_, params->deep_); | |||
| } else { | |||
| RowMajor2Row8Major(src, dst, params->deep_, params->col_); | |||
| } | |||
| } | |||
| } | |||
| void FreeMatMulKernelData(MatMulCPUKernelData *kernel_data, mindspore::lite::Allocator *allocator) { | |||
| if (kernel_data == NULL) { | |||
| return; | |||
| } | |||
| if (kernel_data->a_c12_ptr_ != NULL) { | |||
| allocator->Free(kernel_data->a_c12_ptr_); | |||
| kernel_data->a_c12_ptr_ = NULL; | |||
| } | |||
| if (kernel_data->b_r8_ptr_ != NULL) { | |||
| allocator->Free(kernel_data->b_r8_ptr_); | |||
| kernel_data->b_r8_ptr_ = NULL; | |||
| } | |||
| if (kernel_data->bias_ptr_ != NULL) { | |||
| allocator->Free(kernel_data->bias_ptr_); | |||
| kernel_data->bias_ptr_ = NULL; | |||
| } | |||
| free(kernel_data); | |||
| } | |||
| int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator) { | |||
| if (in_tensors[0]->data_ == NULL || in_tensors[1]->data_ ==NULL) { | |||
| MS_LOG(ERROR) << "input data is NULL!"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (allocator == NULL) { | |||
| MS_LOG(ERROR) << "input allocator is NULL!"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| int batch = 1; | |||
| std::vector<int> a_shape = in_tensors[0]->shape_; | |||
| std::vector<int> c_shape = out_tensors[0]->shape_; | |||
| if (in_tensors.size() == 3) { | |||
| std::vector<int> bias_shape = in_tensors[2]->shape_; | |||
| if (bias_shape[bias_shape.size() - 1] != c_shape[c_shape.size() - 1]) { | |||
| MS_LOG(ERROR) << "The bias' dimension is not equal with column"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < a_shape.size() - 2; ++i) { | |||
| batch *= a_shape[i]; | |||
| } | |||
| MatMulParameter *params = (MatMulParameter *)node->primitive_; | |||
| params->batch = batch; | |||
| params->row_ = c_shape[c_shape.size() - 2]; | |||
| params->col_ = c_shape[c_shape.size() - 1]; | |||
| params->deep_ = params->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | |||
| params->row_12_ = UP_ROUND(params->row_, C12NUM); | |||
| params->col_8_ = UP_ROUND(params->col_, 8); | |||
| MatMulCPUKernelData *kernel_data = (MatMulCPUKernelData *)malloc(sizeof(MatMulCPUKernelData)); | |||
| kernel_data->a_c12_ptr_ | |||
| = reinterpret_cast<float *>(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float))); | |||
| if (kernel_data->a_c12_ptr_ == NULL) { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float)); | |||
| kernel_data->b_r8_ptr_ | |||
| = reinterpret_cast<float *>(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float))); | |||
| if (kernel_data->b_r8_ptr_ == NULL) { | |||
| FreeMatMulKernelData(kernel_data, allocator); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(kernel_data->b_r8_ptr_, 0, params->col_8_ * params->deep_ * sizeof(float)); | |||
| MatMulInitMatrixA((float *)in_tensors[0]->data_, kernel_data->a_c12_ptr_, params); | |||
| MatMulInitMatrixB((float *)in_tensors[1]->data_, kernel_data->b_r8_ptr_, params); | |||
| kernel_data->bias_ptr_ = (float *)(allocator->Malloc(params->col_8_ * sizeof(float))); | |||
| if (kernel_data->bias_ptr_ == NULL) { | |||
| FreeMatMulKernelData(kernel_data, allocator); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(kernel_data->bias_ptr_, 0, params->col_8_ * sizeof(float)); | |||
| if (in_tensors.size() == 3) { | |||
| memcpy(kernel_data->bias_ptr_, in_tensors[2]->data_, params->col_ * sizeof(float)); | |||
| } | |||
| auto c_src = (float *)out_tensors[0]->data_; | |||
| for (int i = 0; i < params->batch; ++i) { | |||
| float *a_ptr = kernel_data->a_c12_ptr_ + i * params->row_12_ * params->deep_; | |||
| float *b_ptr = kernel_data->b_r8_ptr_ + i * params->deep_ * params->col_8_; | |||
| float *c_ptr = c_src + i * params->row_ * params->col_; | |||
| MatMulOpt(a_ptr, b_ptr, c_ptr, kernel_data->bias_ptr_, ActType_No, params->deep_, params->row_, params->col_, | |||
| params->col_, OutType_Nhwc); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ | |||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ | |||
| #include "internal/include/model.h" | |||
| #include "src/runtime/allocator.h" | |||
| int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator); | |||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "internal/src/kernel/fp32_grad/activation_grad.h" | |||
| #include "internal/include/errorcode.h" | |||
| #include "internal/include/ms_tensor.h" | |||
| #include "nnacl/fp32_grad/activation_grad.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "nnacl/errorcode.h" | |||
| int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator) { | |||
| ActivationGradParameter *param = (ActivationGradParameter *)node->primitive_; | |||
| int ret = RET_OK; | |||
| size_t length = in_tensors[0]->ElementsNum(); | |||
| float *dy_data = (float *)in_tensors[0]->data_; | |||
| float *x_data = (float *)in_tensors[1]->data_; | |||
| float *dx_data = (float *)(float *)out_tensors[0]->data_; | |||
| if (param->type_ == ActivationType::RELU) { | |||
| ret = ReluGrad(dy_data, x_data, length, dx_data); | |||
| } else if (param->type_ == ActivationType::SIGMOID) { | |||
| ret = SigmoidGrad(dy_data, x_data, length, dx_data); | |||
| } else if (param->type_ == ActivationType::RELU6) { | |||
| ret = Relu6Grad(dy_data, x_data, length, dx_data); | |||
| } else if (param->type_ == ActivationType::LEAKY_RELU) { | |||
| float alpha = param->alpha_; | |||
| ret = LReluGrad(dy_data, x_data, length, dx_data, alpha); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport activation type " << param->type_; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (ret != NNACL_OK) { | |||
| MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||
| #include "internal/include/model.h" | |||
| #include "src/runtime/allocator.h" | |||
| int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator); | |||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h" | |||
| #include "internal/include/errorcode.h" | |||
| #include "internal/include/ms_tensor.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "nnacl/fp32/arithmetic_self.h" | |||
| #include "nnacl/fp32/arithmetic.h" | |||
| int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator) { | |||
| size_t data_size = in_tensors[0]->ElementsNum(); | |||
| OpParameter *param = node->primitive_; | |||
| float *dy_data = (float *)in_tensors[0]->data_; | |||
| float *x_data = (float *)in_tensors[1]->data_; | |||
| float *dx_data = (float *)(float *)out_tensors[0]->data_; | |||
| int ret; | |||
| if (param->type_ == KernelType::LogGrad) { | |||
| ret = ElementDiv(dy_data, x_data, dx_data, data_size); | |||
| } else if (param->type_ == KernelType::NegGrad) { | |||
| ret = ElementNegative(dy_data, dx_data, data_size); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (ret != NNACL_OK) { | |||
| MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ | |||
| #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ | |||
| #include "internal/include/model.h" | |||
| #include "src/runtime/allocator.h" | |||
| int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, | |||
| mindspore::lite::Allocator *allocator); | |||
| #endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ | |||
| @@ -17,6 +17,13 @@ | |||
| #include "internal/include/model.h" | |||
| #include "internal/include/ms_tensor.h" | |||
| #include "src/runtime/allocator.h" | |||
| #include "internal/include/errorcode.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "internal/src/kernel/fp32/activation.h" | |||
| #include "internal/src/kernel/fp32/arithmetic_self.h" | |||
| #include "internal/src/kernel/fp32/matmul.h" | |||
| #include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h" | |||
| #include "internal/src/kernel/fp32_grad/activation_grad.h" | |||
| static Context *g_Ctx; | |||
| static Model *g_Model; | |||
| @@ -58,11 +65,56 @@ TensorPtrVector LiteSession::GetOutputs() const { | |||
| int LiteSession::RunGraph() { | |||
| // invoke nnacl kernel | |||
| return 0; | |||
| NodePtrVector nodes = g_Model->nodes_; | |||
| size_t nodes_size = nodes.size(); | |||
| for (size_t i = 0; i < nodes_size; ++i) { | |||
| auto node = nodes[i]; | |||
| if (node->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "node's primitive is NULL!"; | |||
| return RET_ERROR; | |||
| } | |||
| TensorPtrVector in_tensors; | |||
| for (size_t j = 0; j < node->input_indices_.size(); ++j) { | |||
| in_tensors.push_back(g_Model->all_tensors_[node->input_indices_[j]]); | |||
| } | |||
| TensorPtrVector out_tensors; | |||
| for (size_t j = 0; j < node->output_indices_.size(); ++j) { | |||
| out_tensors.push_back(g_Model->all_tensors_[node->output_indices_[j]]); | |||
| } | |||
| int type = node->primitive_->type_; | |||
| int ret = RET_ERROR; | |||
| switch (type) { | |||
| case KernelType::MatMul: | |||
| ret = DoMatMul(in_tensors, out_tensors, node, &allocator); | |||
| break; | |||
| case KernelType::Activation: | |||
| ret = DoActivation(in_tensors, out_tensors, node, &allocator); | |||
| break; | |||
| case KernelType::Log: | |||
| case KernelType::Neg: | |||
| ret = DoArithmeticSelf(in_tensors, out_tensors, node, &allocator); | |||
| break; | |||
| case KernelType::LogGrad: | |||
| case KernelType::NegGrad: | |||
| ret = DoArithmeticGradSelf(in_tensors, out_tensors, node, &allocator); | |||
| break; | |||
| case KernelType::ActivationGrad: | |||
| ret = DoActivationGrad(in_tensors, out_tensors, node, &allocator); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport kernel type: " << type; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "run kernel fail!ret: " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| StringVector LiteSession::GetOutputTensorNames() const { return StringVector(); } | |||
| MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { return NULL; } | |||
| int LiteSession::Resize(const TensorPtrVector &inputs) { return 0; } | |||
| int LiteSession::Resize(const TensorPtrVector &inputs, Int32VectorVector dims) { return 0; } | |||