| @@ -940,6 +940,14 @@ int ElementLogicalAndInt(const int *input0, const int *input1, int *output, cons | |||
| return NNACL_OK; | |||
| } | |||
| int ElementLogicalAndBool(const bool *input0, const bool *input1, bool *output, const int element_size) { | |||
| int index = 0; | |||
| for (; index < element_size; index++) { | |||
| output[index] = (bool)((bool)(input0[index]) & (bool)(input1[index])); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ElementSquaredDifference(const float *input0, const float *input1, float *output, const int element_size) { | |||
| ElementSub(input0, input1, output, element_size); | |||
| return ElementMul(output, output, output, element_size); | |||
| @@ -93,6 +93,7 @@ int BroadcastDiv(const float *input0, const float *input1, float *tile_input0, f | |||
| int ElementLogicalAnd(const float *input0, const float *input1, float *output, const int element_size); | |||
| int ElementLogicalAndInt(const int *input0, const int *input1, int *output, const int element_size); | |||
| int ElementLogicalAndBool(const bool *input0, const bool *input1, bool *output, const int element_size); | |||
| int BroadcastLogicalAnd(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output, | |||
| int element_size, ArithmeticParameter *param); | |||
| @@ -43,9 +43,9 @@ void LiteKernel::FreeWorkspace() { | |||
| } | |||
| #endif | |||
| bool LiteKernel::IsReady(const std::vector<lite::Tensor *> &scope_tensors) { | |||
| return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { | |||
| if (IsContain(scope_tensors, kernel_in_tensor)) { | |||
| return (kernel_in_tensor->IsConst() || kernel_in_tensor->IsGraphInput() || kernel_in_tensor->ref_count() >= 1); | |||
| return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *in_tensor) { | |||
| if (IsContain(scope_tensors, in_tensor)) { | |||
| return in_tensor->IsReady(); | |||
| } else { | |||
| return true; | |||
| } | |||
| @@ -66,13 +66,9 @@ void LiteKernel::InitOutTensorInitRefCount() { | |||
| int LiteKernel::DecOutTensorRefCount() { | |||
| for (auto *tensor : this->out_tensors_) { | |||
| tensor->DecRefCount(); | |||
| tensor->set_ref_count(tensor->ref_count() - 1); | |||
| if (0 >= tensor->ref_count()) { | |||
| auto ret = tensor->FreeData(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Free tensor data failed"; | |||
| return ret; | |||
| } | |||
| tensor->FreeData(); | |||
| } | |||
| } | |||
| return 0; | |||
| @@ -81,18 +77,10 @@ int LiteKernel::DecOutTensorRefCount() { | |||
| int LiteKernel::FreeInWorkTensor() const { | |||
| for (auto &in_tensor : this->in_tensors_) { | |||
| MS_ASSERT(in_tensor != nullptr); | |||
| if (in_tensor->IsConst() || in_tensor->IsGraphInput()) { | |||
| if (in_tensor->root_tensor() == in_tensor) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(in_tensor->ref_count() > 0); | |||
| in_tensor->set_ref_count(in_tensor->ref_count() - 1); | |||
| if (in_tensor->ref_count() <= 0) { | |||
| auto ret = in_tensor->FreeData(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Free tensor data failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| in_tensor->DecRefCount(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -157,7 +157,7 @@ class LiteKernel { | |||
| const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; } | |||
| virtual bool IsReady(const std::vector<lite::Tensor *> &scope_tensors); | |||
| virtual bool IsReady(const std::vector<lite::Tensor *> &in_tensor); | |||
| virtual void InitOutTensorInitRefCount(); | |||
| @@ -143,7 +143,7 @@ lite::Tensor *LiteSession::ConvertTensor(const schema::Tensor &src_tensor) { | |||
| } | |||
| } | |||
| } | |||
| lite::Tensor *dst_tensor = nullptr; | |||
| lite::Tensor *dst_tensor; | |||
| if (TypeId(src_tensor.dataType()) == kObjectTypeTensorType) { | |||
| dst_tensor = new (std::nothrow) TensorList(shape, std::vector<int>(), src_category); | |||
| } else { | |||
| @@ -18,6 +18,7 @@ | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -72,8 +73,29 @@ int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu | |||
| return RET_INFER_INVALID; | |||
| } | |||
| for (size_t i = 0; i < inputs_.size() / 2; i++) { | |||
| outputs_[i]->set_data_type(inputs_[i]->data_type()); | |||
| outputs_[i]->set_shape(inputs_[i]->shape()); | |||
| auto *input = inputs_[i]; | |||
| auto *output = outputs_[i]; | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->set_shape(input->shape()); | |||
| output->set_format(input->format()); | |||
| auto data_type = input->data_type(); | |||
| if (data_type != kObjectTypeTensorType) { | |||
| continue; | |||
| } else { | |||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | |||
| auto output_tensorlist = reinterpret_cast<TensorList *>(output); | |||
| output_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||
| output_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||
| output_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -48,7 +48,6 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive | |||
| memset(split_param->split_sizes_, 0, split_param->num_split_ * sizeof(int)); | |||
| auto split_sizes_vector_ = param->size_splits(); | |||
| MS_ASSERT(split_sizes_vector_.size() == split_param->num_split_); | |||
| for (size_t i = 0; i < split_sizes_vector_.size(); i++) { | |||
| split_param->split_sizes_[i] = split_sizes_vector_[i]; | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -76,12 +77,37 @@ int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||
| return RET_INFER_INVALID; | |||
| } | |||
| for (size_t i = 0; i < outputs_.size() / 2; i++) { | |||
| outputs_[i]->set_data_type(inputs_[i + 1]->data_type()); | |||
| outputs_[i + outputs_.size() / 2]->set_data_type(inputs_[i + 1]->data_type()); | |||
| outputs_[i]->set_shape(inputs_[i + 1]->shape()); | |||
| outputs_[i + outputs_.size() / 2]->set_shape(inputs_[i + 1]->shape()); | |||
| outputs_[i]->set_format(inputs_[i + 1]->format()); | |||
| outputs_[i + outputs_.size() / 2]->set_format(inputs_[i + 1]->format()); | |||
| auto *input = inputs_[i + 1]; | |||
| auto *output_true = outputs_[i]; | |||
| auto *output_false = outputs_[i + outputs_.size() / 2]; | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output_true == nullptr || output_false == nullptr) { | |||
| MS_LOG(ERROR) << "output tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| output_true->set_data_type(input->data_type()); | |||
| output_false->set_data_type(input->data_type()); | |||
| output_true->set_shape(input->shape()); | |||
| output_false->set_shape(input->shape()); | |||
| output_true->set_format(input->format()); | |||
| output_false->set_format(input->format()); | |||
| auto data_type = input->data_type(); | |||
| if (data_type != kObjectTypeTensorType) { | |||
| continue; | |||
| } else { | |||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | |||
| auto output_true_tensorlist = reinterpret_cast<TensorList *>(output_true); | |||
| auto output_false_tensorlist = reinterpret_cast<TensorList *>(output_false); | |||
| output_true_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||
| output_false_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||
| output_true_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||
| output_false_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||
| output_true_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||
| output_false_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -136,7 +136,7 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| MS_LOG(ERROR) << "index_:" << index_ << "must in [0, " << input0->ElementsNum() - 1 << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor_index = input0->GetTensorIndex(index_); | |||
| auto tensor_index = input0->GetTensor(index_); | |||
| MS_ASSERT(tensor_index != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| @@ -159,7 +159,7 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| } | |||
| if (!IsFullyDefined(element_shape_)) { | |||
| for (int i = 0; i < input0->ElementsNum(); ++i) { | |||
| auto input = input0->GetTensorIndex(i); | |||
| auto input = input0->GetTensor(i); | |||
| MS_ASSERT(input != nullptr); | |||
| if (input->data_type() != kTypeUnknown) { | |||
| status = MergeShape(input->shape()); | |||
| @@ -140,7 +140,7 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| } else { | |||
| output0->set_shape(input0->shape()); | |||
| for (int i = 0; i < input0->ElementsNum(); ++i) { | |||
| auto src_ptr = input0->GetTensorIndex(i); | |||
| auto src_ptr = input0->GetTensor(i); | |||
| if (src_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!"; | |||
| return RET_ERROR; | |||
| @@ -133,6 +133,7 @@ int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto ele_shape_ptr = reinterpret_cast<int *>(ele_shape->data_c()); | |||
| output_shape_.clear(); | |||
| for (int i = 0; i < ele_shape->ElementsNum(); ++i) { | |||
| output_shape_.push_back(ele_shape_ptr[i]); | |||
| } | |||
| @@ -148,7 +149,7 @@ int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector | |||
| } | |||
| if (!IsFullyDefined(input0->element_shape())) { | |||
| for (int i = 0; i < input0->ElementsNum(); ++i) { | |||
| auto tensor_ele = input0->GetTensorIndex(i); | |||
| auto tensor_ele = input0->GetTensor(i); | |||
| MS_ASSERT(tensor_ele != nullptr); | |||
| if (tensor_ele->data_type() != kTypeUnknown) { | |||
| status = MergeShape(tensor_ele->shape()); | |||
| @@ -62,11 +62,7 @@ int NPUExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector< | |||
| memcpy(npu_input_tensors_[i]->GetBuffer(), data, in_tensors[index]->Size()); | |||
| in_tensors[index]->set_ref_count(in_tensors[index]->ref_count() - 1); | |||
| if (in_tensors[index]->ref_count() <= 0) { | |||
| auto ret = in_tensors[index]->FreeData(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Free tensor data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| in_tensors[index]->FreeData(); | |||
| } | |||
| break; | |||
| } | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/base/carry_data.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/tensorlist.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| namespace mindspore::kernel { | |||
| int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin, | |||
| std::vector<lite::Tensor *>::iterator dst_end, | |||
| std::vector<lite::Tensor *>::iterator src_begin, | |||
| std::vector<lite::Tensor *>::iterator src_limit) { | |||
| for (auto dst_iter = dst_begin, src_iter = src_begin; dst_iter != dst_end; dst_iter++, src_iter++) { | |||
| if (src_iter == src_limit) { | |||
| MS_LOG(ERROR) << "out of range of input tensor"; | |||
| return RET_ERROR; | |||
| } | |||
| auto *dst_tensor = *dst_iter; | |||
| auto *src_tensor = *src_iter; | |||
| if (dst_tensor == nullptr || src_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| lite::STATUS ret; | |||
| if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { | |||
| ret = MoveTensorLiteData(reinterpret_cast<lite::TensorList *>(dst_tensor), | |||
| reinterpret_cast<lite::TensorList *>(src_tensor)); | |||
| } else { | |||
| ret = MoveTensorData(dst_tensor, src_tensor); | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Move data failed : " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) { | |||
| if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || | |||
| !(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) { | |||
| MS_LOG(ERROR) << "input tensor and output tensor is incompatible"; | |||
| return RET_ERROR; | |||
| } | |||
| if (src_tensor->root_tensor() == nullptr) { | |||
| if (src_tensor->IsConst() || src_tensor->IsGraphInput() || src_tensor->ref_count() > 1) { | |||
| auto dst_data = dst_tensor->MutableData(); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "data of dst tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto src_data = src_tensor->data_c(); | |||
| MS_ASSERT(src_data != nullptr); | |||
| memcpy(dst_data, src_data, dst_tensor->Size()); | |||
| } else { | |||
| dst_tensor->FreeData(); | |||
| dst_tensor->set_data(src_tensor->data_c()); | |||
| src_tensor->set_data(nullptr); | |||
| } | |||
| } else { | |||
| auto ret = dst_tensor->set_root_tensor(src_tensor->root_tensor()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Set root tensor for tensor(" << dst_tensor->tensor_name() << ") failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { | |||
| // shape may change, because tensors.size() can be change in RunGraph | |||
| if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || | |||
| !(dst_tensor->element_shape() == src_tensor->element_shape() || | |||
| (dst_tensor->element_shape().empty() && src_tensor->element_shape().empty())) || | |||
| dst_tensor->tensors_data_type() != src_tensor->tensors_data_type()) { | |||
| MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible"; | |||
| return RET_ERROR; | |||
| } | |||
| if (src_tensor->root_tensor() == nullptr) { | |||
| dst_tensor->CopyTensorList(*src_tensor, false); | |||
| src_tensor->set_tensors({}); | |||
| } else { | |||
| dst_tensor->set_shape(src_tensor->shape()); | |||
| auto ret = dst_tensor->set_root_tensor(src_tensor->root_tensor()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Set root tensor for tensor(" << dst_tensor->tensor_name() << ") failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CARRY_DATA_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CARRY_DATA_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/tensor.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore::kernel { | |||
| class CarryDataKernel : public LiteKernel { | |||
| public: | |||
| CarryDataKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~CarryDataKernel() override = default; | |||
| protected: | |||
| int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end, | |||
| std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit); | |||
| static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); | |||
| static int MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CARRY_DATA_H_ | |||
| @@ -18,6 +18,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/tensorlist.h" | |||
| #include "src/common/utils.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -28,104 +29,97 @@ namespace mindspore::kernel { | |||
| int MergeCPUKernel::FreeInWorkTensor() const { | |||
| for (auto &in_tensor : this->in_tensors_) { | |||
| MS_ASSERT(in_tensor != nullptr); | |||
| if (in_tensor->IsConst() || in_tensor->IsGraphInput()) { | |||
| if (in_tensor->root_tensor() == in_tensor) { | |||
| continue; | |||
| } | |||
| if (in_tensor->ref_count() > 0) { | |||
| in_tensor->set_ref_count(in_tensor->ref_count() - 1); | |||
| if (in_tensor->ref_count() <= 0) { | |||
| auto ret = in_tensor->FreeData(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Free tensor data failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| in_tensor->DecRefCount(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| // if one of input of merge is const-tensor, merge is always ready, this will cause error. | |||
| bool MergeCPUKernel::IsReady(const std::vector<lite::Tensor *> &scope_tensors) { | |||
| MS_ASSERT(in_tensors().size() == 2 * out_tensors().size()); | |||
| return std::all_of(this->in_tensors().begin(), this->in_tensors().begin() + in_tensors().size() / 2, | |||
| [&](lite::Tensor *kernel_in_tensor) { | |||
| return kernel_in_tensor->IsConst() || kernel_in_tensor->IsGraphInput() || | |||
| kernel_in_tensor->ref_count() >= 1; | |||
| }) || | |||
| std::all_of(this->in_tensors().begin() + in_tensors().size() / 2, this->in_tensors().end(), | |||
| [&](lite::Tensor *kernel_in_tensor) { | |||
| return kernel_in_tensor->IsConst() || kernel_in_tensor->IsGraphInput() || | |||
| kernel_in_tensor->ref_count() >= 1 || | |||
| (kernel_in_tensor->data_type() == kObjectTypeTensorType); | |||
| }); | |||
| auto ready_part = FindReadyPart(scope_tensors); | |||
| return ready_part == LEFT_INPUT_PART || ready_part == RIGHT_INPUT_PART; | |||
| } | |||
| int MergeCPUKernel::Init() { return RET_OK; } | |||
| int MergeCPUKernel::Init() { | |||
| MS_ASSERT(in_tensors_.size() == 2 * out_tensors_.size()); | |||
| size_t stride = in_tensors_.size() / 2; | |||
| for (size_t i = 0; i < in_tensors_.size() / 2; i++) { | |||
| MS_ASSERT(in_tensors_[i] != nullptr); | |||
| MS_ASSERT(in_tensors_[i + stride] != nullptr); | |||
| if (in_tensors_[i] == in_tensors_[i + stride]) { | |||
| auto ret = in_tensors_[i]->set_root_tensor(in_tensors_[i]); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Set root tensor for tensor(" << in_tensors_[i]->tensor_name() << ") failed"; | |||
| return ret; | |||
| } | |||
| ret = in_tensors_[i + stride]->set_root_tensor(in_tensors_[i + stride]); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Set root tensor for tensor(" << in_tensors_[i + stride]->tensor_name() << ") failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int MergeCPUKernel::ReSize() { return RET_OK; } | |||
| bool MergeCPUKernel::PartialInputReady(int num_begin, int num_end) { | |||
| InputPart MergeCPUKernel::FindReadyPart(const std::vector<lite::Tensor *> &scope_tensors) { | |||
| MS_ASSERT(in_tensors_.size() == 2 * out_tensors_.size()); | |||
| bool result = (std::all_of(this->in_tensors().begin() + num_begin, this->in_tensors().begin() + num_end, | |||
| [&](lite::Tensor *kernel_in_tensor) { | |||
| return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1 || | |||
| kernel_in_tensor->IsGraphInput() || | |||
| kernel_in_tensor->data_type() == kObjectTypeTensorType; | |||
| })) && | |||
| std::all_of(this->in_tensors_.begin() + num_begin, this->in_tensors_.begin() + num_end, | |||
| [&](lite::Tensor *in_tensor) { | |||
| if (in_tensor->data_type() != kObjectTypeTensorType) { | |||
| return in_tensor->data_c() != nullptr; | |||
| } else { | |||
| return true; | |||
| } | |||
| }); | |||
| return result; | |||
| bool is_root_tensor_ready = | |||
| std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *in_tensor) { | |||
| // if not in scope_tensors, not care | |||
| if (!IsContain(scope_tensors, in_tensor)) { | |||
| return true; | |||
| } | |||
| // if not a root_tensor, not care | |||
| if (in_tensor->root_tensor() == nullptr || in_tensor->root_tensor() != in_tensor) { | |||
| return true; | |||
| } | |||
| return in_tensor->IsReady(); | |||
| }); | |||
| // check if all root tensor is ready | |||
| if (!is_root_tensor_ready) { | |||
| return UNKNOWN_INPUT_PART; | |||
| } | |||
| // check one part of in tensors of merge is ready | |||
| // if not in scope_tensors, not care | |||
| // if in scope_tensors, in_tensor need to be ready | |||
| if (std::all_of( | |||
| this->in_tensors().begin() + in_tensors().size() / 2, this->in_tensors().end(), | |||
| [&](lite::Tensor *in_tensor) { return !IsContain(scope_tensors, in_tensor) || in_tensor->IsReady(); })) { | |||
| return RIGHT_INPUT_PART; | |||
| } | |||
| if (std::all_of( | |||
| this->in_tensors().begin(), this->in_tensors().begin() + in_tensors().size() / 2, | |||
| [&](lite::Tensor *in_tensor) { return !IsContain(scope_tensors, in_tensor) || in_tensor->IsReady(); })) { | |||
| return LEFT_INPUT_PART; | |||
| } | |||
| return UNKNOWN_INPUT_PART; | |||
| } | |||
| int MergeCPUKernel::Run() { | |||
| MS_ASSERT(in_tensors_.size() == 2 * out_tensors_.size()); | |||
| int in_tesnor_part_one = 0; | |||
| int in_tensor_part_two = in_tensors_.size() / 2; | |||
| int in_tensor_part_three = in_tensors_.size(); | |||
| if (PartialInputReady(in_tesnor_part_one, in_tensor_part_two)) { | |||
| for (size_t i = 0; i < out_tensors().size(); i++) { | |||
| auto out_data = out_tensors_[i]->data_c(); | |||
| auto in_data = in_tensors_[i]->data_c(); | |||
| if (in_tensors_[i]->data_type() == kObjectTypeTensorType) { | |||
| auto in_tensor_list = reinterpret_cast<lite::TensorList *>(in_tensors_[i]); | |||
| auto out_tensor_list = reinterpret_cast<lite::TensorList *>(out_tensors_[i]); | |||
| if (std::any_of(in_tensor_list->tensors().begin(), in_tensor_list->tensors().end(), | |||
| [&](lite::Tensor *tensor) { return tensor->data_c() == nullptr; })) { | |||
| continue; | |||
| } | |||
| *out_tensor_list = *in_tensor_list; | |||
| continue; | |||
| } | |||
| MS_ASSERT(in_data != nullptr); | |||
| MS_ASSERT(out_data != nullptr); | |||
| memcpy(out_data, in_data, in_tensors_[i]->Size()); | |||
| auto ready_part = FindReadyPart(this->in_tensors_); | |||
| if (ready_part == LEFT_INPUT_PART) { | |||
| auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), this->in_tensors_.begin(), | |||
| this->in_tensors_.end()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| if (PartialInputReady(in_tensor_part_two, in_tensor_part_three)) { | |||
| for (size_t i = 0; i < out_tensors().size(); i++) { | |||
| auto out_data = out_tensors_[i]->data_c(); | |||
| auto in_data = in_tensors_[i + in_tensor_part_two]->data_c(); | |||
| if (in_tensors_[i]->data_type() == kObjectTypeTensorType) { | |||
| auto in_tensor_list = reinterpret_cast<lite::TensorList *>(in_tensors_[i + in_tensor_part_two]); | |||
| auto out_tensor_list = reinterpret_cast<lite::TensorList *>(out_tensors_[i]); | |||
| if (std::any_of(in_tensor_list->tensors().begin(), in_tensor_list->tensors().end(), | |||
| [&](lite::Tensor *tensor) { return tensor->data_c() == nullptr; })) { | |||
| continue; | |||
| } | |||
| *out_tensor_list = *in_tensor_list; | |||
| continue; | |||
| } | |||
| MS_ASSERT(in_data != nullptr); | |||
| MS_ASSERT(out_data != nullptr); | |||
| memcpy(out_data, in_data, in_tensors_[i]->Size()); | |||
| } else if (ready_part == RIGHT_INPUT_PART) { | |||
| auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), | |||
| (this->in_tensors_.begin() + in_tensors_.size() / 2), this->in_tensors_.end()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "none input part of merge is ready"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -17,21 +17,28 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/carry_data.h" | |||
| #include "src/tensor.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore::kernel { | |||
| class MergeCPUKernel : public LiteKernel { | |||
| enum InputPart { UNKNOWN_INPUT_PART, LEFT_INPUT_PART, RIGHT_INPUT_PART }; | |||
| class MergeCPUKernel : public CarryDataKernel { | |||
| public: | |||
| MergeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~MergeCPUKernel() override {} | |||
| : CarryDataKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| bool IsReady(const std::vector<lite::Tensor *> &scope_tensors) override; | |||
| ~MergeCPUKernel() override = default; | |||
| int FreeInWorkTensor() const override; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int FreeInWorkTensor() const override; | |||
| bool IsReady(const std::vector<lite::Tensor *> &scope_tensors) override; | |||
| private: | |||
| InputPart FindReadyPart(const std::vector<lite::Tensor *> &scope_tensors); | |||
| private: | |||
| bool PartialInputReady(int num_begin, int num_end); | |||
| @@ -43,6 +43,16 @@ int SwitchCPUKernel::PostProcess() { | |||
| auto out_tensor = out_tensors_.at(out_index++); | |||
| out_tensor->ResetRefCount(); | |||
| } | |||
| if (!*active) { | |||
| for (auto &in_tensor : this->in_tensors_) { | |||
| MS_ASSERT(in_tensor != nullptr); | |||
| auto root_tensor = in_tensor->root_tensor(); | |||
| if (root_tensor == nullptr) { | |||
| continue; | |||
| } | |||
| root_tensor->DecRefCount(); | |||
| } | |||
| } | |||
| return FreeInWorkTensor(); | |||
| } | |||
| @@ -64,29 +74,20 @@ int SwitchCPUKernel::Run() { | |||
| MS_LOG(ERROR) << "data of bool tensor is nullptr"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| size_t in_index = 1; | |||
| size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2); | |||
| while (in_index < in_tensors_.size()) { | |||
| auto in_tensor = in_tensors_.at(in_index++); | |||
| auto out_tensor = out_tensors_.at(out_index++); | |||
| // copy for tensorlist | |||
| if (in_tensor->data_type() == kObjectTypeTensorType) { | |||
| auto in_tensor_list = reinterpret_cast<lite::TensorList *>(in_tensor); | |||
| auto out_tensor_list = reinterpret_cast<lite::TensorList *>(out_tensor); | |||
| *out_tensor_list = *in_tensor_list; | |||
| continue; | |||
| if (*active) { | |||
| auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.begin() + out_tensors_.size() / 2, | |||
| this->in_tensors_.begin() + 1, this->in_tensors_.end()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| } | |||
| // copy for tensor | |||
| MS_ASSERT(in_tensor != nullptr); | |||
| MS_ASSERT(out_tensor != nullptr); | |||
| auto input = in_tensor->data_c(); | |||
| auto output = out_tensor->data_c(); | |||
| MS_ASSERT(in_tensor->Size() == out_tensor->Size()); | |||
| if (input == nullptr || output == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor or output tensor have not been malloced"; | |||
| return lite::RET_NULL_PTR; | |||
| } else { | |||
| auto ret = MoveData(this->out_tensors_.begin() + out_tensors_.size() / 2, this->out_tensors_.end(), | |||
| this->in_tensors_.begin() + 1, this->in_tensors_.end()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| } | |||
| memcpy(output, input, in_tensor->Size()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -17,30 +17,22 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/base/carry_data.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore::kernel { | |||
| typedef struct SwitchParameter { | |||
| OpParameter op_parameter_; | |||
| } SwitchParameter; | |||
| class SwitchCPUKernel : public LiteKernel { | |||
| class SwitchCPUKernel : public CarryDataKernel { | |||
| public: | |||
| SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| switch_param_ = reinterpret_cast<SwitchParameter *>(op_parameter_); | |||
| } | |||
| : CarryDataKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~SwitchCPUKernel() override = default; | |||
| int PostProcess() override; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| SwitchParameter *switch_param_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -157,6 +157,7 @@ void ArithmeticCPUKernel::InitRunFunction() { | |||
| case PrimitiveType_LogicalAnd: | |||
| arithmetic_run_ = ElementLogicalAnd; | |||
| arithmetic_run_int_ = ElementLogicalAndInt; | |||
| arithmetic_run_bool_ = ElementLogicalAndBool; | |||
| break; | |||
| case PrimitiveType_LogicalOr: | |||
| arithmetic_run_ = ElementLogicalOr; | |||
| @@ -295,6 +296,8 @@ void ArithmeticCPUKernel::InitParam() { | |||
| arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims(); | |||
| if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { | |||
| data_type_ = kDataTypeFloat; | |||
| } else if (in_tensors_[0]->data_type() == kNumberTypeBool) { | |||
| data_type_ = KDataTypeBool; | |||
| } else { | |||
| data_type_ = kDataTypeInt; | |||
| } | |||
| @@ -419,6 +422,10 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { | |||
| error_code = arithmetic_run_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id, | |||
| reinterpret_cast<float *>(input1_ptr_) + stride * task_id, | |||
| reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count); | |||
| } else if (data_type_ == KDataTypeBool) { | |||
| error_code = arithmetic_run_bool_(reinterpret_cast<bool *>(input0_ptr_) + stride * task_id, | |||
| reinterpret_cast<bool *>(input1_ptr_) + stride * task_id, | |||
| reinterpret_cast<bool *>(out_tensors_[0]->data_c()) + stride * task_id, count); | |||
| } else { | |||
| error_code = arithmetic_run_int_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id, | |||
| reinterpret_cast<int *>(input1_ptr_) + stride * task_id, | |||
| @@ -50,6 +50,7 @@ class ArithmeticCPUKernel : public LiteKernel { | |||
| typedef int (*ArithmeticIntRun)(const int *input0, const int *input1, int *output, const int element_size); | |||
| typedef int (*ArithmeticOptIntRun)(const int *input0, const int *input1, int *output, const int element_size, | |||
| const ArithmeticParameter *param); | |||
| typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size); | |||
| public: | |||
| ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| @@ -91,6 +92,7 @@ class ArithmeticCPUKernel : public LiteKernel { | |||
| ArithmeticOptRun arithmetic_opt_run_ = nullptr; | |||
| ArithmeticIntRun arithmetic_run_int_ = nullptr; | |||
| ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; | |||
| ArithmeticBoolRun arithmetic_run_bool_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ | |||
| @@ -89,7 +89,7 @@ int TensorListFromTensorCPUKernel::Run() { | |||
| auto in_ptr = reinterpret_cast<float *>(input0_->data_c()); | |||
| // copy data from input0(tensor) to output(tensorlist) vector<*tensor> | |||
| for (int i = 0; i < dim0; ++i) { | |||
| auto out_ptr = output0->GetTensorIndex(i); | |||
| auto out_ptr = output0->GetTensor(i); | |||
| MS_ASSERT(out_ptr != nullptr); | |||
| if (out_ptr->ElementsNum() != devision_dim0) { | |||
| MS_LOG(ERROR) << "tensors_[" << i << "].ElementsNum():" << out_ptr->ElementsNum() | |||
| @@ -49,7 +49,7 @@ int TensorListGetItemCPUKernel::Init() { | |||
| int TensorListGetItemCPUKernel::Run() { | |||
| auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_[0]); | |||
| auto src_ptr = input0->GetTensorIndex(index_); | |||
| auto src_ptr = input0->GetTensor(index_); | |||
| MS_ASSERT(src_ptr != nullptr); | |||
| if (src_ptr->data_type() != kTypeUnknown) { | |||
| if (src_ptr->ElementsNum() != out_tensors_[0]->ElementsNum()) { | |||
| @@ -57,7 +57,7 @@ int TensorListGetItemCPUKernel::Run() { | |||
| << " must be equal to out_tensors_[0]->ElementsNum():" << out_tensors_[0]->ElementsNum(); | |||
| return RET_ERROR; | |||
| } | |||
| auto status = out_tensors_[0]->CopyTensorData(*src_ptr); | |||
| auto status = lite::Tensor::CopyTensorData(*src_ptr, out_tensors_[0]); | |||
| if (status == RET_ERROR) { | |||
| MS_LOG(ERROR) << "copy tensor data failed!"; | |||
| return RET_ERROR; | |||
| @@ -59,23 +59,41 @@ int TensorListSetItemCPUKernel::Run() { | |||
| MS_ASSERT(output0_ != nullptr); | |||
| // copy each tensor in tensors_ | |||
| for (int i = 0; i < output0_->ElementsNum(); ++i) { | |||
| auto dst = output0_->GetTensorIndex(i); | |||
| MS_ASSERT(dst != nullptr); | |||
| auto src = input0_->GetTensorIndex(i); | |||
| if (i == index_) { | |||
| // copy input2_ data buff | |||
| src = input2_; | |||
| } | |||
| MS_ASSERT(src != nullptr); | |||
| if (src->data_type() != kTypeUnknown) { | |||
| if (src->Size() != dst->Size()) { | |||
| MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size(); | |||
| return RET_ERROR; | |||
| auto dst = output0_->GetTensor(i); | |||
| if (dst == nullptr) { | |||
| dst = lite::Tensor::CopyTensor(*input2_, true); | |||
| auto &tensors = output0_->tensors(); | |||
| tensors.emplace_back(dst); | |||
| } else { | |||
| dst->set_data_type(input2_->data_type()); | |||
| dst->set_shape(input2_->shape()); | |||
| dst->set_format(input2_->format()); | |||
| dst->set_category(input2_->category()); | |||
| dst->set_root_tensor(input2_->root_tensor()); | |||
| dst->set_tensor_name(input2_->tensor_name()); | |||
| dst->set_quant_clusters(input2_->quant_clusters()); | |||
| auto ret = lite::Tensor::CopyTensorData(*input2_, dst); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "CopyTensorData[" << i << "] is failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| auto ret = dst->CopyTensorData(*src); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "CopyTensorData[" << i << "] is failed!"; | |||
| return RET_ERROR; | |||
| } else { | |||
| auto src = input0_->GetTensor(i); | |||
| auto dst = output0_->GetTensor(i); | |||
| MS_ASSERT(src != nullptr); | |||
| MS_ASSERT(dst != nullptr); | |||
| if (src->data_type() != kTypeUnknown) { | |||
| if (src->Size() != dst->Size()) { | |||
| MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = lite::Tensor::CopyTensorData(*src, dst); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "CopyTensorData[" << i << "] is failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -79,6 +79,7 @@ int TensorListStackCPUKernel::MergeElementShape() { | |||
| return RET_ERROR; | |||
| } | |||
| auto ele_shape_data = reinterpret_cast<int *>(in_tensors_[1]->data_c()); | |||
| output_shape_.clear(); | |||
| for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) { | |||
| output_shape_.push_back(ele_shape_data[i]); | |||
| } | |||
| @@ -94,7 +95,7 @@ int TensorListStackCPUKernel::MergeElementShape() { | |||
| } | |||
| if (!IsFullyDefined(input0_->element_shape())) { | |||
| for (int i = 0; i < input0_->ElementsNum(); ++i) { // get tensorlist every tensor | |||
| auto tensor_ele = input0_->GetTensorIndex(i); | |||
| auto tensor_ele = input0_->GetTensor(i); | |||
| MS_ASSERT(tensor_ele != nullptr); | |||
| if (tensor_ele->data_type() != kTypeUnknown) { | |||
| status = MergeSubShape(tensor_ele->shape()); | |||
| @@ -150,7 +151,7 @@ int TensorListStackCPUKernel::Run() { | |||
| } | |||
| auto out_ptr = reinterpret_cast<float *>(output0_->MutableData()); | |||
| for (int i = 0; i < num_element_; ++i) { | |||
| auto in_ptr = input0_->GetTensorIndex(i); | |||
| auto in_ptr = input0_->GetTensor(i); | |||
| MS_ASSERT(in_ptr != nullptr); | |||
| if (in_ptr->data_type() != kTypeUnknown) { | |||
| int in_size = in_ptr->ElementsNum(); | |||
| @@ -115,11 +115,7 @@ int SubGraphKernel::ReSize(bool is_interrupt) { | |||
| std::vector<lite::Tensor *> inputs = kernel->in_tensors(); | |||
| std::vector<lite::Tensor *> outputs = kernel->out_tensors(); | |||
| for (auto &output : outputs) { | |||
| auto ret = output->FreeData(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "FreeData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| output->FreeData(); | |||
| } | |||
| primitive->set_infer_flag(!is_interrupt); | |||
| auto ret = primitive->InferShape(inputs, outputs); | |||
| @@ -52,13 +52,13 @@ struct DataStore { | |||
| class SubGraphKernel : public LiteKernel { | |||
| public: | |||
| explicit SubGraphKernel(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<LiteKernel *> &in_kernels, const std::vector<LiteKernel *> &out_kernels, | |||
| std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | |||
| SubGraphKernel(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels, | |||
| std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | |||
| : LiteKernel(nullptr, inputs, outputs, ctx, nullptr), | |||
| nodes_(std::move(nodes)), | |||
| in_nodes_(in_kernels), | |||
| out_nodes_(out_kernels) { | |||
| in_nodes_(std::move(in_kernels)), | |||
| out_nodes_(std::move(out_kernels)) { | |||
| subgraph_type_ = kCpuFP32SubGraph; | |||
| } | |||
| @@ -109,20 +109,20 @@ class SubGraphKernel : public LiteKernel { | |||
| std::vector<LiteKernel *> nodes() { return this->nodes_; } | |||
| protected: | |||
| std::vector<LiteKernel *> nodes_; | |||
| std::vector<LiteKernel *> nodes_{}; | |||
| // entry nodes in nodes | |||
| std::vector<LiteKernel *> in_nodes_; | |||
| std::vector<LiteKernel *> in_nodes_{}; | |||
| // exit nodes in nodes | |||
| std::vector<LiteKernel *> out_nodes_; | |||
| std::vector<LiteKernel *> out_nodes_{}; | |||
| mindspore::lite::Executor *executor_ = nullptr; | |||
| }; | |||
| class CpuSubGraph : public SubGraphKernel { | |||
| public: | |||
| explicit CpuSubGraph(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<LiteKernel *> &in_kernels, const std::vector<LiteKernel *> &out_kernels, | |||
| const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) | |||
| : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { | |||
| CpuSubGraph(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels, | |||
| std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | |||
| : SubGraphKernel(inputs, outputs, std::move(in_kernels), std::move(out_kernels), std::move(nodes), ctx) { | |||
| subgraph_type_ = kCpuFP32SubGraph; | |||
| } | |||
| @@ -139,10 +139,10 @@ class CpuSubGraph : public SubGraphKernel { | |||
| class CpuFp32SubGraph : public CpuSubGraph { | |||
| public: | |||
| explicit CpuFp32SubGraph(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<LiteKernel *> &in_kernels, const std::vector<LiteKernel *> &out_kernels, | |||
| const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) | |||
| : CpuSubGraph(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { | |||
| CpuFp32SubGraph(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels, | |||
| std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | |||
| : CpuSubGraph(inputs, outputs, std::move(in_kernels), std::move(out_kernels), std::move(nodes), ctx) { | |||
| subgraph_type_ = kCpuFP32SubGraph; | |||
| this->name_ = "CpuFP32SubGraph"; | |||
| } | |||
| @@ -159,10 +159,10 @@ class CpuFp32SubGraph : public CpuSubGraph { | |||
| class CpuFp16SubGraph : public CpuSubGraph { | |||
| public: | |||
| explicit CpuFp16SubGraph(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| const std::vector<LiteKernel *> &in_kernels, const std::vector<LiteKernel *> &out_kernels, | |||
| const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) | |||
| : CpuSubGraph(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { | |||
| CpuFp16SubGraph(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| std::vector<LiteKernel *> in_kernels, std::vector<LiteKernel *> out_kernels, | |||
| std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx) | |||
| : CpuSubGraph(inputs, outputs, std::move(in_kernels), std::move(out_kernels), std::move(nodes), ctx) { | |||
| subgraph_type_ = kCpuFP16SubGraph; | |||
| this->name_ = "CpuFP16SubGraph"; | |||
| } | |||
| @@ -180,7 +180,7 @@ class CpuFp16SubGraph : public CpuSubGraph { | |||
| void FreeOriginInputData(); | |||
| private: | |||
| std::vector<DataStore *> origin_input_data_; | |||
| std::vector<DataStore *> origin_input_data_{}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_SUB_GRAPH_H | |||
| @@ -29,48 +29,53 @@ namespace lite { | |||
| Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const schema::Format &format, Category category) | |||
| : data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {} | |||
| Tensor::Tensor(const Tensor &tensor) { | |||
| auto ret = CopyTensor(tensor, true); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "CopyTensorData error"; | |||
| int Tensor::CopyTensorData(const Tensor &src_tensor, Tensor *dst_tensor) { | |||
| if (dst_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "dst_tensor is nullptr"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| } | |||
| int Tensor::CopyTensorData(const Tensor &src_tensor) { | |||
| if (src_tensor.data_ == nullptr) { | |||
| MS_LOG(ERROR) << "data of src tensor is nullptr"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| size_t data_size = this->Size(); | |||
| MS_ASSERT(data_size == src_tensor.Size()); | |||
| if (this->data_ == nullptr) { | |||
| size_t data_size = dst_tensor->Size(); | |||
| if (data_size != src_tensor.Size()) { | |||
| MS_LOG(ERROR) << "Size of dst tensor is not compatible with src tensor"; | |||
| return RET_ERROR; | |||
| } | |||
| if (dst_tensor->data_ == nullptr) { | |||
| if (data_size > kMaxMallocSize) { | |||
| MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes"; | |||
| return RET_ERROR; | |||
| } | |||
| this->data_ = malloc(data_size); | |||
| if (this->data_ == nullptr) { | |||
| dst_tensor->data_ = malloc(data_size); | |||
| if (dst_tensor->data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc memory failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| memcpy(this->data_, src_tensor.data_, data_size); | |||
| memcpy(dst_tensor->data_, src_tensor.data_, data_size); | |||
| return RET_OK; | |||
| } | |||
| int Tensor::CopyTensor(const Tensor &src_tensor, bool copy_data) { | |||
| this->data_type_ = src_tensor.data_type_; | |||
| this->shape_ = src_tensor.shape_; | |||
| this->category_ = src_tensor.category_; | |||
| this->format_ = src_tensor.format_; | |||
| Tensor *Tensor::CopyTensor(const Tensor &src_tensor, bool copy_data) { | |||
| auto *result = new (std::nothrow) Tensor; | |||
| if (result == nullptr) { | |||
| MS_LOG(ERROR) << "New tensor failed"; | |||
| return nullptr; | |||
| } | |||
| result->data_type_ = src_tensor.data_type_; | |||
| result->shape_ = src_tensor.shape_; | |||
| result->category_ = src_tensor.category_; | |||
| result->format_ = src_tensor.format_; | |||
| if (copy_data) { | |||
| auto ret = CopyTensorData(src_tensor); | |||
| if (0 != ret) { | |||
| auto ret = CopyTensorData(src_tensor, result); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "CopyTensorData error"; | |||
| return RET_ERROR; | |||
| return nullptr; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| return result; | |||
| } | |||
| Tensor::~Tensor() { | |||
| @@ -84,18 +89,6 @@ Tensor::~Tensor() { | |||
| } | |||
| } | |||
| Tensor &Tensor::operator=(const Tensor &tensor) { | |||
| if (&tensor == this) { | |||
| return *this; | |||
| } | |||
| auto ret = CopyTensor(tensor, true); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "CopyTensorData error"; | |||
| MS_ASSERT(false); | |||
| } | |||
| return *this; | |||
| } | |||
| bool Tensor::operator==(const Tensor &tensor) { | |||
| return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_; | |||
| } | |||
| @@ -283,6 +276,25 @@ std::string Tensor::ToString() const { | |||
| return oss.str(); | |||
| } | |||
| int Tensor::set_root_tensor(Tensor *tensor) { | |||
| this->root_tensor_ = tensor; | |||
| if (this->root_tensor_ == this) { | |||
| return RET_OK; | |||
| } | |||
| if (this->root_tensor_ == nullptr) { | |||
| MS_LOG(ERROR) << "root tensor is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| this->shape_ = this->root_tensor_->shape_; | |||
| this->format_ = this->root_tensor_->format_; | |||
| this->data_type_ = this->root_tensor_->data_type_; | |||
| this->allocator_ = this->root_tensor_->allocator_; | |||
| this->category_ = this->root_tensor_->category_; | |||
| this->quant_params_ = this->root_tensor_->quant_params_; | |||
| this->quant_clusters_ = this->root_tensor_->quant_clusters_; | |||
| return RET_OK; | |||
| } | |||
| int Tensor::MallocData(const mindspore::lite::Allocator *allocator) { | |||
| if (nullptr != this->data_) { | |||
| return RET_OK; | |||
| @@ -303,9 +315,9 @@ int Tensor::MallocData(const mindspore::lite::Allocator *allocator) { | |||
| return RET_OK; | |||
| } | |||
| int Tensor::FreeData() { | |||
| void Tensor::FreeData() { | |||
| if (nullptr == this->data_) { | |||
| return RET_OK; | |||
| return; | |||
| } | |||
| if (nullptr == allocator_) { | |||
| free(this->data_); | |||
| @@ -314,10 +326,19 @@ int Tensor::FreeData() { | |||
| allocator_->Free(this->data_); | |||
| this->data_ = nullptr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void *Tensor::MutableData() { | |||
| if (this->root_tensor_ != nullptr) { | |||
| if (this->root_tensor_ != this && this->root_tensor_->data_ == nullptr) { | |||
| MS_LOG(ERROR) << "root tensor has not been malloced"; | |||
| return nullptr; | |||
| } else if (this->root_tensor_ != this && this->root_tensor_->data_ != nullptr) { | |||
| return this->root_tensor_->data_; | |||
| } else { | |||
| // malloc self | |||
| } | |||
| } | |||
| if (this->data_ == nullptr) { | |||
| auto ret = this->MallocData(); | |||
| if (ret != 0) { | |||
| @@ -328,6 +349,17 @@ void *Tensor::MutableData() { | |||
| return this->data_; | |||
| } | |||
| void Tensor::DecRefCount() { | |||
| if (this->IsConst() || this->IsGraphInput()) { | |||
| return; | |||
| } | |||
| this->ref_count_--; | |||
| if (this->ref_count_ <= 0) { | |||
| FreeData(); | |||
| this->ref_count_ = 0; | |||
| } | |||
| } | |||
| void Tensor::AddQuantParam(const QuantArg &quant_arg) { this->quant_params_.push_back(quant_arg); } | |||
| std::vector<QuantArg> Tensor::quant_params() const { return this->quant_params_; } | |||
| @@ -53,15 +53,19 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| Tensor(TypeId data_type, std::vector<int> shape, const schema::Format &format = schema::Format::Format_NHWC, | |||
| Category category = VAR); | |||
| Tensor(const Tensor &tensor); | |||
| Tensor(const Tensor &tensor) = delete; | |||
| ~Tensor() override; | |||
| Tensor(Tensor &&other) = delete; | |||
| Tensor &operator=(const Tensor &tensor) = delete; | |||
| int CopyTensorData(const Tensor &srcTensor); | |||
| Tensor &operator=(Tensor &&src) = delete; | |||
| int CopyTensor(const Tensor &srcTensor, bool copyData = false); | |||
| ~Tensor() override; | |||
| Tensor &operator=(const Tensor &tensor); | |||
| static int CopyTensorData(const Tensor &src_tensor, Tensor *dst_tensor); | |||
| static Tensor *CopyTensor(const Tensor &src_tensor, bool copy_data = false); | |||
| virtual bool operator==(const Tensor &tensor); | |||
| @@ -99,11 +103,16 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| virtual int MallocData(const mindspore::lite::Allocator *allocator = nullptr); | |||
| virtual int FreeData(); | |||
| virtual void FreeData(); | |||
| void *MutableData() override; | |||
| virtual void *data_c() const { return data_; } | |||
| virtual void *data_c() const { | |||
| if (this->root_tensor_ != nullptr) { | |||
| return this->root_tensor_->data_; | |||
| } | |||
| return data_; | |||
| } | |||
| virtual void set_data(void *data) { this->data_ = data; } | |||
| @@ -125,7 +134,7 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| void ResetRefCount() { this->ref_count_ = this->init_ref_count_; } | |||
| void DecRefCount() { this->ref_count_--; } | |||
| void DecRefCount(); | |||
| std::string ToString() const; | |||
| @@ -151,6 +160,14 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| } | |||
| } | |||
| virtual int set_root_tensor(Tensor *tensor); | |||
| Tensor *root_tensor() const { return this->root_tensor_; } | |||
| bool IsReady() const { | |||
| return this->IsConst() || (this->IsGraphInput() && this->data_ != nullptr) || this->ref_count_ >= 1; | |||
| } | |||
| private: | |||
| template <typename T> | |||
| std::string DataToString(void *data, size_t data_number) const { | |||
| @@ -168,7 +185,6 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| protected: | |||
| std::string tensor_name_; | |||
| void *data_ = nullptr; | |||
| void *device_data_ = nullptr; | |||
| TypeId data_type_; | |||
| std::vector<int> shape_; | |||
| schema::Format format_; | |||
| @@ -178,6 +194,7 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| std::vector<QuantArg> quant_params_; | |||
| std::vector<float> quant_clusters_; | |||
| mindspore::lite::Allocator *allocator_ = nullptr; | |||
| Tensor *root_tensor_ = nullptr; | |||
| }; | |||
| inline size_t DataTypeSize(const TypeId type) { | |||
| @@ -14,38 +14,26 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/tensorlist.h" | |||
| #include <utility> | |||
| #include "include/ms_tensor.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/tensor.h" | |||
| #include "src/tensorlist.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| TensorList::TensorList(std::vector<int> shape, std::vector<int> element_shape, Category category) | |||
| : Tensor(kObjectTypeTensorType, shape, schema::Format::Format_NHWC, category), element_shape_(element_shape) {} | |||
| : Tensor(kObjectTypeTensorType, std::move(shape), schema::Format::Format_NHWC, category), | |||
| element_shape_(std::move(element_shape)) {} | |||
| TensorList::~TensorList() { | |||
| if (!this->tensors_.empty()) { | |||
| this->FreeData(); | |||
| this->TensorList::FreeData(); | |||
| this->FreeTensorListData(); | |||
| } | |||
| } | |||
| TensorList &TensorList::operator=(const TensorList &src) { | |||
| if (&src == this) { | |||
| return *this; | |||
| } | |||
| auto ret = CopyTensorList(src, true); | |||
| if (ret == RET_ERROR) { | |||
| MS_LOG(ERROR) << "CopyTensorList error!"; | |||
| MS_ASSERT(false); | |||
| } | |||
| return *this; | |||
| } | |||
| int TensorList::CopyTensorList(const TensorList &src, bool copy_data) { | |||
| this->data_type_ = src.data_type_; | |||
| this->tensors_data_type_ = src.tensors_data_type_; | |||
| @@ -59,6 +47,10 @@ int TensorList::CopyTensorList(const TensorList &src, bool copy_data) { | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| for (auto tensor : this->tensors()) { | |||
| delete tensor; | |||
| } | |||
| this->tensors_.clear(); | |||
| // each tensor in tensors_ will share the same memory space. | |||
| this->tensors_ = src.tensors_; | |||
| } | |||
| @@ -69,17 +61,20 @@ int TensorList::CopyTensorData(const TensorList &src) { | |||
| if (src.tensors_.empty()) { | |||
| return RET_OK; | |||
| } | |||
| for (auto tensor : this->tensors()) { | |||
| delete tensor; | |||
| } | |||
| this->tensors_.clear(); | |||
| for (int i = 0; i < this->ElementsNum(); ++i) { | |||
| if (src.tensors_[i] == nullptr) { | |||
| MS_LOG(ERROR) << "src tensors_[" << i << "] is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| auto dst_tensor = new (std::nothrow) Tensor; | |||
| auto dst_tensor = Tensor::CopyTensor(*src.tensors_[i]); | |||
| if (dst_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "CopyTensorData: new tensor[" << i << "] is failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| *reinterpret_cast<Tensor *>(dst_tensor) = *src.tensors_[i]; | |||
| this->tensors_.push_back(dst_tensor); | |||
| } | |||
| return RET_OK; | |||
| @@ -143,17 +138,11 @@ int TensorList::MallocData(const mindspore::lite::Allocator *allocator) { | |||
| return RET_OK; | |||
| } | |||
| int TensorList::FreeData() { | |||
| void TensorList::FreeData() { | |||
| // free data buf of each tensor in tensors_ | |||
| if (this->tensors_.empty()) { | |||
| return RET_OK; | |||
| for (auto tensor : tensors_) { | |||
| tensor->FreeData(); | |||
| } | |||
| for (int i = 0; i < this->ElementsNum(); ++i) { | |||
| if (this->tensors_[i] != nullptr) { | |||
| this->tensors_[i]->FreeData(); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TensorList::FreeTensorListData() { | |||
| @@ -171,7 +160,7 @@ int TensorList::FreeTensorListData() { | |||
| return RET_OK; | |||
| } | |||
| int TensorList::SetTensorIndex(int index, Tensor *src_tensor) { | |||
| int TensorList::SetTensor(int index, Tensor *src_tensor) { | |||
| // your can use this fun to modify tensor[index] value | |||
| if (src_tensor->data_type() != this->tensors_data_type_) { | |||
| MS_LOG(ERROR) << "src_tensor->data_type():" << src_tensor->data_type() | |||
| @@ -183,15 +172,13 @@ int TensorList::SetTensorIndex(int index, Tensor *src_tensor) { | |||
| return RET_ERROR; | |||
| } | |||
| auto dst_tensor = this->tensors_[index]; | |||
| if (dst_tensor != nullptr) { // free original tensor data | |||
| delete dst_tensor; | |||
| } | |||
| this->tensors_[index] = new (std::nothrow) Tensor; | |||
| // free original tensor data | |||
| delete dst_tensor; | |||
| this->tensors_[index] = Tensor::CopyTensor(*src_tensor); | |||
| if (this->tensors_[index] == nullptr) { | |||
| MS_LOG(ERROR) << "SetTensorIndex: new tensor is failed!"; | |||
| MS_LOG(ERROR) << "SetTensor: new tensor is failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| *this->tensors_[index] = *src_tensor; | |||
| return RET_OK; | |||
| } | |||
| @@ -211,9 +198,40 @@ int TensorList::CheckTensorListParam() { | |||
| return RET_OK; | |||
| } | |||
| Tensor *TensorList::GetTensorIndex(int index) { | |||
| int TensorList::set_root_tensor(Tensor *tensor) { | |||
| auto ret = Tensor::set_root_tensor(tensor); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| if (this->data_type_ != kObjectTypeTensorType) { | |||
| return RET_OK; | |||
| } | |||
| auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_); | |||
| if (root_tensorlist == nullptr) { | |||
| MS_LOG(ERROR) << "root_tensor of tensorlist should be a tensorlist"; | |||
| return RET_INFER_INVALID; | |||
| } | |||
| this->element_shape_ = root_tensorlist->element_shape_; | |||
| this->max_elements_num_ = root_tensorlist->max_elements_num_; | |||
| this->tensors_data_type_ = root_tensorlist->tensors_data_type_; | |||
| return RET_OK; | |||
| } | |||
| Tensor *TensorList::GetTensor(int index) { | |||
| // return tensor[index] ptr. With this function, you can modify tensors_[index] at will. | |||
| if (index < 0 || index >= static_cast<int>(tensors_.size())) { | |||
| if (this->root_tensor_ != nullptr) { | |||
| if (this->data_type_ != kObjectTypeTensorType) { | |||
| MS_LOG(ERROR) << "root_tensor of tensorlist should be a tensorlist"; | |||
| return nullptr; | |||
| } | |||
| auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_); | |||
| if (index < 0 || index >= static_cast<int>(root_tensorlist->tensors_.size())) { | |||
| MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!"; | |||
| return nullptr; | |||
| } | |||
| return root_tensorlist->tensors_[index]; | |||
| } | |||
| if (index < 0 || index >= static_cast<int>(this->tensors_.size())) { | |||
| MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!"; | |||
| return nullptr; | |||
| } | |||
| @@ -264,5 +282,4 @@ STATUS TensorList::Decode(const int *data) { | |||
| bool TensorList::IsConst() const { return this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| @@ -25,8 +25,7 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/tensor.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| /** | |||
| * Tensorlist is a container of vector, in which each element is a tensor object. | |||
| * Member objects: | |||
| @@ -64,17 +63,9 @@ class TensorList : public Tensor { | |||
| ~TensorList() override; | |||
| // **Note**: This is a shallow copy, src and dst tensorlist share one memory space of each tensor in tensors_ | |||
| // If your want to not share one memory space please use "operator=" | |||
| TensorList(const TensorList &other) | |||
| : Tensor(other.data_type_, other.shape()), | |||
| tensors_(other.tensors_), | |||
| tensors_data_type_(other.tensors_data_type_), | |||
| element_shape_(other.element_shape_), | |||
| max_elements_num_(other.max_elements_num_) {} | |||
| TensorList(const TensorList &other) = delete; | |||
| // tensorlist deep copy memory | |||
| TensorList &operator=(const TensorList &tl); | |||
| TensorList &operator=(const TensorList &tl) = delete; | |||
| void set_element_shape(const std::vector<int> &shape) { element_shape_ = shape; } | |||
| @@ -90,15 +81,15 @@ class TensorList : public Tensor { | |||
| int FreeTensorListData(); | |||
| int FreeData() override; | |||
| void FreeData() override; | |||
| int CopyTensorList(const TensorList &src, bool copy_data); | |||
| int CopyTensorData(const TensorList &src); | |||
| int SetTensorIndex(int index, Tensor *); | |||
| int SetTensor(int index, Tensor *src_tensor); | |||
| Tensor *GetTensorIndex(int index); | |||
| Tensor *GetTensor(int index); | |||
| void set_tensors_data_type(TypeId type) { tensors_data_type_ = type; } | |||
| @@ -106,6 +97,8 @@ class TensorList : public Tensor { | |||
| std::vector<Tensor *> &tensors() { return tensors_; } | |||
| void set_tensors(const std::vector<Tensor *> &tensors) { this->tensors_ = tensors; } | |||
| int CheckTensorListParam(); | |||
| bool IsCompatibleShape(const std::vector<int> &shape); | |||
| @@ -116,18 +109,19 @@ class TensorList : public Tensor { | |||
| bool IsConst() const override; | |||
| int set_root_tensor(Tensor *tensor) override; | |||
| protected: | |||
| // The following functions must be masked. | |||
| void set_data(void *data) override { return; } | |||
| void set_data(void *data) override {} | |||
| void *data_c() const override { return nullptr; } | |||
| void *MutableData() override { return nullptr; } | |||
| size_t Size() const override { return 0; } | |||
| std::vector<Tensor *> tensors_; | |||
| TypeId tensors_data_type_; | |||
| std::vector<int> element_shape_; | |||
| std::vector<Tensor *> tensors_{}; | |||
| TypeId tensors_data_type_ = kTypeUnknown; | |||
| std::vector<int> element_shape_{}; | |||
| int max_elements_num_ = -1; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_TENSORLIST_H_ | |||
| @@ -43,23 +43,23 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std | |||
| // simulating benchmark: session_->CompileGraph() -> ConvertTensors() | |||
| MS_LOG(DEBUG) << "create Tensors & init weight data"; | |||
| std::vector<Tensor> tensors; | |||
| std::vector<std::shared_ptr<Tensor>> tensors; | |||
| // firstly, create all Tensors | |||
| tensors.reserve(input_infos.size()); // vector's capacity() is 0, so call reserve() avoiding vector re-malloc | |||
| for (auto input_info : input_infos) { | |||
| auto &shape = std::get<0>(input_info); | |||
| auto category = std::get<2>(input_info); | |||
| auto data_type = std::get<3>(input_info); | |||
| tensors.emplace_back(data_type, shape, Format_NHWC, category); | |||
| tensors.emplace_back(std::make_shared<Tensor>(data_type, shape, Format_NHWC, category)); | |||
| } | |||
| // secondly, init weight Tensor's data | |||
| std::vector<Tensor *> kernel_inputs; | |||
| std::vector<Tensor *> subgraph_inputs; | |||
| std::map<Tensor *, float *> subgraph_inputs_data; | |||
| for (int i = 0; i < tensors.size(); ++i) { | |||
| auto *tensor = &tensors[i]; | |||
| auto tensor = tensors[i]; | |||
| auto *input_data = std::get<1>(input_infos[i]); | |||
| kernel_inputs.push_back(tensor); | |||
| kernel_inputs.push_back(tensor.get()); | |||
| if (tensor->category() != VAR) { // tensor is weight | |||
| // simulating src/lite_session.cc:WeightTensorNeedCopy() | |||
| if (packed_op.count(primitive_type)) { | |||
| @@ -69,8 +69,8 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std | |||
| } | |||
| } else { | |||
| EXPECT_TRUE(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32); | |||
| subgraph_inputs.push_back(tensor); | |||
| subgraph_inputs_data[tensor] = reinterpret_cast<float *>(input_data); | |||
| subgraph_inputs.push_back(tensor.get()); | |||
| subgraph_inputs_data[tensor.get()] = reinterpret_cast<float *>(input_data); | |||
| } | |||
| } | |||
| @@ -115,7 +115,7 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std | |||
| // simulating benchmark: model->Free(), clear weight data in input_infos | |||
| std::vector<std::unique_ptr<uint8_t[]>> saved_weights; | |||
| for (int i = 0; i < tensors.size(); ++i) { | |||
| auto *tensor = &tensors[i]; | |||
| auto &tensor = tensors[i]; | |||
| if (tensor->category() != VAR) { | |||
| saved_weights.emplace_back(new uint8_t[tensor->Size()]); | |||
| auto *weight_data = std::get<1>(input_infos[i]); | |||
| @@ -143,12 +143,12 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std | |||
| MS_LOG(DEBUG) << "release resources"; | |||
| for (auto &tensor : tensors) { | |||
| if (tensor.category() != VAR && packed_op.count(primitive_type)) { | |||
| tensor.set_data(nullptr); | |||
| if (tensor->category() != VAR && packed_op.count(primitive_type)) { | |||
| tensor->set_data(nullptr); | |||
| } | |||
| } | |||
| for (int i = 0, j = 0; i < tensors.size(); ++i) { // resume weight data to input_infos | |||
| auto *tensor = &tensors[i]; | |||
| auto &tensor = tensors[i]; | |||
| if (tensor->category() != VAR) { | |||
| auto *weight_data = std::get<1>(input_infos[i]); | |||
| memcpy(weight_data, saved_weights[j++].get(), tensor->Size()); | |||