| @@ -258,7 +258,8 @@ union PrimitiveType { | |||
| SmoothL1LossGrad, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SigmoidCrossEntropyWithLogitsGrad, | |||
| Reciprocal | |||
| Reciprocal, | |||
| Merge, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -1222,4 +1222,7 @@ table SigmoidCrossEntropyWithLogitsGrad { | |||
| } | |||
| table Reciprocal { | |||
| } | |||
| } | |||
| table Merge { | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/lite/src/executor.h" | |||
| #include "nnacl/pack.h" | |||
| #include "src/executor.h" | |||
| #include <queue> | |||
| #include "include/errorcode.h" | |||
| namespace mindspore::lite { | |||
| @@ -26,7 +26,7 @@ int Executor::CheckInputs(const std::vector<Tensor *> &in_tensors) { | |||
| return RET_ERROR; | |||
| } | |||
| if (inTensor->data_c() == nullptr) { | |||
| MS_LOG(ERROR) << "Graph input tensor data is nullptr"; | |||
| MS_LOG(ERROR) << "Graph input tensor data is nullptr " << in_tensors; | |||
| return RET_ERROR; | |||
| } | |||
| auto shape = inTensor->shape(); | |||
| @@ -49,7 +49,52 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_ | |||
| MS_LOG(ERROR) << "CheckInputs failed"; | |||
| return ret; | |||
| } | |||
| kernel::LiteKernelUtil::InitTensorRefCount(kernels); | |||
| std::queue<kernel::LiteKernel *> kernel_queue; | |||
| for (auto kernel : kernels) { | |||
| if (kernel->IsReady()) { | |||
| kernel_queue.push(kernel); | |||
| } | |||
| } | |||
| while (!kernel_queue.empty()) { | |||
| auto cur_kernel = kernel_queue.front(); | |||
| kernel_queue.pop(); | |||
| MS_ASSERT(nullptr != cur_kernel); | |||
| ret = cur_kernel->PreProcess(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "PreProcess kernel failed, name: " << cur_kernel->name(); | |||
| return ret; | |||
| } | |||
| ret = cur_kernel->Run(before, after); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "run kernel failed, name: " << cur_kernel->name(); | |||
| return ret; | |||
| } | |||
| ret = cur_kernel->PostProcess(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "PostProcess kernel failed, name: " << cur_kernel->name(); | |||
| return ret; | |||
| } | |||
| for (auto &out_kernel : cur_kernel->out_kernels()) { | |||
| if (out_kernel->IsReady()) { | |||
| kernel_queue.push(out_kernel); | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int CpuExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, | |||
| std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator, const KernelCallBack &before, | |||
| const KernelCallBack &after) { | |||
| MS_ASSERT(nullptr != allocator); | |||
| // not check input for merge. too hard | |||
| if (kernels.front()->Type() != schema::PrimitiveType_Merge) { | |||
| auto ret = this->CheckInputs(in_tensors); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "CheckInputs failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them | |||
| out_tensor->set_ref_count(out_tensor->ref_count() + 1); | |||
| @@ -57,7 +102,7 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_ | |||
| #endif | |||
| for (auto *kernel : kernels) { | |||
| MS_ASSERT(nullptr != kernel); | |||
| ret = kernel->PreProcess(); | |||
| auto ret = kernel->PreProcess(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| @@ -37,5 +37,16 @@ class Executor { | |||
| protected: | |||
| static int CheckInputs(const std::vector<Tensor *> &in_tensors); | |||
| }; | |||
| class CpuExecutor : public Executor { | |||
| public: | |||
| CpuExecutor() = default; | |||
| virtual ~CpuExecutor() = default; | |||
| int Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, | |||
| std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr, | |||
| const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif | |||
| @@ -62,7 +62,7 @@ InnerContext::~InnerContext() { | |||
| } | |||
| } | |||
| int InnerContext::IsValid() { | |||
| int InnerContext::IsValid() const { | |||
| if (this->device_list_.empty()) { | |||
| MS_LOG(ERROR) << "Device list is empty."; | |||
| return RET_NOT_SUPPORT; | |||
| @@ -86,33 +86,33 @@ int InnerContext::IsValid() { | |||
| return RET_OK; | |||
| } | |||
| bool InnerContext::IsCpuFloat16Enabled() { | |||
| bool InnerContext::IsCpuFloat16Enabled() const { | |||
| if (!IsCpuEnabled()) { | |||
| return false; | |||
| } | |||
| return GetCpuInfo().enable_float16_; | |||
| } | |||
| bool InnerContext::IsGpuFloat16Enabled() { | |||
| bool InnerContext::IsGpuFloat16Enabled() const { | |||
| if (!IsGpuEnabled()) { | |||
| return false; | |||
| } | |||
| return GetGpuInfo().enable_float16_; | |||
| } | |||
| bool InnerContext::IsCpuEnabled() { | |||
| bool InnerContext::IsCpuEnabled() const { | |||
| return this->device_list_.end() != | |||
| std::find_if(this->device_list_.begin(), this->device_list_.end(), | |||
| [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); | |||
| } | |||
| bool InnerContext::IsGpuEnabled() { | |||
| bool InnerContext::IsGpuEnabled() const { | |||
| return this->device_list_.end() != | |||
| std::find_if(this->device_list_.begin(), this->device_list_.end(), | |||
| [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); | |||
| } | |||
| bool InnerContext::IsNpuEnabled() { | |||
| bool InnerContext::IsNpuEnabled() const { | |||
| #ifdef SUPPORT_NPU | |||
| return this->device_list_.end() != | |||
| std::find_if(this->device_list_.begin(), this->device_list_.end(), | |||
| @@ -123,7 +123,7 @@ bool InnerContext::IsNpuEnabled() { | |||
| #endif | |||
| } | |||
| CpuDeviceInfo InnerContext::GetCpuInfo() { | |||
| CpuDeviceInfo InnerContext::GetCpuInfo() const { | |||
| auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), | |||
| [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); | |||
| if (iter == this->device_list_.end()) { | |||
| @@ -133,7 +133,7 @@ CpuDeviceInfo InnerContext::GetCpuInfo() { | |||
| } | |||
| } | |||
| GpuDeviceInfo InnerContext::GetGpuInfo() { | |||
| GpuDeviceInfo InnerContext::GetGpuInfo() const { | |||
| auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), | |||
| [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); | |||
| if (iter == this->device_list_.end()) { | |||
| @@ -33,23 +33,23 @@ struct InnerContext : public Context { | |||
| int Init(); | |||
| bool IsCpuFloat16Enabled(); | |||
| bool IsCpuFloat16Enabled() const; | |||
| bool IsGpuFloat16Enabled(); | |||
| bool IsGpuFloat16Enabled() const; | |||
| bool IsCpuEnabled(); | |||
| bool IsCpuEnabled() const; | |||
| bool IsGpuEnabled(); | |||
| bool IsGpuEnabled() const; | |||
| bool IsNpuEnabled(); | |||
| bool IsNpuEnabled() const; | |||
| CpuDeviceInfo GetCpuInfo(); | |||
| CpuDeviceInfo GetCpuInfo() const; | |||
| GpuDeviceInfo GetGpuInfo(); | |||
| GpuDeviceInfo GetGpuInfo() const; | |||
| NpuDeviceInfo GetNpuInfo() const; | |||
| int IsValid(); | |||
| int IsValid() const; | |||
| virtual ~InnerContext(); | |||
| }; | |||
| @@ -41,9 +41,21 @@ void LiteKernel::FreeWorkspace() { | |||
| workspace_ = nullptr; | |||
| } | |||
| void LiteKernel::InitOutTensorRefCount() { | |||
| bool LiteKernel::IsReady() { | |||
| return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { | |||
| return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; | |||
| }); | |||
| } | |||
| void LiteKernel::InitOutTensorInitRefCount() { | |||
| for (auto *tensor : this->out_tensors_) { | |||
| tensor->set_ref_count(this->out_kernels_.size()); | |||
| int init_ref_count = 0; | |||
| for (auto *post_kernel : this->out_kernels_) { | |||
| init_ref_count += | |||
| std::count_if(post_kernel->in_tensors_.begin(), post_kernel->in_tensors_.end(), | |||
| [&tensor](const lite::Tensor *post_kernel_in_tensor) { return post_kernel_in_tensor == tensor; }); | |||
| } | |||
| tensor->set_init_ref_count(init_ref_count); | |||
| } | |||
| } | |||
| @@ -61,15 +73,20 @@ int LiteKernel::DecOutTensorRefCount() { | |||
| return 0; | |||
| } | |||
| int LiteKernel::FreeWorkTensor() const { | |||
| for (auto input_kernel : this->in_kernels()) { | |||
| MS_ASSERT(input_kernel != nullptr); | |||
| if (input_kernel->is_model_output()) { | |||
| int LiteKernel::FreeInWorkTensor() const { | |||
| for (auto &in_tensor : this->in_tensors_) { | |||
| MS_ASSERT(in_tensor != nullptr); | |||
| if (in_tensor->IsConst()) { | |||
| continue; | |||
| } | |||
| auto ret = input_kernel->DecOutTensorRefCount(); | |||
| if (0 != ret) { | |||
| MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed"; | |||
| MS_ASSERT(in_tensor->ref_count() > 0); | |||
| in_tensor->set_ref_count(in_tensor->ref_count() - 1); | |||
| if (in_tensor->ref_count() <= 0) { | |||
| auto ret = in_tensor->FreeData(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Free tensor data failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -91,15 +108,12 @@ int LiteKernel::PreProcess() { | |||
| } | |||
| } | |||
| auto outputs = this->out_tensors(); | |||
| for (auto *output : outputs) { | |||
| for (auto *output : this->out_tensors()) { | |||
| MS_ASSERT(output != nullptr); | |||
| if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) { | |||
| MS_LOG(ERROR) << "The size of output tensor is too big"; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = output->MallocData(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "MallocData failed"; | |||
| @@ -109,6 +123,28 @@ int LiteKernel::PreProcess() { | |||
| return RET_OK; | |||
| } | |||
| int LiteKernel::PostProcess() { | |||
| #ifdef SUPPORT_TRAIN | |||
| for (auto input_kernel : this->in_kernels()) { | |||
| MS_ASSERT(input_kernel != nullptr); | |||
| if (input_kernel->is_model_output()) { | |||
| continue; | |||
| } | |||
| auto ret = input_kernel->DecOutTensorRefCount(); | |||
| if (0 != ret) { | |||
| MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed"; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| #else | |||
| for (auto *output : this->out_tensors()) { | |||
| MS_ASSERT(output != nullptr); | |||
| output->ResetRefCount(); | |||
| } | |||
| return FreeInWorkTensor(); | |||
| #endif | |||
| } | |||
| int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) { | |||
| if (before != nullptr) { | |||
| if (!before(TensorVectorCast(this->in_tensors_), TensorVectorCast(this->out_tensors_), | |||
| @@ -153,6 +189,28 @@ std::string LiteKernel::ToString() const { | |||
| return oss.str(); | |||
| } | |||
| void LiteKernel::FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels) { | |||
| // clean io kernels | |||
| this->in_kernels_.clear(); | |||
| this->out_kernels_.clear(); | |||
| // find io kernels | |||
| for (auto *scope_kernel : scope_kernels) { | |||
| if (scope_kernel == this) { | |||
| continue; | |||
| } | |||
| for (auto *tensor : this->in_tensors_) { | |||
| if (lite::IsContain(scope_kernel->out_tensors(), tensor)) { | |||
| this->AddInKernel(scope_kernel); | |||
| } | |||
| } | |||
| for (auto *tensor : this->out_tensors_) { | |||
| if (lite::IsContain(scope_kernel->in_tensors(), tensor)) { | |||
| this->AddOutKernel(scope_kernel); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels( | |||
| const std::vector<kernel::LiteKernel *> &kernels) { | |||
| std::vector<kernel::LiteKernel *> input_kernels; | |||
| @@ -202,7 +260,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||
| if (outer_in_kernels.empty()) { | |||
| for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { | |||
| if (!in_kernel_in_tensor->IsConst()) { | |||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||
| if (!IsContain(input_tensors, in_kernel_in_tensor)) { | |||
| input_tensors.push_back(in_kernel_in_tensor); | |||
| } | |||
| } | |||
| @@ -219,7 +277,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||
| auto outer_in_kernel_out_tensors_iter = | |||
| std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); | |||
| if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { | |||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||
| if (!IsContain(input_tensors, in_kernel_in_tensor)) { | |||
| input_tensors.emplace_back(in_kernel_in_tensor); | |||
| } | |||
| } | |||
| @@ -237,7 +295,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||
| auto &out_kernel_out_tensors = output_kernel->out_tensors(); | |||
| if (outer_out_kernels.empty()) { | |||
| for (auto out_kernel_out_tensor : out_kernel_out_tensors) { | |||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||
| if (!IsContain(output_tensors, out_kernel_out_tensor)) { | |||
| output_tensors.push_back(out_kernel_out_tensor); | |||
| } | |||
| } | |||
| @@ -253,7 +311,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||
| auto outer_out_kernel_in_tensors_iter = | |||
| std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); | |||
| if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { | |||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||
| if (!IsContain(output_tensors, out_kernel_out_tensor)) { | |||
| output_tensors.emplace_back(out_kernel_out_tensor); | |||
| } | |||
| } | |||
| @@ -299,33 +357,9 @@ int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *ke | |||
| return RET_OK; | |||
| } | |||
| void LiteKernelUtil::InitIOKernels(std::vector<kernel::LiteKernel *> &kernels) { | |||
| for (auto *kernel : kernels) { | |||
| // clean io kernels | |||
| kernel->set_in_kernels({}); | |||
| kernel->set_out_kernels({}); | |||
| // find io kernels | |||
| for (auto *search_kernel : kernels) { | |||
| if (search_kernel == kernel) { | |||
| continue; | |||
| } | |||
| for (auto *tensor : kernel->in_tensors()) { | |||
| if (lite::IsContain(search_kernel->out_tensors(), tensor)) { | |||
| kernel->AddInKernel(search_kernel); | |||
| } | |||
| } | |||
| for (auto *tensor : kernel->out_tensors()) { | |||
| if (lite::IsContain(search_kernel->in_tensors(), tensor)) { | |||
| kernel->AddOutKernel(search_kernel); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels) { | |||
| void LiteKernelUtil::InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels) { | |||
| for (auto *kernel : kernels) { | |||
| kernel->InitOutTensorRefCount(); | |||
| kernel->InitOutTensorInitRefCount(); | |||
| } | |||
| } | |||
| @@ -87,10 +87,12 @@ class LiteKernel { | |||
| virtual int Run(const KernelCallBack &before, const KernelCallBack &after); | |||
| // called after Run | |||
| virtual int PostProcess() { return FreeWorkTensor(); } | |||
| virtual int PostProcess(); | |||
| virtual int ReSize() { return mindspore::lite::RET_ERROR; } | |||
| virtual void FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels); | |||
| virtual int Init() { return mindspore::lite::RET_ERROR; } | |||
| std::string name() const { return this->name_; } | |||
| @@ -154,11 +156,13 @@ class LiteKernel { | |||
| const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; } | |||
| void InitOutTensorRefCount(); | |||
| virtual bool IsReady(); | |||
| virtual void InitOutTensorInitRefCount(); | |||
| int DecOutTensorRefCount(); | |||
| int FreeWorkTensor() const; | |||
| virtual int FreeInWorkTensor() const; | |||
| KernelKey desc() const { return desc_; } | |||
| @@ -203,8 +207,6 @@ typedef LiteKernel *(*KernelCreator)(const std::vector<lite::Tensor *> &inputs, | |||
| class LiteKernelUtil { | |||
| public: | |||
| static void InitIOKernels(std::vector<kernel::LiteKernel *> &kernels); | |||
| static std::vector<kernel::LiteKernel *> SubgraphInputKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||
| static std::vector<kernel::LiteKernel *> SubgraphOutputKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||
| @@ -215,7 +217,7 @@ class LiteKernelUtil { | |||
| static int TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels); | |||
| static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels); | |||
| static void InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels); | |||
| static int SetInput(LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs); | |||
| }; | |||
| @@ -295,6 +295,21 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) { | |||
| } | |||
| } | |||
| void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) { | |||
| MS_ASSERT(model != nullptr); | |||
| auto graph_out_size = model->sub_graphs_.front()->output_indices_.size(); | |||
| for (size_t i = 0; i < graph_out_size; ++i) { | |||
| size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i]; | |||
| MS_ASSERT(graph_out_index < this->tensors_.size()); | |||
| auto *out_tensor = this->tensors_.at(graph_out_index); | |||
| if (out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "out_tensor is null!"; | |||
| return; | |||
| } | |||
| out_tensor->set_init_ref_count(out_tensor->init_ref_count() + 1); | |||
| } | |||
| } | |||
| void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | |||
| InitGraphInputTensors(model); | |||
| InitGraphInputMSTensors(); | |||
| @@ -303,6 +318,7 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | |||
| InitGraphOutputNodeMap(model); | |||
| InitGraphOutputTensorNames(model); | |||
| InitGraphOutputTensorMap(model); | |||
| AdjustModelOutputTensorInitRefCount(model); | |||
| } | |||
| int LiteSession::CompileGraph(Model *model) { | |||
| @@ -334,12 +350,9 @@ int LiteSession::CompileGraph(Model *model) { | |||
| is_running_.store(false); | |||
| return ret; | |||
| } | |||
| InitGraphInOutTensors(model); | |||
| // scheduler kernels | |||
| Scheduler scheduler(context_); | |||
| ret = scheduler.Schedule(model, &tensors_, &kernels_); | |||
| Scheduler scheduler(context_, model, tensors_); | |||
| ret = scheduler.Schedule(&kernels_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Schedule kernels failed: " << ret; | |||
| is_running_.store(false); | |||
| @@ -353,6 +366,7 @@ int LiteSession::CompileGraph(Model *model) { | |||
| } | |||
| } | |||
| #endif | |||
| InitGraphInOutTensors(model); | |||
| ret = executor_->Prepare(this->kernels_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare executor failed: " << ret; | |||
| @@ -563,6 +577,32 @@ void LiteSession::ResetInputsShape(const std::vector<std::vector<int>> &dims) { | |||
| } | |||
| } | |||
| int LiteSession::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||
| bool infer_shape_interrupt = false; | |||
| for (auto kernel : kernels) { | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "input kernel is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (kernel->subgraph_type() == kernel::kNotSubGraph) { | |||
| MS_LOG(ERROR) << "All node in graph should be sub_graph"; | |||
| return RET_ERROR; | |||
| } | |||
| auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); | |||
| auto ret = sub_graph->ReSize(infer_shape_interrupt); | |||
| if (ret == RET_INFER_INVALID) { | |||
| MS_LOG(INFO) << "InferShape is interrupted"; | |||
| infer_shape_interrupt = true; | |||
| continue; | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs, | |||
| const std::vector<std::vector<int>> &dims) { | |||
| bool expected = false; | |||
| @@ -581,11 +621,10 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs | |||
| return ret; | |||
| } | |||
| Scheduler scheduler(context_); | |||
| ret = scheduler.ReSizeKernels(kernels_); | |||
| ret = ReSizeKernels(kernels_); | |||
| if (ret != RET_OK) { | |||
| ResetInputsShape(old_dims); | |||
| auto resize_ret = scheduler.ReSizeKernels(kernels_); | |||
| auto resize_ret = ReSizeKernels(kernels_); | |||
| if (resize_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret; | |||
| } | |||
| @@ -92,10 +92,14 @@ class LiteSession : public session::LiteSession { | |||
| void InitGraphOutputTensorMap(const lite::Model *model); | |||
| void AdjustModelOutputTensorInitRefCount(const lite::Model *model); | |||
| int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims); | |||
| int PrepareKernels(); | |||
| static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||
| private: | |||
| void ResetInputsShape(const std::vector<std::vector<int>> &dims); | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/merge.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Merge::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Merge; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Merge) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::MergeT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| PopulaterQuantParam(prim, inputs); | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Merge::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Merge(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Merge return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateMerge(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Merge, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Merge>(primitive); } | |||
| Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); | |||
| #endif | |||
| int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(outputs_.size() == 1); | |||
| MS_ASSERT(inputs_.size() == 2); | |||
| outputs_[0]->set_data_type(inputs_[0]->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Merge : public PrimitiveC { | |||
| public: | |||
| Merge() = default; | |||
| ~Merge() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Merge, PrimitiveC); | |||
| explicit Merge(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/partial.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Partial::GetSubGraphIndex() const { return this->primitive_->value.AsPartial()->subGraphIndex; } | |||
| int Partial::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Partial; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Partial) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::PartialT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Partial::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Partial(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Partial return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreatePartial(*fbb, attr->subGraphIndex()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Partial, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| int Partial::GetSubGraphIndex() const { return this->primitive_->value_as_Partial()->subGraphIndex(); } | |||
| PrimitiveC *PartialCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Partial>(primitive); } | |||
| Registry PartialRegistry(schema::PrimitiveType_Partial, PartialCreator); | |||
| #endif | |||
| int Partial::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Partial : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Partial, PrimitiveC); | |||
| Partial() = default; | |||
| explicit Partial(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Partial() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| int GetSubGraphIndex() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateMergeParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| OpParameter *merge_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||
| if (merge_parameter == nullptr) { | |||
| MS_LOG(ERROR) << "malloc Merge parameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(merge_parameter, 0, sizeof(OpParameter)); | |||
| merge_parameter->type_ = primitive->Type(); | |||
| return reinterpret_cast<OpParameter *>(merge_parameter); | |||
| } | |||
| Registry MergeParameterRegistry(schema::PrimitiveType_Merge, PopulateMergeParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/partial.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| typedef struct PartialParameter { | |||
| OpParameter op_parameter_; | |||
| int sub_graph_index_; | |||
| } PartialParameter; | |||
| OpParameter *PopulatePartialParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter))); | |||
| if (partial_parameter == nullptr) { | |||
| MS_LOG(ERROR) << "malloc partial parameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(partial_parameter, 0, sizeof(PartialParameter)); | |||
| partial_parameter->op_parameter_.type_ = primitive->Type(); | |||
| auto param = reinterpret_cast<mindspore::lite::Partial *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| partial_parameter->sub_graph_index_ = param->GetSubGraphIndex(); | |||
| return reinterpret_cast<OpParameter *>(partial_parameter); | |||
| } | |||
| Registry PartialParameterRegistry(schema::PrimitiveType_Partial, PopulatePartialParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/switch.h" | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpParameter *PopulateSwitchParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| OpParameter *switch_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||
| if (switch_parameter == nullptr) { | |||
| MS_LOG(ERROR) << "malloc SwitchParameter failed."; | |||
| return nullptr; | |||
| } | |||
| memset(switch_parameter, 0, sizeof(OpParameter)); | |||
| switch_parameter->type_ = primitive->Type(); | |||
| return reinterpret_cast<OpParameter *>(switch_parameter); | |||
| } | |||
| Registry SwitchParameterRegistry(schema::PrimitiveType_Switch, PopulateSwitchParameter); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -155,6 +155,9 @@ | |||
| #include "src/ops/tensorlistsetitem.h" | |||
| #include "src/ops/tensorlistreserve.h" | |||
| #include "src/ops/tensorliststack.h" | |||
| #include "src/ops/merge.h" | |||
| #include "src/ops/switch.h" | |||
| #include "src/ops/partial.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -925,7 +928,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) TensorListReserve(primitive); | |||
| case schema::PrimitiveType_TensorListStack: | |||
| return new (std::nothrow) TensorListStack(primitive); | |||
| case schema::PrimitiveType_Switch: | |||
| return new (std::nothrow) Switch(primitive); | |||
| case schema::PrimitiveType_Merge: | |||
| return new (std::nothrow) Merge(primitive); | |||
| case schema::PrimitiveType_Partial: | |||
| return new (std::nothrow) Partial(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/switch.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Switch::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Switch; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Switch) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::SwitchT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Switch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Switch(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Switch return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateSwitch(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Switch, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Switch>(primitive); } | |||
| Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator); | |||
| #endif | |||
| int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Switch : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Switch, PrimitiveC); | |||
| Switch() = default; | |||
| explicit Switch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Switch() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ | |||
| @@ -0,0 +1,84 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/base/merge.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Merge; | |||
| namespace mindspore::kernel { | |||
| // if one of input of merge is const-tensor, merge is always ready, this will cause error. | |||
| bool MergeCPUKernel::IsReady() { | |||
| MS_ASSERT(in_tensors().size() == 2); | |||
| return std::any_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { | |||
| return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; | |||
| }); | |||
| } | |||
| int MergeCPUKernel::Init() { return RET_OK; } | |||
| int MergeCPUKernel::ReSize() { return RET_ERROR; } | |||
| int MergeCPUKernel::Run() { | |||
| MS_ASSERT(in_tensors_.size() == 2); | |||
| MS_ASSERT(out_tensors_.size() == 1); | |||
| auto out_data = out_tensors_.front()->data_c(); | |||
| MS_ASSERT(out_data != nullptr); | |||
| for (size_t i = 0; i < in_tensors().size(); i++) { | |||
| if (in_tensors()[i]->data_c() != nullptr) { | |||
| auto in_data = in_tensors_[i]->data_c(); | |||
| MS_ASSERT(in_data != nullptr); | |||
| memcpy(out_data, in_data, in_tensors_[i]->Size()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuMergeKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||
| const lite::InnerContext *ctx, const KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "parameter is nullptr"; | |||
| return nullptr; | |||
| } | |||
| if (desc.type != PrimitiveType_Merge) { | |||
| MS_LOG(ERROR) << "type in desc is not Merge"; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "ctx is nullptr"; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) MergeCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, CpuMergeKernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, CpuMergeKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore::kernel { | |||
| typedef struct MergeParameter { | |||
| OpParameter op_parameter_; | |||
| } MergeParameter; | |||
| class MergeCPUKernel : public LiteKernel { | |||
| public: | |||
| MergeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| merge_param_ = reinterpret_cast<MergeParameter *>(op_parameter_); | |||
| } | |||
| ~MergeCPUKernel() override {} | |||
| bool IsReady() override; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| MergeParameter *merge_param_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/base/switch.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Switch; | |||
| namespace mindspore::kernel { | |||
| int SwitchCPUKernel::PostProcess() { | |||
| auto bool_tensor = in_tensors_.front(); | |||
| MS_ASSERT(bool_tensor != nullptr); | |||
| MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); | |||
| MS_ASSERT(bool_tensor->shape().size() == 1); | |||
| MS_ASSERT(bool_tensor->shape().front() == 1); | |||
| auto *active = static_cast<bool *>(bool_tensor->data_c()); | |||
| if (active == nullptr) { | |||
| MS_LOG(ERROR) << "data of bool tensor is nullptr"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| size_t in_index = 1; | |||
| size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2); | |||
| while (in_index < in_tensors_.size()) { | |||
| in_index++; | |||
| auto out_tensor = out_tensors_.at(out_index++); | |||
| out_tensor->ResetRefCount(); | |||
| } | |||
| return FreeInWorkTensor(); | |||
| } | |||
| int SwitchCPUKernel::Init() { return RET_OK; } | |||
| int SwitchCPUKernel::ReSize() { return RET_ERROR; } | |||
| // inputs: bool*1 data*n | |||
| // output: true-data*n, false-data*n | |||
| int SwitchCPUKernel::Run() { | |||
| MS_ASSERT(in_tensors_.size() >= 2); | |||
| MS_ASSERT(out_tensors_.size() == 2 * in_tensors_.size()); | |||
| auto bool_tensor = in_tensors_.front(); | |||
| MS_ASSERT(bool_tensor != nullptr); | |||
| MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); | |||
| MS_ASSERT(bool_tensor->shape().size() == 1); | |||
| MS_ASSERT(bool_tensor->shape().front() == 1); | |||
| auto active = static_cast<bool *>(bool_tensor->data_c()); | |||
| if (active == nullptr) { | |||
| MS_LOG(ERROR) << "data of bool tensor is nullptr"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| size_t in_index = 1; | |||
| size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2); | |||
| while (in_index < in_tensors_.size()) { | |||
| auto in_tensor = in_tensors_.at(in_index++); | |||
| auto out_tensor = out_tensors_.at(out_index++); | |||
| MS_ASSERT(in_tensor != nullptr); | |||
| MS_ASSERT(out_tensor != nullptr); | |||
| auto input = reinterpret_cast<float *>(in_tensor->data_c()); | |||
| auto output = reinterpret_cast<float *>(out_tensor->data_c()); | |||
| MS_ASSERT(in_tensor->Size() == out_tensor->Size()); | |||
| if (input == nullptr || output == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor or output tensor have not been malloced"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| memcpy(output, input, in_tensor->Size()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuSwitchKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||
| const lite::InnerContext *ctx, const KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "parameter is nullptr"; | |||
| return nullptr; | |||
| } | |||
| if (desc.type != PrimitiveType_Switch) { | |||
| MS_LOG(ERROR) << "type in desc is not Switch"; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "ctx is nullptr"; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) SwitchCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; | |||
| free(parameter); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, CpuSwitchKernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, CpuSwitchKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore::kernel { | |||
| typedef struct SwitchParameter { | |||
| OpParameter op_parameter_; | |||
| } SwitchParameter; | |||
| class SwitchCPUKernel : public LiteKernel { | |||
| public: | |||
| SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| switch_param_ = reinterpret_cast<SwitchParameter *>(op_parameter_); | |||
| } | |||
| ~SwitchCPUKernel() override = default; | |||
| int PostProcess() override; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| SwitchParameter *switch_param_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ | |||
| @@ -71,6 +71,14 @@ int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| return RET_OK; | |||
| } | |||
| int OpenCLKernel::PostProcess() { | |||
| for (auto *output : this->out_tensors()) { | |||
| MS_ASSERT(output != nullptr); | |||
| output->ResetRefCount(); | |||
| } | |||
| return FreeInWorkTensor(); | |||
| } | |||
| std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() { | |||
| size_t ndim = global_size_.size(); | |||
| std::vector<BaseTuningParameter> tuning_params = {}; | |||
| @@ -164,6 +164,7 @@ class OpenCLKernel : public LiteKernel { | |||
| int Prepare() override { return RET_OK; } | |||
| int PreProcess() override { return RET_ERROR; } | |||
| int PostProcess() override; | |||
| int ReSize() override { return RET_ERROR; } | |||
| int Run() override { return RET_ERROR; } | |||
| @@ -36,7 +36,6 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||
| if (is_tune) { | |||
| opencl_runtime_ins->SetProfiling(true); | |||
| } | |||
| kernel::LiteKernelUtil::InitTensorRefCount(kernels); | |||
| for (auto *kernel : kernels) { | |||
| MS_ASSERT(kernel); | |||
| CallBackParam callbackParam; | |||
| @@ -82,6 +81,11 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||
| MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| ret = kernel->PostProcess(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| if (profiling_tmp) { | |||
| MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() | |||
| << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; | |||
| @@ -92,13 +96,6 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||
| MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name(); | |||
| } | |||
| } | |||
| for (auto input_kernel : kernel->in_kernels()) { | |||
| MS_ASSERT(input_kernel); | |||
| ret = input_kernel->DecOutTensorRefCount(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; | |||
| } | |||
| } | |||
| } | |||
| opencl_runtime_ins->SetProfiling(profiling_tmp); | |||
| return ret; | |||
| @@ -40,9 +40,9 @@ static int RunKernel(void *data, int index) { | |||
| return 0; | |||
| } | |||
| ret = kernel->FreeWorkTensor(); | |||
| ret = kernel->FreeInWorkTensor(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << kernel->name(); | |||
| MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| return 0; | |||
| @@ -62,7 +62,7 @@ int ParallelExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| kernel::LiteKernelUtil::InitTensorRefCount(kernels); | |||
| kernel::LiteKernelUtil::InitTensorInitRefCount(kernels); | |||
| for (auto kernel : kernels) { | |||
| if (kernel->in_kernels().empty()) { | |||
| @@ -96,9 +96,9 @@ int ParallelExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor | |||
| } | |||
| } | |||
| auto ret = completed->FreeWorkTensor(); | |||
| auto ret = completed->FreeInWorkTensor(); | |||
| if (RET_OK != ret) { | |||
| MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << completed->name(); | |||
| MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << completed->name(); | |||
| return ret; | |||
| } | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <queue> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "src/ops/partial.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/graph_util.h" | |||
| #include "src/common/utils.h" | |||
| @@ -36,152 +37,255 @@ namespace mindspore::lite { | |||
| using kernel::KERNEL_ARCH::kCPU; | |||
| using kernel::KERNEL_ARCH::kGPU; | |||
| using kernel::KERNEL_ARCH::kNPU; | |||
| constexpr int kMainSubGraphIndex = 0; | |||
| int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors, | |||
| std::vector<kernel::LiteKernel *> *kernels) { | |||
| int ret = InferShape(model, tensors); | |||
| int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||
| if (src_model_ == nullptr) { | |||
| MS_LOG(ERROR) << "Input model is nullptr"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (src_model_->sub_graphs_.empty()) { | |||
| MS_LOG(ERROR) << "Model should have a subgraph at least"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_); | |||
| bool infer_shape_interrupt = false; | |||
| auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "op infer shape failed."; | |||
| return ret; | |||
| } | |||
| ret = BuildKernels(model, tensors, kernels); | |||
| ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "init op to kernel failed."; | |||
| MS_LOG(ERROR) << "Schedule main subgraph to kernels failed."; | |||
| return ret; | |||
| } | |||
| kernel::LiteKernelUtil::InitIOKernels(*kernels); | |||
| ret = ConstructSubGraphs(kernels); | |||
| FindAllInoutKernels(*dst_kernels); | |||
| ret = ConstructSubGraphs(dst_kernels); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConstructSubGraphs failed."; | |||
| return ret; | |||
| } | |||
| kernel::LiteKernelUtil::InitIOKernels(*kernels); | |||
| FindAllInoutKernels(*dst_kernels); | |||
| kernel::LiteKernelUtil::InitTensorInitRefCount(*dst_kernels); | |||
| MS_LOG(DEBUG) << "schedule kernels success."; | |||
| return RET_OK; | |||
| } | |||
| int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||
| bool infer_shape_interrupt = false; | |||
| for (auto kernel : kernels) { | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "input kernel is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (kernel->subgraph_type() == kernel::kNotSubGraph) { | |||
| MS_LOG(ERROR) << "All node in graph should be sub_graph"; | |||
| return RET_ERROR; | |||
| } | |||
| auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); | |||
| auto ret = sub_graph->ReSize(infer_shape_interrupt); | |||
| if (ret == RET_INFER_INVALID) { | |||
| MS_LOG(INFO) << "InferShape is interrupted"; | |||
| infer_shape_interrupt = true; | |||
| continue; | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed"; | |||
| return RET_ERROR; | |||
| void Scheduler::FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs, | |||
| std::vector<Tensor *> *outputs) { | |||
| MS_ASSERT(inputs != nullptr); | |||
| MS_ASSERT(outputs != nullptr); | |||
| auto in_size = node.input_indices_.size(); | |||
| inputs->reserve(in_size); | |||
| for (size_t j = 0; j < in_size; ++j) { | |||
| inputs->emplace_back(src_tensors_.at(node.input_indices_[j])); | |||
| } | |||
| auto out_size = node.output_indices_.size(); | |||
| outputs->reserve(out_size); | |||
| for (size_t j = 0; j < out_size; ++j) { | |||
| outputs->emplace_back(src_tensors_.at(node.output_indices_[j])); | |||
| } | |||
| } | |||
| int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(infer_shape_interrupt != nullptr); | |||
| auto primitive = node->primitive_; | |||
| MS_ASSERT(primitive != nullptr); | |||
| if (primitive->Type() == schema::PrimitiveType_Partial) { | |||
| return InferPartialShape(node, infer_shape_interrupt); | |||
| } | |||
| std::vector<Tensor *> inputs; | |||
| std::vector<Tensor *> outputs; | |||
| FindNodeInoutTensors(*node, &inputs, &outputs); | |||
| bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { | |||
| auto shape = tensor->shape(); | |||
| return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); | |||
| }); | |||
| if (!infer_valid) { | |||
| *infer_shape_interrupt = true; | |||
| } | |||
| primitive->set_infer_flag(!(*infer_shape_interrupt)); | |||
| auto ret = primitive->InferShape(inputs, outputs); | |||
| if (ret == RET_OK) { | |||
| for (auto &output : outputs) { | |||
| if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) { | |||
| MS_LOG(ERROR) << "The size of output tensor is too big"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| return ret; | |||
| } | |||
| int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tensors) { | |||
| MS_ASSERT(model != nullptr); | |||
| MS_ASSERT(tensors != nullptr); | |||
| bool infer_shape_interrupt = false; | |||
| uint32_t kernelCount = model->all_nodes_.size(); | |||
| for (uint32_t i = 0; i < kernelCount; ++i) { | |||
| auto node = model->all_nodes_[i]; | |||
| int Scheduler::InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt) { | |||
| MS_ASSERT(src_model_ != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(infer_shape_interrupt != nullptr); | |||
| auto primitive = node->primitive_; | |||
| MS_ASSERT(primitive != nullptr); | |||
| if (primitive->Type() != schema::PrimitiveType_Partial) { | |||
| MS_LOG(ERROR) << "Node is not a partial"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto partial_primitive = reinterpret_cast<lite::Partial *>(node->primitive_); | |||
| return InferSubGraphShape(partial_primitive->GetSubGraphIndex(), infer_shape_interrupt); | |||
| } | |||
| int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt) { | |||
| MS_ASSERT(infer_shape_interrupt != nullptr); | |||
| MS_ASSERT(src_model_ != nullptr); | |||
| MS_ASSERT(!src_model_->sub_graphs_.empty()); | |||
| MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index); | |||
| auto subgraph = src_model_->sub_graphs_.at(subgraph_index); | |||
| for (auto node_index : subgraph->node_indices_) { | |||
| auto node = src_model_->all_nodes_[node_index]; | |||
| MS_ASSERT(node != nullptr); | |||
| std::vector<Tensor *> inputs; | |||
| std::vector<Tensor *> outputs; | |||
| auto in_size = node->input_indices_.size(); | |||
| inputs.reserve(in_size); | |||
| for (size_t j = 0; j < in_size; ++j) { | |||
| inputs.emplace_back(tensors->at(node->input_indices_[j])); | |||
| } | |||
| auto out_size = node->output_indices_.size(); | |||
| outputs.reserve(out_size); | |||
| for (size_t j = 0; j < out_size; ++j) { | |||
| outputs.emplace_back(tensors->at(node->output_indices_[j])); | |||
| } | |||
| auto *primitive = node->primitive_; | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!"; | |||
| return RET_ERROR; | |||
| } | |||
| bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { | |||
| auto shape = tensor->shape(); | |||
| return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); | |||
| }); | |||
| if (!infer_valid) { | |||
| infer_shape_interrupt = true; | |||
| } | |||
| primitive->set_infer_flag(!infer_shape_interrupt); | |||
| auto ret = primitive->InferShape(inputs, outputs); | |||
| auto ret = InferNodeShape(node, infer_shape_interrupt); | |||
| if (ret == RET_INFER_INVALID) { | |||
| MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_ | |||
| MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())) | |||
| << "flag set to false."; | |||
| << ", set infer flag to false."; | |||
| primitive->set_infer_flag(false); | |||
| infer_shape_interrupt = true; | |||
| *infer_shape_interrupt = true; | |||
| } else if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "InferShape failed, name: " << node->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | |||
| return RET_INFER_ERR; | |||
| } else { | |||
| for (auto &output : outputs) { | |||
| if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) { | |||
| MS_LOG(ERROR) << "The size of output tensor is too big"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int Scheduler::BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors, | |||
| std::vector<kernel::LiteKernel *> *kernels) { | |||
| MS_ASSERT(model != nullptr); | |||
| MS_ASSERT(tensors != nullptr); | |||
| uint32_t kernelCount = model->all_nodes_.size(); | |||
| auto graph_output_node_indexes = GetGraphOutputNodes(model); | |||
| for (uint32_t i = 0; i < kernelCount; ++i) { | |||
| auto node = model->all_nodes_[i]; | |||
| MS_ASSERT(node != nullptr); | |||
| std::vector<Tensor *> inputs; | |||
| std::vector<Tensor *> outputs; | |||
| auto in_size = node->input_indices_.size(); | |||
| inputs.reserve(in_size); | |||
| for (size_t j = 0; j < in_size; ++j) { | |||
| inputs.emplace_back(tensors->at(node->input_indices_[j])); | |||
| kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors, | |||
| const std::vector<Tensor *> &out_tensors, | |||
| const mindspore::lite::PrimitiveC *primitive, | |||
| const Model::Node *node) { | |||
| MS_ASSERT(primitive != nullptr); | |||
| TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | |||
| kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | |||
| #if SUPPORT_GPU | |||
| if (context_->IsGpuEnabled()) { | |||
| kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; | |||
| return kernel; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " | |||
| << node->name_; | |||
| } | |||
| } | |||
| #endif | |||
| #if SUPPORT_NPU | |||
| if (context_->IsNpuEnabled()) { | |||
| kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_; | |||
| return kernel; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " | |||
| << node->name_; | |||
| } | |||
| auto out_size = node->output_indices_.size(); | |||
| outputs.reserve(out_size); | |||
| for (size_t j = 0; j < out_size; ++j) { | |||
| outputs.emplace_back(tensors->at(node->output_indices_[j])); | |||
| } | |||
| #endif | |||
| if (mindspore::lite::IsSupportFloat16() && | |||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | |||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | |||
| auto *kernel = | |||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " | |||
| << node->name_; | |||
| return kernel; | |||
| } | |||
| } | |||
| if (data_type == kNumberTypeFloat16) { | |||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | |||
| desc.data_type = kNumberTypeFloat32; | |||
| } | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| if (kernel != nullptr) { | |||
| return kernel; | |||
| } | |||
| return nullptr; | |||
| } | |||
| kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *src_node) { | |||
| MS_ASSERT(src_model_ != nullptr); | |||
| MS_ASSERT(src_node != nullptr); | |||
| auto *primitive = src_node->primitive_; | |||
| MS_ASSERT(primitive != nullptr); | |||
| if (primitive->Type() != schema::PrimitiveType_Partial) { | |||
| return nullptr; | |||
| } | |||
| auto partial_primitive = reinterpret_cast<lite::Partial *>(primitive); | |||
| auto sub_graph_index = partial_primitive->GetSubGraphIndex(); | |||
| std::vector<kernel::LiteKernel *> sub_kernels; | |||
| auto ret = ScheduleSubGraphToKernels(sub_graph_index, &sub_kernels); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_; | |||
| return nullptr; | |||
| } | |||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front()); | |||
| // for kernel::LiteKernelUtil::SubgraphInputTensors in CreateSubGraphKernel | |||
| FindAllInoutKernels(sub_kernels); | |||
| auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type); | |||
| subgraph->set_name("subgraph_" + src_node->name_); | |||
| return subgraph; | |||
| } | |||
| kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src_node) { | |||
| auto *primitive = src_node->primitive_; | |||
| MS_ASSERT(primitive != nullptr); | |||
| std::vector<Tensor *> inputs; | |||
| std::vector<Tensor *> outputs; | |||
| FindNodeInoutTensors(*src_node, &inputs, &outputs); | |||
| auto *kernel = this->FindBackendKernel(inputs, outputs, primitive, src_node); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_ | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | |||
| return nullptr; | |||
| } | |||
| SetKernelTensorDataType(kernel); | |||
| kernel->set_name(src_node->name_); | |||
| return kernel; | |||
| } | |||
| int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels) { | |||
| MS_ASSERT(src_model_ != nullptr); | |||
| MS_ASSERT(!src_model_->sub_graphs_.empty()); | |||
| MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index); | |||
| MS_ASSERT(dst_kernels != nullptr); | |||
| MS_ASSERT(dst_kernels->empty()); | |||
| auto subgraph = src_model_->sub_graphs_.at(subgraph_index); | |||
| for (auto node_index : subgraph->node_indices_) { | |||
| auto node = src_model_->all_nodes_[node_index]; | |||
| MS_ASSERT(node != nullptr); | |||
| auto *primitive = node->primitive_; | |||
| MS_ASSERT(primitive != nullptr); | |||
| auto *kernel = this->ScheduleNode(inputs, outputs, primitive, node); | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (primitive->Type() == schema::PrimitiveType_Partial) { // sub_graph | |||
| kernel = SchedulePartialToKernel(node); | |||
| } else { // kernel | |||
| kernel = ScheduleNodeToKernel(node); | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << node->name_ << ", type: " | |||
| MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << node->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | |||
| return RET_ERROR; | |||
| } | |||
| SetKernelTensorDataType(kernel); | |||
| kernel->set_name(node->name_); | |||
| kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i))); | |||
| kernels->emplace_back(kernel); | |||
| kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index))); | |||
| dst_kernels->emplace_back(kernel); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -190,6 +294,11 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||
| MS_ASSERT(head_kernel != nullptr); | |||
| MS_ASSERT(sinked_kernel_map != nullptr); | |||
| std::vector<kernel::LiteKernel *> sub_kernels; | |||
| if (head_kernel->Type() == schema::PrimitiveType_Switch || head_kernel->Type() == schema::PrimitiveType_Merge) { | |||
| (*sinked_kernel_map)[head_kernel] = true; | |||
| sub_kernels.emplace_back(head_kernel); | |||
| return sub_kernels; | |||
| } | |||
| std::queue<kernel::LiteKernel *> kernel_queue; | |||
| kernel_queue.emplace(head_kernel); | |||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | |||
| @@ -200,6 +309,10 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||
| sub_kernels.emplace_back(cur_kernel); | |||
| auto post_kernels = cur_kernel->out_kernels(); | |||
| for (auto post_kernel : post_kernels) { | |||
| if (post_kernel->subgraph_type() != kernel::kNotSubGraph || post_kernel->Type() == schema::PrimitiveType_Merge || | |||
| post_kernel->Type() == schema::PrimitiveType_Switch) { | |||
| continue; | |||
| } | |||
| if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) { | |||
| auto post_kernel_inputs = post_kernel->in_kernels(); | |||
| if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(), | |||
| @@ -215,28 +328,41 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||
| int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { | |||
| auto old_kernels = *kernels; | |||
| kernels->clear(); | |||
| std::map<const kernel::LiteKernel *, bool> is_kernel_sinked; | |||
| std::map<const kernel::LiteKernel *, bool> is_kernel_finish; | |||
| for (auto kernel : old_kernels) { | |||
| is_kernel_sinked[kernel] = false; | |||
| is_kernel_finish[kernel] = false; | |||
| } | |||
| while (true) { | |||
| auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) { | |||
| auto kernel_inputs = kernel->in_kernels(); | |||
| return !is_kernel_sinked[kernel] && | |||
| std::all_of(kernel_inputs.begin(), kernel_inputs.end(), | |||
| [&](kernel::LiteKernel *kernel) { return is_kernel_sinked[kernel]; }); | |||
| if (is_kernel_finish[kernel]) { | |||
| return false; | |||
| } | |||
| // when merge is removed, this if is removed automatically | |||
| if (kernel->Type() == schema::PrimitiveType_Merge) { | |||
| MS_ASSERT(kernel->in_kernels().size() == 2); | |||
| return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]); | |||
| } else { | |||
| return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), | |||
| [&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; }); | |||
| } | |||
| }); | |||
| if (head_kernel_iter == old_kernels.end()) { | |||
| break; | |||
| } | |||
| auto head_kernel = *head_kernel_iter; | |||
| if (head_kernel->subgraph_type() != kernel::kNotSubGraph) { | |||
| is_kernel_finish[head_kernel] = true; | |||
| kernels->emplace_back(head_kernel); | |||
| continue; | |||
| } | |||
| if (head_kernel->desc().arch == mindspore::kernel::kAPU) { | |||
| MS_LOG(ERROR) << "Not support APU now"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | |||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_sinked); | |||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_finish); | |||
| auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type); | |||
| if (subgraph == nullptr) { | |||
| MS_LOG(ERROR) << "Create SubGraphKernel failed"; | |||
| @@ -296,60 +422,6 @@ kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel | |||
| return nullptr; | |||
| } | |||
| kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tensors, | |||
| const std::vector<Tensor *> &out_tensors, | |||
| const mindspore::lite::PrimitiveC *primitive, const Model::Node *node) { | |||
| MS_ASSERT(primitive != nullptr); | |||
| TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | |||
| kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | |||
| #if SUPPORT_NPU | |||
| if (context_->IsNpuEnabled()) { | |||
| kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_; | |||
| return kernel; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " | |||
| << node->name_; | |||
| } | |||
| } | |||
| #endif | |||
| #if SUPPORT_GPU | |||
| if (context_->IsGpuEnabled()) { | |||
| kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; | |||
| return kernel; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " | |||
| << node->name_; | |||
| } | |||
| } | |||
| #endif | |||
| if (mindspore::lite::IsSupportFloat16() && | |||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | |||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | |||
| auto *kernel = | |||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " | |||
| << node->name_; | |||
| return kernel; | |||
| } | |||
| } | |||
| if (data_type == kNumberTypeFloat16) { | |||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | |||
| desc.data_type = kNumberTypeFloat32; | |||
| } | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| if (kernel != nullptr) { | |||
| return kernel; | |||
| } | |||
| return nullptr; | |||
| } | |||
| TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) { | |||
| for (const auto &tensor : in_tensors) { | |||
| auto dtype = tensor->data_type(); | |||
| @@ -411,4 +483,11 @@ kernel::SubGraphType Scheduler::GetKernelSubGraphType(const kernel::LiteKernel * | |||
| } | |||
| return kernel::kNotSubGraph; | |||
| } | |||
| void Scheduler::FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||
| for (auto *kernel : kernels) { | |||
| MS_ASSERT(kernel != nullptr); | |||
| kernel->FindInoutKernels(kernels); | |||
| } | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_SCHEDULER_H_ | |||
| #define MINDSPORE_LITE_SRC_SCHEDULER_H_ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "src/sub_graph_kernel.h" | |||
| @@ -27,30 +28,47 @@ | |||
| namespace mindspore::lite { | |||
| class Scheduler { | |||
| public: | |||
| explicit Scheduler(const InnerContext *ctx) { context_ = const_cast<InnerContext *>(ctx); } | |||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> src_tensors) | |||
| : context_(ctx), src_model_(src_model), src_tensors_(std::move(src_tensors)) {} | |||
| ~Scheduler() = default; | |||
| int Schedule(const lite::Model *model, std::vector<Tensor *> *tensors, std::vector<kernel::LiteKernel *> *kernels); | |||
| static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||
| protected: | |||
| kernel::LiteKernel *ScheduleNode(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||
| const mindspore::lite::PrimitiveC *primitive, const Model::Node *cnode); | |||
| int BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors, | |||
| std::vector<kernel::LiteKernel *> *kernels); | |||
| static int InferShape(const lite::Model *model, std::vector<Tensor *> *tensors); | |||
| int Schedule(std::vector<kernel::LiteKernel *> *dst_kernels); | |||
| private: | |||
| void FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs, | |||
| std::vector<Tensor *> *outputs); | |||
| // infer shape for a partial node | |||
| int InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt); | |||
| // infer shape for a node | |||
| int InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt); | |||
| // infer shape for a subgraph | |||
| int InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt); | |||
| // schedule a node to kernel according to context and kernels registered | |||
| kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors, | |||
| const std::vector<Tensor *> &out_tensors, | |||
| const mindspore::lite::PrimitiveC *primitive, const Model::Node *node); | |||
| // schedule a partial node to a subgraph_kernel | |||
| kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); | |||
| // schedule a node to a kernel | |||
| kernel::LiteKernel *ScheduleNodeToKernel(const lite::Model::Node *src_node); | |||
| // schedule a Model::SubGraph into a vector of kernel and subgraph_kernel | |||
| int ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels); | |||
| // find in_kernels_ and out_kernels of kernel, sub_graph and nodes_ in sub_graph | |||
| static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||
| // vector<LiteKernel/SubGraphKernel> --> vector<SubGraphKernel> | |||
| int ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels); | |||
| // create subgraph_kernel from a vector of kernel | |||
| kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels, | |||
| kernel::SubGraphType type); | |||
| std::vector<kernel::LiteKernel *> FindAllSubGraphKernels( | |||
| kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map); | |||
| // other methods | |||
| static TypeId GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors); | |||
| static void SetKernelTensorDataType(kernel::LiteKernel *kernel); | |||
| @@ -58,7 +76,10 @@ class Scheduler { | |||
| static kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel); | |||
| protected: | |||
| InnerContext *context_ = nullptr; | |||
| const InnerContext *context_ = nullptr; | |||
| Model *src_model_ = nullptr; | |||
| std::vector<Tensor *> src_tensors_; | |||
| std::vector<size_t> graph_output_node_indexes_; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -149,6 +149,12 @@ int SubGraphKernel::ReSize(bool is_interrupt) { | |||
| return RET_OK; | |||
| } | |||
| void SubGraphKernel::InitOutTensorInitRefCount() { | |||
| for (auto *node : nodes_) { | |||
| node->InitOutTensorInitRefCount(); | |||
| } | |||
| } | |||
| int CpuSubGraph::Prepare() { | |||
| auto ret = SubGraphKernel::Prepare(); | |||
| if (ret != RET_OK) { | |||
| @@ -84,6 +84,8 @@ class SubGraphKernel : public LiteKernel { | |||
| int ReSize(bool is_interrupt); | |||
| void InitOutTensorInitRefCount() override; | |||
| std::string ToString() const override; | |||
| std::vector<LiteKernel *> nodes() { return this->nodes_; } | |||
| @@ -104,11 +106,10 @@ class CpuSubGraph : public SubGraphKernel { | |||
| const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) | |||
| : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { | |||
| subgraph_type_ = kCpuFP32SubGraph; | |||
| this->executor_ = new (std::nothrow) mindspore::lite::Executor; | |||
| this->executor_ = new (std::nothrow) mindspore::lite::CpuExecutor; | |||
| } | |||
| ~CpuSubGraph() override { delete this->executor_; } | |||
| int Prepare() override; | |||
| int Init() override { return SubGraphKernel::Init(); } | |||
| int PreProcess() override { return SubGraphKernel::PreProcess(); } | |||
| @@ -110,8 +110,14 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| size_t ref_count() const { return this->ref_count_; } | |||
| size_t init_ref_count() const { return this->init_ref_count_; } | |||
| void set_ref_count(size_t ref_count) { this->ref_count_ = ref_count; } | |||
| void set_init_ref_count(size_t ref_count) { this->init_ref_count_ = ref_count; } | |||
| void ResetRefCount() { this->ref_count_ = this->init_ref_count_; } | |||
| void DecRefCount() { this->ref_count_--; } | |||
| std::string ToString() const; | |||
| @@ -156,6 +162,8 @@ class Tensor : public mindspore::tensor::MSTensor { | |||
| schema::Format format_; | |||
| Category category_; | |||
| size_t ref_count_ = 0; | |||
| size_t init_ref_count_ = 0; | |||
| size_t ready_count_ = 0; | |||
| std::vector<QuantArg> quant_params_; | |||
| std::vector<float> quant_clusters_; | |||
| mindspore::lite::Allocator *allocator_ = nullptr; | |||
| @@ -128,7 +128,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| auto run_kernel = (train_mode_) ? train_kernels_ : inference_kernels_; | |||
| lite::Executor executor; | |||
| lite::CpuExecutor executor; | |||
| if (before == nullptr && after == nullptr) { | |||
| return executor.Run(this->inputs_, this->outputs_, run_kernel, this->context_->allocator.get()); | |||
| } else { | |||
| @@ -261,6 +261,8 @@ if (ENABLE_CONVERTER) | |||
| set(TEST_SRC | |||
| ${TEST_SRC} | |||
| ${TEST_DIR}/st/converter_test.cc | |||
| ${TEST_DIR}/st/control_flow_test.cc | |||
| ${TEST_DIR}/st/sub_graph_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc | |||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc | |||
| @@ -300,7 +302,7 @@ endif () | |||
| add_executable(lite-test ${TEST_SRC}) | |||
| add_dependencies(lite-test fbs_src) | |||
| target_link_libraries(lite-test dl ${GTEST_LIBRARY}) | |||
| if (PLATFORM_ARM64) | |||
| target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) | |||
| @@ -321,6 +323,7 @@ if (SUPPORT_NPU) | |||
| target_link_libraries(lite-test npu_kernel_mid) | |||
| endif () | |||
| if (ENABLE_CONVERTER) | |||
| add_dependencies(lite-test fbs_inner_src) | |||
| target_link_libraries(lite-test | |||
| anf_importer_mid | |||
| anf_exporter_mid | |||
| @@ -1,3 +1,3 @@ | |||
| mobilenet.tflite 0.5 | |||
| transformer_20200831_encoder_fp32.tflite 68 | |||
| transformer_20200831_encoder_fp32.tflite 69 | |||
| transformer_20200831_decoder_fp32.tflite 35 | |||
| @@ -0,0 +1,459 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "mindspore/lite/include/model.h" | |||
| #include "common/common_test.h" | |||
| #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/lite_session.h" | |||
| #include "include/version.h" | |||
| namespace mindspore { | |||
| class ControlFlowTest : public mindspore::CommonTest { | |||
| public: | |||
| ControlFlowTest() {} | |||
| }; | |||
| TEST_F(ControlFlowTest, TestMergeWhileModel) { | |||
| // make graph | |||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||
| MS_LOG(DEBUG) << "make subgraph"; | |||
| meta_graph->name = "graph"; | |||
| meta_graph->version = lite::Version(); | |||
| meta_graph->inputIndex = {0}; | |||
| meta_graph->outputIndex = {9}; | |||
| // subgraph 0 : main graph | |||
| auto sub_graph_0 = std::make_unique<schema::SubGraphT>(); | |||
| sub_graph_0->name = "main_graph"; | |||
| // subgraph 1 : cond graph | |||
| auto sub_graph_1 = std::make_unique<schema::SubGraphT>(); | |||
| sub_graph_1->name = "cond_graph"; | |||
| // subgraph 2: body graph | |||
| auto sub_graph_2 = std::make_unique<schema::SubGraphT>(); | |||
| sub_graph_2->name = "body_graph"; | |||
| MS_LOG(DEBUG) << "make subgraph"; | |||
| // subgraph 0: node 0 before-add-1 | |||
| auto sub_graph_0_node_0 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_0_node_0->inputIndex = {0, 1}; | |||
| sub_graph_0_node_0->outputIndex = {2}; | |||
| sub_graph_0_node_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_0_node_0->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto primitive_sub_graph_0_node_0 = new schema::AddT; | |||
| primitive_sub_graph_0_node_0->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| sub_graph_0_node_0->primitive->value.value = primitive_sub_graph_0_node_0; | |||
| sub_graph_0_node_0->name = "before_Add_1"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_0)); | |||
| sub_graph_0->nodeIndices.push_back(0); | |||
| MS_LOG(DEBUG) << "node 0"; | |||
| // subgraph 0: node 1 before-add-1 | |||
| auto sub_graph_0_node_1 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_0_node_1->inputIndex = {2, 3}; | |||
| sub_graph_0_node_1->outputIndex = {4}; | |||
| sub_graph_0_node_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_0_node_1->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto primitive_sub_graph_0_node_1 = new schema::AddT; | |||
| primitive_sub_graph_0_node_1->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| sub_graph_0_node_1->primitive->value.value = primitive_sub_graph_0_node_1; | |||
| sub_graph_0_node_1->name = "before_Add_2"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_1)); | |||
| sub_graph_0->nodeIndices.push_back(1); | |||
| MS_LOG(DEBUG) << "node 1"; | |||
| // subgraph 0: node 2 merge | |||
| auto sub_graph_0_node_2 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_0_node_2->inputIndex = {4, 17}; | |||
| sub_graph_0_node_2->outputIndex = {16}; | |||
| sub_graph_0_node_2->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_0_node_2->primitive->value.type = schema::PrimitiveType_Merge; | |||
| auto primitive_sub_graph_0_node_2 = new schema::MergeT; | |||
| sub_graph_0_node_2->primitive->value.value = primitive_sub_graph_0_node_2; | |||
| sub_graph_0_node_2->name = "merge"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_2)); | |||
| sub_graph_0->nodeIndices.push_back(2); | |||
| MS_LOG(DEBUG) << "node 2"; | |||
| // subgraph 0: node 3 partial cond subGraph | |||
| auto sub_graph_0_node_3 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_0_node_3->inputIndex = {16}; | |||
| sub_graph_0_node_3->outputIndex = {5}; // 5 : bool | |||
| sub_graph_0_node_3->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_0_node_3->primitive->value.type = schema::PrimitiveType_Partial; | |||
| auto primitive_sub_graph_0_node_3 = new schema::PartialT; | |||
| primitive_sub_graph_0_node_3->subGraphIndex = 1; | |||
| sub_graph_0_node_3->primitive->value.value = primitive_sub_graph_0_node_3; | |||
| sub_graph_0_node_3->name = "Partial_cond"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_3)); | |||
| sub_graph_0->nodeIndices.push_back(3); | |||
| MS_LOG(DEBUG) << "node 2"; | |||
| // subgraph 0: node 4 switch | |||
| auto sub_graph_0_node_4 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_0_node_4->inputIndex = {5, 16}; // 5 : bool; 16 data | |||
| sub_graph_0_node_4->outputIndex = {6, 7}; | |||
| sub_graph_0_node_4->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_0_node_4->primitive->value.type = schema::PrimitiveType_Switch; | |||
| auto primitive_sub_graph_0_node_4 = new schema::SwitchT; | |||
| sub_graph_0_node_4->primitive->value.value = primitive_sub_graph_0_node_4; | |||
| sub_graph_0_node_4->name = "Switch"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_4)); | |||
| sub_graph_0->nodeIndices.push_back(4); | |||
| MS_LOG(DEBUG) << "node 4"; | |||
| // subgraph 0: node 5 partial body subgraph | |||
| auto sub_graph_0_node_5 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_0_node_5->inputIndex = {6}; | |||
| sub_graph_0_node_5->outputIndex = {17}; | |||
| sub_graph_0_node_5->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_0_node_5->primitive->value.type = schema::PrimitiveType_Partial; | |||
| auto primitive_sub_graph_0_node_5 = new schema::PartialT; | |||
| primitive_sub_graph_0_node_5->subGraphIndex = 2; | |||
| sub_graph_0_node_5->primitive->value.value = primitive_sub_graph_0_node_5; | |||
| sub_graph_0_node_5->name = "Partial_body"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_5)); | |||
| sub_graph_0->nodeIndices.push_back(5); | |||
| MS_LOG(DEBUG) << "node 5"; | |||
| // subgraph 0: node 6 add-after | |||
| auto sub_graph_0_node_6 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_0_node_6->inputIndex = {7, 8}; | |||
| sub_graph_0_node_6->outputIndex = {9}; | |||
| sub_graph_0_node_6->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_0_node_6->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto primitive_sub_graph_0_node_6 = new schema::AddT; | |||
| sub_graph_0_node_6->primitive->value.value = primitive_sub_graph_0_node_6; | |||
| sub_graph_0_node_6->name = "Add-after"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_6)); | |||
| sub_graph_0->nodeIndices.push_back(6); | |||
| MS_LOG(DEBUG) << "node 6"; | |||
| sub_graph_0->inputIndices = {0}; | |||
| sub_graph_0->outputIndices = {9}; | |||
| sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17}; | |||
| meta_graph->subGraph.push_back(std::move(sub_graph_0)); | |||
| // subgraph 1 ; node:0 add cond | |||
| auto sub_graph_1_node_0 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_1_node_0->inputIndex = {16, 10}; | |||
| sub_graph_1_node_0->outputIndex = {11}; | |||
| sub_graph_1_node_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_1_node_0->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto primitive_sub_graph_1_node_0 = new schema::AddT; | |||
| sub_graph_1_node_0->primitive->value.value = primitive_sub_graph_1_node_0; | |||
| sub_graph_1_node_0->name = "cond_add"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_0)); | |||
| sub_graph_1->nodeIndices.push_back(7); | |||
| MS_LOG(DEBUG) << "node 6"; | |||
| // subgraph 1 ; node:1 Less cond | |||
| auto sub_graph_1_node_1 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_1_node_1->inputIndex = {11, 12}; | |||
| sub_graph_1_node_1->outputIndex = {5}; | |||
| sub_graph_1_node_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_1_node_1->primitive->value.type = schema::PrimitiveType_Less; | |||
| auto primitive_sub_graph_1_node_1 = new schema::LessT; | |||
| sub_graph_1_node_1->primitive->value.value = primitive_sub_graph_1_node_1; | |||
| sub_graph_1_node_1->name = "cond_Less"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_1)); | |||
| sub_graph_1->nodeIndices.push_back(8); | |||
| MS_LOG(DEBUG) << "node 7"; | |||
| sub_graph_1->inputIndices = {16}; | |||
| sub_graph_1->outputIndices = {5}; | |||
| sub_graph_1->tensorIndices = {16, 10, 11, 12, 5}; | |||
| meta_graph->subGraph.push_back(std::move(sub_graph_1)); | |||
| // subgraph 2 ; node:0 body add-1 | |||
| auto sub_graph_2_node_0 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_2_node_0->inputIndex = {6, 13}; | |||
| sub_graph_2_node_0->outputIndex = {14}; | |||
| sub_graph_2_node_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_2_node_0->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto primitive_sub_graph_2_node_0 = new schema::AddT; | |||
| sub_graph_2_node_0->primitive->value.value = primitive_sub_graph_2_node_0; | |||
| sub_graph_2_node_0->name = "body_add_1"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_0)); | |||
| sub_graph_2->nodeIndices.push_back(9); | |||
| MS_LOG(DEBUG) << "node 8"; | |||
| // subgraph 2 ; node:1 body add-2 | |||
| auto sub_graph_2_node_1 = std::make_unique<schema::CNodeT>(); | |||
| sub_graph_2_node_1->inputIndex = {14, 15}; | |||
| sub_graph_2_node_1->outputIndex = {17}; | |||
| sub_graph_2_node_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| sub_graph_2_node_1->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto primitive_sub_graph_2_node_1 = new schema::AddT; | |||
| sub_graph_2_node_1->primitive->value.value = primitive_sub_graph_2_node_1; | |||
| sub_graph_2_node_1->name = "body_add_2"; | |||
| meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_1)); | |||
| sub_graph_2->nodeIndices.push_back(10); | |||
| MS_LOG(DEBUG) << "node 9"; | |||
| sub_graph_2->inputIndices = {6}; | |||
| sub_graph_2->outputIndices = {17}; | |||
| sub_graph_2->tensorIndices = {13, 14, 15, 6, 17}; | |||
| meta_graph->subGraph.push_back(std::move(sub_graph_2)); | |||
| // ------- tensor --------- | |||
| // tensor: 0 before-add input0 <main graph input> | |||
| auto tensor_0 = std::make_unique<schema::TensorT>(); | |||
| tensor_0->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_0->format = schema::Format_NHWC; | |||
| tensor_0->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_0->dims = {1}; | |||
| tensor_0->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_0)); | |||
| MS_LOG(DEBUG) << "tensor 0"; | |||
| // tensor: 1 before-add input1 <const> | |||
| auto tensor_1 = std::make_unique<schema::TensorT>(); | |||
| tensor_1->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_1->format = schema::Format_NHWC; | |||
| tensor_1->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_1->dims = {1}; | |||
| tensor_1->data.resize(sizeof(float) * 1); | |||
| float input1_data[] = {1}; | |||
| memcpy(tensor_1->data.data(), input1_data, sizeof(float) * 1); | |||
| tensor_1->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_1)); | |||
| MS_LOG(DEBUG) << "tensor 1"; | |||
| // tensor: 2 before-add output/partial input | |||
| auto tensor_2 = std::make_unique<schema::TensorT>(); | |||
| tensor_2->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_2->format = schema::Format_NHWC; | |||
| tensor_2->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_2->dims = {1}; | |||
| tensor_2->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_2)); | |||
| MS_LOG(DEBUG) << "tensor 2"; | |||
| // tensor: 3 before-add input1 <const> | |||
| auto tensor_3 = std::make_unique<schema::TensorT>(); | |||
| tensor_3->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_3->format = schema::Format_NHWC; | |||
| tensor_3->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_3->dims = {1}; | |||
| tensor_3->data.resize(sizeof(float) * 1); | |||
| float tensor_3_data[] = {1}; | |||
| memcpy(tensor_3->data.data(), tensor_3_data, sizeof(float) * 1); | |||
| tensor_3->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_3)); | |||
| MS_LOG(DEBUG) << "tensor 3"; | |||
| auto tensor_4 = std::make_unique<schema::TensorT>(); | |||
| tensor_4->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_4->format = schema::Format_NHWC; | |||
| tensor_4->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_4->dims = {1}; | |||
| tensor_4->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_4)); | |||
| MS_LOG(DEBUG) << "tensor 4"; | |||
| // tensor :5 partial output <bool> | |||
| auto tensor_5 = std::make_unique<schema::TensorT>(); | |||
| tensor_5->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_5->format = schema::Format_NHWC; | |||
| tensor_5->dataType = TypeId::kNumberTypeBool; | |||
| tensor_5->dims = {1}; | |||
| tensor_5->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_5)); | |||
| MS_LOG(DEBUG) << "tensor_4"; | |||
| // tensor: 6 switch true output | |||
| auto tensor_6 = std::make_unique<schema::TensorT>(); | |||
| tensor_6->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_6->format = schema::Format_NHWC; | |||
| tensor_6->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_6->dims = {1}; | |||
| tensor_6->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_6)); | |||
| MS_LOG(DEBUG) << "tensor 6"; | |||
| // tensor: 5 switch False output | |||
| auto tensor_7 = std::make_unique<schema::TensorT>(); | |||
| tensor_7->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_7->format = schema::Format_NHWC; | |||
| tensor_7->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_7->dims = {1}; | |||
| tensor_7->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_7)); | |||
| MS_LOG(DEBUG) << "tensor_7"; | |||
| // tensor: 6 body-add input ,other input is switch true output | |||
| auto tensor_8 = std::make_unique<schema::TensorT>(); | |||
| tensor_8->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_8->format = schema::Format_NHWC; | |||
| tensor_8->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_8->dims = {1}; | |||
| tensor_8->data.resize(sizeof(float) * 1); | |||
| float tensor_8_data[] = {10}; | |||
| memcpy(tensor_8->data.data(), tensor_8_data, sizeof(float) * 1); | |||
| tensor_8->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_8)); | |||
| MS_LOG(DEBUG) << "tensor_8"; | |||
| auto tensor_9 = std::make_unique<schema::TensorT>(); | |||
| tensor_9->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_9->format = schema::Format_NHWC; | |||
| tensor_9->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_9->dims = {1}; | |||
| tensor_9->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_9)); | |||
| MS_LOG(DEBUG) << "tensor_9"; | |||
| // tensor: 7 after-add input ,other input is switch false output | |||
| auto tensor_10 = std::make_unique<schema::TensorT>(); | |||
| tensor_10->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_10->format = schema::Format_NHWC; | |||
| tensor_10->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_10->dims = {1}; | |||
| tensor_10->data.resize(sizeof(float) * 1); | |||
| float tensor_10_data[] = {1}; | |||
| memcpy(tensor_10->data.data(), tensor_10_data, sizeof(float) * 1); | |||
| tensor_10->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_10)); | |||
| MS_LOG(DEBUG) << "tensor_10"; | |||
| // tensor: 8 main graph output | |||
| auto tensor_11 = std::make_unique<schema::TensorT>(); | |||
| tensor_11->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_11->format = schema::Format_NHWC; | |||
| tensor_11->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_11->dims = {1}; | |||
| tensor_11->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_11)); | |||
| MS_LOG(DEBUG) << "tensor 11"; | |||
| // tensor: 9 cond-Less input, other input is tensor 2 | |||
| auto tensor_12 = std::make_unique<schema::TensorT>(); | |||
| tensor_12->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_12->format = schema::Format_NHWC; | |||
| tensor_12->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_12->dims = {1}; | |||
| tensor_12->data.resize(sizeof(float) * 1); | |||
| float tensor_12_data[] = {10}; | |||
| memcpy(tensor_12->data.data(), tensor_12_data, sizeof(float) * 1); | |||
| tensor_12->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_12)); | |||
| MS_LOG(DEBUG) << "tensor_12"; | |||
| auto tensor_13 = std::make_unique<schema::TensorT>(); | |||
| tensor_13->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_13->format = schema::Format_NHWC; | |||
| tensor_13->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_13->dims = {1}; | |||
| tensor_13->data.resize(sizeof(float) * 1); | |||
| float tensor_13_data[] = {1}; | |||
| memcpy(tensor_13->data.data(), tensor_13_data, sizeof(float) * 1); | |||
| tensor_13->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_13)); | |||
| MS_LOG(DEBUG) << "tensor_13"; | |||
| auto tensor_14 = std::make_unique<schema::TensorT>(); | |||
| tensor_14->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_14->format = schema::Format_NHWC; | |||
| tensor_14->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_14->dims = {1}; | |||
| tensor_14->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_14)); | |||
| MS_LOG(DEBUG) << "tensor 14"; | |||
| auto tensor_15 = std::make_unique<schema::TensorT>(); | |||
| tensor_15->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_15->format = schema::Format_NHWC; | |||
| tensor_15->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_15->dims = {1}; | |||
| tensor_15->data.resize(sizeof(float) * 1); | |||
| float tensor_15_data[] = {1}; | |||
| memcpy(tensor_15->data.data(), tensor_15_data, sizeof(float) * 1); | |||
| tensor_15->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_15)); | |||
| MS_LOG(DEBUG) << "tensor_15"; | |||
| auto tensor_16 = std::make_unique<schema::TensorT>(); | |||
| tensor_16->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_16->format = schema::Format_NHWC; | |||
| tensor_16->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_16->dims = {1}; | |||
| tensor_16->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_16)); | |||
| MS_LOG(DEBUG) << "tensor_16"; | |||
| auto tensor_17 = std::make_unique<schema::TensorT>(); | |||
| tensor_17->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_17->format = schema::Format_NHWC; | |||
| tensor_17->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_17->dims = {1}; | |||
| tensor_17->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_17)); | |||
| MS_LOG(DEBUG) << "tensor_17"; | |||
| // ----------------------------------------------------------------------- | |||
| flatbuffers::FlatBufferBuilder builder(1024); | |||
| auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); | |||
| builder.Finish(offset); | |||
| schema::FinishMetaGraphBuffer(builder, offset); | |||
| size_t size = builder.GetSize(); | |||
| const char *content = reinterpret_cast<char *>(builder.GetBufferPointer()); | |||
| auto model = std::shared_ptr<lite::Model>(lite::Model::Import(content, size)); | |||
| ASSERT_NE(model, nullptr); | |||
| lite::Context context; | |||
| context.thread_num_ = 2; | |||
| auto &cpu_device_ctx = context.device_list_[0]; | |||
| cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU; | |||
| cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = false; | |||
| auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context)); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->CompileGraph(model.get()); | |||
| ASSERT_EQ(ret, lite::RET_OK); | |||
| model->Free(); | |||
| auto inputs = session->GetInputs(); | |||
| ASSERT_EQ(inputs.size(), 1); | |||
| auto input = inputs.front(); | |||
| ASSERT_NE(input, nullptr); | |||
| ASSERT_EQ(input->data_type(), kNumberTypeFloat32); | |||
| ASSERT_EQ(input->shape().size(), 1); | |||
| ASSERT_EQ(input->shape().at(0), 1); | |||
| auto in_data = reinterpret_cast<float *>(input->MutableData()); | |||
| ASSERT_NE(in_data, nullptr); | |||
| in_data[0] = 1; | |||
| ret = session->RunGraph(); | |||
| ASSERT_EQ(ret, lite::RET_OK); | |||
| auto outputs = session->GetOutputs(); | |||
| ASSERT_EQ(outputs.size(), 1); | |||
| auto output = outputs.begin()->second; | |||
| ASSERT_NE(output, nullptr); | |||
| ASSERT_EQ(output->data_type(), kNumberTypeFloat32); | |||
| ASSERT_EQ(output->shape().size(), 1); | |||
| ASSERT_EQ(output->shape().at(0), 1); | |||
| auto out_data = reinterpret_cast<float *>(output->MutableData()); | |||
| ASSERT_NE(out_data, nullptr); | |||
| ASSERT_EQ(out_data[0], 19); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,217 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "mindspore/lite/include/model.h" | |||
| #include "common/common_test.h" | |||
| #include "include/lite_session.h" | |||
| #include "include/context.h" | |||
| #include "include/model.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/lite_session.h" | |||
| #include "src/runtime/parallel_executor.h" | |||
| #include "tools/common/storage.h" | |||
| #include "include/version.h" | |||
| namespace mindspore { | |||
| class SubGraphTest : public mindspore::CommonTest { | |||
| public: | |||
| SubGraphTest() {} | |||
| }; | |||
| TEST_F(SubGraphTest, RecursiveSubGraphTest) { | |||
| // add0 partial1 2 3 tensor0 1 2 | |||
| auto add_0 = std::make_unique<schema::CNodeT>(); | |||
| add_0->inputIndex = {0, 1}; | |||
| add_0->outputIndex = {2}; | |||
| add_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| add_0->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto add_0_prim = new schema::AddT; | |||
| add_0_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| add_0->primitive->value.value = add_0_prim; | |||
| add_0->name = "Add0"; | |||
| auto partial_1 = std::make_unique<schema::CNodeT>(); | |||
| partial_1->inputIndex = {2}; | |||
| partial_1->outputIndex = {7}; | |||
| partial_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| partial_1->primitive->value.type = schema::PrimitiveType_Partial; | |||
| auto partial_1_prim = new schema::PartialT; | |||
| partial_1_prim->subGraphIndex = 1; | |||
| partial_1->primitive->value.value = partial_1_prim; | |||
| partial_1->name = "Partial1"; | |||
| auto partial_2 = std::make_unique<schema::CNodeT>(); | |||
| partial_2->inputIndex = {2}; | |||
| partial_2->outputIndex = {7}; | |||
| partial_2->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| partial_2->primitive->value.type = schema::PrimitiveType_Partial; | |||
| auto partial_2_prim = new schema::PartialT; | |||
| partial_2_prim->subGraphIndex = 2; | |||
| partial_2->primitive->value.value = partial_2_prim; | |||
| partial_2->name = "Partial2"; | |||
| auto partial_3 = std::make_unique<schema::CNodeT>(); | |||
| partial_3->inputIndex = {4, 6}; | |||
| partial_3->outputIndex = {7}; | |||
| partial_3->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| partial_3->primitive->value.type = schema::PrimitiveType_Partial; | |||
| auto partial_3_prim = new schema::PartialT; | |||
| partial_3_prim->subGraphIndex = 3; | |||
| partial_3->primitive->value.value = partial_3_prim; | |||
| partial_3->name = "Partial3"; | |||
| auto tensor_0 = std::make_unique<schema::TensorT>(); | |||
| tensor_0->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_0->format = schema::Format_NHWC; | |||
| tensor_0->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_0->dims = {1, 2}; | |||
| auto tensor_1 = std::make_unique<schema::TensorT>(); | |||
| tensor_1->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_1->format = schema::Format_NHWC; | |||
| tensor_1->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_1->dims = {1, 2}; | |||
| auto tensor_2 = std::make_unique<schema::TensorT>(); | |||
| tensor_2->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_2->format = schema::Format_NHWC; | |||
| tensor_2->dataType = TypeId::kNumberTypeFloat32; | |||
| auto sub_graph_0 = std::make_unique<schema::SubGraphT>(); | |||
| sub_graph_0->name = "main_graph"; | |||
| sub_graph_0->inputIndices = {0}; | |||
| sub_graph_0->outputIndices = {7}; | |||
| sub_graph_0->nodeIndices = {0, 1, 2}; | |||
| sub_graph_0->tensorIndices = {0, 1, 2, 7}; | |||
| // add1 tensor3 4 | |||
| auto add_1 = std::make_unique<schema::CNodeT>(); | |||
| add_1->inputIndex = {2, 3}; | |||
| add_1->outputIndex = {4}; | |||
| add_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| add_1->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto add_1_prim = new schema::AddT; | |||
| add_1_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| add_1->primitive->value.value = add_1_prim; | |||
| add_1->name = "Add1"; | |||
| auto tensor_3 = std::make_unique<schema::TensorT>(); | |||
| tensor_3->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_3->format = schema::Format_NHWC; | |||
| tensor_3->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_3->dims = {1, 2}; | |||
| auto tensor_4 = std::make_unique<schema::TensorT>(); | |||
| tensor_4->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_4->format = schema::Format_NHWC; | |||
| tensor_4->dataType = TypeId::kNumberTypeFloat32; | |||
| auto sub_graph_1 = std::make_unique<schema::SubGraphT>(); | |||
| sub_graph_1->name = "sub_graph_1"; | |||
| sub_graph_1->inputIndices = {2}; | |||
| sub_graph_1->outputIndices = {7}; | |||
| sub_graph_1->nodeIndices = {4, 3}; | |||
| sub_graph_1->tensorIndices = {2, 3, 4, 7}; | |||
| // add2 tensor5 6 | |||
| auto add_2 = std::make_unique<schema::CNodeT>(); | |||
| add_2->inputIndex = {2, 5}; | |||
| add_2->outputIndex = {6}; | |||
| add_2->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| add_2->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto add_2_prim = new schema::AddT; | |||
| add_2_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| add_2->primitive->value.value = add_2_prim; | |||
| add_2->name = "Add2"; | |||
| auto tensor_5 = std::make_unique<schema::TensorT>(); | |||
| tensor_5->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| tensor_5->format = schema::Format_NHWC; | |||
| tensor_5->dataType = TypeId::kNumberTypeFloat32; | |||
| tensor_5->dims = {1, 2}; | |||
| auto tensor_6 = std::make_unique<schema::TensorT>(); | |||
| tensor_6->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_6->format = schema::Format_NHWC; | |||
| tensor_6->dataType = TypeId::kNumberTypeFloat32; | |||
| auto sub_graph_2 = std::make_unique<schema::SubGraphT>(); | |||
| sub_graph_2->name = "sub_graph_2"; | |||
| sub_graph_2->inputIndices = {2}; | |||
| sub_graph_2->outputIndices = {7}; | |||
| sub_graph_2->nodeIndices = {5, 3}; | |||
| sub_graph_2->tensorIndices = {2, 5, 6, 7}; | |||
| // add3 tensor7 | |||
| auto add_3 = std::make_unique<schema::CNodeT>(); | |||
| add_3->inputIndex = {4, 6}; | |||
| add_3->outputIndex = {7}; | |||
| add_3->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| add_3->primitive->value.type = schema::PrimitiveType_Add; | |||
| auto add_3_prim = new schema::AddT; | |||
| add_3_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| add_3->primitive->value.value = add_3_prim; | |||
| add_3->name = "Add3"; | |||
| auto tensor_7 = std::make_unique<schema::TensorT>(); | |||
| tensor_7->nodeType = schema::NodeType::NodeType_Parameter; | |||
| tensor_7->format = schema::Format_NHWC; | |||
| tensor_7->dataType = TypeId::kNumberTypeFloat32; | |||
| auto sub_graph_3 = std::make_unique<schema::SubGraphT>(); | |||
| sub_graph_3->name = "sub_graph_3"; | |||
| sub_graph_3->inputIndices = {4, 6}; | |||
| sub_graph_3->outputIndices = {7}; | |||
| sub_graph_3->nodeIndices = {6}; | |||
| sub_graph_3->tensorIndices = {4, 6, 7}; | |||
| // make graph | |||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||
| meta_graph->name = "graph"; | |||
| meta_graph->nodes.emplace_back(std::move(add_0)); | |||
| meta_graph->nodes.emplace_back(std::move(partial_1)); | |||
| meta_graph->nodes.emplace_back(std::move(partial_2)); | |||
| meta_graph->nodes.emplace_back(std::move(partial_3)); | |||
| meta_graph->nodes.emplace_back(std::move(add_1)); | |||
| meta_graph->nodes.emplace_back(std::move(add_2)); | |||
| meta_graph->nodes.emplace_back(std::move(add_3)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_0)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_1)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_2)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_3)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_4)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_5)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_6)); | |||
| meta_graph->allTensors.emplace_back(std::move(tensor_7)); | |||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_0)); | |||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_1)); | |||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_2)); | |||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_3)); | |||
| meta_graph->version = lite::Version(); | |||
| // ----------------------------------------------------------------------- | |||
| lite::Storage::Save(*meta_graph, | |||
| "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph"); | |||
| // ----------------------------------------------------------------------- | |||
| size_t size = 0; | |||
| char *graph_buf = lite::ReadFile( | |||
| "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size); | |||
| ASSERT_NE(graph_buf, nullptr); | |||
| auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size)); | |||
| ASSERT_NE(model, nullptr); | |||
| delete[](graph_buf); | |||
| lite::Context context; | |||
| auto &cpu_device_ctx = context.device_list_[0]; | |||
| cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU; | |||
| context.thread_num_ = 2; | |||
| auto session = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(&context)); | |||
| ASSERT_NE(session, nullptr); | |||
| auto ret = session->CompileGraph(model.get()); | |||
| ASSERT_EQ(ret, lite::RET_OK); | |||
| auto inputs = session->GetInputs(); | |||
| for (auto *input : inputs) { | |||
| (void)input->MutableData(); | |||
| } | |||
| ret = session->RunGraph(); | |||
| ASSERT_EQ(ret, lite::RET_OK); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -142,7 +142,6 @@ add_executable(converter_lite | |||
| ${KERNEL_SRC} | |||
| ${LITE_SRC} | |||
| ) | |||
| add_dependencies(converter_lite tflite_fbs_src) | |||
| add_dependencies(converter_lite fbs_src) | |||
| add_dependencies(converter_lite fbs_inner_src) | |||
| @@ -5,4 +5,5 @@ set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID | |||
| add_library(tflite_parser_mid OBJECT | |||
| ${TFLITE_SRC_LIST} | |||
| ) | |||
| add_dependencies(tflite_parser_mid tflite_fbs_src) | |||
| target_link_libraries(tflite_parser_mid mindspore::flatbuffers) | |||