From 9531d25baecde9fc9ca343428555d382e0a96bea Mon Sep 17 00:00:00 2001 From: chenjianping Date: Sat, 12 Sep 2020 10:12:16 +0800 Subject: [PATCH] support SiteAI --- mindspore/lite/internal/CMakeLists.txt | 2 + .../lite/internal/include/lite_session.h | 2 +- mindspore/lite/internal/include/lite_utils.h | 3 +- mindspore/lite/internal/include/model.h | 177 ++++++++++++++++++ .../internal/src/kernel/fp32/activation.cc | 49 +++++ .../internal/src/kernel/fp32/activation.h | 26 +++ .../src/kernel/fp32/arithmetic_self.cc | 41 ++++ .../src/kernel/fp32/arithmetic_self.h | 26 +++ .../lite/internal/src/kernel/fp32/matmul.cc | 145 ++++++++++++++ .../lite/internal/src/kernel/fp32/matmul.h | 26 +++ .../src/kernel/fp32_grad/activation_grad.cc | 50 +++++ .../src/kernel/fp32_grad/activation_grad.h | 26 +++ .../kernel/fp32_grad/arithmetic_self_grad.cc | 45 +++++ .../kernel/fp32_grad/arithmetic_self_grad.h | 26 +++ mindspore/lite/internal/src/lite_session.cc | 56 +++++- 15 files changed, 696 insertions(+), 4 deletions(-) create mode 100644 mindspore/lite/internal/src/kernel/fp32/activation.cc create mode 100644 mindspore/lite/internal/src/kernel/fp32/activation.h create mode 100644 mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc create mode 100644 mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h create mode 100644 mindspore/lite/internal/src/kernel/fp32/matmul.cc create mode 100644 mindspore/lite/internal/src/kernel/fp32/matmul.h create mode 100644 mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.cc create mode 100644 mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h create mode 100644 mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc create mode 100644 mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h diff --git a/mindspore/lite/internal/CMakeLists.txt b/mindspore/lite/internal/CMakeLists.txt index 1f55bb53a5..77706f721e 100644 --- a/mindspore/lite/internal/CMakeLists.txt +++ b/mindspore/lite/internal/CMakeLists.txt @@ -8,8 +8,10 @@ file(GLOB_RECURSE C_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) file(GLOB KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/*.c ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32/*.c + ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32_grad/*.c ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/int8/*.c ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/quantization/*.c + ${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/fp32/*.cc ) list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c) diff --git a/mindspore/lite/internal/include/lite_session.h b/mindspore/lite/internal/include/lite_session.h index 7d264e27fe..55d2095a27 100644 --- a/mindspore/lite/internal/include/lite_session.h +++ b/mindspore/lite/internal/include/lite_session.h @@ -84,7 +84,7 @@ typedef struct LiteSession { /// \param[in] inputs Define the new inputs shape. /// /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. - int Resize(const TensorPtrVector &inputs); + int Resize(const TensorPtrVector &inputs, Int32VectorVector dims); } LiteSession; #endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H diff --git a/mindspore/lite/internal/include/lite_utils.h b/mindspore/lite/internal/include/lite_utils.h index 663fd2bb5a..5661b75f0c 100644 --- a/mindspore/lite/internal/include/lite_utils.h +++ b/mindspore/lite/internal/include/lite_utils.h @@ -27,5 +27,6 @@ using String = std::string; using StringVector = std::vector; using ShapeVector = std::vector; using NodePtrVector = std::vector; - +using Int32Vector = std::vector; +using Int32VectorVector = std::vector; #endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_ diff --git a/mindspore/lite/internal/include/model.h b/mindspore/lite/internal/include/model.h index eb1b791d36..eafbda9e77 100644 --- a/mindspore/lite/internal/include/model.h +++ b/mindspore/lite/internal/include/model.h @@ -27,6 +27,183 @@ enum NodeType { NodeType_MAX = NodeType_CNode }; +enum KernelType { + Concat, + SoftMax, + Activation, + Conv2D, + FusedBatchNorm, + BatchNorm, + BiasAdd, + Pooling, + ROIPooling, + DepthwiseConv2D, + DeDepthwiseConv2D, + Resize, + DetectionPostProcess, + FullConnection, + Mean, + DeConv2D, + Scale, + Reshape, + Eltwise, + NetOutput, + Add, + Sub, + MatMul, + StridedSlice, + Power, + Slice, + Stack, + Mul, + RealDiv, + Pad, + Maximum, + Minimum, + PReLU, + LeakyReLU, + ArgMax, + ArgMin, + Exp, + Crop, + Range, + Rsqrt, + ExpandDims, + Tile, + Cast, + Shape, + Nchw2Nhwc, + Nhwc2Nchw, + QuantDTypeCast, + Split, + Permute, + FakeQuantWithMinMaxVars, + Equal, + Less, + Greater, + NotEqual, + LessEqual, + GreaterEqual, + Min, + Floor, + Abs, + Neg, + Cos, + Sin, + Sqrt, + Square, + Constant, + Log, + Tan, + Atan, + Asin, + Clip, + Transpose, + Squeeze, + Unsqueeze, + Upsample, + Dropout, + Broadcast, + BroadcastTo, + Lrn, + ZerosLike, + TopK, + SpaceToDepth, + SpaceToBatch, + SparseToDense, + ReverseSequence, + Rank, + Gather, + GatherNd, + Fill, + Elu, + DepthToSpace, + BatchToSpace, + AddN, + Ceil, + EmbeddingLookup, + EmbeddingLookupSparse, + FloorDiv, + FloorMod, + L2Norm, + LocalResponseNormalization, + MatrixDiag, + Reduce, + Reverse, + Round, + Select, + Scatter, + ScatterND, + ConstantOfShape, + Unique, + Unstack, + LogicalAnd, + LogicalOr, + LogicalXor, + LogicalNot, + OnnxInt8Quantize, + OnnxInt8Dequantize, + FakeQuantWithMinMax, + FakeQuantWithMinMaxPerChannel, + BatchNormFold, + MulFold, + AddFold, + SquaredDifference, + Flatten, + FlattenGrad, + TupleGetItem, + Div, + Where, + OneHot, + Lstm, + Conv2DGradFilter, + Conv2DGradInput, + PoolingGrad, + BNGrad, + BNGradInput, + ApplyMomentum, + BiasGrad, + SoftmaxCrossEntropy, + AddGrad, + SubGrad, + MulGrad, + DivGrad, + PowerGrad, + ActivationGrad, + PriorBox, + SpaceToBatchND, + Depend, + Return, + MakeTuple, + ToFormat, + Proposal, + Custom, + BlackBox, + NegGrad, + LogGrad, + BatchToSpaceND, +}; + +enum ActivationType { + NO_ACTIVATION = 0, + RELU = 1, + SIGMOID = 2, + RELU6 = 3, + ELU = 4, + LEAKY_RELU = 5, + ABS = 6, + RELU1 = 7, + SOFTSIGN = 8, + SOFTPLUS = 9, + TANH = 10, + SELU = 11, + HSWISH = 12, + HSIGMOID = 13, + THRESHOLDRELU = 14, + LINEAR = 15, + UNKNOW = 16 +}; + typedef struct Node { String name_; NodeType node_type_; diff --git a/mindspore/lite/internal/src/kernel/fp32/activation.cc b/mindspore/lite/internal/src/kernel/fp32/activation.cc new file mode 100644 index 0000000000..d2eb03175f --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32/activation.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/src/kernel/fp32/activation.h" +#include "internal/include/errorcode.h" +#include "internal/include/ms_tensor.h" +#include "nnacl/fp32/activation.h" +#include "utils/log_adapter.h" +#include "nnacl/errorcode.h" + +int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator) { + ActivationParameter *param = (ActivationParameter *)node->primitive_; + int ret = RET_OK; + size_t length = in_tensors[0]->ElementsNum(); + float *input_addr = (float *)in_tensors[0]->data_; + float *output_addr = (float *)out_tensors[0]->data_; + if (param->type_ == ActivationType::RELU) { + ret = Fp32Relu(input_addr, length, output_addr); + } else if (param->type_ == ActivationType::SIGMOID) { + ret = Sigmoid(input_addr, length, output_addr); + } else if (param->type_ == ActivationType::RELU6) { + ret = Fp32Relu6(input_addr, length, output_addr); + } else if (param->type_ == ActivationType::LEAKY_RELU) { + float alpha = param->alpha_; + ret = LRelu(input_addr, length, output_addr, alpha); + } else { + MS_LOG(ERROR) << "Unsupport activation type " << param->type_; + return RET_PARAM_INVALID; + } + if (ret != NNACL_OK) { + MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} diff --git a/mindspore/lite/internal/src/kernel/fp32/activation.h b/mindspore/lite/internal/src/kernel/fp32/activation.h new file mode 100644 index 0000000000..50f28f26e1 --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32/activation.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ +#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ + +#include "internal/include/model.h" +#include "src/runtime/allocator.h" + +int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator); + +#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ diff --git a/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc new file mode 100644 index 0000000000..9ca58508fa --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/src/kernel/fp32/arithmetic_self.h" +#include "internal/include/errorcode.h" +#include "internal/include/ms_tensor.h" +#include "utils/log_adapter.h" +#include "nnacl/fp32/arithmetic_self.h" + +int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator) { + size_t data_size = in_tensors[0]->ElementsNum(); + OpParameter *param = node->primitive_; + int ret; + if (param->type_ == KernelType::Log) { + ret = ElementLog((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); + } else if (param->type_ == KernelType::Neg) { + ret = ElementNegative((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); + } else { + MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_; + return RET_PARAM_INVALID; + } + if (ret != NNACL_OK) { + MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} diff --git a/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h new file mode 100644 index 0000000000..3d4c285a6d --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ +#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ + +#include "internal/include/model.h" +#include "src/runtime/allocator.h" + +int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator); + +#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ diff --git a/mindspore/lite/internal/src/kernel/fp32/matmul.cc b/mindspore/lite/internal/src/kernel/fp32/matmul.cc new file mode 100644 index 0000000000..d525359bc4 --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32/matmul.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/src/kernel/fp32/matmul.h" +#include "nnacl/fp32/matmul.h" +#include "internal/include/errorcode.h" +#include "internal/include/ms_tensor.h" +#include "utils/log_adapter.h" + +typedef struct MatMulCPUKernelData { + float *a_c12_ptr_; + float *b_r8_ptr_; + float *bias_ptr_; +} MatMulCPUKernelData; + +void MatMulInitMatrixA(float *src_ptr, float *dst_ptr, MatMulParameter *params) { + for (int i = 0; i < params->batch; i++) { + float *src = src_ptr + i * params->deep_ * params->row_; + float *dst = dst_ptr + i * params->deep_ * params->row_12_; + if (params->a_transpose_) { + RowMajor2Row12Major(src, dst, params->deep_, params->row_); + } else { + RowMajor2Col12Major(src, dst, params->row_, params->deep_); + } + } +} + +void MatMulInitMatrixB(float *src_ptr, float *dst_ptr, MatMulParameter *params) { + for (int i = 0; i < params->batch; i++) { + float *src = src_ptr + i * params->deep_ * params->col_; + float *dst = dst_ptr + i * params->deep_ * params->col_8_; + if (params->b_transpose_) { + RowMajor2Col8Major(src, dst, params->col_, params->deep_); + } else { + RowMajor2Row8Major(src, dst, params->deep_, params->col_); + } + } +} + +void FreeMatMulKernelData(MatMulCPUKernelData *kernel_data, mindspore::lite::Allocator *allocator) { + if (kernel_data == NULL) { + return; + } + if (kernel_data->a_c12_ptr_ != NULL) { + allocator->Free(kernel_data->a_c12_ptr_); + kernel_data->a_c12_ptr_ = NULL; + } + + if (kernel_data->b_r8_ptr_ != NULL) { + allocator->Free(kernel_data->b_r8_ptr_); + kernel_data->b_r8_ptr_ = NULL; + } + + if (kernel_data->bias_ptr_ != NULL) { + allocator->Free(kernel_data->bias_ptr_); + kernel_data->bias_ptr_ = NULL; + } + free(kernel_data); +} + +int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator) { + if (in_tensors[0]->data_ == NULL || in_tensors[1]->data_ ==NULL) { + MS_LOG(ERROR) << "input data is NULL!"; + return RET_PARAM_INVALID; + } + if (allocator == NULL) { + MS_LOG(ERROR) << "input allocator is NULL!"; + return RET_PARAM_INVALID; + } + int batch = 1; + std::vector a_shape = in_tensors[0]->shape_; + std::vector c_shape = out_tensors[0]->shape_; + if (in_tensors.size() == 3) { + std::vector bias_shape = in_tensors[2]->shape_; + if (bias_shape[bias_shape.size() - 1] != c_shape[c_shape.size() - 1]) { + MS_LOG(ERROR) << "The bias' dimension is not equal with column"; + return RET_INPUT_TENSOR_ERROR; + } + } + for (size_t i = 0; i < a_shape.size() - 2; ++i) { + batch *= a_shape[i]; + } + + MatMulParameter *params = (MatMulParameter *)node->primitive_; + params->batch = batch; + params->row_ = c_shape[c_shape.size() - 2]; + params->col_ = c_shape[c_shape.size() - 1]; + params->deep_ = params->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; + params->row_12_ = UP_ROUND(params->row_, C12NUM); + params->col_8_ = UP_ROUND(params->col_, 8); + + MatMulCPUKernelData *kernel_data = (MatMulCPUKernelData *)malloc(sizeof(MatMulCPUKernelData)); + kernel_data->a_c12_ptr_ + = reinterpret_cast(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float))); + if (kernel_data->a_c12_ptr_ == NULL) { + return RET_MEMORY_FAILED; + } + memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float)); + + kernel_data->b_r8_ptr_ + = reinterpret_cast(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float))); + if (kernel_data->b_r8_ptr_ == NULL) { + FreeMatMulKernelData(kernel_data, allocator); + return RET_MEMORY_FAILED; + } + memset(kernel_data->b_r8_ptr_, 0, params->col_8_ * params->deep_ * sizeof(float)); + + MatMulInitMatrixA((float *)in_tensors[0]->data_, kernel_data->a_c12_ptr_, params); + MatMulInitMatrixB((float *)in_tensors[1]->data_, kernel_data->b_r8_ptr_, params); + kernel_data->bias_ptr_ = (float *)(allocator->Malloc(params->col_8_ * sizeof(float))); + if (kernel_data->bias_ptr_ == NULL) { + FreeMatMulKernelData(kernel_data, allocator); + return RET_MEMORY_FAILED; + } + memset(kernel_data->bias_ptr_, 0, params->col_8_ * sizeof(float)); + + if (in_tensors.size() == 3) { + memcpy(kernel_data->bias_ptr_, in_tensors[2]->data_, params->col_ * sizeof(float)); + } + auto c_src = (float *)out_tensors[0]->data_; + for (int i = 0; i < params->batch; ++i) { + float *a_ptr = kernel_data->a_c12_ptr_ + i * params->row_12_ * params->deep_; + float *b_ptr = kernel_data->b_r8_ptr_ + i * params->deep_ * params->col_8_; + float *c_ptr = c_src + i * params->row_ * params->col_; + MatMulOpt(a_ptr, b_ptr, c_ptr, kernel_data->bias_ptr_, ActType_No, params->deep_, params->row_, params->col_, + params->col_, OutType_Nhwc); + } + + return RET_OK; +} + diff --git a/mindspore/lite/internal/src/kernel/fp32/matmul.h b/mindspore/lite/internal/src/kernel/fp32/matmul.h new file mode 100644 index 0000000000..3d9a701392 --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32/matmul.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ +#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ + +#include "internal/include/model.h" +#include "src/runtime/allocator.h" + +int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator); + +#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.cc b/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.cc new file mode 100644 index 0000000000..0d9d60caa8 --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/src/kernel/fp32_grad/activation_grad.h" +#include "internal/include/errorcode.h" +#include "internal/include/ms_tensor.h" +#include "nnacl/fp32_grad/activation_grad.h" +#include "utils/log_adapter.h" +#include "nnacl/errorcode.h" + +int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator) { + ActivationGradParameter *param = (ActivationGradParameter *)node->primitive_; + int ret = RET_OK; + size_t length = in_tensors[0]->ElementsNum(); + float *dy_data = (float *)in_tensors[0]->data_; + float *x_data = (float *)in_tensors[1]->data_; + float *dx_data = (float *)(float *)out_tensors[0]->data_; + if (param->type_ == ActivationType::RELU) { + ret = ReluGrad(dy_data, x_data, length, dx_data); + } else if (param->type_ == ActivationType::SIGMOID) { + ret = SigmoidGrad(dy_data, x_data, length, dx_data); + } else if (param->type_ == ActivationType::RELU6) { + ret = Relu6Grad(dy_data, x_data, length, dx_data); + } else if (param->type_ == ActivationType::LEAKY_RELU) { + float alpha = param->alpha_; + ret = LReluGrad(dy_data, x_data, length, dx_data, alpha); + } else { + MS_LOG(ERROR) << "Unsupport activation type " << param->type_; + return RET_PARAM_INVALID; + } + if (ret != NNACL_OK) { + MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h b/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h new file mode 100644 index 0000000000..fee6b5ec49 --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ +#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ + +#include "internal/include/model.h" +#include "src/runtime/allocator.h" + +int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator); + +#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc new file mode 100644 index 0000000000..80356b4400 --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h" +#include "internal/include/errorcode.h" +#include "internal/include/ms_tensor.h" +#include "utils/log_adapter.h" +#include "nnacl/fp32/arithmetic_self.h" +#include "nnacl/fp32/arithmetic.h" + +int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator) { + size_t data_size = in_tensors[0]->ElementsNum(); + OpParameter *param = node->primitive_; + float *dy_data = (float *)in_tensors[0]->data_; + float *x_data = (float *)in_tensors[1]->data_; + float *dx_data = (float *)(float *)out_tensors[0]->data_; + int ret; + if (param->type_ == KernelType::LogGrad) { + ret = ElementDiv(dy_data, x_data, dx_data, data_size); + } else if (param->type_ == KernelType::NegGrad) { + ret = ElementNegative(dy_data, dx_data, data_size); + } else { + MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_; + return RET_PARAM_INVALID; + } + if (ret != NNACL_OK) { + MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret; + return RET_ERROR; + } + return RET_OK; +} diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h new file mode 100644 index 0000000000..952ab2bc7b --- /dev/null +++ b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ +#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ + +#include "internal/include/model.h" +#include "src/runtime/allocator.h" + +int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node, + mindspore::lite::Allocator *allocator); + +#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ diff --git a/mindspore/lite/internal/src/lite_session.cc b/mindspore/lite/internal/src/lite_session.cc index 7c8b928040..d791529971 100644 --- a/mindspore/lite/internal/src/lite_session.cc +++ b/mindspore/lite/internal/src/lite_session.cc @@ -17,6 +17,13 @@ #include "internal/include/model.h" #include "internal/include/ms_tensor.h" #include "src/runtime/allocator.h" +#include "internal/include/errorcode.h" +#include "utils/log_adapter.h" +#include "internal/src/kernel/fp32/activation.h" +#include "internal/src/kernel/fp32/arithmetic_self.h" +#include "internal/src/kernel/fp32/matmul.h" +#include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h" +#include "internal/src/kernel/fp32_grad/activation_grad.h" static Context *g_Ctx; static Model *g_Model; @@ -58,11 +65,56 @@ TensorPtrVector LiteSession::GetOutputs() const { int LiteSession::RunGraph() { // invoke nnacl kernel - return 0; + NodePtrVector nodes = g_Model->nodes_; + size_t nodes_size = nodes.size(); + for (size_t i = 0; i < nodes_size; ++i) { + auto node = nodes[i]; + if (node->primitive_ == nullptr) { + MS_LOG(ERROR) << "node's primitive is NULL!"; + return RET_ERROR; + } + TensorPtrVector in_tensors; + for (size_t j = 0; j < node->input_indices_.size(); ++j) { + in_tensors.push_back(g_Model->all_tensors_[node->input_indices_[j]]); + } + TensorPtrVector out_tensors; + for (size_t j = 0; j < node->output_indices_.size(); ++j) { + out_tensors.push_back(g_Model->all_tensors_[node->output_indices_[j]]); + } + int type = node->primitive_->type_; + int ret = RET_ERROR; + switch (type) { + case KernelType::MatMul: + ret = DoMatMul(in_tensors, out_tensors, node, &allocator); + break; + case KernelType::Activation: + ret = DoActivation(in_tensors, out_tensors, node, &allocator); + break; + case KernelType::Log: + case KernelType::Neg: + ret = DoArithmeticSelf(in_tensors, out_tensors, node, &allocator); + break; + case KernelType::LogGrad: + case KernelType::NegGrad: + ret = DoArithmeticGradSelf(in_tensors, out_tensors, node, &allocator); + break; + case KernelType::ActivationGrad: + ret = DoActivationGrad(in_tensors, out_tensors, node, &allocator); + break; + default: + MS_LOG(ERROR) << "Unsupport kernel type: " << type; + return RET_PARAM_INVALID; + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "run kernel fail!ret: " << ret; + return ret; + } + } + return RET_OK; } StringVector LiteSession::GetOutputTensorNames() const { return StringVector(); } MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { return NULL; } -int LiteSession::Resize(const TensorPtrVector &inputs) { return 0; } +int LiteSession::Resize(const TensorPtrVector &inputs, Int32VectorVector dims) { return 0; }