diff --git a/mindspore/lite/src/executor.h b/mindspore/lite/src/executor.h index e4fb8f8525..ddc54fded9 100644 --- a/mindspore/lite/src/executor.h +++ b/mindspore/lite/src/executor.h @@ -28,7 +28,8 @@ class Executor { Executor() = default; virtual ~Executor() = default; - virtual int Prepare(const std::vector &kernels) { + virtual int Prepare(const std::vector &kernels, const std::vector &inputs, + const std::vector &outputs) { ctx_ = static_cast(kernels[0]->context()); return RET_OK; } diff --git a/mindspore/lite/src/lite_kernel_util.cc b/mindspore/lite/src/lite_kernel_util.cc index 97b662ba5a..e2b99e3d50 100644 --- a/mindspore/lite/src/lite_kernel_util.cc +++ b/mindspore/lite/src/lite_kernel_util.cc @@ -23,11 +23,13 @@ namespace mindspore::kernel { using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; std::vector LiteKernelUtil::SubgraphInputNodes(const std::vector &kernels) { - std::set input_nodes; + std::vector input_nodes; for (const auto &kernel : kernels) { // if kernel has no pre-kernel, kernel is a graph input, it must be a subgraph input if (kernel->in_kernels().empty() && !kernel->in_tensors().empty()) { - input_nodes.insert(kernel); + if (!lite::IsContain(input_nodes, kernel)) { + input_nodes.push_back(kernel); + } continue; } auto all_input_tensors = kernel->in_tensors(); @@ -50,45 +52,49 @@ std::vector LiteKernelUtil::SubgraphInputNodes(const std:: } // if some input tensor is not from kernel in subgraph if (!all_input_tensors.empty()) { - input_nodes.insert(kernel); + if (!lite::IsContain(input_nodes, kernel)) { + input_nodes.push_back(kernel); + } } } - std::vector result; - result.insert(result.end(), input_nodes.begin(), input_nodes.end()); - return result; + return input_nodes; } std::vector LiteKernelUtil::SubgraphOutputNodes( const std::vector &kernels) { - std::set output_nodes; + std::vector output_nodes; // if kernel has no post-kernel, kernel is a graph output, it must be a subgraph output for (const auto &kernel : kernels) { if (kernel->is_model_output() || (kernel->out_kernels().empty() && !kernel->out_tensors().empty())) { - output_nodes.insert(kernel); + if (!lite::IsContain(output_nodes, kernel)) { + output_nodes.push_back(kernel); + } continue; } for (const auto &output : kernel->out_kernels()) { auto out_kernel_in_graph = std::find(kernels.begin(), kernels.end(), output); if (out_kernel_in_graph == kernels.end()) { - output_nodes.insert(kernel); + if (!lite::IsContain(output_nodes, kernel)) { + output_nodes.push_back(kernel); + } break; } } } - std::vector result; - result.insert(result.end(), output_nodes.begin(), output_nodes.end()); - return result; + return output_nodes; } std::vector LiteKernelUtil::SubgraphInputTensors(const std::vector &kernels) { - std::set input_tensors; + std::vector input_tensors; std::vector input_nodes = SubgraphInputNodes(kernels); for (const auto &input_node : input_nodes) { auto &in_node_in_kernels = input_node->in_kernels(); auto &in_node_in_tensors = input_node->in_tensors(); for (auto &in_node_in_tensor : in_node_in_tensors) { if (in_node_in_tensor->IsGraphInput()) { - input_tensors.insert(in_node_in_tensor); + if (!lite::IsContain(input_tensors, in_node_in_tensor)) { + input_tensors.push_back(in_node_in_tensor); + } } } for (auto in_node_in_kernel : in_node_in_kernels) { @@ -101,25 +107,27 @@ std::vector LiteKernelUtil::SubgraphInputTensors(const std::vect auto outer_in_kernel_out_tensors_iter = std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_node_in_tensor); if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { - input_tensors.insert(in_node_in_tensor); + if (!lite::IsContain(input_tensors, in_node_in_tensor)) { + input_tensors.push_back(in_node_in_tensor); + } } } } } - std::vector result; - result.insert(result.end(), input_tensors.begin(), input_tensors.end()); - return result; + return input_tensors; } std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vector &kernels) { - std::set output_tensors; + std::vector output_tensors; std::vector output_nodes = SubgraphOutputNodes(kernels); for (const auto &output_kernel : output_nodes) { auto &outer_out_kernels = output_kernel->out_kernels(); auto &out_kernel_out_tensors = output_kernel->out_tensors(); for (auto out_kernel_out_tensor : out_kernel_out_tensors) { if (out_kernel_out_tensor->IsGraphOutput()) { - output_tensors.insert(out_kernel_out_tensor); + if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { + output_tensors.push_back(out_kernel_out_tensor); + } } } if (!outer_out_kernels.empty()) { @@ -133,15 +141,15 @@ std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vec auto outer_out_kernel_in_tensors_iter = std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { - output_tensors.insert(out_kernel_out_tensor); + if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { + output_tensors.push_back(out_kernel_out_tensor); + } } } } } } - std::vector result; - result.insert(result.end(), output_tensors.begin(), output_tensors.end()); - return result; + return output_tensors; } int LiteKernelUtil::TopologicalSortKernels(std::vector *kernels) { diff --git a/mindspore/lite/src/lite_mindrt.cc b/mindspore/lite/src/lite_mindrt.cc index 33cb7e29e8..968e602664 100644 --- a/mindspore/lite/src/lite_mindrt.cc +++ b/mindspore/lite/src/lite_mindrt.cc @@ -17,15 +17,21 @@ #include #include "src/lite_mindrt.h" #include "mindrt/include/mindrt.hpp" +#include "src/lite_kernel_util.h" +#include "nnacl/partial_fusion_parameter.h" +#include "src/common/tensor_util.h" +#include "nnacl/base/cast_base.h" namespace mindspore::lite { -int LiteOpActor::CompileArrow() { - int outTensorSize = static_cast(kernel_->out_tensors().size()); - for (int i = 0; i < outTensorSize; i++) { + +int LiteOpActor::CompileArrowThroughOutputKernels() { + output_op_arrows_.clear(); + int out_tensor_size = static_cast(kernel_->out_tensors().size()); + for (int i = 0; i < out_tensor_size; i++) { for (auto out : kernel_->out_kernels()) { - int inTensorSize = static_cast(out->in_tensors().size()); + int in_tensor_size = static_cast(out->in_tensors().size()); int to_input_index = -1; - for (int j = 0; j < inTensorSize; j++) { + for (int j = 0; j < in_tensor_size; j++) { if (kernel_->out_tensors()[i] == out->in_tensors()[j]) { to_input_index = j; break; @@ -46,18 +52,134 @@ int LiteOpActor::CompileArrow() { return RET_OK; } +int LiteOpActor::CompileArrowThroughPartialCall() { + auto *subgraph_kernel = reinterpret_cast(kernel_); + if (subgraph_kernel == nullptr) { + MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call."; + return RET_OK; + } + for (auto &node : subgraph_kernel->nodes()) { + if (node->Type() != schema::PrimitiveType_Call) { + continue; + } + call_node_ = node; + auto partial_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_PartialFusion); + if (!partial_node) { + continue; + } + partial_node_ = partial_node; + + auto partial_para = reinterpret_cast(partial_node->op_parameter()); + auto out_actor_id = subgraph_index_to_actor.at(partial_para->sub_graph_index_); + kernel_->set_out_tensors(partial_node->in_tensors()); + for (size_t i = 0; i < partial_node->in_tensors().size(); ++i) { + auto arrow = std::make_shared(i, out_actor_id, i); + if (arrow == nullptr) { + MS_LOG(ERROR) << "create OpArrow failed"; + return RET_ERROR; + } + output_op_arrows_.emplace_back(std::move(arrow)); + } + } + + subgraph_kernel->DropNode(partial_node_); + subgraph_kernel->DropNode(call_node_); + return RET_OK; +} + +int LiteOpActor::CompileArrow() { + output_op_arrows_.clear(); + int ret = CompileArrowThroughPartialCall(); + if (ret != RET_OK) { + output_op_arrows_.clear(); + MS_LOG(ERROR) << "CompileArrowThroughPartialCall failed."; + return ret; + } + if (!output_op_arrows_.empty()) { + MS_LOG(INFO) << "CompileArrowThroughPartialCall done."; + return RET_OK; + } + ret = CompileArrowThroughOutputKernels(); + if (ret != RET_OK) { + output_op_arrows_.clear(); + MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed."; + return ret; + } + return ret; +} + +int LiteOpActor::CheckInputData() { + if (kernel_->in_tensors().size() != inputs_data_.size()) { + MS_LOG(ERROR) << "kernel:" << kernel_->name() << "inputs_data_.size(): " << inputs_data_.size() + << " vs kernel_->in_tensors().size(): " << kernel_->in_tensors().size() << " are not equal."; + return RET_PARAM_INVALID; + } + + for (size_t i = 0; i < inputs_data_.size(); ++i) { + if (kernel_->in_tensors()[i]->shape() != inputs_data_[i]->shape()) { + MS_LOG(ERROR) << "inputs_data_[" << i << "].shape: " << inputs_data_[i]->shape() << " vs kernel_->in_tensors()[" + << i << "].shape: " << kernel_->in_tensors()[i]->shape() << " are not equal."; + return RET_PARAM_INVALID; + } + } + return RET_OK; +} + +void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) { + memcpy(dst_tensor->MutableData(), src_tensor->data_c(), src_tensor->Size()); + dst_tensor->IncRefCount(); + src_tensor->DecRefCount(); +} +void LiteOpActor::CopyInputData(Tensor *dst_tensor, Tensor *src_tensor) { + CastTensorData(dst_tensor, src_tensor); + dst_tensor->IncRefCount(); + src_tensor->DecRefCount(); +} + +int LiteOpActor::CastTensorData(Tensor *dst, Tensor *src) { + if (dst->shape() != src->shape()) { + MS_LOG(ERROR) << "dst tensor: " << dst->tensor_name() << " shape: " << dst->shape() << " vs " + << "src tensor: " << src->tensor_name() << " shape: " << src->shape(); + return RET_PARAM_INVALID; + } + auto dst_data = dst->MutableData(); + auto src_data = src->MutableData(); + auto src_nums_size = src->ElementsNum(); + auto dst_data_type = static_cast(dst->data_type()); + auto src_data_type = static_cast(src->data_type()); + + if (dst_data_type == kNumberTypeFloat32 && src_data_type == kNumberTypeFloat16) { + Fp16ToFloat32(static_cast(src_data), static_cast(dst_data), src_nums_size); + } else if (dst_data_type == kNumberTypeFloat16 && src_data_type == kNumberTypeFloat32) { + Float32ToFp16(static_cast(src_data), static_cast(dst_data), src_nums_size); + } else { + MS_LOG(ERROR) << "not support dst_data_type: " << dst_data_type << " src_data_type: " << src_data_type; + return RET_NOT_SUPPORT; + } + return RET_OK; +} + +int LiteOpActor::SetInputData() { + for (size_t i = 0; i < inputs_data_.size(); ++i) { + auto dst_tensor = kernel_->in_tensors()[i]; + auto src_tensor = inputs_data_[i]; + if (src_tensor->data_type() != dst_tensor->data_type()) { + CopyInputData(dst_tensor, src_tensor); + } else { + MoveInputData(dst_tensor, src_tensor); + } + } + return RET_OK; +} + void LiteOpActor::AsyncOutput(OpContext *context) { - for (auto op_arrow : output_op_arrows_) { - auto data = context->output_data_->at(op_arrow->from_output_index_); + for (const auto &op_arrow : output_op_arrows_) { + auto data = outputs_data_.at(op_arrow->from_output_index_); Async(op_arrow->to_op_id_, &mindspore::OpActor::RunOpData, data, context); } - return; } -void LiteOpActor::AddResultIndex(size_t index) { - results_index_.push_back(index); - return; -} +void LiteOpActor::AddResultIndex(size_t index) { results_index_.push_back(index); } void LiteOpActor::SetOutputData(OpContext *context) { for (auto index : results_index_) { @@ -65,30 +187,243 @@ void LiteOpActor::SetOutputData(OpContext *context) { } } -int MindrtInit() { return mindspore::Initialize("tcp://127.0.0.1:8080", "", "", "", 1); } - -void MindrtTerminate(std::vector> actor_list) { - for (auto actor : actor_list) { - mindspore::Terminate(actor->GetAID()); +int LiteOpActor::PrepareOutputData() { + for (auto &arrow : output_op_arrows_) { + auto data = std::make_shared>(arrow->to_op_id_, kernel_->out_tensors().at(arrow->from_output_index_), + static_cast(arrow->to_input_index_)); + outputs_data_.emplace_back(data); } - mindspore::TerminateCurThreads(1); - return; + return RET_OK; } std::vector> CreateOpActor(const std::vector &kernels) { std::vector> actors; - for (auto kernel : kernels) { - auto actor = std::make_shared(kernel); - if (actor == nullptr) { - MS_LOG(ERROR) << "create LiteOpActor failed: " << kernel->name(); - actors.clear(); - return actors; + std::unordered_map partial_map{}; + for (size_t i = 0; i < kernels.size(); ++i) { + if ((kernel::LiteKernelUtil::IsSwitchCall(kernels[i]))) { + auto switch_actor = std::make_shared(kernels[i]); + if (switch_actor == nullptr) { + MS_LOG(ERROR) << "create LiteSwitchOpActor failed: " << kernels[i]->name(); + actors.clear(); + return actors; + } + partial_map[i] = switch_actor->GetAID(); + actors.push_back(switch_actor); + } else { + auto actor = std::make_shared(kernels[i]); + if (actor == nullptr) { + MS_LOG(ERROR) << "create LiteOpActor failed: " << kernels[i]->name(); + actors.clear(); + return actors; + } + partial_map[i] = actor->GetAID(); + actors.push_back(actor); } - auto aid = mindspore::Spawn(actor); - actors.push_back(actor); } + for (auto &actor : actors) { + actor->SetPartialMap(partial_map); + auto aid = mindspore::Spawn(actor); + } return actors; } +int LiteSwitchOpActor::CompileTrueBranchArrow() { + true_branch_output_op_arrows_.clear(); + if (true_partial_node_ == nullptr) { + MS_LOG(ERROR) << "true_partial_node_ is nullptr."; + return RET_NULL_PTR; + } + auto true_partial_para = reinterpret_cast(true_partial_node_->op_parameter()); + if (true_partial_para == nullptr) { + MS_LOG(ERROR) << "true_partial_node_->op_parameter() is nullptr."; + return RET_NULL_PTR; + } + auto true_branch_actor_id = subgraph_index_to_actor.at(true_partial_para->sub_graph_index_); + + for (size_t i = 0; i < true_partial_node_->in_tensors().size(); ++i) { + int out_tensor_size = static_cast(kernel_->out_tensors().size()); + for (int j = 0; j < out_tensor_size; ++j) { + if (true_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) { + continue; + } + auto arrow = std::make_shared(j, true_branch_actor_id, i); + if (arrow == nullptr) { + MS_LOG(ERROR) << "create OpArrow failed"; + return RET_ERROR; + } + true_branch_output_op_arrows_.emplace_back(std::move(arrow)); + } + } + return RET_OK; +} + +int LiteSwitchOpActor::CompileFalseBranchArrow() { + false_branch_output_op_arrows_.clear(); + if (false_partial_node_ == nullptr) { + MS_LOG(ERROR) << "false_partial_node_ is nullptr."; + return RET_NULL_PTR; + } + auto false_partial_para = reinterpret_cast(false_partial_node_->op_parameter()); + if (false_partial_para == nullptr) { + MS_LOG(ERROR) << "false_partial_para->op_parameter() is nullptr."; + return RET_NULL_PTR; + } + auto false_branch_actor_id = subgraph_index_to_actor.at(false_partial_para->sub_graph_index_); + + for (size_t i = 0; i < false_partial_node_->in_tensors().size(); ++i) { + int out_tensor_size = static_cast(kernel_->out_tensors().size()); + for (int j = 0; j < out_tensor_size; ++j) { + if (false_partial_node_->in_tensors()[i] != kernel_->out_tensors()[j]) { + continue; + } + auto arrow = std::make_shared(j, false_branch_actor_id, i); + if (arrow == nullptr) { + MS_LOG(ERROR) << "create OpArrow failed"; + return RET_ERROR; + } + false_branch_output_op_arrows_.emplace_back(std::move(arrow)); + } + } + return RET_OK; +} + +int LiteSwitchOpActor::GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel) { + for (auto &node : subgraph_kernel->nodes()) { + if (node->Type() != schema::PrimitiveType_Call) { + continue; + } + call_node_ = node; + auto switch_node = kernel::LiteKernelUtil::GetInputsSpecificNode(node, schema::PrimitiveType_Switch); + if (!switch_node) { + continue; + } + switch_node_ = switch_node; + if (switch_node->in_kernels().size() != kSwitchInputsSize) { + MS_LOG(ERROR) << "switch input size: " << switch_node->in_kernels().size(); + return RET_MEMORY_FAILED; + } + + bool_node_ = switch_node->in_kernels().at(kSwitchCondInputIndex); + true_partial_node_ = switch_node->in_kernels().at(kSwitchTruePartialInputIndex); + false_partial_node_ = switch_node->in_kernels().at(kSwitchFalsePartialInputIndex); + break; + } + return RET_OK; +} + +void LiteSwitchOpActor::AppendOutputTensors() { + output_tensors_.push_back(bool_node_->out_tensors().front()); + for (auto &tensor : true_partial_node_->in_tensors()) { + if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) { + output_tensors_.push_back(tensor); + } + } + for (auto &tensor : false_partial_node_->in_tensors()) { + if (std::find(output_tensors_.begin(), output_tensors_.end(), tensor) == output_tensors_.end()) { + output_tensors_.push_back(tensor); + } + } + kernel_->set_out_tensors(output_tensors_); +} + +int LiteSwitchOpActor::CompileArrowThroughSwitchCall() { + auto *subgraph_kernel = reinterpret_cast(kernel_); + if (subgraph_kernel == nullptr) { + MS_LOG(INFO) << "kernel is not subgraph kernel, no partial call."; + return RET_OK; + } + + int ret = GetSwitchAndCallNode(subgraph_kernel); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GetSwitchAndCallCnode failed."; + return ret; + } + + AppendOutputTensors(); + + ret = CompileTrueBranchArrow(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CompileTrueBranchArrow failed."; + true_branch_output_op_arrows_.clear(); + return ret; + } + + ret = CompileFalseBranchArrow(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CompileFalseBranchArrow failed."; + false_branch_output_op_arrows_.clear(); + true_branch_output_op_arrows_.clear(); + return ret; + } + + subgraph_kernel->DropNode(call_node_); + subgraph_kernel->DropNode(switch_node_); + subgraph_kernel->DropNode(true_partial_node_); + subgraph_kernel->DropNode(false_partial_node_); + + return ret; +} + +int LiteSwitchOpActor::CompileArrow() { + int ret = CompileArrowThroughSwitchCall(); + if (ret != RET_OK) { + true_branch_output_op_arrows_.clear(); + false_branch_output_op_arrows_.clear(); + MS_LOG(ERROR) << "CompileArrowThroughSwitchCall failed."; + return ret; + } + if (!true_branch_output_op_arrows_.empty() && !false_branch_output_op_arrows_.empty()) { + MS_LOG(INFO) << "CompileArrowThroughSwitchCall done."; + return RET_OK; + } + ret = CompileArrowThroughOutputKernels(); + if (ret != RET_OK) { + output_op_arrows_.clear(); + MS_LOG(ERROR) << "CompileArrowThroughOutputKernels failed."; + return ret; + } + return ret; +} + +int LiteSwitchOpActor::PrepareOutputData() { + for (auto &arrow : true_branch_output_op_arrows_) { + auto data = std::make_shared>(arrow->to_op_id_, kernel_->out_tensors().at(arrow->from_output_index_), + static_cast(arrow->to_input_index_)); + true_branch_outputs_data_.emplace_back(data); + } + + for (auto &arrow : false_branch_output_op_arrows_) { + auto data = std::make_shared>(arrow->to_op_id_, kernel_->out_tensors().at(arrow->from_output_index_), + static_cast(arrow->to_input_index_)); + false_branch_outputs_data_.emplace_back(data); + } + return RET_OK; +} + +void LiteSwitchOpActor::AsyncTrueBranchOutput(OpContext *context) { + MS_ASSERT(true_branch_output_op_arrows_.size() == true_branch_outputs_data_.size()); + for (size_t i = 0; i < true_branch_output_op_arrows_.size(); ++i) { + auto &data = true_branch_outputs_data_.at(i); + Async(true_branch_output_op_arrows_[i]->to_op_id_, &mindspore::OpActor::RunOpData, data, context); + } +} + +void LiteSwitchOpActor::AsyncFalseBranchOutput(OpContext *context) { + MS_ASSERT(false_branch_output_op_arrows_.size() == false_branch_outputs_data_.size()); + for (size_t i = 0; i < false_branch_output_op_arrows_.size(); ++i) { + auto &data = false_branch_outputs_data_.at(i); + Async(false_branch_output_op_arrows_[i]->to_op_id_, &mindspore::OpActor::RunOpData, data, context); + } +} + +int MindrtInit() { return mindspore::Initialize("tcp://127.0.0.1:8080", "", "", "", 1); } + +void MindrtTerminate(const std::vector> &actor_list) { + for (const auto &actor : actor_list) { + mindspore::Terminate(actor->GetAID()); + } + mindspore::TerminateCurThreads(1); +} + } // namespace mindspore::lite diff --git a/mindspore/lite/src/lite_mindrt.h b/mindspore/lite/src/lite_mindrt.h index 422aaaeb95..6bc3c84f65 100644 --- a/mindspore/lite/src/lite_mindrt.h +++ b/mindspore/lite/src/lite_mindrt.h @@ -27,48 +27,81 @@ #include "async/future.h" #include "src/sub_graph_kernel.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { typedef enum { GRAPH, OP_BY_OP } MindRTMode; +const constexpr int kSwitchInputsSize = 3; +const constexpr int kSwitchCondInputIndex = 0; +const constexpr int kSwitchTruePartialInputIndex = 1; +const constexpr int kSwitchFalsePartialInputIndex = 2; class LiteOpActor : public OpActor { public: explicit LiteOpActor(kernel::LiteKernel *kernel) : OpActor(kernel->name()), kernel_(kernel) {} - virtual ~LiteOpActor() = default; - virtual void RunOpData(OpDataPtr inputs, OpContext *context = nullptr) { + ~LiteOpActor() override = default; + void RunOpData(OpDataPtr inputs, OpContext *context = nullptr) override { auto op_uuid = context->sequential_num_; input_op_datas_[op_uuid].push_back(inputs); + inputs_data_.push_back(inputs->data_); if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) { return; } + CpuBindMode cpu_bind_mode = kernel_->context()->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_; BindThreads(static_cast(kernel_->context())->thread_pool_, true, cpu_bind_mode); - auto ret = RunKernel(*(reinterpret_cast(context->kernel_call_back_before_)), - *(reinterpret_cast(context->kernel_call_back_after_))); + int ret = CheckInputData(); + if (ret != RET_OK) { + input_op_datas_.erase(op_uuid); + context->SetFailed(ret); + return; + } + + ret = SetInputData(); + if (ret != RET_OK) { + input_op_datas_.erase(op_uuid); + context->SetFailed(ret); + return; + } + for (auto &arrow : output_op_arrows_) { + kernel_->out_tensors().at(arrow->from_output_index_)->IncRefCount(); + } + + ret = RunKernel(*(reinterpret_cast(context->kernel_call_back_before_)), + *(reinterpret_cast(context->kernel_call_back_after_))); if (ret != RET_OK) { input_op_datas_.erase(op_uuid); context->SetFailed(ret); return; } input_op_datas_.erase(op_uuid); + inputs_data_.clear(); AsyncOutput(context); BindThreads(static_cast(kernel_->context())->thread_pool_, false, cpu_bind_mode); SetOutputData(context); + + for (auto &input_data : inputs_data_) { + input_data->DecRefCount(); + } } - void Init() { + int CastTensorData(Tensor *dst, Tensor *src); + void Init() override { auto ret = CompileArrow(); if (ret != RET_OK) { MS_LOG(ERROR) << "CompileArrow failed, name: " << kernel_->name(); // do not support return error } + + ret = PrepareOutputData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PrepareOutputData failed, name: " << kernel_->name(); + // do not support return error + } } - int CompileArrow(); + virtual int CompileArrow(); int RunKernel(const KernelCallBack &before, const KernelCallBack &after) { - int ret; - ret = kernel_->PreProcess(); + int ret = kernel_->PreProcess(); if (RET_OK != ret) { MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel_->name(); return ret; @@ -89,20 +122,124 @@ class LiteOpActor : public OpActor { public: void AddResultIndex(size_t index); + void SetPartialMap(const std::unordered_map &partial_map) { subgraph_index_to_actor = partial_map; } - private: + protected: + int CheckInputData(); + int SetInputData(); + void MoveInputData(Tensor *dst_tensor, Tensor *src_tensor); + void CopyInputData(Tensor *dst_tensor, Tensor *src_tensor); void SetOutputData(OpContext *context); void AsyncOutput(OpContext *context); + int CompileArrowThroughPartialCall(); + int CompileArrowThroughOutputKernels(); + virtual int PrepareOutputData(); kernel::LiteKernel *kernel_; - std::vector results_index_; + std::vector results_index_{}; + std::unordered_map subgraph_index_to_actor{}; + std::vector> outputs_data_{}; + std::vector inputs_data_{}; + + private: + kernel::LiteKernel *partial_node_ = nullptr; + kernel::LiteKernel *call_node_ = nullptr; +}; + +class LiteSwitchOpActor : public LiteOpActor { + public: + explicit LiteSwitchOpActor(kernel::LiteKernel *kernel) : LiteOpActor(kernel) {} + ~LiteSwitchOpActor() override = default; + void RunOpData(OpDataPtr inputs, OpContext *context = nullptr) override { + auto op_uuid = context->sequential_num_; + input_op_datas_[op_uuid].push_back(inputs); + inputs_data_.push_back(inputs->data_); + if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) { + return; + } + + int ret = CheckInputData(); + if (ret != RET_OK) { + input_op_datas_.erase(op_uuid); + context->SetFailed(ret); + return; + } + + ret = SetInputData(); + if (ret != RET_OK) { + input_op_datas_.erase(op_uuid); + context->SetFailed(ret); + return; + } + + ret = RunKernel(*(reinterpret_cast(context->kernel_call_back_before_)), + *(reinterpret_cast(context->kernel_call_back_after_))); + if (ret != RET_OK) { + input_op_datas_.erase(op_uuid); + context->SetFailed(ret); + return; + } + input_op_datas_.erase(op_uuid); + inputs_data_.clear(); + + bool *cond = reinterpret_cast(output_tensors_[0]->data()); + if (*cond) { + for (auto &arrow : true_branch_output_op_arrows_) { + kernel_->out_tensors().at(arrow->from_output_index_)->IncRefCount(); + } + AsyncTrueBranchOutput(context); + } else { + for (auto &arrow : false_branch_output_op_arrows_) { + kernel_->out_tensors().at(arrow->from_output_index_)->IncRefCount(); + } + AsyncFalseBranchOutput(context); + } + } + + void Init() override { + auto ret = CompileArrow(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CompileArrow failed, name: " << kernel_->name(); + // do not support return error + } + + ret = PrepareOutputData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PrepareOutputData failed, name: " << kernel_->name(); + // do not support return error + } + } + int CompileArrow() override; + + private: + void AsyncTrueBranchOutput(OpContext *context); + void AsyncFalseBranchOutput(OpContext *context); + + int GetSwitchAndCallNode(kernel::SubGraphKernel *subgraph_kernel); + void AppendOutputTensors(); + int CompileTrueBranchArrow(); + int CompileFalseBranchArrow(); + int CompileArrowThroughSwitchCall(); + int PrepareOutputData() override; + + std::vector true_branch_output_op_arrows_; + std::vector false_branch_output_op_arrows_; + + kernel::LiteKernel *bool_node_ = nullptr; + kernel::LiteKernel *true_partial_node_ = nullptr; + kernel::LiteKernel *false_partial_node_ = nullptr; + kernel::LiteKernel *switch_node_ = nullptr; + kernel::LiteKernel *call_node_ = nullptr; + std::vector output_tensors_{}; + + std::vector> true_branch_outputs_data_; + std::vector> false_branch_outputs_data_; }; int MindrtInit(); -void MindrtTerminate(std::vector>); +void MindrtTerminate(const std::vector> &); std::vector> CreateOpActor(const std::vector &kernels); -} // namespace lite -} // namespace mindspore +} // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_LITE_MINDRT_H_ diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 099092c023..ee66e4fd2e 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -357,6 +357,9 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { for (auto *tensor : this->inputs_) { tensor->set_category(Tensor::Category::GRAPH_INPUT); } + for (auto *tensor : this->outputs_) { + tensor->set_category(Tensor::Category::GRAPH_OUTPUT); + } } void LiteSession::FreePackOpWeight(const std::vector &kernels) { @@ -373,7 +376,7 @@ void LiteSession::FreePackOpWeight(const std::vector &kern auto inputs = kernel->in_tensors(); for (auto *tensor : inputs) { MS_ASSERT(tensor != nullptr); - if (!tensor->IsConst()) { + if (!tensor->IsConst() || tensor->init_ref_count() != 1) { continue; } tensor->FreeData(); @@ -455,7 +458,7 @@ int LiteSession::CompileGraph(Model *model) { return RET_ERROR; } - ret = executor_->Prepare(this->kernels_); + ret = executor_->Prepare(this->kernels_, this->inputs_, this->outputs_); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare executor failed: " << ret; is_running_.store(false); diff --git a/mindspore/lite/src/mindrt_executor.cc b/mindspore/lite/src/mindrt_executor.cc index 265c8c53d0..7ce04c2ba1 100644 --- a/mindspore/lite/src/mindrt_executor.cc +++ b/mindspore/lite/src/mindrt_executor.cc @@ -22,40 +22,62 @@ namespace mindspore::lite { -int MindrtExecutor::Prepare(const std::vector &kernels) { +void MindrtExecutor::PrepareInputData(const std::vector &kernels, + const std::vector &inputs) { + for (size_t i = 0; i < inputs.size(); ++i) { + for (size_t j = 0; j < kernels.size(); ++j) { + if (!kernels[j]->in_kernels().empty()) { + continue; + } + auto in_tensor_size = kernels[j]->in_tensors().size(); + for (size_t k = 0; k < in_tensor_size; ++k) { + if (inputs[i] != kernels[j]->in_tensors()[k]) { + continue; + } + auto data = std::make_shared>(op_actors_[j]->GetAID(), inputs[i], static_cast(k)); + input_data_.emplace_back(data); + } + } + } +} + +void MindrtExecutor::PrepareOutputData(const std::vector &kernels, + const std::vector &outputs) { + for (size_t i = 0; i < outputs.size(); ++i) { + for (size_t j = 0; j < kernels.size(); ++j) { + if (!kernels[i]->out_kernels().empty()) { + continue; + } + auto out_tensor_size = kernels[j]->out_tensors().size(); + for (size_t k = 0; k < out_tensor_size; ++k) { + if (outputs[i] != kernels[j]->out_tensors()[k]) { + continue; + } + auto data = std::make_shared>(op_actors_[j]->GetAID(), outputs[i], static_cast(k)); + op_actors_[j]->AddResultIndex(output_data_.size()); + output_data_.emplace_back(data); + } + } + } +} + +int MindrtExecutor::Prepare(const std::vector &kernels, const std::vector &inputs, + const std::vector &outputs) { auto ret = MindrtInit(); if (ret != RET_OK) { MS_LOG(ERROR) << "MindrtInit failed"; return ret; } - auto kernelSize = kernels.size(); - opActors_ = CreateOpActor(kernels); - if (opActors_.size() != kernelSize) { + op_actors_ = CreateOpActor(kernels); + if (op_actors_.size() != kernels.size()) { MS_LOG(ERROR) << "CreateOpActor failed"; return RET_ERROR; } - for (size_t i = 0; i < kernelSize; i++) { - if (kernels[i]->in_kernels().size() == 0) { - auto inTensorSize = kernels[i]->in_tensors().size(); - for (size_t j = 0; j < inTensorSize; j++) { - auto data = - std::make_shared>(opActors_[i]->GetAID(), kernels[i]->in_tensors()[j], static_cast(j)); - inputData_.emplace_back(data); - } - } + PrepareInputData(kernels, inputs); - if (kernels[i]->out_kernels().size() == 0) { - auto outTensorSize = kernels[i]->out_tensors().size(); + PrepareOutputData(kernels, outputs); - for (size_t j = 0; j < outTensorSize; j++) { - auto data = - std::make_shared>(opActors_[i]->GetAID(), kernels[i]->out_tensors()[j], static_cast(j)); - opActors_[i]->AddResultIndex(outputData_.size()); - outputData_.emplace_back(data); - } - } - } return RET_OK; } @@ -71,13 +93,11 @@ int MindrtExecutor::Run(const std::vector &in_tensors, const std::vect } } // clear ref_count - for (auto *kernel : kernels) { - for (auto *tensor : kernel->in_tensors()) { - tensor->set_ref_count(0); - } + for (auto *tensor : kernels.front()->in_tensors()) { + tensor->set_ref_count(0); } - return MindrtRun(inputData_, &outputData_, &before, &after); + return MindrtRun(input_data_, &output_data_, &before, &after); } } // namespace mindspore::lite diff --git a/mindspore/lite/src/mindrt_executor.h b/mindspore/lite/src/mindrt_executor.h index 004e69d385..2adff8de35 100644 --- a/mindspore/lite/src/mindrt_executor.h +++ b/mindspore/lite/src/mindrt_executor.h @@ -29,18 +29,21 @@ namespace mindspore::lite { class MindrtExecutor : public Executor { public: MindrtExecutor() = default; - virtual ~MindrtExecutor() { MindrtTerminate(opActors_); } + virtual ~MindrtExecutor() { MindrtTerminate(op_actors_); } - virtual int Prepare(const std::vector &kernels); + virtual int Prepare(const std::vector &kernels, const std::vector &inputs, + const std::vector &outputs); virtual int Run(const std::vector &in_tensors, const std::vector &out_tensors, const std::vector &kernels, mindspore::Allocator *allocator = nullptr, const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr); protected: - std::vector> opActors_; - std::vector> inputData_; - std::vector> outputData_; + void PrepareInputData(const std::vector &kernels, const std::vector &inputs); + void PrepareOutputData(const std::vector &kernels, const std::vector &outputs); + std::vector> op_actors_; + std::vector> input_data_; + std::vector> output_data_; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/agent/npu/npu_executor.cc b/mindspore/lite/src/runtime/agent/npu/npu_executor.cc index 1389eba21b..e36c35e818 100644 --- a/mindspore/lite/src/runtime/agent/npu/npu_executor.cc +++ b/mindspore/lite/src/runtime/agent/npu/npu_executor.cc @@ -32,7 +32,8 @@ NPUExecutor::~NPUExecutor() { npu_output_tensors_.clear(); } -int NPUExecutor::Prepare(const std::vector &kernels) { +int NPUExecutor::Prepare(const std::vector &kernels, const std::vector &inputs, + const std::vector &outputs) { MS_ASSERT(npu_manager_ != nullptr); this->client_ = npu_manager_->GetClient(model_name_); if (this->client_ == nullptr) { diff --git a/mindspore/lite/src/runtime/agent/npu/npu_executor.h b/mindspore/lite/src/runtime/agent/npu/npu_executor.h index dc65001ffc..46eb6c0c7b 100644 --- a/mindspore/lite/src/runtime/agent/npu/npu_executor.h +++ b/mindspore/lite/src/runtime/agent/npu/npu_executor.h @@ -33,7 +33,8 @@ class NPUExecutor : public Executor { explicit NPUExecutor(const std::string &model_name, NPUManager *npu_manager = nullptr) : model_name_(model_name), npu_manager_(npu_manager) {} ~NPUExecutor() override; - int Prepare(const std::vector &kernels) override; + int Prepare(const std::vector &kernels, const std::vector &inputs, + const std::vector &outputs) override; int Run(const std::vector &in_tensors, const std::vector &out_tensors, const std::vector &in_kernels, const std::vector &kernels, diff --git a/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc b/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc index 6e43f3fb66..5e57c4ad31 100644 --- a/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc +++ b/mindspore/lite/src/runtime/agent/npu/subgraph_npu_kernel.cc @@ -226,7 +226,7 @@ int SubGraphNpuKernel::Init() { } int SubGraphNpuKernel::Prepare() { - if (executor_->Prepare(nodes_) != RET_OK) { + if (executor_->Prepare(nodes_, in_tensors_, out_tensors_) != RET_OK) { MS_LOG(ERROR) << "NPU executor prepare failed."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h b/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h index b136ba4c09..6a27fb0ef2 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_executor.h @@ -31,7 +31,10 @@ class OpenCLExecutor : public Executor { ~OpenCLExecutor() override = default; - int Prepare(const std::vector &kernels) override { return RET_OK; } + int Prepare(const std::vector &kernels, const std::vector &inputs, + const std::vector &outputs) override { + return RET_OK; + } int Run(const std::vector &inputs, const std::vector &outputs, const std::vector &kernels, mindspore::Allocator *allocator = nullptr, diff --git a/mindspore/lite/src/runtime/parallel_executor.cc b/mindspore/lite/src/runtime/parallel_executor.cc index 302ce7c8a3..5b9f57552b 100644 --- a/mindspore/lite/src/runtime/parallel_executor.cc +++ b/mindspore/lite/src/runtime/parallel_executor.cc @@ -21,7 +21,8 @@ namespace mindspore::lite { ParallelExecutor::~ParallelExecutor() { DestroyThreadPool(thread_pool_); } -int ParallelExecutor::Prepare(const std::vector &kernels) { +int ParallelExecutor::Prepare(const std::vector &kernels, + const std::vector &inputs, const std::vector &outputs) { thread_pool_ = CreateLiteThreadPool(max_thread_num_, NO_BIND); if (thread_pool_ == nullptr) { MS_LOG(ERROR) << "Memory error: fail to new ThreadPool"; diff --git a/mindspore/lite/src/runtime/parallel_executor.h b/mindspore/lite/src/runtime/parallel_executor.h index 12b65ddfcd..6c679096b3 100644 --- a/mindspore/lite/src/runtime/parallel_executor.h +++ b/mindspore/lite/src/runtime/parallel_executor.h @@ -31,7 +31,8 @@ class ParallelExecutor : public Executor { ParallelExecutor() = default; ~ParallelExecutor() override; - int Prepare(const std::vector &kernels) override; + int Prepare(const std::vector &kernels, const std::vector &inputs, + const std::vector &outputs) override; int Run(const std::vector &in_tensors, const std::vector &out_tensors, const std::vector &kernels, mindspore::Allocator *allocator = nullptr,