| @@ -258,7 +258,8 @@ union PrimitiveType { | |||||
| SmoothL1LossGrad, | SmoothL1LossGrad, | ||||
| SigmoidCrossEntropyWithLogits, | SigmoidCrossEntropyWithLogits, | ||||
| SigmoidCrossEntropyWithLogitsGrad, | SigmoidCrossEntropyWithLogitsGrad, | ||||
| Reciprocal | |||||
| Reciprocal, | |||||
| Merge, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -1222,4 +1222,7 @@ table SigmoidCrossEntropyWithLogitsGrad { | |||||
| } | } | ||||
| table Reciprocal { | table Reciprocal { | ||||
| } | |||||
| } | |||||
| table Merge { | |||||
| } | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "mindspore/lite/src/executor.h" | |||||
| #include "nnacl/pack.h" | |||||
| #include "src/executor.h" | |||||
| #include <queue> | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| @@ -26,7 +26,7 @@ int Executor::CheckInputs(const std::vector<Tensor *> &in_tensors) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (inTensor->data_c() == nullptr) { | if (inTensor->data_c() == nullptr) { | ||||
| MS_LOG(ERROR) << "Graph input tensor data is nullptr"; | |||||
| MS_LOG(ERROR) << "Graph input tensor data is nullptr " << in_tensors; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto shape = inTensor->shape(); | auto shape = inTensor->shape(); | ||||
| @@ -49,7 +49,52 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_ | |||||
| MS_LOG(ERROR) << "CheckInputs failed"; | MS_LOG(ERROR) << "CheckInputs failed"; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| kernel::LiteKernelUtil::InitTensorRefCount(kernels); | |||||
| std::queue<kernel::LiteKernel *> kernel_queue; | |||||
| for (auto kernel : kernels) { | |||||
| if (kernel->IsReady()) { | |||||
| kernel_queue.push(kernel); | |||||
| } | |||||
| } | |||||
| while (!kernel_queue.empty()) { | |||||
| auto cur_kernel = kernel_queue.front(); | |||||
| kernel_queue.pop(); | |||||
| MS_ASSERT(nullptr != cur_kernel); | |||||
| ret = cur_kernel->PreProcess(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "PreProcess kernel failed, name: " << cur_kernel->name(); | |||||
| return ret; | |||||
| } | |||||
| ret = cur_kernel->Run(before, after); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "run kernel failed, name: " << cur_kernel->name(); | |||||
| return ret; | |||||
| } | |||||
| ret = cur_kernel->PostProcess(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "PostProcess kernel failed, name: " << cur_kernel->name(); | |||||
| return ret; | |||||
| } | |||||
| for (auto &out_kernel : cur_kernel->out_kernels()) { | |||||
| if (out_kernel->IsReady()) { | |||||
| kernel_queue.push(out_kernel); | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int CpuExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, | |||||
| std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator, const KernelCallBack &before, | |||||
| const KernelCallBack &after) { | |||||
| MS_ASSERT(nullptr != allocator); | |||||
| // not check input for merge. too hard | |||||
| if (kernels.front()->Type() != schema::PrimitiveType_Merge) { | |||||
| auto ret = this->CheckInputs(in_tensors); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "CheckInputs failed"; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them | for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them | ||||
| out_tensor->set_ref_count(out_tensor->ref_count() + 1); | out_tensor->set_ref_count(out_tensor->ref_count() + 1); | ||||
| @@ -57,7 +102,7 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_ | |||||
| #endif | #endif | ||||
| for (auto *kernel : kernels) { | for (auto *kernel : kernels) { | ||||
| MS_ASSERT(nullptr != kernel); | MS_ASSERT(nullptr != kernel); | ||||
| ret = kernel->PreProcess(); | |||||
| auto ret = kernel->PreProcess(); | |||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name(); | MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name(); | ||||
| return ret; | return ret; | ||||
| @@ -37,5 +37,16 @@ class Executor { | |||||
| protected: | protected: | ||||
| static int CheckInputs(const std::vector<Tensor *> &in_tensors); | static int CheckInputs(const std::vector<Tensor *> &in_tensors); | ||||
| }; | }; | ||||
| class CpuExecutor : public Executor { | |||||
| public: | |||||
| CpuExecutor() = default; | |||||
| virtual ~CpuExecutor() = default; | |||||
| int Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, | |||||
| std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr, | |||||
| const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; | |||||
| }; | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif | #endif | ||||
| @@ -62,7 +62,7 @@ InnerContext::~InnerContext() { | |||||
| } | } | ||||
| } | } | ||||
| int InnerContext::IsValid() { | |||||
| int InnerContext::IsValid() const { | |||||
| if (this->device_list_.empty()) { | if (this->device_list_.empty()) { | ||||
| MS_LOG(ERROR) << "Device list is empty."; | MS_LOG(ERROR) << "Device list is empty."; | ||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| @@ -86,33 +86,33 @@ int InnerContext::IsValid() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| bool InnerContext::IsCpuFloat16Enabled() { | |||||
| bool InnerContext::IsCpuFloat16Enabled() const { | |||||
| if (!IsCpuEnabled()) { | if (!IsCpuEnabled()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| return GetCpuInfo().enable_float16_; | return GetCpuInfo().enable_float16_; | ||||
| } | } | ||||
| bool InnerContext::IsGpuFloat16Enabled() { | |||||
| bool InnerContext::IsGpuFloat16Enabled() const { | |||||
| if (!IsGpuEnabled()) { | if (!IsGpuEnabled()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| return GetGpuInfo().enable_float16_; | return GetGpuInfo().enable_float16_; | ||||
| } | } | ||||
| bool InnerContext::IsCpuEnabled() { | |||||
| bool InnerContext::IsCpuEnabled() const { | |||||
| return this->device_list_.end() != | return this->device_list_.end() != | ||||
| std::find_if(this->device_list_.begin(), this->device_list_.end(), | std::find_if(this->device_list_.begin(), this->device_list_.end(), | ||||
| [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); | [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); | ||||
| } | } | ||||
| bool InnerContext::IsGpuEnabled() { | |||||
| bool InnerContext::IsGpuEnabled() const { | |||||
| return this->device_list_.end() != | return this->device_list_.end() != | ||||
| std::find_if(this->device_list_.begin(), this->device_list_.end(), | std::find_if(this->device_list_.begin(), this->device_list_.end(), | ||||
| [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); | [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); | ||||
| } | } | ||||
| bool InnerContext::IsNpuEnabled() { | |||||
| bool InnerContext::IsNpuEnabled() const { | |||||
| #ifdef SUPPORT_NPU | #ifdef SUPPORT_NPU | ||||
| return this->device_list_.end() != | return this->device_list_.end() != | ||||
| std::find_if(this->device_list_.begin(), this->device_list_.end(), | std::find_if(this->device_list_.begin(), this->device_list_.end(), | ||||
| @@ -123,7 +123,7 @@ bool InnerContext::IsNpuEnabled() { | |||||
| #endif | #endif | ||||
| } | } | ||||
| CpuDeviceInfo InnerContext::GetCpuInfo() { | |||||
| CpuDeviceInfo InnerContext::GetCpuInfo() const { | |||||
| auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), | auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), | ||||
| [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); | [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); | ||||
| if (iter == this->device_list_.end()) { | if (iter == this->device_list_.end()) { | ||||
| @@ -133,7 +133,7 @@ CpuDeviceInfo InnerContext::GetCpuInfo() { | |||||
| } | } | ||||
| } | } | ||||
| GpuDeviceInfo InnerContext::GetGpuInfo() { | |||||
| GpuDeviceInfo InnerContext::GetGpuInfo() const { | |||||
| auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), | auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), | ||||
| [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); | [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); | ||||
| if (iter == this->device_list_.end()) { | if (iter == this->device_list_.end()) { | ||||
| @@ -33,23 +33,23 @@ struct InnerContext : public Context { | |||||
| int Init(); | int Init(); | ||||
| bool IsCpuFloat16Enabled(); | |||||
| bool IsCpuFloat16Enabled() const; | |||||
| bool IsGpuFloat16Enabled(); | |||||
| bool IsGpuFloat16Enabled() const; | |||||
| bool IsCpuEnabled(); | |||||
| bool IsCpuEnabled() const; | |||||
| bool IsGpuEnabled(); | |||||
| bool IsGpuEnabled() const; | |||||
| bool IsNpuEnabled(); | |||||
| bool IsNpuEnabled() const; | |||||
| CpuDeviceInfo GetCpuInfo(); | |||||
| CpuDeviceInfo GetCpuInfo() const; | |||||
| GpuDeviceInfo GetGpuInfo(); | |||||
| GpuDeviceInfo GetGpuInfo() const; | |||||
| NpuDeviceInfo GetNpuInfo() const; | NpuDeviceInfo GetNpuInfo() const; | ||||
| int IsValid(); | |||||
| int IsValid() const; | |||||
| virtual ~InnerContext(); | virtual ~InnerContext(); | ||||
| }; | }; | ||||
| @@ -41,9 +41,21 @@ void LiteKernel::FreeWorkspace() { | |||||
| workspace_ = nullptr; | workspace_ = nullptr; | ||||
| } | } | ||||
| void LiteKernel::InitOutTensorRefCount() { | |||||
| bool LiteKernel::IsReady() { | |||||
| return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { | |||||
| return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; | |||||
| }); | |||||
| } | |||||
| void LiteKernel::InitOutTensorInitRefCount() { | |||||
| for (auto *tensor : this->out_tensors_) { | for (auto *tensor : this->out_tensors_) { | ||||
| tensor->set_ref_count(this->out_kernels_.size()); | |||||
| int init_ref_count = 0; | |||||
| for (auto *post_kernel : this->out_kernels_) { | |||||
| init_ref_count += | |||||
| std::count_if(post_kernel->in_tensors_.begin(), post_kernel->in_tensors_.end(), | |||||
| [&tensor](const lite::Tensor *post_kernel_in_tensor) { return post_kernel_in_tensor == tensor; }); | |||||
| } | |||||
| tensor->set_init_ref_count(init_ref_count); | |||||
| } | } | ||||
| } | } | ||||
| @@ -61,15 +73,20 @@ int LiteKernel::DecOutTensorRefCount() { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| int LiteKernel::FreeWorkTensor() const { | |||||
| for (auto input_kernel : this->in_kernels()) { | |||||
| MS_ASSERT(input_kernel != nullptr); | |||||
| if (input_kernel->is_model_output()) { | |||||
| int LiteKernel::FreeInWorkTensor() const { | |||||
| for (auto &in_tensor : this->in_tensors_) { | |||||
| MS_ASSERT(in_tensor != nullptr); | |||||
| if (in_tensor->IsConst()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto ret = input_kernel->DecOutTensorRefCount(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed"; | |||||
| MS_ASSERT(in_tensor->ref_count() > 0); | |||||
| in_tensor->set_ref_count(in_tensor->ref_count() - 1); | |||||
| if (in_tensor->ref_count() <= 0) { | |||||
| auto ret = in_tensor->FreeData(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(ERROR) << "Free tensor data failed"; | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -91,15 +108,12 @@ int LiteKernel::PreProcess() { | |||||
| } | } | ||||
| } | } | ||||
| auto outputs = this->out_tensors(); | |||||
| for (auto *output : outputs) { | |||||
| for (auto *output : this->out_tensors()) { | |||||
| MS_ASSERT(output != nullptr); | MS_ASSERT(output != nullptr); | ||||
| if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) { | if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) { | ||||
| MS_LOG(ERROR) << "The size of output tensor is too big"; | MS_LOG(ERROR) << "The size of output tensor is too big"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto ret = output->MallocData(); | auto ret = output->MallocData(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "MallocData failed"; | MS_LOG(ERROR) << "MallocData failed"; | ||||
| @@ -109,6 +123,28 @@ int LiteKernel::PreProcess() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int LiteKernel::PostProcess() { | |||||
| #ifdef SUPPORT_TRAIN | |||||
| for (auto input_kernel : this->in_kernels()) { | |||||
| MS_ASSERT(input_kernel != nullptr); | |||||
| if (input_kernel->is_model_output()) { | |||||
| continue; | |||||
| } | |||||
| auto ret = input_kernel->DecOutTensorRefCount(); | |||||
| if (0 != ret) { | |||||
| MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed"; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| #else | |||||
| for (auto *output : this->out_tensors()) { | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->ResetRefCount(); | |||||
| } | |||||
| return FreeInWorkTensor(); | |||||
| #endif | |||||
| } | |||||
| int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) { | int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) { | ||||
| if (before != nullptr) { | if (before != nullptr) { | ||||
| if (!before(TensorVectorCast(this->in_tensors_), TensorVectorCast(this->out_tensors_), | if (!before(TensorVectorCast(this->in_tensors_), TensorVectorCast(this->out_tensors_), | ||||
| @@ -153,6 +189,28 @@ std::string LiteKernel::ToString() const { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| void LiteKernel::FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels) { | |||||
| // clean io kernels | |||||
| this->in_kernels_.clear(); | |||||
| this->out_kernels_.clear(); | |||||
| // find io kernels | |||||
| for (auto *scope_kernel : scope_kernels) { | |||||
| if (scope_kernel == this) { | |||||
| continue; | |||||
| } | |||||
| for (auto *tensor : this->in_tensors_) { | |||||
| if (lite::IsContain(scope_kernel->out_tensors(), tensor)) { | |||||
| this->AddInKernel(scope_kernel); | |||||
| } | |||||
| } | |||||
| for (auto *tensor : this->out_tensors_) { | |||||
| if (lite::IsContain(scope_kernel->in_tensors(), tensor)) { | |||||
| this->AddOutKernel(scope_kernel); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels( | std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels( | ||||
| const std::vector<kernel::LiteKernel *> &kernels) { | const std::vector<kernel::LiteKernel *> &kernels) { | ||||
| std::vector<kernel::LiteKernel *> input_kernels; | std::vector<kernel::LiteKernel *> input_kernels; | ||||
| @@ -202,7 +260,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||||
| if (outer_in_kernels.empty()) { | if (outer_in_kernels.empty()) { | ||||
| for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { | for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { | ||||
| if (!in_kernel_in_tensor->IsConst()) { | if (!in_kernel_in_tensor->IsConst()) { | ||||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||||
| if (!IsContain(input_tensors, in_kernel_in_tensor)) { | |||||
| input_tensors.push_back(in_kernel_in_tensor); | input_tensors.push_back(in_kernel_in_tensor); | ||||
| } | } | ||||
| } | } | ||||
| @@ -219,7 +277,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||||
| auto outer_in_kernel_out_tensors_iter = | auto outer_in_kernel_out_tensors_iter = | ||||
| std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); | std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); | ||||
| if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { | if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { | ||||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||||
| if (!IsContain(input_tensors, in_kernel_in_tensor)) { | |||||
| input_tensors.emplace_back(in_kernel_in_tensor); | input_tensors.emplace_back(in_kernel_in_tensor); | ||||
| } | } | ||||
| } | } | ||||
| @@ -237,7 +295,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||||
| auto &out_kernel_out_tensors = output_kernel->out_tensors(); | auto &out_kernel_out_tensors = output_kernel->out_tensors(); | ||||
| if (outer_out_kernels.empty()) { | if (outer_out_kernels.empty()) { | ||||
| for (auto out_kernel_out_tensor : out_kernel_out_tensors) { | for (auto out_kernel_out_tensor : out_kernel_out_tensors) { | ||||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||||
| if (!IsContain(output_tensors, out_kernel_out_tensor)) { | |||||
| output_tensors.push_back(out_kernel_out_tensor); | output_tensors.push_back(out_kernel_out_tensor); | ||||
| } | } | ||||
| } | } | ||||
| @@ -253,7 +311,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||||
| auto outer_out_kernel_in_tensors_iter = | auto outer_out_kernel_in_tensors_iter = | ||||
| std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); | std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); | ||||
| if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { | if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { | ||||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||||
| if (!IsContain(output_tensors, out_kernel_out_tensor)) { | |||||
| output_tensors.emplace_back(out_kernel_out_tensor); | output_tensors.emplace_back(out_kernel_out_tensor); | ||||
| } | } | ||||
| } | } | ||||
| @@ -299,33 +357,9 @@ int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *ke | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void LiteKernelUtil::InitIOKernels(std::vector<kernel::LiteKernel *> &kernels) { | |||||
| for (auto *kernel : kernels) { | |||||
| // clean io kernels | |||||
| kernel->set_in_kernels({}); | |||||
| kernel->set_out_kernels({}); | |||||
| // find io kernels | |||||
| for (auto *search_kernel : kernels) { | |||||
| if (search_kernel == kernel) { | |||||
| continue; | |||||
| } | |||||
| for (auto *tensor : kernel->in_tensors()) { | |||||
| if (lite::IsContain(search_kernel->out_tensors(), tensor)) { | |||||
| kernel->AddInKernel(search_kernel); | |||||
| } | |||||
| } | |||||
| for (auto *tensor : kernel->out_tensors()) { | |||||
| if (lite::IsContain(search_kernel->in_tensors(), tensor)) { | |||||
| kernel->AddOutKernel(search_kernel); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels) { | |||||
| void LiteKernelUtil::InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels) { | |||||
| for (auto *kernel : kernels) { | for (auto *kernel : kernels) { | ||||
| kernel->InitOutTensorRefCount(); | |||||
| kernel->InitOutTensorInitRefCount(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -87,10 +87,12 @@ class LiteKernel { | |||||
| virtual int Run(const KernelCallBack &before, const KernelCallBack &after); | virtual int Run(const KernelCallBack &before, const KernelCallBack &after); | ||||
| // called after Run | // called after Run | ||||
| virtual int PostProcess() { return FreeWorkTensor(); } | |||||
| virtual int PostProcess(); | |||||
| virtual int ReSize() { return mindspore::lite::RET_ERROR; } | virtual int ReSize() { return mindspore::lite::RET_ERROR; } | ||||
| virtual void FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels); | |||||
| virtual int Init() { return mindspore::lite::RET_ERROR; } | virtual int Init() { return mindspore::lite::RET_ERROR; } | ||||
| std::string name() const { return this->name_; } | std::string name() const { return this->name_; } | ||||
| @@ -154,11 +156,13 @@ class LiteKernel { | |||||
| const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; } | const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; } | ||||
| void InitOutTensorRefCount(); | |||||
| virtual bool IsReady(); | |||||
| virtual void InitOutTensorInitRefCount(); | |||||
| int DecOutTensorRefCount(); | int DecOutTensorRefCount(); | ||||
| int FreeWorkTensor() const; | |||||
| virtual int FreeInWorkTensor() const; | |||||
| KernelKey desc() const { return desc_; } | KernelKey desc() const { return desc_; } | ||||
| @@ -203,8 +207,6 @@ typedef LiteKernel *(*KernelCreator)(const std::vector<lite::Tensor *> &inputs, | |||||
| class LiteKernelUtil { | class LiteKernelUtil { | ||||
| public: | public: | ||||
| static void InitIOKernels(std::vector<kernel::LiteKernel *> &kernels); | |||||
| static std::vector<kernel::LiteKernel *> SubgraphInputKernels(const std::vector<kernel::LiteKernel *> &kernels); | static std::vector<kernel::LiteKernel *> SubgraphInputKernels(const std::vector<kernel::LiteKernel *> &kernels); | ||||
| static std::vector<kernel::LiteKernel *> SubgraphOutputKernels(const std::vector<kernel::LiteKernel *> &kernels); | static std::vector<kernel::LiteKernel *> SubgraphOutputKernels(const std::vector<kernel::LiteKernel *> &kernels); | ||||
| @@ -215,7 +217,7 @@ class LiteKernelUtil { | |||||
| static int TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels); | static int TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels); | ||||
| static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels); | |||||
| static void InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels); | |||||
| static int SetInput(LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs); | static int SetInput(LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs); | ||||
| }; | }; | ||||
| @@ -295,6 +295,21 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) { | |||||
| } | } | ||||
| } | } | ||||
| void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) { | |||||
| MS_ASSERT(model != nullptr); | |||||
| auto graph_out_size = model->sub_graphs_.front()->output_indices_.size(); | |||||
| for (size_t i = 0; i < graph_out_size; ++i) { | |||||
| size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i]; | |||||
| MS_ASSERT(graph_out_index < this->tensors_.size()); | |||||
| auto *out_tensor = this->tensors_.at(graph_out_index); | |||||
| if (out_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "out_tensor is null!"; | |||||
| return; | |||||
| } | |||||
| out_tensor->set_init_ref_count(out_tensor->init_ref_count() + 1); | |||||
| } | |||||
| } | |||||
| void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | ||||
| InitGraphInputTensors(model); | InitGraphInputTensors(model); | ||||
| InitGraphInputMSTensors(); | InitGraphInputMSTensors(); | ||||
| @@ -303,6 +318,7 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | |||||
| InitGraphOutputNodeMap(model); | InitGraphOutputNodeMap(model); | ||||
| InitGraphOutputTensorNames(model); | InitGraphOutputTensorNames(model); | ||||
| InitGraphOutputTensorMap(model); | InitGraphOutputTensorMap(model); | ||||
| AdjustModelOutputTensorInitRefCount(model); | |||||
| } | } | ||||
| int LiteSession::CompileGraph(Model *model) { | int LiteSession::CompileGraph(Model *model) { | ||||
| @@ -334,12 +350,9 @@ int LiteSession::CompileGraph(Model *model) { | |||||
| is_running_.store(false); | is_running_.store(false); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| InitGraphInOutTensors(model); | |||||
| // scheduler kernels | // scheduler kernels | ||||
| Scheduler scheduler(context_); | |||||
| ret = scheduler.Schedule(model, &tensors_, &kernels_); | |||||
| Scheduler scheduler(context_, model, tensors_); | |||||
| ret = scheduler.Schedule(&kernels_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Schedule kernels failed: " << ret; | MS_LOG(ERROR) << "Schedule kernels failed: " << ret; | ||||
| is_running_.store(false); | is_running_.store(false); | ||||
| @@ -353,6 +366,7 @@ int LiteSession::CompileGraph(Model *model) { | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| InitGraphInOutTensors(model); | |||||
| ret = executor_->Prepare(this->kernels_); | ret = executor_->Prepare(this->kernels_); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Prepare executor failed: " << ret; | MS_LOG(ERROR) << "Prepare executor failed: " << ret; | ||||
| @@ -563,6 +577,32 @@ void LiteSession::ResetInputsShape(const std::vector<std::vector<int>> &dims) { | |||||
| } | } | ||||
| } | } | ||||
| int LiteSession::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||||
| bool infer_shape_interrupt = false; | |||||
| for (auto kernel : kernels) { | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "input kernel is nullptr!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (kernel->subgraph_type() == kernel::kNotSubGraph) { | |||||
| MS_LOG(ERROR) << "All node in graph should be sub_graph"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); | |||||
| auto ret = sub_graph->ReSize(infer_shape_interrupt); | |||||
| if (ret == RET_INFER_INVALID) { | |||||
| MS_LOG(INFO) << "InferShape is interrupted"; | |||||
| infer_shape_interrupt = true; | |||||
| continue; | |||||
| } | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs, | int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs, | ||||
| const std::vector<std::vector<int>> &dims) { | const std::vector<std::vector<int>> &dims) { | ||||
| bool expected = false; | bool expected = false; | ||||
| @@ -581,11 +621,10 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| Scheduler scheduler(context_); | |||||
| ret = scheduler.ReSizeKernels(kernels_); | |||||
| ret = ReSizeKernels(kernels_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| ResetInputsShape(old_dims); | ResetInputsShape(old_dims); | ||||
| auto resize_ret = scheduler.ReSizeKernels(kernels_); | |||||
| auto resize_ret = ReSizeKernels(kernels_); | |||||
| if (resize_ret != RET_OK) { | if (resize_ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret; | MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret; | ||||
| } | } | ||||
| @@ -92,10 +92,14 @@ class LiteSession : public session::LiteSession { | |||||
| void InitGraphOutputTensorMap(const lite::Model *model); | void InitGraphOutputTensorMap(const lite::Model *model); | ||||
| void AdjustModelOutputTensorInitRefCount(const lite::Model *model); | |||||
| int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims); | int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims); | ||||
| int PrepareKernels(); | int PrepareKernels(); | ||||
| static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||||
| private: | private: | ||||
| void ResetInputsShape(const std::vector<std::vector<int>> &dims); | void ResetInputsShape(const std::vector<std::vector<int>> &dims); | ||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/merge.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Merge::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Merge; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Merge) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| this->primitive_->value.value = new (std::nothrow) schema::MergeT(); | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int Merge::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Merge(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Merge return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateMerge(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Merge, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Merge>(primitive); } | |||||
| Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); | |||||
| #endif | |||||
| int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| MS_ASSERT(outputs_.size() == 1); | |||||
| MS_ASSERT(inputs_.size() == 2); | |||||
| outputs_[0]->set_data_type(inputs_[0]->data_type()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Merge : public PrimitiveC { | |||||
| public: | |||||
| Merge() = default; | |||||
| ~Merge() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Merge, PrimitiveC); | |||||
| explicit Merge(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ | |||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/partial.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Partial::GetSubGraphIndex() const { return this->primitive_->value.AsPartial()->subGraphIndex; } | |||||
| int Partial::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Partial; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Partial) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::PartialT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int Partial::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Partial(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Partial return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreatePartial(*fbb, attr->subGraphIndex()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Partial, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| int Partial::GetSubGraphIndex() const { return this->primitive_->value_as_Partial()->subGraphIndex(); } | |||||
| PrimitiveC *PartialCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Partial>(primitive); } | |||||
| Registry PartialRegistry(schema::PrimitiveType_Partial, PartialCreator); | |||||
| #endif | |||||
| int Partial::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Partial : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Partial, PrimitiveC); | |||||
| Partial() = default; | |||||
| explicit Partial(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| Partial() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| int GetSubGraphIndex() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateMergeParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *merge_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (merge_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc Merge parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(merge_parameter, 0, sizeof(OpParameter)); | |||||
| merge_parameter->type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(merge_parameter); | |||||
| } | |||||
| Registry MergeParameterRegistry(schema::PrimitiveType_Merge, PopulateMergeParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/partial.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| typedef struct PartialParameter { | |||||
| OpParameter op_parameter_; | |||||
| int sub_graph_index_; | |||||
| } PartialParameter; | |||||
| OpParameter *PopulatePartialParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter))); | |||||
| if (partial_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc partial parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(partial_parameter, 0, sizeof(PartialParameter)); | |||||
| partial_parameter->op_parameter_.type_ = primitive->Type(); | |||||
| auto param = reinterpret_cast<mindspore::lite::Partial *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| partial_parameter->sub_graph_index_ = param->GetSubGraphIndex(); | |||||
| return reinterpret_cast<OpParameter *>(partial_parameter); | |||||
| } | |||||
| Registry PartialParameterRegistry(schema::PrimitiveType_Partial, PopulatePartialParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/switch.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateSwitchParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| OpParameter *switch_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| if (switch_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc SwitchParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(switch_parameter, 0, sizeof(OpParameter)); | |||||
| switch_parameter->type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(switch_parameter); | |||||
| } | |||||
| Registry SwitchParameterRegistry(schema::PrimitiveType_Switch, PopulateSwitchParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -155,6 +155,9 @@ | |||||
| #include "src/ops/tensorlistsetitem.h" | #include "src/ops/tensorlistsetitem.h" | ||||
| #include "src/ops/tensorlistreserve.h" | #include "src/ops/tensorlistreserve.h" | ||||
| #include "src/ops/tensorliststack.h" | #include "src/ops/tensorliststack.h" | ||||
| #include "src/ops/merge.h" | |||||
| #include "src/ops/switch.h" | |||||
| #include "src/ops/partial.h" | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| @@ -925,7 +928,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) TensorListReserve(primitive); | return new (std::nothrow) TensorListReserve(primitive); | ||||
| case schema::PrimitiveType_TensorListStack: | case schema::PrimitiveType_TensorListStack: | ||||
| return new (std::nothrow) TensorListStack(primitive); | return new (std::nothrow) TensorListStack(primitive); | ||||
| case schema::PrimitiveType_Switch: | |||||
| return new (std::nothrow) Switch(primitive); | |||||
| case schema::PrimitiveType_Merge: | |||||
| return new (std::nothrow) Merge(primitive); | |||||
| case schema::PrimitiveType_Partial: | |||||
| return new (std::nothrow) Partial(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| return new (std::nothrow) ActivationGrad(primitive); | return new (std::nothrow) ActivationGrad(primitive); | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/switch.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Switch::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Switch; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Switch) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::SwitchT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int Switch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Switch(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Switch return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateSwitch(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Switch, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Switch>(primitive); } | |||||
| Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator); | |||||
| #endif | |||||
| int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Switch : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Switch, PrimitiveC); | |||||
| Switch() = default; | |||||
| explicit Switch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| #else | |||||
| Switch() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ | |||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/merge.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Merge; | |||||
| namespace mindspore::kernel { | |||||
| // if one of input of merge is const-tensor, merge is always ready, this will cause error. | |||||
| bool MergeCPUKernel::IsReady() { | |||||
| MS_ASSERT(in_tensors().size() == 2); | |||||
| return std::any_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { | |||||
| return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; | |||||
| }); | |||||
| } | |||||
| int MergeCPUKernel::Init() { return RET_OK; } | |||||
| int MergeCPUKernel::ReSize() { return RET_ERROR; } | |||||
| int MergeCPUKernel::Run() { | |||||
| MS_ASSERT(in_tensors_.size() == 2); | |||||
| MS_ASSERT(out_tensors_.size() == 1); | |||||
| auto out_data = out_tensors_.front()->data_c(); | |||||
| MS_ASSERT(out_data != nullptr); | |||||
| for (size_t i = 0; i < in_tensors().size(); i++) { | |||||
| if (in_tensors()[i]->data_c() != nullptr) { | |||||
| auto in_data = in_tensors_[i]->data_c(); | |||||
| MS_ASSERT(in_data != nullptr); | |||||
| memcpy(out_data, in_data, in_tensors_[i]->Size()); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuMergeKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||||
| const lite::InnerContext *ctx, const KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| if (desc.type != PrimitiveType_Merge) { | |||||
| MS_LOG(ERROR) << "type in desc is not Merge"; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| if (ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "ctx is nullptr"; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| auto *kernel = new (std::nothrow) MergeCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, CpuMergeKernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, CpuMergeKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| typedef struct MergeParameter { | |||||
| OpParameter op_parameter_; | |||||
| } MergeParameter; | |||||
| class MergeCPUKernel : public LiteKernel { | |||||
| public: | |||||
| MergeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| merge_param_ = reinterpret_cast<MergeParameter *>(op_parameter_); | |||||
| } | |||||
| ~MergeCPUKernel() override {} | |||||
| bool IsReady() override; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| MergeParameter *merge_param_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ | |||||
| @@ -0,0 +1,115 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/base/switch.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Switch; | |||||
| namespace mindspore::kernel { | |||||
| int SwitchCPUKernel::PostProcess() { | |||||
| auto bool_tensor = in_tensors_.front(); | |||||
| MS_ASSERT(bool_tensor != nullptr); | |||||
| MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); | |||||
| MS_ASSERT(bool_tensor->shape().size() == 1); | |||||
| MS_ASSERT(bool_tensor->shape().front() == 1); | |||||
| auto *active = static_cast<bool *>(bool_tensor->data_c()); | |||||
| if (active == nullptr) { | |||||
| MS_LOG(ERROR) << "data of bool tensor is nullptr"; | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| size_t in_index = 1; | |||||
| size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2); | |||||
| while (in_index < in_tensors_.size()) { | |||||
| in_index++; | |||||
| auto out_tensor = out_tensors_.at(out_index++); | |||||
| out_tensor->ResetRefCount(); | |||||
| } | |||||
| return FreeInWorkTensor(); | |||||
| } | |||||
| int SwitchCPUKernel::Init() { return RET_OK; } | |||||
| int SwitchCPUKernel::ReSize() { return RET_ERROR; } | |||||
| // inputs: bool*1 data*n | |||||
| // output: true-data*n, false-data*n | |||||
| int SwitchCPUKernel::Run() { | |||||
| MS_ASSERT(in_tensors_.size() >= 2); | |||||
| MS_ASSERT(out_tensors_.size() == 2 * in_tensors_.size()); | |||||
| auto bool_tensor = in_tensors_.front(); | |||||
| MS_ASSERT(bool_tensor != nullptr); | |||||
| MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); | |||||
| MS_ASSERT(bool_tensor->shape().size() == 1); | |||||
| MS_ASSERT(bool_tensor->shape().front() == 1); | |||||
| auto active = static_cast<bool *>(bool_tensor->data_c()); | |||||
| if (active == nullptr) { | |||||
| MS_LOG(ERROR) << "data of bool tensor is nullptr"; | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| size_t in_index = 1; | |||||
| size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2); | |||||
| while (in_index < in_tensors_.size()) { | |||||
| auto in_tensor = in_tensors_.at(in_index++); | |||||
| auto out_tensor = out_tensors_.at(out_index++); | |||||
| MS_ASSERT(in_tensor != nullptr); | |||||
| MS_ASSERT(out_tensor != nullptr); | |||||
| auto input = reinterpret_cast<float *>(in_tensor->data_c()); | |||||
| auto output = reinterpret_cast<float *>(out_tensor->data_c()); | |||||
| MS_ASSERT(in_tensor->Size() == out_tensor->Size()); | |||||
| if (input == nullptr || output == nullptr) { | |||||
| MS_LOG(ERROR) << "input tensor or output tensor have not been malloced"; | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| memcpy(output, input, in_tensor->Size()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuSwitchKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, | |||||
| const lite::InnerContext *ctx, const KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| if (desc.type != PrimitiveType_Switch) { | |||||
| MS_LOG(ERROR) << "type in desc is not Switch"; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| if (ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "ctx is nullptr"; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| auto *kernel = new (std::nothrow) SwitchCPUKernel(parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; | |||||
| free(parameter); | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, CpuSwitchKernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, CpuSwitchKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| typedef struct SwitchParameter { | |||||
| OpParameter op_parameter_; | |||||
| } SwitchParameter; | |||||
| class SwitchCPUKernel : public LiteKernel { | |||||
| public: | |||||
| SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| switch_param_ = reinterpret_cast<SwitchParameter *>(op_parameter_); | |||||
| } | |||||
| ~SwitchCPUKernel() override = default; | |||||
| int PostProcess() override; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| SwitchParameter *switch_param_ = nullptr; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ | |||||
| @@ -71,6 +71,14 @@ int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int OpenCLKernel::PostProcess() { | |||||
| for (auto *output : this->out_tensors()) { | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->ResetRefCount(); | |||||
| } | |||||
| return FreeInWorkTensor(); | |||||
| } | |||||
| std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() { | std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() { | ||||
| size_t ndim = global_size_.size(); | size_t ndim = global_size_.size(); | ||||
| std::vector<BaseTuningParameter> tuning_params = {}; | std::vector<BaseTuningParameter> tuning_params = {}; | ||||
| @@ -164,6 +164,7 @@ class OpenCLKernel : public LiteKernel { | |||||
| int Prepare() override { return RET_OK; } | int Prepare() override { return RET_OK; } | ||||
| int PreProcess() override { return RET_ERROR; } | int PreProcess() override { return RET_ERROR; } | ||||
| int PostProcess() override; | |||||
| int ReSize() override { return RET_ERROR; } | int ReSize() override { return RET_ERROR; } | ||||
| int Run() override { return RET_ERROR; } | int Run() override { return RET_ERROR; } | ||||
| @@ -36,7 +36,6 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||||
| if (is_tune) { | if (is_tune) { | ||||
| opencl_runtime_ins->SetProfiling(true); | opencl_runtime_ins->SetProfiling(true); | ||||
| } | } | ||||
| kernel::LiteKernelUtil::InitTensorRefCount(kernels); | |||||
| for (auto *kernel : kernels) { | for (auto *kernel : kernels) { | ||||
| MS_ASSERT(kernel); | MS_ASSERT(kernel); | ||||
| CallBackParam callbackParam; | CallBackParam callbackParam; | ||||
| @@ -82,6 +81,11 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||||
| MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); | MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = kernel->PostProcess(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); | |||||
| return ret; | |||||
| } | |||||
| if (profiling_tmp) { | if (profiling_tmp) { | ||||
| MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() | MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() | ||||
| << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; | << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; | ||||
| @@ -92,13 +96,6 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||||
| MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name(); | MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name(); | ||||
| } | } | ||||
| } | } | ||||
| for (auto input_kernel : kernel->in_kernels()) { | |||||
| MS_ASSERT(input_kernel); | |||||
| ret = input_kernel->DecOutTensorRefCount(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| opencl_runtime_ins->SetProfiling(profiling_tmp); | opencl_runtime_ins->SetProfiling(profiling_tmp); | ||||
| return ret; | return ret; | ||||
| @@ -40,9 +40,9 @@ static int RunKernel(void *data, int index) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| ret = kernel->FreeWorkTensor(); | |||||
| ret = kernel->FreeInWorkTensor(); | |||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << kernel->name(); | |||||
| MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << kernel->name(); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| return 0; | return 0; | ||||
| @@ -62,7 +62,7 @@ int ParallelExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| kernel::LiteKernelUtil::InitTensorRefCount(kernels); | |||||
| kernel::LiteKernelUtil::InitTensorInitRefCount(kernels); | |||||
| for (auto kernel : kernels) { | for (auto kernel : kernels) { | ||||
| if (kernel->in_kernels().empty()) { | if (kernel->in_kernels().empty()) { | ||||
| @@ -96,9 +96,9 @@ int ParallelExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor | |||||
| } | } | ||||
| } | } | ||||
| auto ret = completed->FreeWorkTensor(); | |||||
| auto ret = completed->FreeInWorkTensor(); | |||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << completed->name(); | |||||
| MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << completed->name(); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <queue> | #include <queue> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/partial.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/graph_util.h" | #include "src/common/graph_util.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -36,152 +37,255 @@ namespace mindspore::lite { | |||||
| using kernel::KERNEL_ARCH::kCPU; | using kernel::KERNEL_ARCH::kCPU; | ||||
| using kernel::KERNEL_ARCH::kGPU; | using kernel::KERNEL_ARCH::kGPU; | ||||
| using kernel::KERNEL_ARCH::kNPU; | using kernel::KERNEL_ARCH::kNPU; | ||||
| constexpr int kMainSubGraphIndex = 0; | |||||
| int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors, | |||||
| std::vector<kernel::LiteKernel *> *kernels) { | |||||
| int ret = InferShape(model, tensors); | |||||
| int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) { | |||||
| if (src_model_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Input model is nullptr"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (src_model_->sub_graphs_.empty()) { | |||||
| MS_LOG(ERROR) << "Model should have a subgraph at least"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_); | |||||
| bool infer_shape_interrupt = false; | |||||
| auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "op infer shape failed."; | MS_LOG(ERROR) << "op infer shape failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = BuildKernels(model, tensors, kernels); | |||||
| ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "init op to kernel failed."; | |||||
| MS_LOG(ERROR) << "Schedule main subgraph to kernels failed."; | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| kernel::LiteKernelUtil::InitIOKernels(*kernels); | |||||
| ret = ConstructSubGraphs(kernels); | |||||
| FindAllInoutKernels(*dst_kernels); | |||||
| ret = ConstructSubGraphs(dst_kernels); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConstructSubGraphs failed."; | MS_LOG(ERROR) << "ConstructSubGraphs failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| kernel::LiteKernelUtil::InitIOKernels(*kernels); | |||||
| FindAllInoutKernels(*dst_kernels); | |||||
| kernel::LiteKernelUtil::InitTensorInitRefCount(*dst_kernels); | |||||
| MS_LOG(DEBUG) << "schedule kernels success."; | MS_LOG(DEBUG) << "schedule kernels success."; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||||
| bool infer_shape_interrupt = false; | |||||
| for (auto kernel : kernels) { | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "input kernel is nullptr!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (kernel->subgraph_type() == kernel::kNotSubGraph) { | |||||
| MS_LOG(ERROR) << "All node in graph should be sub_graph"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); | |||||
| auto ret = sub_graph->ReSize(infer_shape_interrupt); | |||||
| if (ret == RET_INFER_INVALID) { | |||||
| MS_LOG(INFO) << "InferShape is interrupted"; | |||||
| infer_shape_interrupt = true; | |||||
| continue; | |||||
| } | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed"; | |||||
| return RET_ERROR; | |||||
| void Scheduler::FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs, | |||||
| std::vector<Tensor *> *outputs) { | |||||
| MS_ASSERT(inputs != nullptr); | |||||
| MS_ASSERT(outputs != nullptr); | |||||
| auto in_size = node.input_indices_.size(); | |||||
| inputs->reserve(in_size); | |||||
| for (size_t j = 0; j < in_size; ++j) { | |||||
| inputs->emplace_back(src_tensors_.at(node.input_indices_[j])); | |||||
| } | |||||
| auto out_size = node.output_indices_.size(); | |||||
| outputs->reserve(out_size); | |||||
| for (size_t j = 0; j < out_size; ++j) { | |||||
| outputs->emplace_back(src_tensors_.at(node.output_indices_[j])); | |||||
| } | |||||
| } | |||||
| int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt) { | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(infer_shape_interrupt != nullptr); | |||||
| auto primitive = node->primitive_; | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| if (primitive->Type() == schema::PrimitiveType_Partial) { | |||||
| return InferPartialShape(node, infer_shape_interrupt); | |||||
| } | |||||
| std::vector<Tensor *> inputs; | |||||
| std::vector<Tensor *> outputs; | |||||
| FindNodeInoutTensors(*node, &inputs, &outputs); | |||||
| bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { | |||||
| auto shape = tensor->shape(); | |||||
| return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); | |||||
| }); | |||||
| if (!infer_valid) { | |||||
| *infer_shape_interrupt = true; | |||||
| } | |||||
| primitive->set_infer_flag(!(*infer_shape_interrupt)); | |||||
| auto ret = primitive->InferShape(inputs, outputs); | |||||
| if (ret == RET_OK) { | |||||
| for (auto &output : outputs) { | |||||
| if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) { | |||||
| MS_LOG(ERROR) << "The size of output tensor is too big"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| return ret; | |||||
| } | } | ||||
| int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tensors) { | |||||
| MS_ASSERT(model != nullptr); | |||||
| MS_ASSERT(tensors != nullptr); | |||||
| bool infer_shape_interrupt = false; | |||||
| uint32_t kernelCount = model->all_nodes_.size(); | |||||
| for (uint32_t i = 0; i < kernelCount; ++i) { | |||||
| auto node = model->all_nodes_[i]; | |||||
| int Scheduler::InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt) { | |||||
| MS_ASSERT(src_model_ != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(infer_shape_interrupt != nullptr); | |||||
| auto primitive = node->primitive_; | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| if (primitive->Type() != schema::PrimitiveType_Partial) { | |||||
| MS_LOG(ERROR) << "Node is not a partial"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto partial_primitive = reinterpret_cast<lite::Partial *>(node->primitive_); | |||||
| return InferSubGraphShape(partial_primitive->GetSubGraphIndex(), infer_shape_interrupt); | |||||
| } | |||||
| int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt) { | |||||
| MS_ASSERT(infer_shape_interrupt != nullptr); | |||||
| MS_ASSERT(src_model_ != nullptr); | |||||
| MS_ASSERT(!src_model_->sub_graphs_.empty()); | |||||
| MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index); | |||||
| auto subgraph = src_model_->sub_graphs_.at(subgraph_index); | |||||
| for (auto node_index : subgraph->node_indices_) { | |||||
| auto node = src_model_->all_nodes_[node_index]; | |||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| std::vector<Tensor *> inputs; | |||||
| std::vector<Tensor *> outputs; | |||||
| auto in_size = node->input_indices_.size(); | |||||
| inputs.reserve(in_size); | |||||
| for (size_t j = 0; j < in_size; ++j) { | |||||
| inputs.emplace_back(tensors->at(node->input_indices_[j])); | |||||
| } | |||||
| auto out_size = node->output_indices_.size(); | |||||
| outputs.reserve(out_size); | |||||
| for (size_t j = 0; j < out_size; ++j) { | |||||
| outputs.emplace_back(tensors->at(node->output_indices_[j])); | |||||
| } | |||||
| auto *primitive = node->primitive_; | auto *primitive = node->primitive_; | ||||
| if (primitive == nullptr) { | if (primitive == nullptr) { | ||||
| MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!"; | MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { | |||||
| auto shape = tensor->shape(); | |||||
| return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); | |||||
| }); | |||||
| if (!infer_valid) { | |||||
| infer_shape_interrupt = true; | |||||
| } | |||||
| primitive->set_infer_flag(!infer_shape_interrupt); | |||||
| auto ret = primitive->InferShape(inputs, outputs); | |||||
| auto ret = InferNodeShape(node, infer_shape_interrupt); | |||||
| if (ret == RET_INFER_INVALID) { | if (ret == RET_INFER_INVALID) { | ||||
| MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_ | |||||
| MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_ | |||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())) | << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())) | ||||
| << "flag set to false."; | |||||
| << ", set infer flag to false."; | |||||
| primitive->set_infer_flag(false); | primitive->set_infer_flag(false); | ||||
| infer_shape_interrupt = true; | |||||
| *infer_shape_interrupt = true; | |||||
| } else if (ret != RET_OK) { | } else if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "InferShape failed, name: " << node->name_ << ", type: " | MS_LOG(ERROR) << "InferShape failed, name: " << node->name_ << ", type: " | ||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | ||||
| return RET_INFER_ERR; | return RET_INFER_ERR; | ||||
| } else { | |||||
| for (auto &output : outputs) { | |||||
| if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) { | |||||
| MS_LOG(ERROR) << "The size of output tensor is too big"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Scheduler::BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors, | |||||
| std::vector<kernel::LiteKernel *> *kernels) { | |||||
| MS_ASSERT(model != nullptr); | |||||
| MS_ASSERT(tensors != nullptr); | |||||
| uint32_t kernelCount = model->all_nodes_.size(); | |||||
| auto graph_output_node_indexes = GetGraphOutputNodes(model); | |||||
| for (uint32_t i = 0; i < kernelCount; ++i) { | |||||
| auto node = model->all_nodes_[i]; | |||||
| MS_ASSERT(node != nullptr); | |||||
| std::vector<Tensor *> inputs; | |||||
| std::vector<Tensor *> outputs; | |||||
| auto in_size = node->input_indices_.size(); | |||||
| inputs.reserve(in_size); | |||||
| for (size_t j = 0; j < in_size; ++j) { | |||||
| inputs.emplace_back(tensors->at(node->input_indices_[j])); | |||||
| kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors, | |||||
| const std::vector<Tensor *> &out_tensors, | |||||
| const mindspore::lite::PrimitiveC *primitive, | |||||
| const Model::Node *node) { | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | |||||
| kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | |||||
| #if SUPPORT_GPU | |||||
| if (context_->IsGpuEnabled()) { | |||||
| kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); | |||||
| if (kernel != nullptr) { | |||||
| MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; | |||||
| return kernel; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " | |||||
| << node->name_; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| #if SUPPORT_NPU | |||||
| if (context_->IsNpuEnabled()) { | |||||
| kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); | |||||
| if (kernel != nullptr) { | |||||
| MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_; | |||||
| return kernel; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " | |||||
| << node->name_; | |||||
| } | } | ||||
| auto out_size = node->output_indices_.size(); | |||||
| outputs.reserve(out_size); | |||||
| for (size_t j = 0; j < out_size; ++j) { | |||||
| outputs.emplace_back(tensors->at(node->output_indices_[j])); | |||||
| } | |||||
| #endif | |||||
| if (mindspore::lite::IsSupportFloat16() && | |||||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | |||||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | |||||
| auto *kernel = | |||||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | |||||
| if (kernel != nullptr) { | |||||
| MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " | |||||
| << node->name_; | |||||
| return kernel; | |||||
| } | } | ||||
| } | |||||
| if (data_type == kNumberTypeFloat16) { | |||||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | |||||
| desc.data_type = kNumberTypeFloat32; | |||||
| } | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||||
| if (kernel != nullptr) { | |||||
| return kernel; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *src_node) { | |||||
| MS_ASSERT(src_model_ != nullptr); | |||||
| MS_ASSERT(src_node != nullptr); | |||||
| auto *primitive = src_node->primitive_; | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| if (primitive->Type() != schema::PrimitiveType_Partial) { | |||||
| return nullptr; | |||||
| } | |||||
| auto partial_primitive = reinterpret_cast<lite::Partial *>(primitive); | |||||
| auto sub_graph_index = partial_primitive->GetSubGraphIndex(); | |||||
| std::vector<kernel::LiteKernel *> sub_kernels; | |||||
| auto ret = ScheduleSubGraphToKernels(sub_graph_index, &sub_kernels); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_; | |||||
| return nullptr; | |||||
| } | |||||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front()); | |||||
| // for kernel::LiteKernelUtil::SubgraphInputTensors in CreateSubGraphKernel | |||||
| FindAllInoutKernels(sub_kernels); | |||||
| auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type); | |||||
| subgraph->set_name("subgraph_" + src_node->name_); | |||||
| return subgraph; | |||||
| } | |||||
| kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src_node) { | |||||
| auto *primitive = src_node->primitive_; | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| std::vector<Tensor *> inputs; | |||||
| std::vector<Tensor *> outputs; | |||||
| FindNodeInoutTensors(*src_node, &inputs, &outputs); | |||||
| auto *kernel = this->FindBackendKernel(inputs, outputs, primitive, src_node); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_ | |||||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | |||||
| return nullptr; | |||||
| } | |||||
| SetKernelTensorDataType(kernel); | |||||
| kernel->set_name(src_node->name_); | |||||
| return kernel; | |||||
| } | |||||
| int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels) { | |||||
| MS_ASSERT(src_model_ != nullptr); | |||||
| MS_ASSERT(!src_model_->sub_graphs_.empty()); | |||||
| MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index); | |||||
| MS_ASSERT(dst_kernels != nullptr); | |||||
| MS_ASSERT(dst_kernels->empty()); | |||||
| auto subgraph = src_model_->sub_graphs_.at(subgraph_index); | |||||
| for (auto node_index : subgraph->node_indices_) { | |||||
| auto node = src_model_->all_nodes_[node_index]; | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto *primitive = node->primitive_; | auto *primitive = node->primitive_; | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| auto *kernel = this->ScheduleNode(inputs, outputs, primitive, node); | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (primitive->Type() == schema::PrimitiveType_Partial) { // sub_graph | |||||
| kernel = SchedulePartialToKernel(node); | |||||
| } else { // kernel | |||||
| kernel = ScheduleNodeToKernel(node); | |||||
| } | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << node->name_ << ", type: " | |||||
| MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << node->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| SetKernelTensorDataType(kernel); | |||||
| kernel->set_name(node->name_); | |||||
| kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i))); | |||||
| kernels->emplace_back(kernel); | |||||
| kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index))); | |||||
| dst_kernels->emplace_back(kernel); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -190,6 +294,11 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||||
| MS_ASSERT(head_kernel != nullptr); | MS_ASSERT(head_kernel != nullptr); | ||||
| MS_ASSERT(sinked_kernel_map != nullptr); | MS_ASSERT(sinked_kernel_map != nullptr); | ||||
| std::vector<kernel::LiteKernel *> sub_kernels; | std::vector<kernel::LiteKernel *> sub_kernels; | ||||
| if (head_kernel->Type() == schema::PrimitiveType_Switch || head_kernel->Type() == schema::PrimitiveType_Merge) { | |||||
| (*sinked_kernel_map)[head_kernel] = true; | |||||
| sub_kernels.emplace_back(head_kernel); | |||||
| return sub_kernels; | |||||
| } | |||||
| std::queue<kernel::LiteKernel *> kernel_queue; | std::queue<kernel::LiteKernel *> kernel_queue; | ||||
| kernel_queue.emplace(head_kernel); | kernel_queue.emplace(head_kernel); | ||||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | ||||
| @@ -200,6 +309,10 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||||
| sub_kernels.emplace_back(cur_kernel); | sub_kernels.emplace_back(cur_kernel); | ||||
| auto post_kernels = cur_kernel->out_kernels(); | auto post_kernels = cur_kernel->out_kernels(); | ||||
| for (auto post_kernel : post_kernels) { | for (auto post_kernel : post_kernels) { | ||||
| if (post_kernel->subgraph_type() != kernel::kNotSubGraph || post_kernel->Type() == schema::PrimitiveType_Merge || | |||||
| post_kernel->Type() == schema::PrimitiveType_Switch) { | |||||
| continue; | |||||
| } | |||||
| if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) { | if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) { | ||||
| auto post_kernel_inputs = post_kernel->in_kernels(); | auto post_kernel_inputs = post_kernel->in_kernels(); | ||||
| if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(), | if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(), | ||||
| @@ -215,28 +328,41 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels( | |||||
| int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { | int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { | ||||
| auto old_kernels = *kernels; | auto old_kernels = *kernels; | ||||
| kernels->clear(); | kernels->clear(); | ||||
| std::map<const kernel::LiteKernel *, bool> is_kernel_sinked; | |||||
| std::map<const kernel::LiteKernel *, bool> is_kernel_finish; | |||||
| for (auto kernel : old_kernels) { | for (auto kernel : old_kernels) { | ||||
| is_kernel_sinked[kernel] = false; | |||||
| is_kernel_finish[kernel] = false; | |||||
| } | } | ||||
| while (true) { | while (true) { | ||||
| auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) { | auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) { | ||||
| auto kernel_inputs = kernel->in_kernels(); | auto kernel_inputs = kernel->in_kernels(); | ||||
| return !is_kernel_sinked[kernel] && | |||||
| std::all_of(kernel_inputs.begin(), kernel_inputs.end(), | |||||
| [&](kernel::LiteKernel *kernel) { return is_kernel_sinked[kernel]; }); | |||||
| if (is_kernel_finish[kernel]) { | |||||
| return false; | |||||
| } | |||||
| // when merge is removed, this if is removed automatically | |||||
| if (kernel->Type() == schema::PrimitiveType_Merge) { | |||||
| MS_ASSERT(kernel->in_kernels().size() == 2); | |||||
| return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]); | |||||
| } else { | |||||
| return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), | |||||
| [&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; }); | |||||
| } | |||||
| }); | }); | ||||
| if (head_kernel_iter == old_kernels.end()) { | if (head_kernel_iter == old_kernels.end()) { | ||||
| break; | break; | ||||
| } | } | ||||
| auto head_kernel = *head_kernel_iter; | auto head_kernel = *head_kernel_iter; | ||||
| if (head_kernel->subgraph_type() != kernel::kNotSubGraph) { | |||||
| is_kernel_finish[head_kernel] = true; | |||||
| kernels->emplace_back(head_kernel); | |||||
| continue; | |||||
| } | |||||
| if (head_kernel->desc().arch == mindspore::kernel::kAPU) { | if (head_kernel->desc().arch == mindspore::kernel::kAPU) { | ||||
| MS_LOG(ERROR) << "Not support APU now"; | MS_LOG(ERROR) << "Not support APU now"; | ||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| } | } | ||||
| auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); | ||||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_sinked); | |||||
| auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_finish); | |||||
| auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type); | auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type); | ||||
| if (subgraph == nullptr) { | if (subgraph == nullptr) { | ||||
| MS_LOG(ERROR) << "Create SubGraphKernel failed"; | MS_LOG(ERROR) << "Create SubGraphKernel failed"; | ||||
| @@ -296,60 +422,6 @@ kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tensors, | |||||
| const std::vector<Tensor *> &out_tensors, | |||||
| const mindspore::lite::PrimitiveC *primitive, const Model::Node *node) { | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); | |||||
| kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | |||||
| #if SUPPORT_NPU | |||||
| if (context_->IsNpuEnabled()) { | |||||
| kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); | |||||
| if (kernel != nullptr) { | |||||
| MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_; | |||||
| return kernel; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " | |||||
| << node->name_; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| #if SUPPORT_GPU | |||||
| if (context_->IsGpuEnabled()) { | |||||
| kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); | |||||
| if (kernel != nullptr) { | |||||
| MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; | |||||
| return kernel; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " | |||||
| << node->name_; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| if (mindspore::lite::IsSupportFloat16() && | |||||
| ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { | |||||
| kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; | |||||
| auto *kernel = | |||||
| KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); | |||||
| if (kernel != nullptr) { | |||||
| MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " | |||||
| << node->name_; | |||||
| return kernel; | |||||
| } | |||||
| } | |||||
| if (data_type == kNumberTypeFloat16) { | |||||
| MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; | |||||
| desc.data_type = kNumberTypeFloat32; | |||||
| } | |||||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||||
| if (kernel != nullptr) { | |||||
| return kernel; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) { | TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) { | ||||
| for (const auto &tensor : in_tensors) { | for (const auto &tensor : in_tensors) { | ||||
| auto dtype = tensor->data_type(); | auto dtype = tensor->data_type(); | ||||
| @@ -411,4 +483,11 @@ kernel::SubGraphType Scheduler::GetKernelSubGraphType(const kernel::LiteKernel * | |||||
| } | } | ||||
| return kernel::kNotSubGraph; | return kernel::kNotSubGraph; | ||||
| } | } | ||||
| void Scheduler::FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||||
| for (auto *kernel : kernels) { | |||||
| MS_ASSERT(kernel != nullptr); | |||||
| kernel->FindInoutKernels(kernels); | |||||
| } | |||||
| } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -17,6 +17,7 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_SCHEDULER_H_ | #ifndef MINDSPORE_LITE_SRC_SCHEDULER_H_ | ||||
| #define MINDSPORE_LITE_SRC_SCHEDULER_H_ | #define MINDSPORE_LITE_SRC_SCHEDULER_H_ | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include "src/sub_graph_kernel.h" | #include "src/sub_graph_kernel.h" | ||||
| @@ -27,30 +28,47 @@ | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class Scheduler { | class Scheduler { | ||||
| public: | public: | ||||
| explicit Scheduler(const InnerContext *ctx) { context_ = const_cast<InnerContext *>(ctx); } | |||||
| Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> src_tensors) | |||||
| : context_(ctx), src_model_(src_model), src_tensors_(std::move(src_tensors)) {} | |||||
| ~Scheduler() = default; | ~Scheduler() = default; | ||||
| int Schedule(const lite::Model *model, std::vector<Tensor *> *tensors, std::vector<kernel::LiteKernel *> *kernels); | |||||
| static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||||
| protected: | |||||
| kernel::LiteKernel *ScheduleNode(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, | |||||
| const mindspore::lite::PrimitiveC *primitive, const Model::Node *cnode); | |||||
| int BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors, | |||||
| std::vector<kernel::LiteKernel *> *kernels); | |||||
| static int InferShape(const lite::Model *model, std::vector<Tensor *> *tensors); | |||||
| int Schedule(std::vector<kernel::LiteKernel *> *dst_kernels); | |||||
| private: | |||||
| void FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs, | |||||
| std::vector<Tensor *> *outputs); | |||||
| // infer shape for a partial node | |||||
| int InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt); | |||||
| // infer shape for a node | |||||
| int InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt); | |||||
| // infer shape for a subgraph | |||||
| int InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt); | |||||
| // schedule a node to kernel according to context and kernels registered | |||||
| kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors, | |||||
| const std::vector<Tensor *> &out_tensors, | |||||
| const mindspore::lite::PrimitiveC *primitive, const Model::Node *node); | |||||
| // schedule a partial node to a subgraph_kernel | |||||
| kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); | |||||
| // schedule a node to a kernel | |||||
| kernel::LiteKernel *ScheduleNodeToKernel(const lite::Model::Node *src_node); | |||||
| // schedule a Model::SubGraph into a vector of kernel and subgraph_kernel | |||||
| int ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels); | |||||
| // find in_kernels_ and out_kernels of kernel, sub_graph and nodes_ in sub_graph | |||||
| static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels); | |||||
| // vector<LiteKernel/SubGraphKernel> --> vector<SubGraphKernel> | |||||
| int ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels); | int ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels); | ||||
| // create subgraph_kernel from a vector of kernel | |||||
| kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels, | kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels, | ||||
| kernel::SubGraphType type); | kernel::SubGraphType type); | ||||
| std::vector<kernel::LiteKernel *> FindAllSubGraphKernels( | std::vector<kernel::LiteKernel *> FindAllSubGraphKernels( | ||||
| kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map); | kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map); | ||||
| // other methods | |||||
| static TypeId GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors); | static TypeId GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors); | ||||
| static void SetKernelTensorDataType(kernel::LiteKernel *kernel); | static void SetKernelTensorDataType(kernel::LiteKernel *kernel); | ||||
| @@ -58,7 +76,10 @@ class Scheduler { | |||||
| static kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel); | static kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel); | ||||
| protected: | protected: | ||||
| InnerContext *context_ = nullptr; | |||||
| const InnerContext *context_ = nullptr; | |||||
| Model *src_model_ = nullptr; | |||||
| std::vector<Tensor *> src_tensors_; | |||||
| std::vector<size_t> graph_output_node_indexes_; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -149,6 +149,12 @@ int SubGraphKernel::ReSize(bool is_interrupt) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void SubGraphKernel::InitOutTensorInitRefCount() { | |||||
| for (auto *node : nodes_) { | |||||
| node->InitOutTensorInitRefCount(); | |||||
| } | |||||
| } | |||||
| int CpuSubGraph::Prepare() { | int CpuSubGraph::Prepare() { | ||||
| auto ret = SubGraphKernel::Prepare(); | auto ret = SubGraphKernel::Prepare(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -84,6 +84,8 @@ class SubGraphKernel : public LiteKernel { | |||||
| int ReSize(bool is_interrupt); | int ReSize(bool is_interrupt); | ||||
| void InitOutTensorInitRefCount() override; | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::vector<LiteKernel *> nodes() { return this->nodes_; } | std::vector<LiteKernel *> nodes() { return this->nodes_; } | ||||
| @@ -104,11 +106,10 @@ class CpuSubGraph : public SubGraphKernel { | |||||
| const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) | const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) | ||||
| : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { | : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { | ||||
| subgraph_type_ = kCpuFP32SubGraph; | subgraph_type_ = kCpuFP32SubGraph; | ||||
| this->executor_ = new (std::nothrow) mindspore::lite::Executor; | |||||
| this->executor_ = new (std::nothrow) mindspore::lite::CpuExecutor; | |||||
| } | } | ||||
| ~CpuSubGraph() override { delete this->executor_; } | ~CpuSubGraph() override { delete this->executor_; } | ||||
| int Prepare() override; | int Prepare() override; | ||||
| int Init() override { return SubGraphKernel::Init(); } | int Init() override { return SubGraphKernel::Init(); } | ||||
| int PreProcess() override { return SubGraphKernel::PreProcess(); } | int PreProcess() override { return SubGraphKernel::PreProcess(); } | ||||
| @@ -110,8 +110,14 @@ class Tensor : public mindspore::tensor::MSTensor { | |||||
| size_t ref_count() const { return this->ref_count_; } | size_t ref_count() const { return this->ref_count_; } | ||||
| size_t init_ref_count() const { return this->init_ref_count_; } | |||||
| void set_ref_count(size_t ref_count) { this->ref_count_ = ref_count; } | void set_ref_count(size_t ref_count) { this->ref_count_ = ref_count; } | ||||
| void set_init_ref_count(size_t ref_count) { this->init_ref_count_ = ref_count; } | |||||
| void ResetRefCount() { this->ref_count_ = this->init_ref_count_; } | |||||
| void DecRefCount() { this->ref_count_--; } | void DecRefCount() { this->ref_count_--; } | ||||
| std::string ToString() const; | std::string ToString() const; | ||||
| @@ -156,6 +162,8 @@ class Tensor : public mindspore::tensor::MSTensor { | |||||
| schema::Format format_; | schema::Format format_; | ||||
| Category category_; | Category category_; | ||||
| size_t ref_count_ = 0; | size_t ref_count_ = 0; | ||||
| size_t init_ref_count_ = 0; | |||||
| size_t ready_count_ = 0; | |||||
| std::vector<QuantArg> quant_params_; | std::vector<QuantArg> quant_params_; | ||||
| std::vector<float> quant_clusters_; | std::vector<float> quant_clusters_; | ||||
| mindspore::lite::Allocator *allocator_ = nullptr; | mindspore::lite::Allocator *allocator_ = nullptr; | ||||
| @@ -128,7 +128,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a | |||||
| return lite::RET_NULL_PTR; | return lite::RET_NULL_PTR; | ||||
| } | } | ||||
| auto run_kernel = (train_mode_) ? train_kernels_ : inference_kernels_; | auto run_kernel = (train_mode_) ? train_kernels_ : inference_kernels_; | ||||
| lite::Executor executor; | |||||
| lite::CpuExecutor executor; | |||||
| if (before == nullptr && after == nullptr) { | if (before == nullptr && after == nullptr) { | ||||
| return executor.Run(this->inputs_, this->outputs_, run_kernel, this->context_->allocator.get()); | return executor.Run(this->inputs_, this->outputs_, run_kernel, this->context_->allocator.get()); | ||||
| } else { | } else { | ||||
| @@ -261,6 +261,8 @@ if (ENABLE_CONVERTER) | |||||
| set(TEST_SRC | set(TEST_SRC | ||||
| ${TEST_SRC} | ${TEST_SRC} | ||||
| ${TEST_DIR}/st/converter_test.cc | ${TEST_DIR}/st/converter_test.cc | ||||
| ${TEST_DIR}/st/control_flow_test.cc | |||||
| ${TEST_DIR}/st/sub_graph_test.cc | |||||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc | ${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc | ||||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc | ${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc | ||||
| ${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc | ${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc | ||||
| @@ -300,7 +302,7 @@ endif () | |||||
| add_executable(lite-test ${TEST_SRC}) | add_executable(lite-test ${TEST_SRC}) | ||||
| add_dependencies(lite-test fbs_src) | |||||
| target_link_libraries(lite-test dl ${GTEST_LIBRARY}) | target_link_libraries(lite-test dl ${GTEST_LIBRARY}) | ||||
| if (PLATFORM_ARM64) | if (PLATFORM_ARM64) | ||||
| target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) | target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) | ||||
| @@ -321,6 +323,7 @@ if (SUPPORT_NPU) | |||||
| target_link_libraries(lite-test npu_kernel_mid) | target_link_libraries(lite-test npu_kernel_mid) | ||||
| endif () | endif () | ||||
| if (ENABLE_CONVERTER) | if (ENABLE_CONVERTER) | ||||
| add_dependencies(lite-test fbs_inner_src) | |||||
| target_link_libraries(lite-test | target_link_libraries(lite-test | ||||
| anf_importer_mid | anf_importer_mid | ||||
| anf_exporter_mid | anf_exporter_mid | ||||
| @@ -1,3 +1,3 @@ | |||||
| mobilenet.tflite 0.5 | mobilenet.tflite 0.5 | ||||
| transformer_20200831_encoder_fp32.tflite 68 | |||||
| transformer_20200831_encoder_fp32.tflite 69 | |||||
| transformer_20200831_decoder_fp32.tflite 35 | transformer_20200831_decoder_fp32.tflite 35 | ||||
| @@ -0,0 +1,459 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "mindspore/lite/include/model.h" | |||||
| #include "common/common_test.h" | |||||
| #include "include/lite_session.h" | |||||
| #include "include/context.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/lite_session.h" | |||||
| #include "include/version.h" | |||||
| namespace mindspore { | |||||
| class ControlFlowTest : public mindspore::CommonTest { | |||||
| public: | |||||
| ControlFlowTest() {} | |||||
| }; | |||||
| TEST_F(ControlFlowTest, TestMergeWhileModel) { | |||||
| // make graph | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| MS_LOG(DEBUG) << "make subgraph"; | |||||
| meta_graph->name = "graph"; | |||||
| meta_graph->version = lite::Version(); | |||||
| meta_graph->inputIndex = {0}; | |||||
| meta_graph->outputIndex = {9}; | |||||
| // subgraph 0 : main graph | |||||
| auto sub_graph_0 = std::make_unique<schema::SubGraphT>(); | |||||
| sub_graph_0->name = "main_graph"; | |||||
| // subgraph 1 : cond graph | |||||
| auto sub_graph_1 = std::make_unique<schema::SubGraphT>(); | |||||
| sub_graph_1->name = "cond_graph"; | |||||
| // subgraph 2: body graph | |||||
| auto sub_graph_2 = std::make_unique<schema::SubGraphT>(); | |||||
| sub_graph_2->name = "body_graph"; | |||||
| MS_LOG(DEBUG) << "make subgraph"; | |||||
| // subgraph 0: node 0 before-add-1 | |||||
| auto sub_graph_0_node_0 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_0_node_0->inputIndex = {0, 1}; | |||||
| sub_graph_0_node_0->outputIndex = {2}; | |||||
| sub_graph_0_node_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_0_node_0->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto primitive_sub_graph_0_node_0 = new schema::AddT; | |||||
| primitive_sub_graph_0_node_0->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| sub_graph_0_node_0->primitive->value.value = primitive_sub_graph_0_node_0; | |||||
| sub_graph_0_node_0->name = "before_Add_1"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_0)); | |||||
| sub_graph_0->nodeIndices.push_back(0); | |||||
| MS_LOG(DEBUG) << "node 0"; | |||||
| // subgraph 0: node 1 before-add-1 | |||||
| auto sub_graph_0_node_1 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_0_node_1->inputIndex = {2, 3}; | |||||
| sub_graph_0_node_1->outputIndex = {4}; | |||||
| sub_graph_0_node_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_0_node_1->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto primitive_sub_graph_0_node_1 = new schema::AddT; | |||||
| primitive_sub_graph_0_node_1->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| sub_graph_0_node_1->primitive->value.value = primitive_sub_graph_0_node_1; | |||||
| sub_graph_0_node_1->name = "before_Add_2"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_1)); | |||||
| sub_graph_0->nodeIndices.push_back(1); | |||||
| MS_LOG(DEBUG) << "node 1"; | |||||
| // subgraph 0: node 2 merge | |||||
| auto sub_graph_0_node_2 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_0_node_2->inputIndex = {4, 17}; | |||||
| sub_graph_0_node_2->outputIndex = {16}; | |||||
| sub_graph_0_node_2->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_0_node_2->primitive->value.type = schema::PrimitiveType_Merge; | |||||
| auto primitive_sub_graph_0_node_2 = new schema::MergeT; | |||||
| sub_graph_0_node_2->primitive->value.value = primitive_sub_graph_0_node_2; | |||||
| sub_graph_0_node_2->name = "merge"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_2)); | |||||
| sub_graph_0->nodeIndices.push_back(2); | |||||
| MS_LOG(DEBUG) << "node 2"; | |||||
| // subgraph 0: node 3 partial cond subGraph | |||||
| auto sub_graph_0_node_3 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_0_node_3->inputIndex = {16}; | |||||
| sub_graph_0_node_3->outputIndex = {5}; // 5 : bool | |||||
| sub_graph_0_node_3->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_0_node_3->primitive->value.type = schema::PrimitiveType_Partial; | |||||
| auto primitive_sub_graph_0_node_3 = new schema::PartialT; | |||||
| primitive_sub_graph_0_node_3->subGraphIndex = 1; | |||||
| sub_graph_0_node_3->primitive->value.value = primitive_sub_graph_0_node_3; | |||||
| sub_graph_0_node_3->name = "Partial_cond"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_3)); | |||||
| sub_graph_0->nodeIndices.push_back(3); | |||||
| MS_LOG(DEBUG) << "node 2"; | |||||
| // subgraph 0: node 4 switch | |||||
| auto sub_graph_0_node_4 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_0_node_4->inputIndex = {5, 16}; // 5 : bool; 16 data | |||||
| sub_graph_0_node_4->outputIndex = {6, 7}; | |||||
| sub_graph_0_node_4->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_0_node_4->primitive->value.type = schema::PrimitiveType_Switch; | |||||
| auto primitive_sub_graph_0_node_4 = new schema::SwitchT; | |||||
| sub_graph_0_node_4->primitive->value.value = primitive_sub_graph_0_node_4; | |||||
| sub_graph_0_node_4->name = "Switch"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_4)); | |||||
| sub_graph_0->nodeIndices.push_back(4); | |||||
| MS_LOG(DEBUG) << "node 4"; | |||||
| // subgraph 0: node 5 partial body subgraph | |||||
| auto sub_graph_0_node_5 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_0_node_5->inputIndex = {6}; | |||||
| sub_graph_0_node_5->outputIndex = {17}; | |||||
| sub_graph_0_node_5->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_0_node_5->primitive->value.type = schema::PrimitiveType_Partial; | |||||
| auto primitive_sub_graph_0_node_5 = new schema::PartialT; | |||||
| primitive_sub_graph_0_node_5->subGraphIndex = 2; | |||||
| sub_graph_0_node_5->primitive->value.value = primitive_sub_graph_0_node_5; | |||||
| sub_graph_0_node_5->name = "Partial_body"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_5)); | |||||
| sub_graph_0->nodeIndices.push_back(5); | |||||
| MS_LOG(DEBUG) << "node 5"; | |||||
| // subgraph 0: node 6 add-after | |||||
| auto sub_graph_0_node_6 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_0_node_6->inputIndex = {7, 8}; | |||||
| sub_graph_0_node_6->outputIndex = {9}; | |||||
| sub_graph_0_node_6->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_0_node_6->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto primitive_sub_graph_0_node_6 = new schema::AddT; | |||||
| sub_graph_0_node_6->primitive->value.value = primitive_sub_graph_0_node_6; | |||||
| sub_graph_0_node_6->name = "Add-after"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_6)); | |||||
| sub_graph_0->nodeIndices.push_back(6); | |||||
| MS_LOG(DEBUG) << "node 6"; | |||||
| sub_graph_0->inputIndices = {0}; | |||||
| sub_graph_0->outputIndices = {9}; | |||||
| sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17}; | |||||
| meta_graph->subGraph.push_back(std::move(sub_graph_0)); | |||||
| // subgraph 1 ; node:0 add cond | |||||
| auto sub_graph_1_node_0 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_1_node_0->inputIndex = {16, 10}; | |||||
| sub_graph_1_node_0->outputIndex = {11}; | |||||
| sub_graph_1_node_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_1_node_0->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto primitive_sub_graph_1_node_0 = new schema::AddT; | |||||
| sub_graph_1_node_0->primitive->value.value = primitive_sub_graph_1_node_0; | |||||
| sub_graph_1_node_0->name = "cond_add"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_0)); | |||||
| sub_graph_1->nodeIndices.push_back(7); | |||||
| MS_LOG(DEBUG) << "node 6"; | |||||
| // subgraph 1 ; node:1 Less cond | |||||
| auto sub_graph_1_node_1 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_1_node_1->inputIndex = {11, 12}; | |||||
| sub_graph_1_node_1->outputIndex = {5}; | |||||
| sub_graph_1_node_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_1_node_1->primitive->value.type = schema::PrimitiveType_Less; | |||||
| auto primitive_sub_graph_1_node_1 = new schema::LessT; | |||||
| sub_graph_1_node_1->primitive->value.value = primitive_sub_graph_1_node_1; | |||||
| sub_graph_1_node_1->name = "cond_Less"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_1)); | |||||
| sub_graph_1->nodeIndices.push_back(8); | |||||
| MS_LOG(DEBUG) << "node 7"; | |||||
| sub_graph_1->inputIndices = {16}; | |||||
| sub_graph_1->outputIndices = {5}; | |||||
| sub_graph_1->tensorIndices = {16, 10, 11, 12, 5}; | |||||
| meta_graph->subGraph.push_back(std::move(sub_graph_1)); | |||||
| // subgraph 2 ; node:0 body add-1 | |||||
| auto sub_graph_2_node_0 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_2_node_0->inputIndex = {6, 13}; | |||||
| sub_graph_2_node_0->outputIndex = {14}; | |||||
| sub_graph_2_node_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_2_node_0->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto primitive_sub_graph_2_node_0 = new schema::AddT; | |||||
| sub_graph_2_node_0->primitive->value.value = primitive_sub_graph_2_node_0; | |||||
| sub_graph_2_node_0->name = "body_add_1"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_0)); | |||||
| sub_graph_2->nodeIndices.push_back(9); | |||||
| MS_LOG(DEBUG) << "node 8"; | |||||
| // subgraph 2 ; node:1 body add-2 | |||||
| auto sub_graph_2_node_1 = std::make_unique<schema::CNodeT>(); | |||||
| sub_graph_2_node_1->inputIndex = {14, 15}; | |||||
| sub_graph_2_node_1->outputIndex = {17}; | |||||
| sub_graph_2_node_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| sub_graph_2_node_1->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto primitive_sub_graph_2_node_1 = new schema::AddT; | |||||
| sub_graph_2_node_1->primitive->value.value = primitive_sub_graph_2_node_1; | |||||
| sub_graph_2_node_1->name = "body_add_2"; | |||||
| meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_1)); | |||||
| sub_graph_2->nodeIndices.push_back(10); | |||||
| MS_LOG(DEBUG) << "node 9"; | |||||
| sub_graph_2->inputIndices = {6}; | |||||
| sub_graph_2->outputIndices = {17}; | |||||
| sub_graph_2->tensorIndices = {13, 14, 15, 6, 17}; | |||||
| meta_graph->subGraph.push_back(std::move(sub_graph_2)); | |||||
| // ------- tensor --------- | |||||
| // tensor: 0 before-add input0 <main graph input> | |||||
| auto tensor_0 = std::make_unique<schema::TensorT>(); | |||||
| tensor_0->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_0->format = schema::Format_NHWC; | |||||
| tensor_0->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_0->dims = {1}; | |||||
| tensor_0->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_0)); | |||||
| MS_LOG(DEBUG) << "tensor 0"; | |||||
| // tensor: 1 before-add input1 <const> | |||||
| auto tensor_1 = std::make_unique<schema::TensorT>(); | |||||
| tensor_1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_1->format = schema::Format_NHWC; | |||||
| tensor_1->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_1->dims = {1}; | |||||
| tensor_1->data.resize(sizeof(float) * 1); | |||||
| float input1_data[] = {1}; | |||||
| memcpy(tensor_1->data.data(), input1_data, sizeof(float) * 1); | |||||
| tensor_1->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_1)); | |||||
| MS_LOG(DEBUG) << "tensor 1"; | |||||
| // tensor: 2 before-add output/partial input | |||||
| auto tensor_2 = std::make_unique<schema::TensorT>(); | |||||
| tensor_2->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_2->format = schema::Format_NHWC; | |||||
| tensor_2->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_2->dims = {1}; | |||||
| tensor_2->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_2)); | |||||
| MS_LOG(DEBUG) << "tensor 2"; | |||||
| // tensor: 3 before-add input1 <const> | |||||
| auto tensor_3 = std::make_unique<schema::TensorT>(); | |||||
| tensor_3->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_3->format = schema::Format_NHWC; | |||||
| tensor_3->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_3->dims = {1}; | |||||
| tensor_3->data.resize(sizeof(float) * 1); | |||||
| float tensor_3_data[] = {1}; | |||||
| memcpy(tensor_3->data.data(), tensor_3_data, sizeof(float) * 1); | |||||
| tensor_3->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_3)); | |||||
| MS_LOG(DEBUG) << "tensor 3"; | |||||
| auto tensor_4 = std::make_unique<schema::TensorT>(); | |||||
| tensor_4->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_4->format = schema::Format_NHWC; | |||||
| tensor_4->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_4->dims = {1}; | |||||
| tensor_4->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_4)); | |||||
| MS_LOG(DEBUG) << "tensor 4"; | |||||
| // tensor :5 partial output <bool> | |||||
| auto tensor_5 = std::make_unique<schema::TensorT>(); | |||||
| tensor_5->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_5->format = schema::Format_NHWC; | |||||
| tensor_5->dataType = TypeId::kNumberTypeBool; | |||||
| tensor_5->dims = {1}; | |||||
| tensor_5->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_5)); | |||||
| MS_LOG(DEBUG) << "tensor_4"; | |||||
| // tensor: 6 switch true output | |||||
| auto tensor_6 = std::make_unique<schema::TensorT>(); | |||||
| tensor_6->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_6->format = schema::Format_NHWC; | |||||
| tensor_6->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_6->dims = {1}; | |||||
| tensor_6->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_6)); | |||||
| MS_LOG(DEBUG) << "tensor 6"; | |||||
| // tensor: 5 switch False output | |||||
| auto tensor_7 = std::make_unique<schema::TensorT>(); | |||||
| tensor_7->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_7->format = schema::Format_NHWC; | |||||
| tensor_7->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_7->dims = {1}; | |||||
| tensor_7->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_7)); | |||||
| MS_LOG(DEBUG) << "tensor_7"; | |||||
| // tensor: 6 body-add input ,other input is switch true output | |||||
| auto tensor_8 = std::make_unique<schema::TensorT>(); | |||||
| tensor_8->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_8->format = schema::Format_NHWC; | |||||
| tensor_8->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_8->dims = {1}; | |||||
| tensor_8->data.resize(sizeof(float) * 1); | |||||
| float tensor_8_data[] = {10}; | |||||
| memcpy(tensor_8->data.data(), tensor_8_data, sizeof(float) * 1); | |||||
| tensor_8->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_8)); | |||||
| MS_LOG(DEBUG) << "tensor_8"; | |||||
| auto tensor_9 = std::make_unique<schema::TensorT>(); | |||||
| tensor_9->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_9->format = schema::Format_NHWC; | |||||
| tensor_9->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_9->dims = {1}; | |||||
| tensor_9->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_9)); | |||||
| MS_LOG(DEBUG) << "tensor_9"; | |||||
| // tensor: 7 after-add input ,other input is switch false output | |||||
| auto tensor_10 = std::make_unique<schema::TensorT>(); | |||||
| tensor_10->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_10->format = schema::Format_NHWC; | |||||
| tensor_10->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_10->dims = {1}; | |||||
| tensor_10->data.resize(sizeof(float) * 1); | |||||
| float tensor_10_data[] = {1}; | |||||
| memcpy(tensor_10->data.data(), tensor_10_data, sizeof(float) * 1); | |||||
| tensor_10->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_10)); | |||||
| MS_LOG(DEBUG) << "tensor_10"; | |||||
| // tensor: 8 main graph output | |||||
| auto tensor_11 = std::make_unique<schema::TensorT>(); | |||||
| tensor_11->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_11->format = schema::Format_NHWC; | |||||
| tensor_11->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_11->dims = {1}; | |||||
| tensor_11->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_11)); | |||||
| MS_LOG(DEBUG) << "tensor 11"; | |||||
| // tensor: 9 cond-Less input, other input is tensor 2 | |||||
| auto tensor_12 = std::make_unique<schema::TensorT>(); | |||||
| tensor_12->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_12->format = schema::Format_NHWC; | |||||
| tensor_12->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_12->dims = {1}; | |||||
| tensor_12->data.resize(sizeof(float) * 1); | |||||
| float tensor_12_data[] = {10}; | |||||
| memcpy(tensor_12->data.data(), tensor_12_data, sizeof(float) * 1); | |||||
| tensor_12->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_12)); | |||||
| MS_LOG(DEBUG) << "tensor_12"; | |||||
| auto tensor_13 = std::make_unique<schema::TensorT>(); | |||||
| tensor_13->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_13->format = schema::Format_NHWC; | |||||
| tensor_13->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_13->dims = {1}; | |||||
| tensor_13->data.resize(sizeof(float) * 1); | |||||
| float tensor_13_data[] = {1}; | |||||
| memcpy(tensor_13->data.data(), tensor_13_data, sizeof(float) * 1); | |||||
| tensor_13->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_13)); | |||||
| MS_LOG(DEBUG) << "tensor_13"; | |||||
| auto tensor_14 = std::make_unique<schema::TensorT>(); | |||||
| tensor_14->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_14->format = schema::Format_NHWC; | |||||
| tensor_14->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_14->dims = {1}; | |||||
| tensor_14->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_14)); | |||||
| MS_LOG(DEBUG) << "tensor 14"; | |||||
| auto tensor_15 = std::make_unique<schema::TensorT>(); | |||||
| tensor_15->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_15->format = schema::Format_NHWC; | |||||
| tensor_15->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_15->dims = {1}; | |||||
| tensor_15->data.resize(sizeof(float) * 1); | |||||
| float tensor_15_data[] = {1}; | |||||
| memcpy(tensor_15->data.data(), tensor_15_data, sizeof(float) * 1); | |||||
| tensor_15->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_15)); | |||||
| MS_LOG(DEBUG) << "tensor_15"; | |||||
| auto tensor_16 = std::make_unique<schema::TensorT>(); | |||||
| tensor_16->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_16->format = schema::Format_NHWC; | |||||
| tensor_16->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_16->dims = {1}; | |||||
| tensor_16->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_16)); | |||||
| MS_LOG(DEBUG) << "tensor_16"; | |||||
| auto tensor_17 = std::make_unique<schema::TensorT>(); | |||||
| tensor_17->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_17->format = schema::Format_NHWC; | |||||
| tensor_17->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_17->dims = {1}; | |||||
| tensor_17->offset = -1; | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_17)); | |||||
| MS_LOG(DEBUG) << "tensor_17"; | |||||
| // ----------------------------------------------------------------------- | |||||
| flatbuffers::FlatBufferBuilder builder(1024); | |||||
| auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); | |||||
| builder.Finish(offset); | |||||
| schema::FinishMetaGraphBuffer(builder, offset); | |||||
| size_t size = builder.GetSize(); | |||||
| const char *content = reinterpret_cast<char *>(builder.GetBufferPointer()); | |||||
| auto model = std::shared_ptr<lite::Model>(lite::Model::Import(content, size)); | |||||
| ASSERT_NE(model, nullptr); | |||||
| lite::Context context; | |||||
| context.thread_num_ = 2; | |||||
| auto &cpu_device_ctx = context.device_list_[0]; | |||||
| cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU; | |||||
| cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = false; | |||||
| auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context)); | |||||
| ASSERT_NE(session, nullptr); | |||||
| auto ret = session->CompileGraph(model.get()); | |||||
| ASSERT_EQ(ret, lite::RET_OK); | |||||
| model->Free(); | |||||
| auto inputs = session->GetInputs(); | |||||
| ASSERT_EQ(inputs.size(), 1); | |||||
| auto input = inputs.front(); | |||||
| ASSERT_NE(input, nullptr); | |||||
| ASSERT_EQ(input->data_type(), kNumberTypeFloat32); | |||||
| ASSERT_EQ(input->shape().size(), 1); | |||||
| ASSERT_EQ(input->shape().at(0), 1); | |||||
| auto in_data = reinterpret_cast<float *>(input->MutableData()); | |||||
| ASSERT_NE(in_data, nullptr); | |||||
| in_data[0] = 1; | |||||
| ret = session->RunGraph(); | |||||
| ASSERT_EQ(ret, lite::RET_OK); | |||||
| auto outputs = session->GetOutputs(); | |||||
| ASSERT_EQ(outputs.size(), 1); | |||||
| auto output = outputs.begin()->second; | |||||
| ASSERT_NE(output, nullptr); | |||||
| ASSERT_EQ(output->data_type(), kNumberTypeFloat32); | |||||
| ASSERT_EQ(output->shape().size(), 1); | |||||
| ASSERT_EQ(output->shape().at(0), 1); | |||||
| auto out_data = reinterpret_cast<float *>(output->MutableData()); | |||||
| ASSERT_NE(out_data, nullptr); | |||||
| ASSERT_EQ(out_data[0], 19); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,217 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "mindspore/lite/include/model.h" | |||||
| #include "common/common_test.h" | |||||
| #include "include/lite_session.h" | |||||
| #include "include/context.h" | |||||
| #include "include/model.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/lite_session.h" | |||||
| #include "src/runtime/parallel_executor.h" | |||||
| #include "tools/common/storage.h" | |||||
| #include "include/version.h" | |||||
| namespace mindspore { | |||||
| class SubGraphTest : public mindspore::CommonTest { | |||||
| public: | |||||
| SubGraphTest() {} | |||||
| }; | |||||
| TEST_F(SubGraphTest, RecursiveSubGraphTest) { | |||||
| // add0 partial1 2 3 tensor0 1 2 | |||||
| auto add_0 = std::make_unique<schema::CNodeT>(); | |||||
| add_0->inputIndex = {0, 1}; | |||||
| add_0->outputIndex = {2}; | |||||
| add_0->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| add_0->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto add_0_prim = new schema::AddT; | |||||
| add_0_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| add_0->primitive->value.value = add_0_prim; | |||||
| add_0->name = "Add0"; | |||||
| auto partial_1 = std::make_unique<schema::CNodeT>(); | |||||
| partial_1->inputIndex = {2}; | |||||
| partial_1->outputIndex = {7}; | |||||
| partial_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| partial_1->primitive->value.type = schema::PrimitiveType_Partial; | |||||
| auto partial_1_prim = new schema::PartialT; | |||||
| partial_1_prim->subGraphIndex = 1; | |||||
| partial_1->primitive->value.value = partial_1_prim; | |||||
| partial_1->name = "Partial1"; | |||||
| auto partial_2 = std::make_unique<schema::CNodeT>(); | |||||
| partial_2->inputIndex = {2}; | |||||
| partial_2->outputIndex = {7}; | |||||
| partial_2->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| partial_2->primitive->value.type = schema::PrimitiveType_Partial; | |||||
| auto partial_2_prim = new schema::PartialT; | |||||
| partial_2_prim->subGraphIndex = 2; | |||||
| partial_2->primitive->value.value = partial_2_prim; | |||||
| partial_2->name = "Partial2"; | |||||
| auto partial_3 = std::make_unique<schema::CNodeT>(); | |||||
| partial_3->inputIndex = {4, 6}; | |||||
| partial_3->outputIndex = {7}; | |||||
| partial_3->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| partial_3->primitive->value.type = schema::PrimitiveType_Partial; | |||||
| auto partial_3_prim = new schema::PartialT; | |||||
| partial_3_prim->subGraphIndex = 3; | |||||
| partial_3->primitive->value.value = partial_3_prim; | |||||
| partial_3->name = "Partial3"; | |||||
| auto tensor_0 = std::make_unique<schema::TensorT>(); | |||||
| tensor_0->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_0->format = schema::Format_NHWC; | |||||
| tensor_0->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_0->dims = {1, 2}; | |||||
| auto tensor_1 = std::make_unique<schema::TensorT>(); | |||||
| tensor_1->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_1->format = schema::Format_NHWC; | |||||
| tensor_1->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_1->dims = {1, 2}; | |||||
| auto tensor_2 = std::make_unique<schema::TensorT>(); | |||||
| tensor_2->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_2->format = schema::Format_NHWC; | |||||
| tensor_2->dataType = TypeId::kNumberTypeFloat32; | |||||
| auto sub_graph_0 = std::make_unique<schema::SubGraphT>(); | |||||
| sub_graph_0->name = "main_graph"; | |||||
| sub_graph_0->inputIndices = {0}; | |||||
| sub_graph_0->outputIndices = {7}; | |||||
| sub_graph_0->nodeIndices = {0, 1, 2}; | |||||
| sub_graph_0->tensorIndices = {0, 1, 2, 7}; | |||||
| // add1 tensor3 4 | |||||
| auto add_1 = std::make_unique<schema::CNodeT>(); | |||||
| add_1->inputIndex = {2, 3}; | |||||
| add_1->outputIndex = {4}; | |||||
| add_1->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| add_1->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto add_1_prim = new schema::AddT; | |||||
| add_1_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| add_1->primitive->value.value = add_1_prim; | |||||
| add_1->name = "Add1"; | |||||
| auto tensor_3 = std::make_unique<schema::TensorT>(); | |||||
| tensor_3->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_3->format = schema::Format_NHWC; | |||||
| tensor_3->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_3->dims = {1, 2}; | |||||
| auto tensor_4 = std::make_unique<schema::TensorT>(); | |||||
| tensor_4->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_4->format = schema::Format_NHWC; | |||||
| tensor_4->dataType = TypeId::kNumberTypeFloat32; | |||||
| auto sub_graph_1 = std::make_unique<schema::SubGraphT>(); | |||||
| sub_graph_1->name = "sub_graph_1"; | |||||
| sub_graph_1->inputIndices = {2}; | |||||
| sub_graph_1->outputIndices = {7}; | |||||
| sub_graph_1->nodeIndices = {4, 3}; | |||||
| sub_graph_1->tensorIndices = {2, 3, 4, 7}; | |||||
| // add2 tensor5 6 | |||||
| auto add_2 = std::make_unique<schema::CNodeT>(); | |||||
| add_2->inputIndex = {2, 5}; | |||||
| add_2->outputIndex = {6}; | |||||
| add_2->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| add_2->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto add_2_prim = new schema::AddT; | |||||
| add_2_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| add_2->primitive->value.value = add_2_prim; | |||||
| add_2->name = "Add2"; | |||||
| auto tensor_5 = std::make_unique<schema::TensorT>(); | |||||
| tensor_5->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| tensor_5->format = schema::Format_NHWC; | |||||
| tensor_5->dataType = TypeId::kNumberTypeFloat32; | |||||
| tensor_5->dims = {1, 2}; | |||||
| auto tensor_6 = std::make_unique<schema::TensorT>(); | |||||
| tensor_6->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_6->format = schema::Format_NHWC; | |||||
| tensor_6->dataType = TypeId::kNumberTypeFloat32; | |||||
| auto sub_graph_2 = std::make_unique<schema::SubGraphT>(); | |||||
| sub_graph_2->name = "sub_graph_2"; | |||||
| sub_graph_2->inputIndices = {2}; | |||||
| sub_graph_2->outputIndices = {7}; | |||||
| sub_graph_2->nodeIndices = {5, 3}; | |||||
| sub_graph_2->tensorIndices = {2, 5, 6, 7}; | |||||
| // add3 tensor7 | |||||
| auto add_3 = std::make_unique<schema::CNodeT>(); | |||||
| add_3->inputIndex = {4, 6}; | |||||
| add_3->outputIndex = {7}; | |||||
| add_3->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| add_3->primitive->value.type = schema::PrimitiveType_Add; | |||||
| auto add_3_prim = new schema::AddT; | |||||
| add_3_prim->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| add_3->primitive->value.value = add_3_prim; | |||||
| add_3->name = "Add3"; | |||||
| auto tensor_7 = std::make_unique<schema::TensorT>(); | |||||
| tensor_7->nodeType = schema::NodeType::NodeType_Parameter; | |||||
| tensor_7->format = schema::Format_NHWC; | |||||
| tensor_7->dataType = TypeId::kNumberTypeFloat32; | |||||
| auto sub_graph_3 = std::make_unique<schema::SubGraphT>(); | |||||
| sub_graph_3->name = "sub_graph_3"; | |||||
| sub_graph_3->inputIndices = {4, 6}; | |||||
| sub_graph_3->outputIndices = {7}; | |||||
| sub_graph_3->nodeIndices = {6}; | |||||
| sub_graph_3->tensorIndices = {4, 6, 7}; | |||||
| // make graph | |||||
| auto meta_graph = std::make_shared<schema::MetaGraphT>(); | |||||
| meta_graph->name = "graph"; | |||||
| meta_graph->nodes.emplace_back(std::move(add_0)); | |||||
| meta_graph->nodes.emplace_back(std::move(partial_1)); | |||||
| meta_graph->nodes.emplace_back(std::move(partial_2)); | |||||
| meta_graph->nodes.emplace_back(std::move(partial_3)); | |||||
| meta_graph->nodes.emplace_back(std::move(add_1)); | |||||
| meta_graph->nodes.emplace_back(std::move(add_2)); | |||||
| meta_graph->nodes.emplace_back(std::move(add_3)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_0)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_1)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_2)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_3)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_4)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_5)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_6)); | |||||
| meta_graph->allTensors.emplace_back(std::move(tensor_7)); | |||||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_0)); | |||||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_1)); | |||||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_2)); | |||||
| meta_graph->subGraph.emplace_back(std::move(sub_graph_3)); | |||||
| meta_graph->version = lite::Version(); | |||||
| // ----------------------------------------------------------------------- | |||||
| lite::Storage::Save(*meta_graph, | |||||
| "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph"); | |||||
| // ----------------------------------------------------------------------- | |||||
| size_t size = 0; | |||||
| char *graph_buf = lite::ReadFile( | |||||
| "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size); | |||||
| ASSERT_NE(graph_buf, nullptr); | |||||
| auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size)); | |||||
| ASSERT_NE(model, nullptr); | |||||
| delete[](graph_buf); | |||||
| lite::Context context; | |||||
| auto &cpu_device_ctx = context.device_list_[0]; | |||||
| cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU; | |||||
| context.thread_num_ = 2; | |||||
| auto session = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(&context)); | |||||
| ASSERT_NE(session, nullptr); | |||||
| auto ret = session->CompileGraph(model.get()); | |||||
| ASSERT_EQ(ret, lite::RET_OK); | |||||
| auto inputs = session->GetInputs(); | |||||
| for (auto *input : inputs) { | |||||
| (void)input->MutableData(); | |||||
| } | |||||
| ret = session->RunGraph(); | |||||
| ASSERT_EQ(ret, lite::RET_OK); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -142,7 +142,6 @@ add_executable(converter_lite | |||||
| ${KERNEL_SRC} | ${KERNEL_SRC} | ||||
| ${LITE_SRC} | ${LITE_SRC} | ||||
| ) | ) | ||||
| add_dependencies(converter_lite tflite_fbs_src) | |||||
| add_dependencies(converter_lite fbs_src) | add_dependencies(converter_lite fbs_src) | ||||
| add_dependencies(converter_lite fbs_inner_src) | add_dependencies(converter_lite fbs_inner_src) | ||||
| @@ -5,4 +5,5 @@ set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID | |||||
| add_library(tflite_parser_mid OBJECT | add_library(tflite_parser_mid OBJECT | ||||
| ${TFLITE_SRC_LIST} | ${TFLITE_SRC_LIST} | ||||
| ) | ) | ||||
| add_dependencies(tflite_parser_mid tflite_fbs_src) | |||||
| target_link_libraries(tflite_parser_mid mindspore::flatbuffers) | target_link_libraries(tflite_parser_mid mindspore::flatbuffers) | ||||