From 5f35574930189ad6736ee1bdc2bed3da614a0812 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Tue, 22 Sep 2020 10:27:19 +0800 Subject: [PATCH] lite/internal support allocator --- mindspore/lite/include/lite_session.h | 2 +- mindspore/lite/include/model.h | 2 +- mindspore/lite/internal/CMakeLists.txt | 14 +- .../lite/internal/include/lite_session.h | 1 + mindspore/lite/internal/include/model.h | 310 +++++++++--------- mindspore/lite/internal/include/ms_tensor.h | 2 +- mindspore/lite/internal/include/string.h | 10 +- mindspore/lite/internal/include/vector.h | 5 + mindspore/lite/internal/src/allocator.cc | 220 +++++++++++++ mindspore/lite/internal/src/allocator.h | 60 ++++ mindspore/lite/internal/src/common/vector.cc | 35 +- .../internal/src/kernel/fp32/activation.h | 2 +- .../internal/src/kernel/fp32/arithmetic.cc | 33 +- .../internal/src/kernel/fp32/arithmetic.h | 2 +- .../src/kernel/fp32/arithmetic_self.cc | 4 +- .../src/kernel/fp32/arithmetic_self.h | 2 +- .../lite/internal/src/kernel/fp32/bias_add.cc | 13 +- .../lite/internal/src/kernel/fp32/bias_add.h | 2 +- .../lite/internal/src/kernel/fp32/matmul.cc | 48 +-- .../lite/internal/src/kernel/fp32/matmul.h | 2 +- .../lite/internal/src/kernel/fp32/reduce.cc | 75 ++--- .../lite/internal/src/kernel/fp32/reduce.h | 2 +- .../src/kernel/fp32_grad/activation_grad.h | 2 +- .../kernel/fp32_grad/arithmetic_self_grad.cc | 4 +- .../kernel/fp32_grad/arithmetic_self_grad.h | 2 +- mindspore/lite/internal/src/lite_log.h | 8 +- mindspore/lite/internal/src/lite_session.cc | 53 +-- mindspore/lite/internal/src/ms_tensor.cc | 15 +- mindspore/lite/nnacl/fp32/matmul.c | 35 ++ mindspore/lite/nnacl/fp32/matmul.h | 5 + .../lite/test/ut/internal/CMakeLists.txt | 11 +- .../lite/test/ut/internal/allocator_test.cc | 99 ++++++ mindspore/lite/test/ut/internal/infer_test.cc | 8 +- .../lite/test/ut/internal/vector_test.cc | 54 +++ 34 files changed, 814 insertions(+), 328 deletions(-) create mode 100644 mindspore/lite/internal/src/allocator.cc create mode 100644 mindspore/lite/internal/src/allocator.h create mode 100644 mindspore/lite/test/ut/internal/allocator_test.cc create mode 100644 mindspore/lite/test/ut/internal/vector_test.cc diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h index 28222de1e1..81ed1b020e 100644 --- a/mindspore/lite/include/lite_session.h +++ b/mindspore/lite/include/lite_session.h @@ -114,7 +114,7 @@ class MS_API LiteSession { /// \brief Resize inputs shape. /// /// \param[in] inputs Define the inputs of the model. - /// \param[in] inputs Define the inputs new shape. + /// \param[in] dims Define the inputs new shape. /// /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. virtual int Resize(const std::vector &inputs, const std::vector>& dims) = 0; diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h index 621f94b513..1da850eb48 100644 --- a/mindspore/lite/include/model.h +++ b/mindspore/lite/include/model.h @@ -48,7 +48,7 @@ struct Model { /// \brief Free meta graph temporary buffer virtual void Free(); - /// \brief Free all temporay buffer + /// \brief Free all temporay buffer.EG: nodes in the model. void Destroy(); /// \brief Model destruct, free all memory diff --git a/mindspore/lite/internal/CMakeLists.txt b/mindspore/lite/internal/CMakeLists.txt index 86292f709e..6237f6c369 100644 --- a/mindspore/lite/internal/CMakeLists.txt +++ b/mindspore/lite/internal/CMakeLists.txt @@ -1,8 +1,10 @@ cmake_minimum_required(VERSION 3.14) project (Lite_Internal) set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../) - +set(CMAKE_CXX_COMPILER ${CMAKE_C_COMPILER}) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions") include_directories(${TOP_DIR}) +add_compile_definitions(ENABLE_NNACL_INFER_SHAPE) file(GLOB KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/arithmetic_common.c @@ -26,13 +28,11 @@ endif () list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c) set(CCSRC + ${CMAKE_CURRENT_SOURCE_DIR}/src/common/vector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/src/common/string.cc ${CMAKE_CURRENT_SOURCE_DIR}/src/lite_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/src/allocator.cc ${CMAKE_CURRENT_SOURCE_DIR}/src/ms_tensor.cc - ${CMAKE_CURRENT_SOURCE_DIR}/src/common/string.cc - ${CMAKE_CURRENT_SOURCE_DIR}/src/common/vector.cc - ${TOP_DIR}/src/common/log_adapter.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../core/gvar/logging_level.cc - ${TOP_DIR}/src/runtime/allocator.cc ) if (PLATFORM_ARM64) @@ -43,6 +43,4 @@ if (PLATFORM_ARM64) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) add_library(mslite_internal SHARED ${CCSRC} ${KERNEL_SRC} ${TRAIN_KERNEL_SRC}) - target_link_libraries(mslite_internal log) endif() - diff --git a/mindspore/lite/internal/include/lite_session.h b/mindspore/lite/internal/include/lite_session.h index c1bf5242df..670232d75a 100644 --- a/mindspore/lite/internal/include/lite_session.h +++ b/mindspore/lite/internal/include/lite_session.h @@ -82,6 +82,7 @@ typedef struct LiteSession { /// \brief Resize inputs shape. /// /// \param[in] inputs Define the new inputs shape. + /// \param[in] dims Define the inputs new shape. /// /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. int Resize(const TensorPtrVector &inputs, const Int32VectorVector &dims); diff --git a/mindspore/lite/internal/include/model.h b/mindspore/lite/internal/include/model.h index ba926a8cb5..dbe3d31b72 100644 --- a/mindspore/lite/internal/include/model.h +++ b/mindspore/lite/internal/include/model.h @@ -28,161 +28,161 @@ enum NodeType { }; enum KernelType : int { - Concat = 0, - SoftMax, - Activation, - Conv2D, - FusedBatchNorm, - BatchNorm, - BiasAdd, - Pooling, - ROIPooling, - DepthwiseConv2D, - DeDepthwiseConv2D, - Resize, - DetectionPostProcess, - FullConnection, - Mean, - DeConv2D, - Scale, - Reshape, - Eltwise, - NetOutput, - Add, - Sub, - MatMul, - StridedSlice, - Power, - Slice, - Stack, - Mul, - RealDiv, - Pad, - Maximum, - Minimum, - PReLU, - LeakyReLU, - ArgMax, - ArgMin, - Exp, - Crop, - Range, - Rsqrt, - ExpandDims, - Tile, - Cast, - Shape, - Nchw2Nhwc, - Nhwc2Nchw, - QuantDTypeCast, - Split, - Permute, - FakeQuantWithMinMaxVars, - Equal, - Less, - Greater, - NotEqual, - LessEqual, - GreaterEqual, - Min, - Floor, - Abs, - Neg, - Cos, - Sin, - Sqrt, - Square, - Constant, - Log, - Tan, - Atan, - Asin, - Clip, - Transpose, - Squeeze, - Unsqueeze, - Upsample, - Dropout, - Broadcast, - BroadcastTo, - Lrn, - ZerosLike, - TopK, - SpaceToDepth, - SpaceToBatch, - SparseToDense, - ReverseSequence, - Rank, - Gather, - GatherNd, - Fill, - Elu, - DepthToSpace, - BatchToSpace, - AddN, - Ceil, - EmbeddingLookup, - EmbeddingLookupSparse, - FloorDiv, - FloorMod, - L2Norm, - LocalResponseNormalization, - MatrixDiag, - Reduce, - Reverse, - Round, - Select, - Scatter, - ScatterND, - ConstantOfShape, - Unique, - Unstack, - LogicalAnd, - LogicalOr, - LogicalXor, - LogicalNot, - OnnxInt8Quantize, - OnnxInt8Dequantize, - FakeQuantWithMinMax, - FakeQuantWithMinMaxPerChannel, - BatchNormFold, - MulFold, - AddFold, - SquaredDifference, - Flatten, - FlattenGrad, - TupleGetItem, - Div, - Where, - OneHot, - Lstm, - Conv2DGradFilter, - Conv2DGradInput, - PoolingGrad, - BNGrad, - BNGradInput, - ApplyMomentum, - BiasGrad, - SoftmaxCrossEntropy, - AddGrad, - SubGrad, - MulGrad, - DivGrad, - PowerGrad, - ActivationGrad, - PriorBox, - SpaceToBatchND, - Depend, - Return, - MakeTuple, - ToFormat, - Proposal, - Custom, - BlackBox, - NegGrad, - LogGrad, - BatchToSpaceND, - END, + KernelType_Concat = 0, + KernelType_SoftMax, + KernelType_Activation, + KernelType_Conv2D, + KernelType_FusedBatchNorm, + KernelType_BatchNorm, + KernelType_BiasAdd, + KernelType_Pooling, + KernelType_ROIPooling, + KernelType_DepthwiseConv2D, + KernelType_DeDepthwiseConv2D, + KernelType_Resize, + KernelType_DetectionPostProcess, + KernelType_FullConnection, + KernelType_Mean, + KernelType_DeConv2D, + KernelType_Scale, + KernelType_Reshape, + KernelType_Eltwise, + KernelType_NetOutput, + KernelType_Add, + KernelType_Sub, + KernelType_MatMul, + KernelType_StridedSlice, + KernelType_Power, + KernelType_Slice, + KernelType_Stack, + KernelType_Mul, + KernelType_RealDiv, + KernelType_Pad, + KernelType_Maximum, + KernelType_Minimum, + KernelType_PReLU, + KernelType_LeakyReLU, + KernelType_ArgMax, + KernelType_ArgMin, + KernelType_Exp, + KernelType_Crop, + KernelType_Range, + KernelType_Rsqrt, + KernelType_ExpandDims, + KernelType_Tile, + KernelType_Cast, + KernelType_Shape, + KernelType_Nchw2Nhwc, + KernelType_Nhwc2Nchw, + KernelType_QuantDTypeCast, + KernelType_Split, + KernelType_Permute, + KernelType_FakeQuantWithMinMaxVars, + KernelType_Equal, + KernelType_Less, + KernelType_Greater, + KernelType_NotEqual, + KernelType_LessEqual, + KernelType_GreaterEqual, + KernelType_Min, + KernelType_Floor, + KernelType_Abs, + KernelType_Neg, + KernelType_Cos, + KernelType_Sin, + KernelType_Sqrt, + KernelType_Square, + KernelType_Constant, + KernelType_Log, + KernelType_Tan, + KernelType_Atan, + KernelType_Asin, + KernelType_Clip, + KernelType_Transpose, + KernelType_Squeeze, + KernelType_Unsqueeze, + KernelType_Upsample, + KernelType_Dropout, + KernelType_Broadcast, + KernelType_BroadcastTo, + KernelType_Lrn, + KernelType_ZerosLike, + KernelType_TopK, + KernelType_SpaceToDepth, + KernelType_SpaceToBatch, + KernelType_SparseToDense, + KernelType_ReverseSequence, + KernelType_Rank, + KernelType_Gather, + KernelType_GatherNd, + KernelType_Fill, + KernelType_Elu, + KernelType_DepthToSpace, + KernelType_BatchToSpace, + KernelType_AddN, + KernelType_Ceil, + KernelType_EmbeddingLookup, + KernelType_EmbeddingLookupSparse, + KernelType_FloorDiv, + KernelType_FloorMod, + KernelType_L2Norm, + KernelType_LocalResponseNormalization, + KernelType_MatrixDiag, + KernelType_Reduce, + KernelType_Reverse, + KernelType_Round, + KernelType_Select, + KernelType_Scatter, + KernelType_ScatterND, + KernelType_ConstantOfShape, + KernelType_Unique, + KernelType_Unstack, + KernelType_LogicalAnd, + KernelType_LogicalOr, + KernelType_LogicalXor, + KernelType_LogicalNot, + KernelType_OnnxInt8Quantize, + KernelType_OnnxInt8Dequantize, + KernelType_FakeQuantWithMinMax, + KernelType_FakeQuantWithMinMaxPerChannel, + KernelType_BatchNormFold, + KernelType_MulFold, + KernelType_AddFold, + KernelType_SquaredDifference, + KernelType_Flatten, + KernelType_FlattenGrad, + KernelType_TupleGetItem, + KernelType_Div, + KernelType_Where, + KernelType_OneHot, + KernelType_Lstm, + KernelType_Conv2DGradFilter, + KernelType_Conv2DGradInput, + KernelType_PoolingGrad, + KernelType_BNGrad, + KernelType_BNGradInput, + KernelType_ApplyMomentum, + KernelType_BiasGrad, + KernelType_SoftmaxCrossEntropy, + KernelType_AddGrad, + KernelType_SubGrad, + KernelType_MulGrad, + KernelType_DivGrad, + KernelType_PowerGrad, + KernelType_ActivationGrad, + KernelType_PriorBox, + KernelType_SpaceToBatchND, + KernelType_Depend, + KernelType_Return, + KernelType_MakeTuple, + KernelType_ToFormat, + KernelType_Proposal, + KernelType_Custom, + KernelType_BlackBox, + KernelType_NegGrad, + KernelType_LogGrad, + KernelType_BatchToSpaceND, + KernelType_END, }; enum ActivationType { diff --git a/mindspore/lite/internal/include/ms_tensor.h b/mindspore/lite/internal/include/ms_tensor.h index e261e3ab91..8b6aac7f52 100644 --- a/mindspore/lite/internal/include/ms_tensor.h +++ b/mindspore/lite/internal/include/ms_tensor.h @@ -107,7 +107,7 @@ typedef struct MSTensor { TypeId data_type_; Format format_ = Format_NHWC; Category category_ = VAR; - ShapeVector shape_ = {}; + ShapeVector shape_; size_t refCount = 0; int32_t Batch() const; diff --git a/mindspore/lite/internal/include/string.h b/mindspore/lite/internal/include/string.h index 3e92a466ea..177c111ac0 100644 --- a/mindspore/lite/internal/include/string.h +++ b/mindspore/lite/internal/include/string.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef INTERNAL_SRC_STRING_H_ -#define INTERNAL_SRC_STRING_H_ +#ifndef MINDSPORE_LITE_INTERNAL_SRC_STRING_H_ +#define MINDSPORE_LITE_INTERNAL_SRC_STRING_H_ #include #include @@ -34,8 +34,8 @@ typedef struct String { char &at(size_t pos); const char &at(size_t pos) const; - char &operator[](size_t pos); - const char &operator[](size_t pos) const; + inline char &operator[](size_t pos); + inline const char &operator[](size_t pos) const; char *data() noexcept; const char *data() const noexcept; const char *c_str() const noexcept; @@ -97,4 +97,4 @@ String to_String(float value); String to_String(double value); String to_String(long double value); -#endif // INTERNAL_SRC_STRING_H_ +#endif // MINDSPORE_LITE_INTERNAL_SRC_STRING_H_ diff --git a/mindspore/lite/internal/include/vector.h b/mindspore/lite/internal/include/vector.h index bb0cbc9020..775400652c 100644 --- a/mindspore/lite/internal/include/vector.h +++ b/mindspore/lite/internal/include/vector.h @@ -17,6 +17,7 @@ #define MINDSPORE_LITE_INTERNAL_INCLUDE_VECTOR_H #include +#include #include "internal/include/string.h" #define DEFAULT_CAPACITY 1 @@ -44,6 +45,8 @@ class Vector { void push_back(const T &elem); + void push_back(T &&); + void pop_back(); void insert(const T &elem, size_t index); @@ -87,6 +90,8 @@ class Vector { void resize(size_t size); void reserve(size_t capacity); + + Vector &operator=(const Vector &v); }; template diff --git a/mindspore/lite/internal/src/allocator.cc b/mindspore/lite/internal/src/allocator.cc new file mode 100644 index 0000000000..b65101ad60 --- /dev/null +++ b/mindspore/lite/internal/src/allocator.cc @@ -0,0 +1,220 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/src/allocator.h" +#include +#include "internal/src/lite_log.h" + +namespace mindspore::lite { +namespace { +constexpr size_t kMaxMallocSize = 2000 * 1024 * 1024; +constexpr int kBlockSize = 1024; +constexpr size_t kBlockLimit = (kBlockSize << (kBlockRange - 1)); + +int SizeToIndex(size_t size) { + if (size > kBlockLimit) { + return -1; + } + int index = 0; + for (int i = 0; i < kBlockRange; ++i) { + if ((size & (kBlockSize << i))) { + index = i; + } + } + if (size > (size_t)(kBlockSize << index)) { + index += 1; + } + return index; +} + +void PopMemNode(MemNode **head) { + if (*head == nullptr) { + return; + } + MemNode *next = (*head)->next_; + (*head) = next; + if (*head != nullptr) { + (*head)->pre_ = nullptr; + } +} + +void PushMemNode(MemNode **head, MemNode *node) { + if (node == nullptr) { + return; + } + if (*head == nullptr) { + *head = node; + return; + } + (*head)->pre_ = node; + node->next_ = *head; + node->pre_ = nullptr; + *head = node; +} + +void RemoveMemNode(MemNode **head, MemNode *node) { + if (node == nullptr) { + return; + } + if ((*head) == node) { + *head = node->next_; + if (*head != nullptr) { + (*head)->pre_ = nullptr; + } + } else { + MemNode *node_pre = node->pre_; + node_pre->next_ = node->next_; + node->next_ = nullptr; + node->pre_ = nullptr; + } +} + +void FreeNodesList(MemNode *head) { + MemNode *node = head; + while (node != nullptr) { + MemNode *next = node->next_; + free(node); + node = next; + } +} +} // namespace + +Allocator::Allocator() { + for (int i = 0; i < kBlockRange; ++i) { + allocated_list_[i] = nullptr; + free_list_[i] = nullptr; + } +} + +Allocator::~Allocator() { Clear(); } + +void Allocator::SetContext(const AllocatorContext &ctx) { + lock_flag_ = ctx.lock_flag_; +} + +void Allocator::Lock() { + if (lock_flag_) { + pthread_mutex_lock(&lock_); + } +} + +void Allocator::UnLock() { + if (lock_flag_) { + pthread_mutex_unlock(&lock_); + } +} + +void *Allocator::Malloc(size_t size) { + if (size > kMaxMallocSize) { + LITE_ERROR_LOG("MallocData out of max_size, size: %zd", size); + return nullptr; + } + void *result = nullptr; + int index = SizeToIndex(size); + if (index < 0) { + MemNode *node = (MemNode *)malloc(sizeof(MemNode) + size); + if (node == nullptr) { + LITE_ERROR_LOG("MallocData out of max_size, size: %zd", (size + sizeof(MemNode))); + return result; + } + node->size_ = size; + result = (char *)node + sizeof(MemNode); + Lock(); + PushMemNode(&large_mem_list_, node); + UnLock(); + return result; + } + Lock(); + size_t size_apply = (kBlockSize << index); + if (free_list_[index] != nullptr) { + MemNode *free_node = free_list_[index]; + PopMemNode(&free_list_[index]); + PushMemNode(&allocated_list_[index], free_node); + result = (char *)free_node + sizeof(MemNode); + UnLock(); + return result; + } else { + MemNode *new_node = (MemNode *)malloc(sizeof(MemNode) + size_apply); + if (new_node == nullptr) { + UnLock(); + LITE_LOG_ERROR("malloc MemNode fail!"); + return nullptr; + } + new_node->size_ = size; + PushMemNode(&allocated_list_[index], new_node); + result = (char *)new_node + sizeof(MemNode); + UnLock(); + return result; + } +} + +void Allocator::Free(void *buf) { + if (buf == nullptr) { + return; + } + + MemNode *node = (MemNode *)((char *)buf - sizeof(MemNode)); + size_t buf_size = node->size_; + Lock(); + if (buf_size > kBlockLimit) { + RemoveMemNode(&large_mem_list_, node); + free(node); + } else { + int index = SizeToIndex(buf_size); + RemoveMemNode(&allocated_list_[index], node); + PushMemNode(&free_list_[index], node); + } + UnLock(); +} + +size_t Allocator::GetTotalSize() { + Lock(); + size_t total_size = 0; + for (int i = 0; i < kBlockRange; ++i) { + MemNode *node = allocated_list_[i]; + while (node != nullptr) { + total_size += node->size_; + node = node->next_; + } + + node = free_list_[i]; + while (node != nullptr) { + total_size += node->size_; + node = node->next_; + } + } + MemNode *node = large_mem_list_; + while (node != nullptr) { + total_size += node->size_; + node = node->next_; + } + UnLock(); + return total_size; +} + +void Allocator::Clear() { + Lock(); + for (int i = 0; i < kBlockRange; ++i) { + FreeNodesList(allocated_list_[i]); + allocated_list_[i] = nullptr; + + FreeNodesList(free_list_[i]); + free_list_[i] = nullptr; + } + FreeNodesList(large_mem_list_); + UnLock(); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/internal/src/allocator.h b/mindspore/lite/internal/src/allocator.h new file mode 100644 index 0000000000..c4aa53dd86 --- /dev/null +++ b/mindspore/lite/internal/src/allocator.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INTERNAL_SRC_ALLOCATOR_H_ +#define MINDSPORE_LITE_INTERNAL_SRC_ALLOCATOR_H_ + +#include +#include +#include "internal/include/string.h" + +namespace mindspore::lite { +constexpr int kBlockRange = 9; + +typedef struct AllocatorContext { + bool lock_flag_; +} AllocatorContext; + +typedef struct MemNode { + MemNode *pre_ = nullptr; + MemNode *next_ = nullptr; + size_t size_ = 0; +} MemNode; + + +class Allocator { + public: + Allocator(); + ~Allocator(); + void SetContext(const AllocatorContext &ctx); + void *Malloc(size_t size); + void Free(void *ptr); + void Clear(); + size_t GetTotalSize(); + + private: + void Lock(); + void UnLock(); + + bool lock_flag_ = false; + pthread_mutex_t lock_ = PTHREAD_MUTEX_INITIALIZER; + MemNode *large_mem_list_ = nullptr; + MemNode *allocated_list_[kBlockRange]; + MemNode *free_list_[kBlockRange]; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_INTERNAL_SRC_ALLOCATOR_H_ diff --git a/mindspore/lite/internal/src/common/vector.cc b/mindspore/lite/internal/src/common/vector.cc index 1196d04d93..5e9839145d 100644 --- a/mindspore/lite/internal/src/common/vector.cc +++ b/mindspore/lite/internal/src/common/vector.cc @@ -52,9 +52,25 @@ Vector::Vector(const Vector &vec) { memcpy(data_, vec.data_, size_ * elem_size_); } +template +Vector &Vector::operator=(const Vector &vec) { + if (this == &vec) { + return *this; + } + size_ = vec.size_; + elem_size_ = sizeof(T); + capacity_ = vec.capacity_; + data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); + if (data_ == nullptr) { + MS_C_EXCEPTION("malloc data failed"); + } + memcpy(data_, vec.data_, size_ * elem_size_); + return *this; +} + template Vector::~Vector() { - if (data_) { + if (data_ != nullptr) { free(data_); } } @@ -62,7 +78,7 @@ Vector::~Vector() { template void Vector::clear() { size_ = 0; - if (data_) { + if (data_ != nullptr) { free(data_); data_ = nullptr; } @@ -83,6 +99,21 @@ void Vector::push_back(const T &elem) { ++size_; } +template +void Vector::push_back(T &&elem) { + if (data_ == nullptr) { + data_ = reinterpret_cast(malloc(capacity_ * elem_size_)); + if (data_ == nullptr) { + MS_C_EXCEPTION("malloc data failed"); + } + } else if (size_ == capacity_) { + resize(size_ + 1); + --size_; + } + memcpy(data_ + size_, &elem, elem_size_); + ++size_; +} + template void Vector::pop_back() { if (size_ > 0) { diff --git a/mindspore/lite/internal/src/kernel/fp32/activation.h b/mindspore/lite/internal/src/kernel/fp32/activation.h index dbb5c1b79e..5918d00ffd 100644 --- a/mindspore/lite/internal/src/kernel/fp32/activation.h +++ b/mindspore/lite/internal/src/kernel/fp32/activation.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_ #include "internal/include/model.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" int DoActivationInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); int DoActivation(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, Node *node, diff --git a/mindspore/lite/internal/src/kernel/fp32/arithmetic.cc b/mindspore/lite/internal/src/kernel/fp32/arithmetic.cc index 119ea96bbb..af0d1c9953 100644 --- a/mindspore/lite/internal/src/kernel/fp32/arithmetic.cc +++ b/mindspore/lite/internal/src/kernel/fp32/arithmetic.cc @@ -19,7 +19,6 @@ #include "internal/include/model.h" #include "internal/include/ms_tensor.h" #include "internal/include/lite_utils.h" -#include "src/runtime/allocator.h" #include "nnacl/arithmetic_common.h" #include "nnacl/fp32/arithmetic.h" @@ -47,14 +46,14 @@ int BroadcastRun(float *input0, float *input1, float *output, int dim, int out_c int CalBroadCasting(const TensorPtrVector &in_tensors, int *outside, int *break_pos, ArithmeticParameter *params) { params->broadcasting_ = false; - for (int i = 0; i < params->ndim_; i++) { + for (size_t i = 0; i < params->ndim_; ++i) { if (params->in_shape0_[i] != params->in_shape1_[i]) { if (params->in_shape0_[i] == 1) { params->out_shape_[i] = params->in_shape1_[i]; } else if (params->in_shape1_[i] == 1) { params->out_shape_[i] = params->in_shape0_[i]; } else { - LITE_ERROR_LOG("shapes of input tensors can not be broadCasted"); + LITE_LOG_ERROR("shapes of input tensors can not be broadCasted"); return RET_INPUT_TENSOR_ERROR; } params->broadcasting_ = true; @@ -100,11 +99,11 @@ int RunArithmetic(const TensorPtrVector &in_tensors, const TensorPtrVector &out_ int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param) { if (in_tensors.size() != 2 || in_tensors[0]->data_ == NULL || in_tensors[1]->data_ == NULL) { - LITE_ERROR_LOG("input tensors num not correct or input data is NULL!") + LITE_LOG_ERROR("input tensors num not correct or input data is NULL!"); return RET_INPUT_TENSOR_ERROR; } if (out_tensors.size() != 1) { - LITE_ERROR_LOG("output tensors num not correct!") + LITE_LOG_ERROR("output tensors num not correct!"); return RET_ERROR; } ShapeVector in_shape0 = in_tensors[0]->shape_; @@ -116,7 +115,7 @@ int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVec arithmeticParameter->ndim_ = ndim1; int fill_dim_num = ndim1 - ndim0; int j = 0; - for (size_t i = 0; i < ndim1; i++) { + for (int i = 0; i < ndim1; ++i) { if (i < fill_dim_num) { arithmeticParameter->in_shape0_[i] = 1; } else { @@ -128,7 +127,7 @@ int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVec arithmeticParameter->ndim_ = ndim0; int fill_dim_num = ndim0 - ndim1; int j = 0; - for (size_t i = 0; i < ndim0; i++) { + for (int i = 0; i < ndim0; ++i) { if (i < fill_dim_num) { arithmeticParameter->in_shape1_[i] = 1; } else { @@ -138,20 +137,20 @@ int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVec } } else { arithmeticParameter->ndim_ = ndim0; - for (size_t i = 0; i < ndim0; i++) { + for (int i = 0; i < ndim0; ++i) { arithmeticParameter->in_shape0_[i] = in_shape0[i]; arithmeticParameter->in_shape1_[i] = in_shape1[i]; } } ShapeVector out_shape; - for (int i = 0; i < arithmeticParameter->ndim_; i++) { + for (size_t i = 0; i < arithmeticParameter->ndim_; ++i) { if (arithmeticParameter->in_shape0_[i] != arithmeticParameter->in_shape1_[i]) { if (arithmeticParameter->in_shape0_[i] == 1) { out_shape.push_back(arithmeticParameter->in_shape1_[i]); } else if (arithmeticParameter->in_shape1_[i] == 1) { out_shape.push_back(arithmeticParameter->in_shape0_[i]); } else { - LITE_ERROR_LOG("shapes of input tensors can not be broadcasted!") + LITE_LOG_ERROR("shapes of input tensors can not be broadcasted!"); return RET_INPUT_TENSOR_ERROR; } } else { @@ -165,7 +164,7 @@ int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVec } int ChooseKernel(const int kernel_type, ArithmeticRun *arithmetic_run, ArithmeticParameter *params) { - if (kernel_type == KernelType::Mul) { + if (kernel_type == KernelType::KernelType_Mul) { if (params->activation_type_ == ActivationType::RELU) { *arithmetic_run = ElementMulRelu; } else if (params->activation_type_ == ActivationType::RELU6) { @@ -174,14 +173,14 @@ int ChooseKernel(const int kernel_type, ArithmeticRun *arithmetic_run, Arithmeti *arithmetic_run = ElementMul; } } else { - LITE_ERROR_LOG("unsupported operator type"); + LITE_LOG_INFO("unsupported operator type"); return RET_ERROR; } return RET_OK; } int ChooseOptKernel(const int kernel_type, ArithmeticOptRun *arithmetic_opt_run, ArithmeticParameter *params) { - if (kernel_type == KernelType::Mul) { + if (kernel_type == KernelType::KernelType_Mul) { if (params->activation_type_ == ActivationType::RELU) { *arithmetic_opt_run = ElementOptMulRelu; } else if (params->activation_type_ == ActivationType::RELU6) { @@ -190,7 +189,7 @@ int ChooseOptKernel(const int kernel_type, ArithmeticOptRun *arithmetic_opt_run, *arithmetic_opt_run = ElementOptMul; } } else { - LITE_INFO_LOG("kernel not have opt version"); + LITE_LOG_INFO("kernel not have opt version"); } return RET_OK; } @@ -198,15 +197,15 @@ int ChooseOptKernel(const int kernel_type, ArithmeticOptRun *arithmetic_opt_run, int DoArithmetic(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, Node *node, mindspore::lite::Allocator *allocator) { if (in_tensors.size() != 2 || in_tensors[0]->data_ == NULL || in_tensors[1]->data_ == NULL) { - LITE_ERROR_LOG("input tensors num not correct or input data is NULL!") + LITE_LOG_ERROR("input tensors num not correct or input data is NULL!"); return RET_INPUT_TENSOR_ERROR; } if (out_tensors.size() != 1 || out_tensors[0]->data_ == NULL) { - LITE_ERROR_LOG("output tensors num not correct or output data is NULL!") + LITE_LOG_ERROR("output tensors num not correct or output data is NULL!"); return RET_ERROR; } if (allocator == NULL) { - LITE_ERROR_LOG("allocator is NULL!") + LITE_LOG_ERROR("allocator is NULL!"); return RET_ERROR; } ArithmeticParameter *params = reinterpret_cast(node->primitive_); diff --git a/mindspore/lite/internal/src/kernel/fp32/arithmetic.h b/mindspore/lite/internal/src/kernel/fp32/arithmetic.h index 42cad0df65..e178be4297 100644 --- a/mindspore/lite/internal/src/kernel/fp32/arithmetic.h +++ b/mindspore/lite/internal/src/kernel/fp32/arithmetic.h @@ -18,7 +18,7 @@ #include "internal/include/model.h" #include "internal/include/lite_utils.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" #include "nnacl/arithmetic_common.h" int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); diff --git a/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc index 601ee3b406..8f01195ac0 100644 --- a/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc +++ b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.cc @@ -31,9 +31,9 @@ int DoArithmeticSelf(const TensorPtrVector &in_tensors, const TensorPtrVector &o size_t data_size = in_tensors[0]->ElementsNum(); OpParameter *param = node->primitive_; int ret; - if (param->type_ == KernelType::Log) { + if (param->type_ == KernelType::KernelType_Log) { ret = ElementLog((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); - } else if (param->type_ == KernelType::Neg) { + } else if (param->type_ == KernelType::KernelType_Neg) { ret = ElementNegative((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size); } else { LITE_ERROR_LOG("Unsupport kernel type: %d", param->type_); diff --git a/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h index 37b23c81fc..08b1d3a78c 100644 --- a/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h +++ b/mindspore/lite/internal/src/kernel/fp32/arithmetic_self.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_ #include "internal/include/model.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" int DoArithmeticSelfInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); diff --git a/mindspore/lite/internal/src/kernel/fp32/bias_add.cc b/mindspore/lite/internal/src/kernel/fp32/bias_add.cc index 2cd63f40f3..e8d507533b 100644 --- a/mindspore/lite/internal/src/kernel/fp32/bias_add.cc +++ b/mindspore/lite/internal/src/kernel/fp32/bias_add.cc @@ -17,7 +17,6 @@ #include "internal/include/model.h" #include "internal/include/ms_tensor.h" #include "internal/include/lite_utils.h" -#include "src/runtime/allocator.h" #include "internal/src/lite_log.h" #include "internal/include/errorcode.h" #include "nnacl/arithmetic_common.h" @@ -25,11 +24,11 @@ int DoBiasAddInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param) { if (in_tensors.size() != 2 || in_tensors[0]->data_ == NULL || in_tensors[1]->data_ == NULL) { - LITE_ERROR_LOG("input tensors num not correct or input data is NULL!") + LITE_LOG_ERROR("input tensors num not correct or input data is NULL!"); return RET_INPUT_TENSOR_ERROR; } if (out_tensors.size() != 1) { - LITE_ERROR_LOG("output tensors num not correct!") + LITE_LOG_ERROR("output tensors num not correct!"); return RET_ERROR; } out_tensors[0]->shape_ = in_tensors[0]->shape_; @@ -41,15 +40,15 @@ int DoBiasAddInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector int DoBiasAdd(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, Node *node, mindspore::lite::Allocator *allocator) { if (in_tensors.size() != 2 || in_tensors[0]->data_ == NULL || in_tensors[1]->data_ == NULL) { - LITE_ERROR_LOG("input tensors num not correct or input data is NULL!") + LITE_LOG_ERROR("input tensors num not correct or input data is NULL!"); return RET_INPUT_TENSOR_ERROR; } if (out_tensors.size() != 1 || out_tensors[0]->data_ == NULL) { - LITE_ERROR_LOG("output tensors num not correct or output data is NULL!") + LITE_LOG_ERROR("output tensors num not correct or output data is NULL!"); return RET_ERROR; } if (allocator == NULL) { - LITE_ERROR_LOG("allocator is NULL!") + LITE_LOG_ERROR("allocator is NULL!"); return RET_ERROR; } ArithmeticParameter *params = reinterpret_cast(node->primitive_); @@ -70,7 +69,7 @@ int DoBiasAdd(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tens float *tile_in = reinterpret_cast(allocator->Malloc(data_size * sizeof(float))); float *tile_bias = reinterpret_cast(allocator->Malloc(data_size * sizeof(float))); if (tile_in == NULL || tile_bias == NULL) { - LITE_ERROR_LOG("Memory allocation failed!") + LITE_LOG_ERROR("Memory allocation failed!"); allocator->Free(tile_in); allocator->Free(tile_bias); return RET_ERROR; diff --git a/mindspore/lite/internal/src/kernel/fp32/bias_add.h b/mindspore/lite/internal/src/kernel/fp32/bias_add.h index c368b2c388..0eabc970c6 100644 --- a/mindspore/lite/internal/src/kernel/fp32/bias_add.h +++ b/mindspore/lite/internal/src/kernel/fp32/bias_add.h @@ -18,7 +18,7 @@ #include "internal/include/model.h" #include "internal/include/lite_utils.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" int DoBiasAddInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); diff --git a/mindspore/lite/internal/src/kernel/fp32/matmul.cc b/mindspore/lite/internal/src/kernel/fp32/matmul.cc index 1f4694ac91..9342fbe3b9 100644 --- a/mindspore/lite/internal/src/kernel/fp32/matmul.cc +++ b/mindspore/lite/internal/src/kernel/fp32/matmul.cc @@ -71,14 +71,7 @@ void FreeMatMulKernelData(MatMulCPUKernelData *kernel_data, mindspore::lite::All free(kernel_data); } -static void SwapDims(Int32Vector *dims, int index1, int index2) { - int tmp = dims->at(index1); - dims->at(index1) = dims->at(index2); - dims->at(index2) = tmp; -} - int DoMatMulInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param) { - MS_ASSERT(this->primitive_ != nullptr); TensorPtr input0 = in_tensors.at(0); MS_ASSERT(input0 != nullptr); TensorPtr input1 = in_tensors.at(1); @@ -86,31 +79,20 @@ int DoMatMulInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector TensorPtr output = out_tensors.at(0); MS_ASSERT(output != nullptr); - output->data_type_ = input0->data_type_; - output->format_ = input0->format_; - - Int32Vector a_shape = input0->shape_; - Int32Vector b_shape = input1->shape_; - if (a_shape.size() < 2 || b_shape.size() < 2) { - LITE_ERROR_LOG("inputs shape is invalid"); - return RET_INPUT_TENSOR_ERROR; - } - for (size_t i = 0; i < a_shape.size() - 2; ++i) { - if (a_shape[i] != b_shape[i]) { - LITE_ERROR_LOG("Op MatMul's dimensions must be equal"); - return RET_INPUT_TENSOR_ERROR; - } - } - - MatMulParameter *matmul_param = (MatMulParameter *)param; - if (matmul_param->a_transpose_) { - SwapDims(&a_shape, a_shape.size() - 1, a_shape.size() - 2); - } - if (matmul_param->b_transpose_) { - SwapDims(&b_shape, b_shape.size() - 1, b_shape.size() - 2); - } - output->shape_ = a_shape; - output->shape_.at(a_shape.size() - 1) = b_shape.at(b_shape.size() - 1); + int in_datatype[2] = {input0->data_type_, input1->data_type_}; + int in_format[2] = {static_cast(input0->format_), static_cast(input1->format_)}; + size_t dim_size[2] = {input0->shape_.size(), input1->shape_.size()}; + int *in_shape[2] = {input0->shape_.data(), input1->shape_.data()}; + int out_format; + int out_datatype; + int ret = MatMulInferShape(in_shape, 2, dim_size, output->shape_.data(), in_format, &out_format, in_datatype, + &out_datatype, param); + if (ret != NNACL_OK) { + LITE_ERROR_LOG("matmul infershape fail!ret: %d", ret); + return RET_ERROR; + } + output->format_ = static_cast(out_format); + output->data_type_ = static_cast(out_datatype); return RET_OK; } @@ -149,7 +131,7 @@ int DoMatMul(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tenso MatMulCPUKernelData *kernel_data = (MatMulCPUKernelData *)malloc(sizeof(MatMulCPUKernelData)); if (kernel_data == NULL) { - LITE_ERROR_LOG("Malloc MatMulCPUKernelData failed"); + LITE_LOG_ERROR("Malloc MatMulCPUKernelData failed"); return RET_MEMORY_FAILED; } kernel_data->a_c12_ptr_ diff --git a/mindspore/lite/internal/src/kernel/fp32/matmul.h b/mindspore/lite/internal/src/kernel/fp32/matmul.h index a0b98dacc7..5fbca43c16 100644 --- a/mindspore/lite/internal/src/kernel/fp32/matmul.h +++ b/mindspore/lite/internal/src/kernel/fp32/matmul.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_ #include "internal/include/model.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" int DoMatMulInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); int DoMatMul(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, Node *node, diff --git a/mindspore/lite/internal/src/kernel/fp32/reduce.cc b/mindspore/lite/internal/src/kernel/fp32/reduce.cc index 235b611993..6a37674a4e 100644 --- a/mindspore/lite/internal/src/kernel/fp32/reduce.cc +++ b/mindspore/lite/internal/src/kernel/fp32/reduce.cc @@ -15,10 +15,8 @@ */ #include "internal/src/kernel/fp32/reduce.h" -#include #include "internal/include/model.h" #include "internal/include/lite_utils.h" -#include "src/runtime/allocator.h" #include "internal/src/lite_log.h" #include "internal/include/errorcode.h" #include "nnacl/reduce_parameter.h" @@ -27,16 +25,8 @@ typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data, const int tid, const int thread_num); -int MallocTmpBuffer(std::vector *data_buffers, const ShapeVector &shape, const int *axes, const int num_axes, +int MallocTmpBuffer(float *data_buffers[], const ShapeVector &shape, const int *axes, const int num_axes, mindspore::lite::Allocator *allocator) { - for (int i = 0; i < data_buffers->size(); ++i) { - if (data_buffers->at(i) != NULL) { - free(data_buffers->at(i)); - data_buffers->at(i) = NULL; - } - } - data_buffers->clear(); - ShapeVector input_shape = shape; const int rank = input_shape.size(); for (auto i = 0; i < num_axes - 1; i++) { @@ -48,39 +38,39 @@ int MallocTmpBuffer(std::vector *data_buffers, const ShapeVector &shape } } float *buffer = reinterpret_cast(allocator->Malloc(size * sizeof(float))); - if (buffer == NULL) { - LITE_ERROR_LOG("Memory allocation failed!") + if (buffer == nullptr) { + LITE_LOG_ERROR("Memory allocation failed!"); return RET_ERROR; } - data_buffers->emplace_back(buffer); + data_buffers[i] = buffer; input_shape[axis] = 1; } return RET_OK; } -void FreeTmpBuffer(std::vector *data_buffers, mindspore::lite::Allocator *allocator) { +void FreeTmpBuffer(float *data_buffers[], int size, mindspore::lite::Allocator *allocator) { if (data_buffers == nullptr) { return; } - for (int i = 0; i < data_buffers->size(); ++i) { - allocator->Free(data_buffers->at(i)); + for (int i = 0; i < size; ++i) { + allocator->Free(data_buffers[i]); + data_buffers[i] = nullptr; } - data_buffers->clear(); } -int RunReduce(Reducer reducer, std::vector data_buffers, float *in_data, float *out_data, Int32Vector axes, +int RunReduce(Reducer reducer, float *data_buffers[], float *in_data, float *out_data, ReduceParameter *params, ShapeVector shape) { int rank = shape.size(); float *dst_data = NULL; float *src_data = in_data; ShapeVector tmp_shape = shape; - for (size_t i = 0; i < axes.size(); ++i) { - if (i != axes.size() - 1) { + for (int i = 0; i < params->num_axes_; ++i) { + if (i != params->num_axes_ - 1) { dst_data = data_buffers[i]; } else { dst_data = out_data; } - int axis = axes[i]; + int axis = params->axes_[i]; int outer_size = 1; for (int j = 0; j < axis; j++) { outer_size *= tmp_shape[j]; @@ -92,7 +82,7 @@ int RunReduce(Reducer reducer, std::vector data_buffers, float *in_data int axis_size = tmp_shape[axis]; int error_code = reducer(outer_size, inner_size, axis_size, src_data, dst_data, 0, 1); if (error_code != RET_OK) { - LITE_ERROR_LOG("Reduce run error!") + LITE_LOG_ERROR("Reduce run error!"); return RET_ERROR; } tmp_shape[axis] = 1; @@ -103,11 +93,11 @@ int RunReduce(Reducer reducer, std::vector data_buffers, float *in_data int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param) { if (in_tensors.size() != 1 || in_tensors[0]->data_ == NULL) { - LITE_ERROR_LOG("input tensors num not correct or input data is NULL!") + LITE_LOG_ERROR("input tensors num not correct or input data is NULL!"); return RET_INPUT_TENSOR_ERROR; } if (out_tensors.size() != 1) { - LITE_ERROR_LOG("output tensors num not correct!") + LITE_LOG_ERROR("output tensors num not correct!"); return RET_ERROR; } @@ -121,7 +111,7 @@ int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector int actual_axes_num = num_axes; for (int i = 0; i < num_axes; ++i) { if (reduceParameter->axes_[i] < -rank || reduceParameter->axes_[i] >= rank) { - LITE_ERROR_LOG("reduce_sum got invalid axis!") + LITE_LOG_ERROR("reduce_sum got invalid axis!"); return RET_ERROR; } if (reduceParameter->axes_[i] < 0) { @@ -132,7 +122,7 @@ int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector } if (reduceParameter->reduce_to_end_) { if (num_axes != 1) { - LITE_ERROR_LOG("Reduce when reduce_to_end, num of axis should be 1!") + LITE_LOG_ERROR("Reduce when reduce_to_end, num of axis should be 1!"); return RET_ERROR; } int begin_axis = axes[0]; @@ -144,14 +134,14 @@ int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector if (num_axes == 0) { axes.resize(rank); - for (size_t i = 0; i < rank; i++) { + for (auto i = 0; i < rank; ++i) { axes[i] = i; if (keep_dims) { out_shape.push_back(1); } } reduceParameter->num_axes_ = axes.size(); - for (int i = 0; i < axes.size(); ++i) { + for (size_t i = 0; i < axes.size(); ++i) { reduceParameter->axes_[i] = axes[i]; } out_tensors[0]->shape_ = out_shape; @@ -160,9 +150,9 @@ int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector return RET_OK; } // reduce on selected axes - for (size_t i = 0; i < rank; i++) { + for (auto i = 0; i < rank; ++i) { bool reduce_axis = false; - for (size_t idx = 0; idx < num_axes; ++idx) { + for (auto idx = 0; idx < num_axes; ++idx) { if (axes[idx] == i) { reduce_axis = true; break; @@ -177,7 +167,7 @@ int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector } } reduceParameter->num_axes_ = axes.size(); - for (int i = 0; i < axes.size(); ++i) { + for (size_t i = 0; i < axes.size(); ++i) { reduceParameter->axes_[i] = axes[i]; } out_tensors[0]->shape_ = out_shape; @@ -189,15 +179,15 @@ int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector int DoReduce(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, Node *node, mindspore::lite::Allocator *allocator) { if (in_tensors.size() != 1 || in_tensors[0]->data_ == NULL) { - LITE_ERROR_LOG("input tensors num not correct or input data is NULL!") + LITE_LOG_ERROR("input tensors num not correct or input data is NULL!"); return RET_INPUT_TENSOR_ERROR; } if (out_tensors.size() != 1 || out_tensors[0]->data_ == NULL) { - LITE_ERROR_LOG("output tensors num not correct or output data is NULL!") + LITE_LOG_ERROR("output tensors num not correct or output data is NULL!"); return RET_ERROR; } if (allocator == NULL) { - LITE_ERROR_LOG("allocator is NULL!") + LITE_LOG_ERROR("allocator is NULL!"); return RET_ERROR; } @@ -209,21 +199,18 @@ int DoReduce(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tenso reducer = ReduceMean; } - std::vector data_buffers; - int status = MallocTmpBuffer(&data_buffers, in_tensors[0]->shape_, params->axes_, params->num_axes_, allocator); + int buf_num = params->num_axes_ - 1; + float *data_buffers[buf_num]; + int status = MallocTmpBuffer(data_buffers, in_tensors[0]->shape_, params->axes_, params->num_axes_, allocator); if (status != RET_OK) { - FreeTmpBuffer(&data_buffers, allocator); + FreeTmpBuffer(data_buffers, buf_num, allocator); return status; } - Int32Vector axes; - for (int i = 0; i < params->num_axes_; ++i) { - axes.push_back(params->axes_[i]); - } status = RunReduce(reducer, data_buffers, reinterpret_cast(in_tensors[0]->data_), - reinterpret_cast(out_tensors[0]->data_), axes, in_tensors[0]->shape_); + reinterpret_cast(out_tensors[0]->data_), params, in_tensors[0]->shape_); - FreeTmpBuffer(&data_buffers, allocator); + FreeTmpBuffer(data_buffers, buf_num, allocator); if (status != RET_OK) { return RET_ERROR; diff --git a/mindspore/lite/internal/src/kernel/fp32/reduce.h b/mindspore/lite/internal/src/kernel/fp32/reduce.h index 2372b9fb7e..cadd555d20 100644 --- a/mindspore/lite/internal/src/kernel/fp32/reduce.h +++ b/mindspore/lite/internal/src/kernel/fp32/reduce.h @@ -19,7 +19,7 @@ #include "internal/include/model.h" #include "internal/include/ms_tensor.h" #include "internal/include/lite_utils.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" int DoReduceInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h b/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h index 8e3eddc3b3..9b60374428 100644 --- a/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h +++ b/mindspore/lite/internal/src/kernel/fp32_grad/activation_grad.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_ #include "internal/include/model.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" int DoActivationGradInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc index 7e44dd33a0..07da37ab54 100644 --- a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc +++ b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.cc @@ -35,9 +35,9 @@ int DoArithmeticSelfGrad(const TensorPtrVector &in_tensors, const TensorPtrVecto float *x_data = reinterpret_cast(in_tensors[1]->data_); float *dx_data = reinterpret_cast(out_tensors[0]->data_); int ret; - if (param->type_ == KernelType::LogGrad) { + if (param->type_ == KernelType::KernelType_LogGrad) { ret = ElementDiv(dy_data, x_data, dx_data, data_size); - } else if (param->type_ == KernelType::NegGrad) { + } else if (param->type_ == KernelType::KernelType_NegGrad) { ret = ElementNegative(dy_data, dx_data, data_size); } else { LITE_ERROR_LOG("Unsupport kernel type: %d", param->type_); diff --git a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h index 8f1d06aae2..38070b62fb 100644 --- a/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h +++ b/mindspore/lite/internal/src/kernel/fp32_grad/arithmetic_self_grad.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_ #include "internal/include/model.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" int DoArithmeticSelfGradInferShape(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); diff --git a/mindspore/lite/internal/src/lite_log.h b/mindspore/lite/internal/src/lite_log.h index d8fcdae594..ae42a87fe5 100644 --- a/mindspore/lite/internal/src/lite_log.h +++ b/mindspore/lite/internal/src/lite_log.h @@ -18,15 +18,18 @@ #define MINDSPORE_LITE_INTERNAL_SRC_LITE_LOG_H_ #include -#ifdef DEBUG +#include +#ifndef Release #include #endif -#ifdef DEBUG +#ifndef Release #define LITE_DEBUG_LOG(format, ...) \ printf("[DEBUG] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) #define LITE_INFO_LOG(format, ...) \ printf("[INFO] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) +#define LITE_LOG_INFO(...) \ + printf("[INFO] [%s %s] [%s] [%d] %s\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) #define LITE_WARNING_LOG(format, ...) \ printf("[WARNING] [%s %s] [%s] [%d] " format "\n", __DATE__, __TIME__, __FILE__, __LINE__, __VA_ARGS__) #define LITE_ERROR_LOG(format, ...) \ @@ -40,6 +43,7 @@ #else #define LITE_DEBUG_LOG(...) #define LITE_INFO_LOG(...) +#define LITE_LOG_INFO(...) #define LITE_WARNING_LOG(...) #define LITE_ERROR_LOG(...) #define LITE_LOG_ERROR(...) diff --git a/mindspore/lite/internal/src/lite_session.cc b/mindspore/lite/internal/src/lite_session.cc index 87c56dde0e..4a3196197e 100644 --- a/mindspore/lite/internal/src/lite_session.cc +++ b/mindspore/lite/internal/src/lite_session.cc @@ -16,26 +16,30 @@ #include "internal/include/lite_session.h" #include "internal/include/model.h" #include "internal/include/ms_tensor.h" -#include "src/runtime/allocator.h" +#include "internal/src/allocator.h" #include "internal/include/errorcode.h" #include "internal/src/lite_log.h" #include "internal/src/kernel/fp32/activation.h" #include "internal/src/kernel/fp32/arithmetic_self.h" #include "internal/src/kernel/fp32/matmul.h" +#include "internal/src/kernel/fp32/arithmetic.h" +#include "internal/src/kernel/fp32/bias_add.h" +#ifdef SUPPORT_TRAIN #include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h" #include "internal/src/kernel/fp32_grad/activation_grad.h" +#endif static Context *g_ctx; static Model *g_model; static LiteSession g_session; -static mindspore::lite::DefaultAllocator g_allocator; +static mindspore::lite::Allocator g_allocator; static bool g_infershape_interrupt = false; static bool g_first_load = true; typedef int (*InferShape)(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, OpParameter *param); typedef int (*RunKernel)(const TensorPtrVector &in_tensors, const TensorPtrVector &out_tensors, Node *node, mindspore::lite::Allocator *allocator); -static InferShape g_infershape_funcs[KernelType::END]; -static RunKernel g_runkernel_funcs[KernelType::END]; +static InferShape g_infershape_funcs[KernelType::KernelType_END]; +static RunKernel g_runkernel_funcs[KernelType::KernelType_END]; static int ModelInferShape() { NodePtrVector nodes = g_model->nodes_; @@ -43,7 +47,7 @@ static int ModelInferShape() { for (size_t i = 0; i < nodes_size; ++i) { auto node = nodes[i]; if (node->primitive_ == NULL) { - LITE_ERROR_LOG("node's primitive is NULL!"); + LITE_LOG_ERROR("node's primitive is NULL!"); return RET_ERROR; } TensorPtrVector in_tensors; @@ -75,22 +79,27 @@ static int ModelInferShape() { static void InitFuncs() { if (g_first_load) { - g_infershape_funcs[KernelType::MatMul] = DoMatMulInferShape; - g_infershape_funcs[KernelType::Activation] = DoActivationInferShape; - g_infershape_funcs[KernelType::Log] = DoArithmeticSelfInferShape; - g_infershape_funcs[KernelType::Neg] = DoArithmeticSelfInferShape; - - g_runkernel_funcs[KernelType::MatMul] = DoMatMul; - g_runkernel_funcs[KernelType::Activation] = DoActivation; - g_runkernel_funcs[KernelType::Log] = DoArithmeticSelf; - g_runkernel_funcs[KernelType::Neg] = DoArithmeticSelf; - + g_infershape_funcs[KernelType::KernelType_MatMul] = DoMatMulInferShape; + g_infershape_funcs[KernelType::KernelType_Activation] = DoActivationInferShape; + g_infershape_funcs[KernelType::KernelType_Log] = DoArithmeticSelfInferShape; + g_infershape_funcs[KernelType::KernelType_Neg] = DoArithmeticSelfInferShape; + g_infershape_funcs[KernelType::KernelType_Mul] = DoArithmeticInferShape; + g_infershape_funcs[KernelType::KernelType_BiasAdd] = DoBiasAddInferShape; + + g_runkernel_funcs[KernelType::KernelType_MatMul] = DoMatMul; + g_runkernel_funcs[KernelType::KernelType_Activation] = DoActivation; + g_runkernel_funcs[KernelType::KernelType_Log] = DoArithmeticSelf; + g_runkernel_funcs[KernelType::KernelType_Neg] = DoArithmeticSelf; + g_runkernel_funcs[KernelType::KernelType_Mul] = DoArithmetic; + g_runkernel_funcs[KernelType::KernelType_BiasAdd] = DoBiasAdd; #ifdef SUPPORT_TRAIN - g_infershape_funcs[KernelType::ActivationGrad] = DoActivationGradInferShape; + g_infershape_funcs[KernelType::KernelType_ActivationGrad] = DoActivationGradInferShape; + g_infershape_funcs[KernelType::KernelType_NegGrad] = DoArithmeticSelfGradInferShape; + g_infershape_funcs[KernelType::KernelType_LogGrad] = DoArithmeticSelfGradInferShape; - g_runkernel_funcs[KernelType::NegGrad] = DoArithmeticSelfGrad; - g_runkernel_funcs[KernelType::ActivationGrad] = DoActivationGrad; - g_runkernel_funcs[KernelType::LogGrad] = DoArithmeticSelfGrad; + g_runkernel_funcs[KernelType::KernelType_NegGrad] = DoArithmeticSelfGrad; + g_runkernel_funcs[KernelType::KernelType_ActivationGrad] = DoActivationGrad; + g_runkernel_funcs[KernelType::KernelType_LogGrad] = DoArithmeticSelfGrad; #endif g_first_load = false; } @@ -155,7 +164,7 @@ int LiteSession::RunGraph() { for (size_t i = 0; i < nodes_size; ++i) { auto node = nodes[i]; if (node->primitive_ == nullptr) { - LITE_ERROR_LOG("node's primitive is NULL!"); + LITE_LOG_ERROR("node's primitive is NULL!"); return RET_ERROR; } TensorPtrVector in_tensors; @@ -182,7 +191,7 @@ int LiteSession::RunGraph() { for (size_t j = 0; j < out_tensors.size(); ++j) { out_tensors[j]->data_ = g_allocator.Malloc(out_tensors[j]->Size()); if (out_tensors[j]->data_ == NULL) { - LITE_ERROR_LOG("Malloc data for out tensor fail!"); + LITE_LOG_ERROR("Malloc data for out tensor fail!"); return RET_NULL_PTR; } } @@ -194,7 +203,7 @@ int LiteSession::RunGraph() { int ret = (*run_kernel)(in_tensors, out_tensors, node, &g_allocator); if (ret != RET_OK) { - LITE_ERROR_LOG("run kernel fail!ret: ", ret); + LITE_ERROR_LOG("run kernel fail!ret: %d", ret); return ret; } } diff --git a/mindspore/lite/internal/src/ms_tensor.cc b/mindspore/lite/internal/src/ms_tensor.cc index 682f8a789e..2bcf31b460 100644 --- a/mindspore/lite/internal/src/ms_tensor.cc +++ b/mindspore/lite/internal/src/ms_tensor.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include "internal/include/string.h" #include "internal/include/vector.h" #include "internal/include/ms_tensor.h" @@ -85,7 +84,7 @@ size_t MSTensor::Size() const { size = sizeof(bool); break; default: - std::cout << "Not support the type: " << this->data_type_; + LITE_ERROR_LOG("Not support the type: %d", this->data_type_); return 0; } size *= (format_ == Format::Format_NC4HW4 || format_ == Format::Format_NHWC4) ? ElementsC4Num() : ElementsNum(); @@ -94,7 +93,7 @@ size_t MSTensor::Size() const { } int32_t MSTensor::Batch() const { if (this->shape_.size() != 4 && this->shape_.size() != 2) { - std::cout << "Unsupported tensor shape: " << this->shape_.size(); + LITE_ERROR_LOG("Unsupported tensor shape: %zu", this->shape_.size()); return -1; } switch (this->format_) { @@ -115,14 +114,14 @@ int32_t MSTensor::Batch() const { case Format::Format_CKHW: return this->shape_[1]; default: - // std::cout << "Unsupported format: " << EnumNameFormat(this->format_); + LITE_ERROR_LOG("Unsupported format: %d", this->format_); return -1; } } int32_t MSTensor::Channel() const { if (this->shape_.size() != 4 && this->shape_.size() != 2) { - std::cout << "Unsupported tensor shape: " << this->shape_.size(); + LITE_ERROR_LOG("Unsupported tensor shape: %zu", this->shape_.size()); return -1; } switch (this->format_) { @@ -149,7 +148,7 @@ int32_t MSTensor::Channel() const { int32_t MSTensor::Height() const { if (this->shape_.size() != 4 && this->shape_.size() != 2) { - std::cout << "Unsupported tensor shape: " << this->shape_.size(); + LITE_ERROR_LOG("Unsupported tensor shape: %zu", this->shape_.size()); return -1; } switch (this->format_) { @@ -169,14 +168,14 @@ int32_t MSTensor::Height() const { case Format::Format_HW4: return this->shape_[0]; default: - // std::cout << "Unsupported format: " << EnumNameFormat(this->format_); + LITE_ERROR_LOG("Unsupported format: %d", this->format_); return -1; } } int32_t MSTensor::Width() const { if (this->shape_.size() != 4 && this->shape_.size() != 2) { - std::cout << "Unsupported tensor shape: " << this->shape_.size(); + LITE_ERROR_LOG("Unsupported tensor shape: %zu", this->shape_.size()); return -1; } switch (this->format_) { diff --git a/mindspore/lite/nnacl/fp32/matmul.c b/mindspore/lite/nnacl/fp32/matmul.c index add3522821..b3484b4687 100644 --- a/mindspore/lite/nnacl/fp32/matmul.c +++ b/mindspore/lite/nnacl/fp32/matmul.c @@ -483,3 +483,38 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); #endif } + +#ifdef ENABLE_NNACL_INFER_SHAPE +static void SwapDims(int *dims, int index1, int index2) { + int tmp = dims[index1]; + dims[index1] = dims[index2]; + dims[index2] = tmp; +} + +int MatMulInferShape(int **in_shape, int in_num, size_t *dim_size, int *out_shape, int *in_format, + int *out_format, int *in_datatype, int *out_datatype, OpParameter *param) { + *out_datatype = in_datatype[0]; + *out_format = in_format[0]; + if (dim_size[0] < 2 || dim_size[1] < 2) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < dim_size[0] - 2; ++i) { + if (in_shape[0][i] != in_shape[1][i]) { + return NNACL_PARAM_INVALID; + } + } + MatMulParameter *matmul_param = (MatMulParameter *)param; + if (matmul_param->a_transpose_) { + SwapDims(in_shape[0], dim_size[0] - 1, dim_size[0] - 2); + } + if (matmul_param->b_transpose_) { + SwapDims(in_shape[1], dim_size[1] - 1, dim_size[1] - 2); + } + for (int i = 0; i < dim_size[0] - 1; ++i) { + out_shape[i] = in_shape[0][i]; + } + out_shape[dim_size[0] - 1] = in_shape[1][dim_size[1] - 1]; + return NNACL_OK; +} +#endif diff --git a/mindspore/lite/nnacl/fp32/matmul.h b/mindspore/lite/nnacl/fp32/matmul.h index cc002a20ae..be981a1a33 100644 --- a/mindspore/lite/nnacl/fp32/matmul.h +++ b/mindspore/lite/nnacl/fp32/matmul.h @@ -44,6 +44,11 @@ void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int de void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, size_t stride, size_t write_nhwc, size_t write_c4); #endif + +#ifdef ENABLE_NNACL_INFER_SHAPE +int MatMulInferShape(int **in_shape, int in_num, size_t *dim_size, int *out_shape, int *in_format, + int *out_format, int *in_datatype, int *out_datatype, OpParameter *param); +#endif #ifdef __cplusplus } #endif diff --git a/mindspore/lite/test/ut/internal/CMakeLists.txt b/mindspore/lite/test/ut/internal/CMakeLists.txt index 48ef01de61..d70413d56b 100644 --- a/mindspore/lite/test/ut/internal/CMakeLists.txt +++ b/mindspore/lite/test/ut/internal/CMakeLists.txt @@ -37,10 +37,12 @@ endif() ### runtime framework set(TEST_LITE_SRC + ${LITE_DIR}/internal/src/common/string.cc ${LITE_DIR}/internal/src/lite_session.cc - ${LITE_DIR}/src/runtime/allocator.cc + ${LITE_DIR}/internal/src/allocator.cc ${LITE_DIR}/internal/src/ms_tensor.cc ${LITE_DIR}/internal/src/common/string.cc + ${LITE_DIR}/internal/src/common/vector.cc ${TOP_DIR}/mindspore/core/utils/log_adapter.cc ${TOP_DIR}/mindspore/core/gvar/logging_level.cc ) @@ -65,10 +67,3 @@ set(TEST_SRC add_executable(lite-test-internal ${TEST_SRC}) target_link_libraries(lite-test-internal dl ${GTEST_LIBRARY}) -if (PLATFORM_ARM64) - target_link_libraries(lite-test-internal mslite_internal) -endif() - -if (PLATFORM_ARM32 OR PLATFORM_ARM64) - target_link_libraries(lite-test-internal log) -endif() diff --git a/mindspore/lite/test/ut/internal/allocator_test.cc b/mindspore/lite/test/ut/internal/allocator_test.cc new file mode 100644 index 0000000000..5f63808961 --- /dev/null +++ b/mindspore/lite/test/ut/internal/allocator_test.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "internal/include/model.h" +#include "internal/include/errorcode.h" +#include "nnacl/op_base.h" +#undef private +#define private public +#include "internal/src/allocator.h" +#undef private + +namespace mindspore { +class AllocatorTest : public mindspore::CommonTest { + public: + AllocatorTest() {} +}; + +TEST_F(AllocatorTest, AllocatorTest1) { + lite::DefaultAllocator allocator; + constexpr int data1_size = 10 * sizeof(float); + ASSERT_EQ(allocator.allocated_list_[0], nullptr); + float *data1 = reinterpret_cast(allocator.Malloc(data1_size)); + ASSERT_NE(data1, nullptr); + ASSERT_NE(allocator.allocated_list_[0], nullptr); + + ASSERT_EQ(allocator.free_list_[0], nullptr); + allocator.Free(data1); + ASSERT_EQ(allocator.allocated_list_[0], nullptr); + ASSERT_NE(allocator.free_list_[0], nullptr); +} + +TEST_F(AllocatorTest, AllocatorTest2) { + lite::DefaultAllocator allocator; + constexpr int data1_size = 10 * sizeof(float); + ASSERT_EQ(allocator.allocated_list_[0], nullptr); + float *data1 = reinterpret_cast(allocator.Malloc(data1_size)); + ASSERT_NE(data1, nullptr); + ASSERT_NE(allocator.allocated_list_[0], nullptr); + + constexpr int data2_size = (1024 << lite::kBlockRange); + ASSERT_EQ(allocator.large_mem_list_, nullptr); + float *data2 = reinterpret_cast(allocator.Malloc(data2_size)); + ASSERT_NE(data2, nullptr); + ASSERT_NE(allocator.large_mem_list_, nullptr); + + constexpr int data3_size = (1024 << 3); + ASSERT_EQ(allocator.allocated_list_[3], nullptr); + float *data3 = reinterpret_cast(allocator.Malloc(data3_size)); + ASSERT_NE(data3, nullptr); + ASSERT_NE(allocator.allocated_list_[3], nullptr); + + int expect_total_size = data1_size + data2_size + data3_size; + size_t total_size = allocator.GetTotalSize(); + ASSERT_EQ(total_size, expect_total_size); + + allocator.Clear(); + total_size = allocator.GetTotalSize(); + ASSERT_EQ(total_size, 0); +} + +TEST_F(AllocatorTest, AllocatorTest3) { + lite::DefaultAllocator allocator; + constexpr int data1_size = 10 * sizeof(float); + ASSERT_EQ(allocator.allocated_list_[0], nullptr); + float *data1 = reinterpret_cast(allocator.Malloc(data1_size)); + ASSERT_NE(data1, nullptr); + ASSERT_NE(allocator.allocated_list_[0], nullptr); + + constexpr int data2_size = 11 * sizeof(float); + float *data2 = reinterpret_cast(allocator.Malloc(data2_size)); + ASSERT_NE(data2, nullptr); + + constexpr int data3_size = 12 * sizeof(float); + float *data3 = reinterpret_cast(allocator.Malloc(data3_size)); + ASSERT_NE(data3, nullptr); + + int expect_total_size = data1_size + data2_size + data3_size; + size_t total_size = allocator.GetTotalSize(); + ASSERT_EQ(total_size, expect_total_size); + + allocator.Free(data2); + total_size = allocator.GetTotalSize(); + ASSERT_EQ(total_size, expect_total_size); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/internal/infer_test.cc b/mindspore/lite/test/ut/internal/infer_test.cc index d9e98c0d66..aa3805c851 100644 --- a/mindspore/lite/test/ut/internal/infer_test.cc +++ b/mindspore/lite/test/ut/internal/infer_test.cc @@ -42,13 +42,17 @@ TEST_F(InferTest, TestSession) { node.primitive_ = &prim; node.input_indices_.push_back(0); node.output_indices_.push_back(1); - ShapeVector shape = {1, 1, 1, 10}; + ShapeVector shape(4); + shape[0] = 1; + shape[1] = 1; + shape[2] = 1; + shape[3] = 10; MSTensor *in = CreateTensor(kNumberTypeFloat32, shape); model.all_tensors_.push_back(in); model.input_indices_.push_back(0); MSTensor *out = CreateTensor(kNumberTypeFloat32, shape); - model.all_tensors_.emplace_back(out); + model.all_tensors_.push_back(out); model.output_indices_.push_back(1); LiteSession session; diff --git a/mindspore/lite/test/ut/internal/vector_test.cc b/mindspore/lite/test/ut/internal/vector_test.cc new file mode 100644 index 0000000000..1c8732bedf --- /dev/null +++ b/mindspore/lite/test/ut/internal/vector_test.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "internal/include/vector.h" +#include "nnacl/op_base.h" + +namespace mindspore { +class VectorTest : public mindspore::CommonTest { + public: + VectorTest() {} +}; + +void CheckArrValue(Vector arr) { + for (size_t i = 0; i < arr.size(); ++i) { + ASSERT_EQ(arr[i], i); + } +} + +TEST_F(VectorTest, VectorTest1) { + constexpr int kLen1 = 10; + Vector arr1(kLen1); + for (int i = 0 ; i < kLen1; ++i) { + arr1[i] = i; + } + Vector arr2 = arr1; + ASSERT_EQ(arr2.size(), kLen1); + for (int i = 0; i < kLen1; ++i) { + ASSERT_EQ(arr2[i], i); + } + + Vector arr3; + for (int i = 0; i < kLen1; ++i) { + arr3.push_back(std::move(arr1[i])); + } + CheckArrValue(arr3); +} + +} // namespace mindspore