Browse Source

support control flow

tags/v1.1.0
hangangqiang 5 years ago
parent
commit
6e10a6288a
40 changed files with 2004 additions and 274 deletions
  1. +2
    -1
      mindspore/lite/schema/model.fbs
  2. +4
    -1
      mindspore/lite/schema/ops.fbs
  3. +50
    -5
      mindspore/lite/src/executor.cc
  4. +11
    -0
      mindspore/lite/src/executor.h
  5. +8
    -8
      mindspore/lite/src/inner_context.cc
  6. +8
    -8
      mindspore/lite/src/inner_context.h
  7. +77
    -43
      mindspore/lite/src/lite_kernel.cc
  8. +8
    -6
      mindspore/lite/src/lite_kernel.h
  9. +47
    -8
      mindspore/lite/src/lite_session.cc
  10. +4
    -0
      mindspore/lite/src/lite_session.h
  11. +78
    -0
      mindspore/lite/src/ops/merge.cc
  12. +44
    -0
      mindspore/lite/src/ops/merge.h
  13. +83
    -0
      mindspore/lite/src/ops/partial.cc
  14. +48
    -0
      mindspore/lite/src/ops/partial.h
  15. +35
    -0
      mindspore/lite/src/ops/populate/merge_populate.cc
  16. +44
    -0
      mindspore/lite/src/ops/populate/partial_populate.cc
  17. +36
    -0
      mindspore/lite/src/ops/populate/switch_populate.cc
  18. +9
    -1
      mindspore/lite/src/ops/primitive_c.cc
  19. +75
    -0
      mindspore/lite/src/ops/switch.cc
  20. +47
    -0
      mindspore/lite/src/ops/switch.h
  21. +84
    -0
      mindspore/lite/src/runtime/kernel/arm/base/merge.cc
  22. +47
    -0
      mindspore/lite/src/runtime/kernel/arm/base/merge.h
  23. +115
    -0
      mindspore/lite/src/runtime/kernel/arm/base/switch.cc
  24. +47
    -0
      mindspore/lite/src/runtime/kernel/arm/base/switch.h
  25. +8
    -0
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  26. +1
    -0
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h
  27. +5
    -8
      mindspore/lite/src/runtime/opencl/opencl_executor.cc
  28. +5
    -5
      mindspore/lite/src/runtime/parallel_executor.cc
  29. +238
    -159
      mindspore/lite/src/scheduler.cc
  30. +36
    -15
      mindspore/lite/src/scheduler.h
  31. +6
    -0
      mindspore/lite/src/sub_graph_kernel.cc
  32. +3
    -2
      mindspore/lite/src/sub_graph_kernel.h
  33. +8
    -0
      mindspore/lite/src/tensor.h
  34. +1
    -1
      mindspore/lite/src/train/train_session.cc
  35. +4
    -1
      mindspore/lite/test/CMakeLists.txt
  36. +1
    -1
      mindspore/lite/test/models_tflite_posttraining.cfg
  37. +459
    -0
      mindspore/lite/test/st/control_flow_test.cc
  38. +217
    -0
      mindspore/lite/test/st/sub_graph_test.cc
  39. +0
    -1
      mindspore/lite/tools/converter/CMakeLists.txt
  40. +1
    -0
      mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt

+ 2
- 1
mindspore/lite/schema/model.fbs View File

@@ -258,7 +258,8 @@ union PrimitiveType {
SmoothL1LossGrad,
SigmoidCrossEntropyWithLogits,
SigmoidCrossEntropyWithLogitsGrad,
Reciprocal
Reciprocal,
Merge,
}

enum QuantType: int {


+ 4
- 1
mindspore/lite/schema/ops.fbs View File

@@ -1222,4 +1222,7 @@ table SigmoidCrossEntropyWithLogitsGrad {
}

table Reciprocal {
}
}

table Merge {
}

+ 50
- 5
mindspore/lite/src/executor.cc View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#include "mindspore/lite/src/executor.h"
#include "nnacl/pack.h"
#include "src/executor.h"
#include <queue>
#include "include/errorcode.h"

namespace mindspore::lite {
@@ -26,7 +26,7 @@ int Executor::CheckInputs(const std::vector<Tensor *> &in_tensors) {
return RET_ERROR;
}
if (inTensor->data_c() == nullptr) {
MS_LOG(ERROR) << "Graph input tensor data is nullptr";
MS_LOG(ERROR) << "Graph input tensor data is nullptr " << in_tensors;
return RET_ERROR;
}
auto shape = inTensor->shape();
@@ -49,7 +49,52 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_
MS_LOG(ERROR) << "CheckInputs failed";
return ret;
}
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
std::queue<kernel::LiteKernel *> kernel_queue;
for (auto kernel : kernels) {
if (kernel->IsReady()) {
kernel_queue.push(kernel);
}
}
while (!kernel_queue.empty()) {
auto cur_kernel = kernel_queue.front();
kernel_queue.pop();
MS_ASSERT(nullptr != cur_kernel);
ret = cur_kernel->PreProcess();
if (RET_OK != ret) {
MS_LOG(ERROR) << "PreProcess kernel failed, name: " << cur_kernel->name();
return ret;
}
ret = cur_kernel->Run(before, after);
if (RET_OK != ret) {
MS_LOG(ERROR) << "run kernel failed, name: " << cur_kernel->name();
return ret;
}
ret = cur_kernel->PostProcess();
if (RET_OK != ret) {
MS_LOG(ERROR) << "PostProcess kernel failed, name: " << cur_kernel->name();
return ret;
}
for (auto &out_kernel : cur_kernel->out_kernels()) {
if (out_kernel->IsReady()) {
kernel_queue.push(out_kernel);
}
}
}
return RET_OK;
}

int CpuExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator, const KernelCallBack &before,
const KernelCallBack &after) {
MS_ASSERT(nullptr != allocator);
// not check input for merge. too hard
if (kernels.front()->Type() != schema::PrimitiveType_Merge) {
auto ret = this->CheckInputs(in_tensors);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CheckInputs failed";
return ret;
}
}
#ifdef SUPPORT_TRAIN
for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them
out_tensor->set_ref_count(out_tensor->ref_count() + 1);
@@ -57,7 +102,7 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_
#endif
for (auto *kernel : kernels) {
MS_ASSERT(nullptr != kernel);
ret = kernel->PreProcess();
auto ret = kernel->PreProcess();
if (RET_OK != ret) {
MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name();
return ret;


+ 11
- 0
mindspore/lite/src/executor.h View File

@@ -37,5 +37,16 @@ class Executor {
protected:
static int CheckInputs(const std::vector<Tensor *> &in_tensors);
};

class CpuExecutor : public Executor {
public:
CpuExecutor() = default;
virtual ~CpuExecutor() = default;

int Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
};

} // namespace mindspore::lite
#endif

+ 8
- 8
mindspore/lite/src/inner_context.cc View File

@@ -62,7 +62,7 @@ InnerContext::~InnerContext() {
}
}

int InnerContext::IsValid() {
int InnerContext::IsValid() const {
if (this->device_list_.empty()) {
MS_LOG(ERROR) << "Device list is empty.";
return RET_NOT_SUPPORT;
@@ -86,33 +86,33 @@ int InnerContext::IsValid() {
return RET_OK;
}

bool InnerContext::IsCpuFloat16Enabled() {
bool InnerContext::IsCpuFloat16Enabled() const {
if (!IsCpuEnabled()) {
return false;
}
return GetCpuInfo().enable_float16_;
}

bool InnerContext::IsGpuFloat16Enabled() {
bool InnerContext::IsGpuFloat16Enabled() const {
if (!IsGpuEnabled()) {
return false;
}
return GetGpuInfo().enable_float16_;
}

bool InnerContext::IsCpuEnabled() {
bool InnerContext::IsCpuEnabled() const {
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_CPU; });
}

bool InnerContext::IsGpuEnabled() {
bool InnerContext::IsGpuEnabled() const {
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_GPU; });
}

bool InnerContext::IsNpuEnabled() {
bool InnerContext::IsNpuEnabled() const {
#ifdef SUPPORT_NPU
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
@@ -123,7 +123,7 @@ bool InnerContext::IsNpuEnabled() {
#endif
}

CpuDeviceInfo InnerContext::GetCpuInfo() {
CpuDeviceInfo InnerContext::GetCpuInfo() const {
auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_CPU; });
if (iter == this->device_list_.end()) {
@@ -133,7 +133,7 @@ CpuDeviceInfo InnerContext::GetCpuInfo() {
}
}

GpuDeviceInfo InnerContext::GetGpuInfo() {
GpuDeviceInfo InnerContext::GetGpuInfo() const {
auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_GPU; });
if (iter == this->device_list_.end()) {


+ 8
- 8
mindspore/lite/src/inner_context.h View File

@@ -33,23 +33,23 @@ struct InnerContext : public Context {

int Init();

bool IsCpuFloat16Enabled();
bool IsCpuFloat16Enabled() const;

bool IsGpuFloat16Enabled();
bool IsGpuFloat16Enabled() const;

bool IsCpuEnabled();
bool IsCpuEnabled() const;

bool IsGpuEnabled();
bool IsGpuEnabled() const;

bool IsNpuEnabled();
bool IsNpuEnabled() const;

CpuDeviceInfo GetCpuInfo();
CpuDeviceInfo GetCpuInfo() const;

GpuDeviceInfo GetGpuInfo();
GpuDeviceInfo GetGpuInfo() const;

NpuDeviceInfo GetNpuInfo() const;

int IsValid();
int IsValid() const;

virtual ~InnerContext();
};


+ 77
- 43
mindspore/lite/src/lite_kernel.cc View File

@@ -41,9 +41,21 @@ void LiteKernel::FreeWorkspace() {
workspace_ = nullptr;
}

void LiteKernel::InitOutTensorRefCount() {
bool LiteKernel::IsReady() {
return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) {
return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1;
});
}

void LiteKernel::InitOutTensorInitRefCount() {
for (auto *tensor : this->out_tensors_) {
tensor->set_ref_count(this->out_kernels_.size());
int init_ref_count = 0;
for (auto *post_kernel : this->out_kernels_) {
init_ref_count +=
std::count_if(post_kernel->in_tensors_.begin(), post_kernel->in_tensors_.end(),
[&tensor](const lite::Tensor *post_kernel_in_tensor) { return post_kernel_in_tensor == tensor; });
}
tensor->set_init_ref_count(init_ref_count);
}
}

@@ -61,15 +73,20 @@ int LiteKernel::DecOutTensorRefCount() {
return 0;
}

int LiteKernel::FreeWorkTensor() const {
for (auto input_kernel : this->in_kernels()) {
MS_ASSERT(input_kernel != nullptr);
if (input_kernel->is_model_output()) {
int LiteKernel::FreeInWorkTensor() const {
for (auto &in_tensor : this->in_tensors_) {
MS_ASSERT(in_tensor != nullptr);
if (in_tensor->IsConst()) {
continue;
}
auto ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed";
MS_ASSERT(in_tensor->ref_count() > 0);
in_tensor->set_ref_count(in_tensor->ref_count() - 1);
if (in_tensor->ref_count() <= 0) {
auto ret = in_tensor->FreeData();
if (0 != ret) {
MS_LOG(ERROR) << "Free tensor data failed";
return ret;
}
}
}
return RET_OK;
@@ -91,15 +108,12 @@ int LiteKernel::PreProcess() {
}
}

auto outputs = this->out_tensors();
for (auto *output : outputs) {
for (auto *output : this->out_tensors()) {
MS_ASSERT(output != nullptr);

if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
MS_LOG(ERROR) << "The size of output tensor is too big";
return RET_ERROR;
}

auto ret = output->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MallocData failed";
@@ -109,6 +123,28 @@ int LiteKernel::PreProcess() {
return RET_OK;
}

int LiteKernel::PostProcess() {
#ifdef SUPPORT_TRAIN
for (auto input_kernel : this->in_kernels()) {
MS_ASSERT(input_kernel != nullptr);
if (input_kernel->is_model_output()) {
continue;
}
auto ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed";
}
}
return RET_OK;
#else
for (auto *output : this->out_tensors()) {
MS_ASSERT(output != nullptr);
output->ResetRefCount();
}
return FreeInWorkTensor();
#endif
}

int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) {
if (before != nullptr) {
if (!before(TensorVectorCast(this->in_tensors_), TensorVectorCast(this->out_tensors_),
@@ -153,6 +189,28 @@ std::string LiteKernel::ToString() const {
return oss.str();
}

void LiteKernel::FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels) {
// clean io kernels
this->in_kernels_.clear();
this->out_kernels_.clear();
// find io kernels
for (auto *scope_kernel : scope_kernels) {
if (scope_kernel == this) {
continue;
}
for (auto *tensor : this->in_tensors_) {
if (lite::IsContain(scope_kernel->out_tensors(), tensor)) {
this->AddInKernel(scope_kernel);
}
}
for (auto *tensor : this->out_tensors_) {
if (lite::IsContain(scope_kernel->in_tensors(), tensor)) {
this->AddOutKernel(scope_kernel);
}
}
}
}

std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels(
const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<kernel::LiteKernel *> input_kernels;
@@ -202,7 +260,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
if (outer_in_kernels.empty()) {
for (auto &in_kernel_in_tensor : in_kernel_in_tensors) {
if (!in_kernel_in_tensor->IsConst()) {
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
if (!IsContain(input_tensors, in_kernel_in_tensor)) {
input_tensors.push_back(in_kernel_in_tensor);
}
}
@@ -219,7 +277,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
auto outer_in_kernel_out_tensors_iter =
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor);
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
if (!IsContain(input_tensors, in_kernel_in_tensor)) {
input_tensors.emplace_back(in_kernel_in_tensor);
}
}
@@ -237,7 +295,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
auto &out_kernel_out_tensors = output_kernel->out_tensors();
if (outer_out_kernels.empty()) {
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
if (!IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.push_back(out_kernel_out_tensor);
}
}
@@ -253,7 +311,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
auto outer_out_kernel_in_tensors_iter =
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
if (!IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.emplace_back(out_kernel_out_tensor);
}
}
@@ -299,33 +357,9 @@ int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *ke
return RET_OK;
}

void LiteKernelUtil::InitIOKernels(std::vector<kernel::LiteKernel *> &kernels) {
for (auto *kernel : kernels) {
// clean io kernels
kernel->set_in_kernels({});
kernel->set_out_kernels({});
// find io kernels
for (auto *search_kernel : kernels) {
if (search_kernel == kernel) {
continue;
}
for (auto *tensor : kernel->in_tensors()) {
if (lite::IsContain(search_kernel->out_tensors(), tensor)) {
kernel->AddInKernel(search_kernel);
}
}
for (auto *tensor : kernel->out_tensors()) {
if (lite::IsContain(search_kernel->in_tensors(), tensor)) {
kernel->AddOutKernel(search_kernel);
}
}
}
}
}

void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels) {
void LiteKernelUtil::InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels) {
for (auto *kernel : kernels) {
kernel->InitOutTensorRefCount();
kernel->InitOutTensorInitRefCount();
}
}



+ 8
- 6
mindspore/lite/src/lite_kernel.h View File

@@ -87,10 +87,12 @@ class LiteKernel {

virtual int Run(const KernelCallBack &before, const KernelCallBack &after);
// called after Run
virtual int PostProcess() { return FreeWorkTensor(); }
virtual int PostProcess();

virtual int ReSize() { return mindspore::lite::RET_ERROR; }

virtual void FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels);

virtual int Init() { return mindspore::lite::RET_ERROR; }

std::string name() const { return this->name_; }
@@ -154,11 +156,13 @@ class LiteKernel {

const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; }

void InitOutTensorRefCount();
virtual bool IsReady();

virtual void InitOutTensorInitRefCount();

int DecOutTensorRefCount();

int FreeWorkTensor() const;
virtual int FreeInWorkTensor() const;

KernelKey desc() const { return desc_; }

@@ -203,8 +207,6 @@ typedef LiteKernel *(*KernelCreator)(const std::vector<lite::Tensor *> &inputs,

class LiteKernelUtil {
public:
static void InitIOKernels(std::vector<kernel::LiteKernel *> &kernels);

static std::vector<kernel::LiteKernel *> SubgraphInputKernels(const std::vector<kernel::LiteKernel *> &kernels);

static std::vector<kernel::LiteKernel *> SubgraphOutputKernels(const std::vector<kernel::LiteKernel *> &kernels);
@@ -215,7 +217,7 @@ class LiteKernelUtil {

static int TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels);

static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels);
static void InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels);

static int SetInput(LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs);
};


+ 47
- 8
mindspore/lite/src/lite_session.cc View File

@@ -295,6 +295,21 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
}
}

void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i];
MS_ASSERT(graph_out_index < this->tensors_.size());
auto *out_tensor = this->tensors_.at(graph_out_index);
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "out_tensor is null!";
return;
}
out_tensor->set_init_ref_count(out_tensor->init_ref_count() + 1);
}
}

void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphInputTensors(model);
InitGraphInputMSTensors();
@@ -303,6 +318,7 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphOutputNodeMap(model);
InitGraphOutputTensorNames(model);
InitGraphOutputTensorMap(model);
AdjustModelOutputTensorInitRefCount(model);
}

int LiteSession::CompileGraph(Model *model) {
@@ -334,12 +350,9 @@ int LiteSession::CompileGraph(Model *model) {
is_running_.store(false);
return ret;
}

InitGraphInOutTensors(model);

// scheduler kernels
Scheduler scheduler(context_);
ret = scheduler.Schedule(model, &tensors_, &kernels_);
Scheduler scheduler(context_, model, tensors_);
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Schedule kernels failed: " << ret;
is_running_.store(false);
@@ -353,6 +366,7 @@ int LiteSession::CompileGraph(Model *model) {
}
}
#endif
InitGraphInOutTensors(model);
ret = executor_->Prepare(this->kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare executor failed: " << ret;
@@ -563,6 +577,32 @@ void LiteSession::ResetInputsShape(const std::vector<std::vector<int>> &dims) {
}
}

int LiteSession::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
bool infer_shape_interrupt = false;
for (auto kernel : kernels) {
if (kernel == nullptr) {
MS_LOG(ERROR) << "input kernel is nullptr!";
return RET_ERROR;
}
if (kernel->subgraph_type() == kernel::kNotSubGraph) {
MS_LOG(ERROR) << "All node in graph should be sub_graph";
return RET_ERROR;
}
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
auto ret = sub_graph->ReSize(infer_shape_interrupt);
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape is interrupted";
infer_shape_interrupt = true;
continue;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed";
return RET_ERROR;
}
}
return RET_OK;
}

int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<std::vector<int>> &dims) {
bool expected = false;
@@ -581,11 +621,10 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
return ret;
}

Scheduler scheduler(context_);
ret = scheduler.ReSizeKernels(kernels_);
ret = ReSizeKernels(kernels_);
if (ret != RET_OK) {
ResetInputsShape(old_dims);
auto resize_ret = scheduler.ReSizeKernels(kernels_);
auto resize_ret = ReSizeKernels(kernels_);
if (resize_ret != RET_OK) {
MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret;
}


+ 4
- 0
mindspore/lite/src/lite_session.h View File

@@ -92,10 +92,14 @@ class LiteSession : public session::LiteSession {

void InitGraphOutputTensorMap(const lite::Model *model);

void AdjustModelOutputTensorInitRefCount(const lite::Model *model);

int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims);

int PrepareKernels();

static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels);

private:
void ResetInputsShape(const std::vector<std::vector<int>> &dims);



+ 78
- 0
mindspore/lite/src/ops/merge.cc View File

@@ -0,0 +1,78 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/merge.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE

int Merge::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Merge;
}
if (this->primitive_->value.type != schema::PrimitiveType_Merge) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::MergeT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}

#else
int Merge::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Merge();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Merge return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateMerge(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Merge, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Merge>(primitive); }
Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
#endif

int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(outputs_.size() == 1);
MS_ASSERT(inputs_.size() == 2);
outputs_[0]->set_data_type(inputs_[0]->data_type());

return RET_OK;
}

} // namespace lite
} // namespace mindspore

+ 44
- 0
mindspore/lite/src/ops/merge.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_MERGE_H_
#define LITE_MINDSPORE_LITE_C_OPS_MERGE_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {

class Merge : public PrimitiveC {
public:
Merge() = default;
~Merge() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Merge, PrimitiveC);
explicit Merge(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_MERGE_H_

+ 83
- 0
mindspore/lite/src/ops/partial.cc View File

@@ -0,0 +1,83 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/partial.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE

int Partial::GetSubGraphIndex() const { return this->primitive_->value.AsPartial()->subGraphIndex; }

int Partial::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Partial;
}
if (this->primitive_->value.type != schema::PrimitiveType_Partial) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::PartialT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}

this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}

#else

int Partial::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Partial();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Partial return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreatePartial(*fbb, attr->subGraphIndex());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Partial, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

int Partial::GetSubGraphIndex() const { return this->primitive_->value_as_Partial()->subGraphIndex(); }

PrimitiveC *PartialCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Partial>(primitive); }
Registry PartialRegistry(schema::PrimitiveType_Partial, PartialCreator);

#endif

int Partial::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
} // namespace lite
} // namespace mindspore

+ 48
- 0
mindspore/lite/src/ops/partial.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>

#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Partial : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Partial, PrimitiveC);
Partial() = default;
explicit Partial(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;

#else
Partial() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetSubGraphIndex() const;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_

+ 35
- 0
mindspore/lite/src/ops/populate/merge_populate.cc View File

@@ -0,0 +1,35 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {

OpParameter *PopulateMergeParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *merge_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (merge_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Merge parameter failed.";
return nullptr;
}
memset(merge_parameter, 0, sizeof(OpParameter));
merge_parameter->type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(merge_parameter);
}
Registry MergeParameterRegistry(schema::PrimitiveType_Merge, PopulateMergeParameter);
} // namespace lite
} // namespace mindspore

+ 44
- 0
mindspore/lite/src/ops/populate/partial_populate.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/partial.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {
typedef struct PartialParameter {
OpParameter op_parameter_;
int sub_graph_index_;
} PartialParameter;

OpParameter *PopulatePartialParameter(const mindspore::lite::PrimitiveC *primitive) {
PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter)));
if (partial_parameter == nullptr) {
MS_LOG(ERROR) << "malloc partial parameter failed.";
return nullptr;
}
memset(partial_parameter, 0, sizeof(PartialParameter));
partial_parameter->op_parameter_.type_ = primitive->Type();

auto param = reinterpret_cast<mindspore::lite::Partial *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
partial_parameter->sub_graph_index_ = param->GetSubGraphIndex();

return reinterpret_cast<OpParameter *>(partial_parameter);
}
Registry PartialParameterRegistry(schema::PrimitiveType_Partial, PopulatePartialParameter);
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/src/ops/populate/switch_populate.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/switch.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"

namespace mindspore {
namespace lite {
OpParameter *PopulateSwitchParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *switch_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (switch_parameter == nullptr) {
MS_LOG(ERROR) << "malloc SwitchParameter failed.";
return nullptr;
}
memset(switch_parameter, 0, sizeof(OpParameter));
switch_parameter->type_ = primitive->Type();

return reinterpret_cast<OpParameter *>(switch_parameter);
}
Registry SwitchParameterRegistry(schema::PrimitiveType_Switch, PopulateSwitchParameter);
} // namespace lite
} // namespace mindspore

+ 9
- 1
mindspore/lite/src/ops/primitive_c.cc View File

@@ -155,6 +155,9 @@
#include "src/ops/tensorlistsetitem.h"
#include "src/ops/tensorlistreserve.h"
#include "src/ops/tensorliststack.h"
#include "src/ops/merge.h"
#include "src/ops/switch.h"
#include "src/ops/partial.h"

#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@@ -925,7 +928,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) TensorListReserve(primitive);
case schema::PrimitiveType_TensorListStack:
return new (std::nothrow) TensorListStack(primitive);

case schema::PrimitiveType_Switch:
return new (std::nothrow) Switch(primitive);
case schema::PrimitiveType_Merge:
return new (std::nothrow) Merge(primitive);
case schema::PrimitiveType_Partial:
return new (std::nothrow) Partial(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);


+ 75
- 0
mindspore/lite/src/ops/switch.cc View File

@@ -0,0 +1,75 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/switch.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Switch::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Switch;
}
if (this->primitive_->value.type != schema::PrimitiveType_Switch) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SwitchT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}

this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Switch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Switch();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Switch return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateSwitch(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Switch, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Switch>(primitive); }
Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
#endif

int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
} // namespace lite
} // namespace mindspore

+ 47
- 0
mindspore/lite/src/ops/switch.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_
#define LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>

#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Switch : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Switch, PrimitiveC);
Switch() = default;
explicit Switch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;

#else
Switch() = default;

int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_

+ 84
- 0
mindspore/lite/src/runtime/kernel/arm/base/merge.cc View File

@@ -0,0 +1,84 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/base/merge.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Merge;

namespace mindspore::kernel {
// if one of input of merge is const-tensor, merge is always ready, this will cause error.
bool MergeCPUKernel::IsReady() {
MS_ASSERT(in_tensors().size() == 2);
return std::any_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) {
return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1;
});
}

int MergeCPUKernel::Init() { return RET_OK; }

int MergeCPUKernel::ReSize() { return RET_ERROR; }

int MergeCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() == 2);
MS_ASSERT(out_tensors_.size() == 1);
auto out_data = out_tensors_.front()->data_c();
MS_ASSERT(out_data != nullptr);
for (size_t i = 0; i < in_tensors().size(); i++) {
if (in_tensors()[i]->data_c() != nullptr) {
auto in_data = in_tensors_[i]->data_c();
MS_ASSERT(in_data != nullptr);
memcpy(out_data, in_data, in_tensors_[i]->Size());
}
}
return RET_OK;
}

kernel::LiteKernel *CpuMergeKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr";
return nullptr;
}
if (desc.type != PrimitiveType_Merge) {
MS_LOG(ERROR) << "type in desc is not Merge";
free(parameter);
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "ctx is nullptr";
free(parameter);
return nullptr;
}

auto *kernel = new (std::nothrow) MergeCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
free(parameter);
return nullptr;
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, CpuMergeKernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, CpuMergeKernelCreator)
} // namespace mindspore::kernel

+ 47
- 0
mindspore/lite/src/runtime/kernel/arm/base/merge.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_

#include <vector>
#include "src/lite_kernel.h"

namespace mindspore::kernel {

typedef struct MergeParameter {
OpParameter op_parameter_;
} MergeParameter;

class MergeCPUKernel : public LiteKernel {
public:
MergeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
merge_param_ = reinterpret_cast<MergeParameter *>(op_parameter_);
}
~MergeCPUKernel() override {}
bool IsReady() override;
int Init() override;
int ReSize() override;
int Run() override;

private:
MergeParameter *merge_param_ = nullptr;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_

+ 115
- 0
mindspore/lite/src/runtime/kernel/arm/base/switch.cc View File

@@ -0,0 +1,115 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/base/switch.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Switch;

namespace mindspore::kernel {
int SwitchCPUKernel::PostProcess() {
auto bool_tensor = in_tensors_.front();
MS_ASSERT(bool_tensor != nullptr);
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
MS_ASSERT(bool_tensor->shape().size() == 1);
MS_ASSERT(bool_tensor->shape().front() == 1);
auto *active = static_cast<bool *>(bool_tensor->data_c());
if (active == nullptr) {
MS_LOG(ERROR) << "data of bool tensor is nullptr";
return lite::RET_NULL_PTR;
}
size_t in_index = 1;
size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2);
while (in_index < in_tensors_.size()) {
in_index++;
auto out_tensor = out_tensors_.at(out_index++);
out_tensor->ResetRefCount();
}
return FreeInWorkTensor();
}

int SwitchCPUKernel::Init() { return RET_OK; }

int SwitchCPUKernel::ReSize() { return RET_ERROR; }

// inputs: bool*1 data*n
// output: true-data*n, false-data*n
int SwitchCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() >= 2);
MS_ASSERT(out_tensors_.size() == 2 * in_tensors_.size());
auto bool_tensor = in_tensors_.front();
MS_ASSERT(bool_tensor != nullptr);
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
MS_ASSERT(bool_tensor->shape().size() == 1);
MS_ASSERT(bool_tensor->shape().front() == 1);
auto active = static_cast<bool *>(bool_tensor->data_c());
if (active == nullptr) {
MS_LOG(ERROR) << "data of bool tensor is nullptr";
return lite::RET_NULL_PTR;
}
size_t in_index = 1;
size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2);
while (in_index < in_tensors_.size()) {
auto in_tensor = in_tensors_.at(in_index++);
auto out_tensor = out_tensors_.at(out_index++);
MS_ASSERT(in_tensor != nullptr);
MS_ASSERT(out_tensor != nullptr);
auto input = reinterpret_cast<float *>(in_tensor->data_c());
auto output = reinterpret_cast<float *>(out_tensor->data_c());
MS_ASSERT(in_tensor->Size() == out_tensor->Size());
if (input == nullptr || output == nullptr) {
MS_LOG(ERROR) << "input tensor or output tensor have not been malloced";
return lite::RET_NULL_PTR;
}
memcpy(output, input, in_tensor->Size());
}
return RET_OK;
}

kernel::LiteKernel *CpuSwitchKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr";
return nullptr;
}
if (desc.type != PrimitiveType_Switch) {
MS_LOG(ERROR) << "type in desc is not Switch";
free(parameter);
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "ctx is nullptr";
free(parameter);
return nullptr;
}
auto *kernel = new (std::nothrow) SwitchCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
free(parameter);
return nullptr;
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, CpuSwitchKernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, CpuSwitchKernelCreator)
} // namespace mindspore::kernel

+ 47
- 0
mindspore/lite/src/runtime/kernel/arm/base/switch.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_

#include <vector>
#include "src/lite_kernel.h"

namespace mindspore::kernel {

typedef struct SwitchParameter {
OpParameter op_parameter_;
} SwitchParameter;

class SwitchCPUKernel : public LiteKernel {
public:
SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
switch_param_ = reinterpret_cast<SwitchParameter *>(op_parameter_);
}
~SwitchCPUKernel() override = default;
int PostProcess() override;
int Init() override;
int ReSize() override;
int Run() override;

private:
SwitchParameter *switch_param_ = nullptr;
};
} // namespace mindspore::kernel

#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_

+ 8
- 0
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -71,6 +71,14 @@ int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
return RET_OK;
}

int OpenCLKernel::PostProcess() {
for (auto *output : this->out_tensors()) {
MS_ASSERT(output != nullptr);
output->ResetRefCount();
}
return FreeInWorkTensor();
}

std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() {
size_t ndim = global_size_.size();
std::vector<BaseTuningParameter> tuning_params = {};


+ 1
- 0
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h View File

@@ -164,6 +164,7 @@ class OpenCLKernel : public LiteKernel {

int Prepare() override { return RET_OK; }
int PreProcess() override { return RET_ERROR; }
int PostProcess() override;
int ReSize() override { return RET_ERROR; }
int Run() override { return RET_ERROR; }



+ 5
- 8
mindspore/lite/src/runtime/opencl/opencl_executor.cc View File

@@ -36,7 +36,6 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor
if (is_tune) {
opencl_runtime_ins->SetProfiling(true);
}
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
for (auto *kernel : kernels) {
MS_ASSERT(kernel);
CallBackParam callbackParam;
@@ -82,6 +81,11 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
return ret;
}
ret = kernel->PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name();
return ret;
}
if (profiling_tmp) {
MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str()
<< ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms";
@@ -92,13 +96,6 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor
MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->name();
}
}
for (auto input_kernel : kernel->in_kernels()) {
MS_ASSERT(input_kernel);
ret = input_kernel->DecOutTensorRefCount();
if (ret != RET_OK) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
}
}
}
opencl_runtime_ins->SetProfiling(profiling_tmp);
return ret;


+ 5
- 5
mindspore/lite/src/runtime/parallel_executor.cc View File

@@ -40,9 +40,9 @@ static int RunKernel(void *data, int index) {
return 0;
}

ret = kernel->FreeWorkTensor();
ret = kernel->FreeInWorkTensor();
if (RET_OK != ret) {
MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << kernel->name();
MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << kernel->name();
return ret;
}
return 0;
@@ -62,7 +62,7 @@ int ParallelExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor
return RET_ERROR;
}
}
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
kernel::LiteKernelUtil::InitTensorInitRefCount(kernels);

for (auto kernel : kernels) {
if (kernel->in_kernels().empty()) {
@@ -96,9 +96,9 @@ int ParallelExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor
}
}

auto ret = completed->FreeWorkTensor();
auto ret = completed->FreeInWorkTensor();
if (RET_OK != ret) {
MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << completed->name();
MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << completed->name();
return ret;
}
}


+ 238
- 159
mindspore/lite/src/scheduler.cc View File

@@ -19,6 +19,7 @@
#include <queue>
#include <string>
#include <vector>
#include "src/ops/partial.h"
#include "include/errorcode.h"
#include "src/common/graph_util.h"
#include "src/common/utils.h"
@@ -36,152 +37,255 @@ namespace mindspore::lite {
using kernel::KERNEL_ARCH::kCPU;
using kernel::KERNEL_ARCH::kGPU;
using kernel::KERNEL_ARCH::kNPU;
constexpr int kMainSubGraphIndex = 0;

int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) {
int ret = InferShape(model, tensors);
int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
if (src_model_ == nullptr) {
MS_LOG(ERROR) << "Input model is nullptr";
return RET_PARAM_INVALID;
}
if (src_model_->sub_graphs_.empty()) {
MS_LOG(ERROR) << "Model should have a subgraph at least";
return RET_PARAM_INVALID;
}

this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_);
bool infer_shape_interrupt = false;
auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt);
if (ret != RET_OK) {
MS_LOG(ERROR) << "op infer shape failed.";
return ret;
}
ret = BuildKernels(model, tensors, kernels);
ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "init op to kernel failed.";
MS_LOG(ERROR) << "Schedule main subgraph to kernels failed.";
return ret;
}

kernel::LiteKernelUtil::InitIOKernels(*kernels);

ret = ConstructSubGraphs(kernels);
FindAllInoutKernels(*dst_kernels);
ret = ConstructSubGraphs(dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConstructSubGraphs failed.";
return ret;
}

kernel::LiteKernelUtil::InitIOKernels(*kernels);

FindAllInoutKernels(*dst_kernels);
kernel::LiteKernelUtil::InitTensorInitRefCount(*dst_kernels);
MS_LOG(DEBUG) << "schedule kernels success.";
return RET_OK;
}

int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
bool infer_shape_interrupt = false;
for (auto kernel : kernels) {
if (kernel == nullptr) {
MS_LOG(ERROR) << "input kernel is nullptr!";
return RET_ERROR;
}
if (kernel->subgraph_type() == kernel::kNotSubGraph) {
MS_LOG(ERROR) << "All node in graph should be sub_graph";
return RET_ERROR;
}
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
auto ret = sub_graph->ReSize(infer_shape_interrupt);
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape is interrupted";
infer_shape_interrupt = true;
continue;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed";
return RET_ERROR;
void Scheduler::FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs,
std::vector<Tensor *> *outputs) {
MS_ASSERT(inputs != nullptr);
MS_ASSERT(outputs != nullptr);
auto in_size = node.input_indices_.size();
inputs->reserve(in_size);
for (size_t j = 0; j < in_size; ++j) {
inputs->emplace_back(src_tensors_.at(node.input_indices_[j]));
}
auto out_size = node.output_indices_.size();
outputs->reserve(out_size);
for (size_t j = 0; j < out_size; ++j) {
outputs->emplace_back(src_tensors_.at(node.output_indices_[j]));
}
}

int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt) {
MS_ASSERT(node != nullptr);
MS_ASSERT(infer_shape_interrupt != nullptr);
auto primitive = node->primitive_;
MS_ASSERT(primitive != nullptr);
if (primitive->Type() == schema::PrimitiveType_Partial) {
return InferPartialShape(node, infer_shape_interrupt);
}
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
FindNodeInoutTensors(*node, &inputs, &outputs);
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) {
auto shape = tensor->shape();
return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; });
});
if (!infer_valid) {
*infer_shape_interrupt = true;
}
primitive->set_infer_flag(!(*infer_shape_interrupt));
auto ret = primitive->InferShape(inputs, outputs);
if (ret == RET_OK) {
for (auto &output : outputs) {
if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
MS_LOG(ERROR) << "The size of output tensor is too big";
return RET_ERROR;
}
}
}
return RET_OK;
return ret;
}

int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tensors) {
MS_ASSERT(model != nullptr);
MS_ASSERT(tensors != nullptr);
bool infer_shape_interrupt = false;
uint32_t kernelCount = model->all_nodes_.size();
for (uint32_t i = 0; i < kernelCount; ++i) {
auto node = model->all_nodes_[i];
int Scheduler::InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt) {
MS_ASSERT(src_model_ != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(infer_shape_interrupt != nullptr);
auto primitive = node->primitive_;
MS_ASSERT(primitive != nullptr);
if (primitive->Type() != schema::PrimitiveType_Partial) {
MS_LOG(ERROR) << "Node is not a partial";
return RET_PARAM_INVALID;
}
auto partial_primitive = reinterpret_cast<lite::Partial *>(node->primitive_);
return InferSubGraphShape(partial_primitive->GetSubGraphIndex(), infer_shape_interrupt);
}

int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt) {
MS_ASSERT(infer_shape_interrupt != nullptr);
MS_ASSERT(src_model_ != nullptr);
MS_ASSERT(!src_model_->sub_graphs_.empty());
MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index);
auto subgraph = src_model_->sub_graphs_.at(subgraph_index);
for (auto node_index : subgraph->node_indices_) {
auto node = src_model_->all_nodes_[node_index];
MS_ASSERT(node != nullptr);
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
auto in_size = node->input_indices_.size();
inputs.reserve(in_size);
for (size_t j = 0; j < in_size; ++j) {
inputs.emplace_back(tensors->at(node->input_indices_[j]));
}
auto out_size = node->output_indices_.size();
outputs.reserve(out_size);
for (size_t j = 0; j < out_size; ++j) {
outputs.emplace_back(tensors->at(node->output_indices_[j]));
}
auto *primitive = node->primitive_;
if (primitive == nullptr) {
MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
return RET_ERROR;
}
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) {
auto shape = tensor->shape();
return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; });
});
if (!infer_valid) {
infer_shape_interrupt = true;
}
primitive->set_infer_flag(!infer_shape_interrupt);
auto ret = primitive->InferShape(inputs, outputs);
auto ret = InferNodeShape(node, infer_shape_interrupt);
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_
MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()))
<< "flag set to false.";
<< ", set infer flag to false.";
primitive->set_infer_flag(false);
infer_shape_interrupt = true;
*infer_shape_interrupt = true;
} else if (ret != RET_OK) {
MS_LOG(ERROR) << "InferShape failed, name: " << node->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()));
return RET_INFER_ERR;
} else {
for (auto &output : outputs) {
if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
MS_LOG(ERROR) << "The size of output tensor is too big";
return RET_ERROR;
}
}
}
}

return RET_OK;
}

int Scheduler::BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) {
MS_ASSERT(model != nullptr);
MS_ASSERT(tensors != nullptr);
uint32_t kernelCount = model->all_nodes_.size();
auto graph_output_node_indexes = GetGraphOutputNodes(model);
for (uint32_t i = 0; i < kernelCount; ++i) {
auto node = model->all_nodes_[i];
MS_ASSERT(node != nullptr);
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
auto in_size = node->input_indices_.size();
inputs.reserve(in_size);
for (size_t j = 0; j < in_size; ++j) {
inputs.emplace_back(tensors->at(node->input_indices_[j]));
kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors,
const mindspore::lite::PrimitiveC *primitive,
const Model::Node *node) {
MS_ASSERT(primitive != nullptr);
TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
#if SUPPORT_GPU
if (context_->IsGpuEnabled()) {
kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type};
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_;
return kernel;
} else {
MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " "
<< node->name_;
}
}
#endif
#if SUPPORT_NPU
if (context_->IsNpuEnabled()) {
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type};
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_;
return kernel;
} else {
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " "
<< node->name_;
}
auto out_size = node->output_indices_.size();
outputs.reserve(out_size);
for (size_t j = 0; j < out_size; ++j) {
outputs.emplace_back(tensors->at(node->output_indices_[j]));
}
#endif
if (mindspore::lite::IsSupportFloat16() &&
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
auto *kernel =
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " "
<< node->name_;
return kernel;
}
}
if (data_type == kNumberTypeFloat16) {
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32;
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
if (kernel != nullptr) {
return kernel;
}
return nullptr;
}

kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *src_node) {
MS_ASSERT(src_model_ != nullptr);
MS_ASSERT(src_node != nullptr);
auto *primitive = src_node->primitive_;
MS_ASSERT(primitive != nullptr);
if (primitive->Type() != schema::PrimitiveType_Partial) {
return nullptr;
}
auto partial_primitive = reinterpret_cast<lite::Partial *>(primitive);
auto sub_graph_index = partial_primitive->GetSubGraphIndex();
std::vector<kernel::LiteKernel *> sub_kernels;
auto ret = ScheduleSubGraphToKernels(sub_graph_index, &sub_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_;
return nullptr;
}
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front());
// for kernel::LiteKernelUtil::SubgraphInputTensors in CreateSubGraphKernel
FindAllInoutKernels(sub_kernels);
auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type);
subgraph->set_name("subgraph_" + src_node->name_);
return subgraph;
}

kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src_node) {
auto *primitive = src_node->primitive_;
MS_ASSERT(primitive != nullptr);
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
FindNodeInoutTensors(*src_node, &inputs, &outputs);
auto *kernel = this->FindBackendKernel(inputs, outputs, primitive, src_node);
if (kernel == nullptr) {
MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()));
return nullptr;
}
SetKernelTensorDataType(kernel);
kernel->set_name(src_node->name_);
return kernel;
}

int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels) {
MS_ASSERT(src_model_ != nullptr);
MS_ASSERT(!src_model_->sub_graphs_.empty());
MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index);
MS_ASSERT(dst_kernels != nullptr);
MS_ASSERT(dst_kernels->empty());
auto subgraph = src_model_->sub_graphs_.at(subgraph_index);
for (auto node_index : subgraph->node_indices_) {
auto node = src_model_->all_nodes_[node_index];
MS_ASSERT(node != nullptr);
auto *primitive = node->primitive_;
MS_ASSERT(primitive != nullptr);
auto *kernel = this->ScheduleNode(inputs, outputs, primitive, node);
kernel::LiteKernel *kernel = nullptr;
if (primitive->Type() == schema::PrimitiveType_Partial) { // sub_graph
kernel = SchedulePartialToKernel(node);
} else { // kernel
kernel = ScheduleNodeToKernel(node);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << node->name_ << ", type: "
MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << node->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()));
return RET_ERROR;
}
SetKernelTensorDataType(kernel);
kernel->set_name(node->name_);
kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i)));
kernels->emplace_back(kernel);
kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index)));
dst_kernels->emplace_back(kernel);
}

return RET_OK;
}

@@ -190,6 +294,11 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels(
MS_ASSERT(head_kernel != nullptr);
MS_ASSERT(sinked_kernel_map != nullptr);
std::vector<kernel::LiteKernel *> sub_kernels;
if (head_kernel->Type() == schema::PrimitiveType_Switch || head_kernel->Type() == schema::PrimitiveType_Merge) {
(*sinked_kernel_map)[head_kernel] = true;
sub_kernels.emplace_back(head_kernel);
return sub_kernels;
}
std::queue<kernel::LiteKernel *> kernel_queue;
kernel_queue.emplace(head_kernel);
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel);
@@ -200,6 +309,10 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels(
sub_kernels.emplace_back(cur_kernel);
auto post_kernels = cur_kernel->out_kernels();
for (auto post_kernel : post_kernels) {
if (post_kernel->subgraph_type() != kernel::kNotSubGraph || post_kernel->Type() == schema::PrimitiveType_Merge ||
post_kernel->Type() == schema::PrimitiveType_Switch) {
continue;
}
if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) {
auto post_kernel_inputs = post_kernel->in_kernels();
if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(),
@@ -215,28 +328,41 @@ std::vector<kernel::LiteKernel *> Scheduler::FindAllSubGraphKernels(
int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
auto old_kernels = *kernels;
kernels->clear();
std::map<const kernel::LiteKernel *, bool> is_kernel_sinked;
std::map<const kernel::LiteKernel *, bool> is_kernel_finish;
for (auto kernel : old_kernels) {
is_kernel_sinked[kernel] = false;
is_kernel_finish[kernel] = false;
}

while (true) {
auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) {
auto kernel_inputs = kernel->in_kernels();
return !is_kernel_sinked[kernel] &&
std::all_of(kernel_inputs.begin(), kernel_inputs.end(),
[&](kernel::LiteKernel *kernel) { return is_kernel_sinked[kernel]; });
if (is_kernel_finish[kernel]) {
return false;
}
// when merge is removed, this if is removed automatically
if (kernel->Type() == schema::PrimitiveType_Merge) {
MS_ASSERT(kernel->in_kernels().size() == 2);
return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]);
} else {
return std::all_of(kernel_inputs.begin(), kernel_inputs.end(),
[&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; });
}
});
if (head_kernel_iter == old_kernels.end()) {
break;
}
auto head_kernel = *head_kernel_iter;
if (head_kernel->subgraph_type() != kernel::kNotSubGraph) {
is_kernel_finish[head_kernel] = true;
kernels->emplace_back(head_kernel);
continue;
}
if (head_kernel->desc().arch == mindspore::kernel::kAPU) {
MS_LOG(ERROR) << "Not support APU now";
return RET_NOT_SUPPORT;
}
auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel);
auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_sinked);
auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_finish);
auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type);
if (subgraph == nullptr) {
MS_LOG(ERROR) << "Create SubGraphKernel failed";
@@ -296,60 +422,6 @@ kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel
return nullptr;
}

kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors,
const mindspore::lite::PrimitiveC *primitive, const Model::Node *node) {
MS_ASSERT(primitive != nullptr);
TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
#if SUPPORT_NPU
if (context_->IsNpuEnabled()) {
kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type};
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_;
return kernel;
} else {
MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " "
<< node->name_;
}
}
#endif
#if SUPPORT_GPU
if (context_->IsGpuEnabled()) {
kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type};
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_;
return kernel;
} else {
MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " "
<< node->name_;
}
}
#endif
if (mindspore::lite::IsSupportFloat16() &&
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
auto *kernel =
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " "
<< node->name_;
return kernel;
}
}
if (data_type == kNumberTypeFloat16) {
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32;
}
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
if (kernel != nullptr) {
return kernel;
}
return nullptr;
}

TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors) {
for (const auto &tensor : in_tensors) {
auto dtype = tensor->data_type();
@@ -411,4 +483,11 @@ kernel::SubGraphType Scheduler::GetKernelSubGraphType(const kernel::LiteKernel *
}
return kernel::kNotSubGraph;
}

void Scheduler::FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels) {
for (auto *kernel : kernels) {
MS_ASSERT(kernel != nullptr);
kernel->FindInoutKernels(kernels);
}
}
} // namespace mindspore::lite

+ 36
- 15
mindspore/lite/src/scheduler.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_SCHEDULER_H_
#define MINDSPORE_LITE_SRC_SCHEDULER_H_

#include <utility>
#include <vector>
#include <map>
#include "src/sub_graph_kernel.h"
@@ -27,30 +28,47 @@
namespace mindspore::lite {
class Scheduler {
public:
explicit Scheduler(const InnerContext *ctx) { context_ = const_cast<InnerContext *>(ctx); }
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> src_tensors)
: context_(ctx), src_model_(src_model), src_tensors_(std::move(src_tensors)) {}
~Scheduler() = default;

int Schedule(const lite::Model *model, std::vector<Tensor *> *tensors, std::vector<kernel::LiteKernel *> *kernels);

static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels);

protected:
kernel::LiteKernel *ScheduleNode(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const mindspore::lite::PrimitiveC *primitive, const Model::Node *cnode);

int BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels);

static int InferShape(const lite::Model *model, std::vector<Tensor *> *tensors);

int Schedule(std::vector<kernel::LiteKernel *> *dst_kernels);

private:
void FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs,
std::vector<Tensor *> *outputs);
// infer shape for a partial node
int InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt);
// infer shape for a node
int InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt);
// infer shape for a subgraph
int InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt);

// schedule a node to kernel according to context and kernels registered
kernel::LiteKernel *FindBackendKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors,
const mindspore::lite::PrimitiveC *primitive, const Model::Node *node);
// schedule a partial node to a subgraph_kernel
kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node);
// schedule a node to a kernel
kernel::LiteKernel *ScheduleNodeToKernel(const lite::Model::Node *src_node);
// schedule a Model::SubGraph into a vector of kernel and subgraph_kernel
int ScheduleSubGraphToKernels(size_t subgraph_index, std::vector<kernel::LiteKernel *> *dst_kernels);

// find in_kernels_ and out_kernels of kernel, sub_graph and nodes_ in sub_graph
static void FindAllInoutKernels(const std::vector<kernel::LiteKernel *> &kernels);

// vector<LiteKernel/SubGraphKernel> --> vector<SubGraphKernel>
int ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels);

// create subgraph_kernel from a vector of kernel
kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
kernel::SubGraphType type);

std::vector<kernel::LiteKernel *> FindAllSubGraphKernels(
kernel::LiteKernel *head_kernel, std::map<const kernel::LiteKernel *, bool> *sinked_kernel_map);

// other methods
static TypeId GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_tensors);

static void SetKernelTensorDataType(kernel::LiteKernel *kernel);
@@ -58,7 +76,10 @@ class Scheduler {
static kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel);

protected:
InnerContext *context_ = nullptr;
const InnerContext *context_ = nullptr;
Model *src_model_ = nullptr;
std::vector<Tensor *> src_tensors_;
std::vector<size_t> graph_output_node_indexes_;
};
} // namespace mindspore::lite



+ 6
- 0
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -149,6 +149,12 @@ int SubGraphKernel::ReSize(bool is_interrupt) {
return RET_OK;
}

void SubGraphKernel::InitOutTensorInitRefCount() {
for (auto *node : nodes_) {
node->InitOutTensorInitRefCount();
}
}

int CpuSubGraph::Prepare() {
auto ret = SubGraphKernel::Prepare();
if (ret != RET_OK) {


+ 3
- 2
mindspore/lite/src/sub_graph_kernel.h View File

@@ -84,6 +84,8 @@ class SubGraphKernel : public LiteKernel {

int ReSize(bool is_interrupt);

void InitOutTensorInitRefCount() override;

std::string ToString() const override;

std::vector<LiteKernel *> nodes() { return this->nodes_; }
@@ -104,11 +106,10 @@ class CpuSubGraph : public SubGraphKernel {
const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx)
: SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) {
subgraph_type_ = kCpuFP32SubGraph;
this->executor_ = new (std::nothrow) mindspore::lite::Executor;
this->executor_ = new (std::nothrow) mindspore::lite::CpuExecutor;
}

~CpuSubGraph() override { delete this->executor_; }

int Prepare() override;
int Init() override { return SubGraphKernel::Init(); }
int PreProcess() override { return SubGraphKernel::PreProcess(); }


+ 8
- 0
mindspore/lite/src/tensor.h View File

@@ -110,8 +110,14 @@ class Tensor : public mindspore::tensor::MSTensor {

size_t ref_count() const { return this->ref_count_; }

size_t init_ref_count() const { return this->init_ref_count_; }

void set_ref_count(size_t ref_count) { this->ref_count_ = ref_count; }

void set_init_ref_count(size_t ref_count) { this->init_ref_count_ = ref_count; }

void ResetRefCount() { this->ref_count_ = this->init_ref_count_; }

void DecRefCount() { this->ref_count_--; }

std::string ToString() const;
@@ -156,6 +162,8 @@ class Tensor : public mindspore::tensor::MSTensor {
schema::Format format_;
Category category_;
size_t ref_count_ = 0;
size_t init_ref_count_ = 0;
size_t ready_count_ = 0;
std::vector<QuantArg> quant_params_;
std::vector<float> quant_clusters_;
mindspore::lite::Allocator *allocator_ = nullptr;


+ 1
- 1
mindspore/lite/src/train/train_session.cc View File

@@ -128,7 +128,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a
return lite::RET_NULL_PTR;
}
auto run_kernel = (train_mode_) ? train_kernels_ : inference_kernels_;
lite::Executor executor;
lite::CpuExecutor executor;
if (before == nullptr && after == nullptr) {
return executor.Run(this->inputs_, this->outputs_, run_kernel, this->context_->allocator.get());
} else {


+ 4
- 1
mindspore/lite/test/CMakeLists.txt View File

@@ -261,6 +261,8 @@ if (ENABLE_CONVERTER)
set(TEST_SRC
${TEST_SRC}
${TEST_DIR}/st/converter_test.cc
${TEST_DIR}/st/control_flow_test.cc
${TEST_DIR}/st/sub_graph_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc
@@ -300,7 +302,7 @@ endif ()


add_executable(lite-test ${TEST_SRC})
add_dependencies(lite-test fbs_src)
target_link_libraries(lite-test dl ${GTEST_LIBRARY})
if (PLATFORM_ARM64)
target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid)
@@ -321,6 +323,7 @@ if (SUPPORT_NPU)
target_link_libraries(lite-test npu_kernel_mid)
endif ()
if (ENABLE_CONVERTER)
add_dependencies(lite-test fbs_inner_src)
target_link_libraries(lite-test
anf_importer_mid
anf_exporter_mid


+ 1
- 1
mindspore/lite/test/models_tflite_posttraining.cfg View File

@@ -1,3 +1,3 @@
mobilenet.tflite 0.5
transformer_20200831_encoder_fp32.tflite 68
transformer_20200831_encoder_fp32.tflite 69
transformer_20200831_decoder_fp32.tflite 35

+ 459
- 0
mindspore/lite/test/st/control_flow_test.cc View File

@@ -0,0 +1,459 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cmath>
#include <memory>
#include "schema/inner/model_generated.h"
#include "mindspore/lite/include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "include/version.h"

namespace mindspore {
class ControlFlowTest : public mindspore::CommonTest {
public:
ControlFlowTest() {}
};

TEST_F(ControlFlowTest, TestMergeWhileModel) {
// make graph
auto meta_graph = std::make_shared<schema::MetaGraphT>();
MS_LOG(DEBUG) << "make subgraph";
meta_graph->name = "graph";
meta_graph->version = lite::Version();
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {9};
// subgraph 0 : main graph
auto sub_graph_0 = std::make_unique<schema::SubGraphT>();
sub_graph_0->name = "main_graph";

// subgraph 1 : cond graph
auto sub_graph_1 = std::make_unique<schema::SubGraphT>();
sub_graph_1->name = "cond_graph";

// subgraph 2: body graph
auto sub_graph_2 = std::make_unique<schema::SubGraphT>();
sub_graph_2->name = "body_graph";

MS_LOG(DEBUG) << "make subgraph";

// subgraph 0: node 0 before-add-1
auto sub_graph_0_node_0 = std::make_unique<schema::CNodeT>();
sub_graph_0_node_0->inputIndex = {0, 1};
sub_graph_0_node_0->outputIndex = {2};
sub_graph_0_node_0->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_0_node_0->primitive->value.type = schema::PrimitiveType_Add;
auto primitive_sub_graph_0_node_0 = new schema::AddT;
primitive_sub_graph_0_node_0->activationType = schema::ActivationType_NO_ACTIVATION;
sub_graph_0_node_0->primitive->value.value = primitive_sub_graph_0_node_0;
sub_graph_0_node_0->name = "before_Add_1";
meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_0));
sub_graph_0->nodeIndices.push_back(0);
MS_LOG(DEBUG) << "node 0";

// subgraph 0: node 1 before-add-1
auto sub_graph_0_node_1 = std::make_unique<schema::CNodeT>();
sub_graph_0_node_1->inputIndex = {2, 3};
sub_graph_0_node_1->outputIndex = {4};
sub_graph_0_node_1->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_0_node_1->primitive->value.type = schema::PrimitiveType_Add;
auto primitive_sub_graph_0_node_1 = new schema::AddT;
primitive_sub_graph_0_node_1->activationType = schema::ActivationType_NO_ACTIVATION;
sub_graph_0_node_1->primitive->value.value = primitive_sub_graph_0_node_1;
sub_graph_0_node_1->name = "before_Add_2";
meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_1));
sub_graph_0->nodeIndices.push_back(1);
MS_LOG(DEBUG) << "node 1";

// subgraph 0: node 2 merge
auto sub_graph_0_node_2 = std::make_unique<schema::CNodeT>();
sub_graph_0_node_2->inputIndex = {4, 17};
sub_graph_0_node_2->outputIndex = {16};
sub_graph_0_node_2->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_0_node_2->primitive->value.type = schema::PrimitiveType_Merge;
auto primitive_sub_graph_0_node_2 = new schema::MergeT;
sub_graph_0_node_2->primitive->value.value = primitive_sub_graph_0_node_2;
sub_graph_0_node_2->name = "merge";
meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_2));
sub_graph_0->nodeIndices.push_back(2);
MS_LOG(DEBUG) << "node 2";

// subgraph 0: node 3 partial cond subGraph
auto sub_graph_0_node_3 = std::make_unique<schema::CNodeT>();
sub_graph_0_node_3->inputIndex = {16};
sub_graph_0_node_3->outputIndex = {5}; // 5 : bool
sub_graph_0_node_3->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_0_node_3->primitive->value.type = schema::PrimitiveType_Partial;
auto primitive_sub_graph_0_node_3 = new schema::PartialT;
primitive_sub_graph_0_node_3->subGraphIndex = 1;
sub_graph_0_node_3->primitive->value.value = primitive_sub_graph_0_node_3;
sub_graph_0_node_3->name = "Partial_cond";
meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_3));
sub_graph_0->nodeIndices.push_back(3);
MS_LOG(DEBUG) << "node 2";

// subgraph 0: node 4 switch
auto sub_graph_0_node_4 = std::make_unique<schema::CNodeT>();
sub_graph_0_node_4->inputIndex = {5, 16}; // 5 : bool; 16 data
sub_graph_0_node_4->outputIndex = {6, 7};
sub_graph_0_node_4->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_0_node_4->primitive->value.type = schema::PrimitiveType_Switch;
auto primitive_sub_graph_0_node_4 = new schema::SwitchT;
sub_graph_0_node_4->primitive->value.value = primitive_sub_graph_0_node_4;
sub_graph_0_node_4->name = "Switch";
meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_4));
sub_graph_0->nodeIndices.push_back(4);
MS_LOG(DEBUG) << "node 4";

// subgraph 0: node 5 partial body subgraph
auto sub_graph_0_node_5 = std::make_unique<schema::CNodeT>();
sub_graph_0_node_5->inputIndex = {6};
sub_graph_0_node_5->outputIndex = {17};
sub_graph_0_node_5->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_0_node_5->primitive->value.type = schema::PrimitiveType_Partial;
auto primitive_sub_graph_0_node_5 = new schema::PartialT;
primitive_sub_graph_0_node_5->subGraphIndex = 2;
sub_graph_0_node_5->primitive->value.value = primitive_sub_graph_0_node_5;
sub_graph_0_node_5->name = "Partial_body";
meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_5));
sub_graph_0->nodeIndices.push_back(5);
MS_LOG(DEBUG) << "node 5";

// subgraph 0: node 6 add-after
auto sub_graph_0_node_6 = std::make_unique<schema::CNodeT>();
sub_graph_0_node_6->inputIndex = {7, 8};
sub_graph_0_node_6->outputIndex = {9};
sub_graph_0_node_6->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_0_node_6->primitive->value.type = schema::PrimitiveType_Add;
auto primitive_sub_graph_0_node_6 = new schema::AddT;
sub_graph_0_node_6->primitive->value.value = primitive_sub_graph_0_node_6;
sub_graph_0_node_6->name = "Add-after";
meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_6));
sub_graph_0->nodeIndices.push_back(6);
MS_LOG(DEBUG) << "node 6";

sub_graph_0->inputIndices = {0};
sub_graph_0->outputIndices = {9};
sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17};

meta_graph->subGraph.push_back(std::move(sub_graph_0));

// subgraph 1 ; node:0 add cond
auto sub_graph_1_node_0 = std::make_unique<schema::CNodeT>();
sub_graph_1_node_0->inputIndex = {16, 10};
sub_graph_1_node_0->outputIndex = {11};
sub_graph_1_node_0->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_1_node_0->primitive->value.type = schema::PrimitiveType_Add;
auto primitive_sub_graph_1_node_0 = new schema::AddT;
sub_graph_1_node_0->primitive->value.value = primitive_sub_graph_1_node_0;
sub_graph_1_node_0->name = "cond_add";
meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_0));
sub_graph_1->nodeIndices.push_back(7);
MS_LOG(DEBUG) << "node 6";

// subgraph 1 ; node:1 Less cond
auto sub_graph_1_node_1 = std::make_unique<schema::CNodeT>();
sub_graph_1_node_1->inputIndex = {11, 12};
sub_graph_1_node_1->outputIndex = {5};
sub_graph_1_node_1->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_1_node_1->primitive->value.type = schema::PrimitiveType_Less;
auto primitive_sub_graph_1_node_1 = new schema::LessT;
sub_graph_1_node_1->primitive->value.value = primitive_sub_graph_1_node_1;
sub_graph_1_node_1->name = "cond_Less";
meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_1));
sub_graph_1->nodeIndices.push_back(8);
MS_LOG(DEBUG) << "node 7";

sub_graph_1->inputIndices = {16};
sub_graph_1->outputIndices = {5};
sub_graph_1->tensorIndices = {16, 10, 11, 12, 5};
meta_graph->subGraph.push_back(std::move(sub_graph_1));

// subgraph 2 ; node:0 body add-1
auto sub_graph_2_node_0 = std::make_unique<schema::CNodeT>();
sub_graph_2_node_0->inputIndex = {6, 13};
sub_graph_2_node_0->outputIndex = {14};
sub_graph_2_node_0->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_2_node_0->primitive->value.type = schema::PrimitiveType_Add;
auto primitive_sub_graph_2_node_0 = new schema::AddT;
sub_graph_2_node_0->primitive->value.value = primitive_sub_graph_2_node_0;
sub_graph_2_node_0->name = "body_add_1";
meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_0));
sub_graph_2->nodeIndices.push_back(9);
MS_LOG(DEBUG) << "node 8";

// subgraph 2 ; node:1 body add-2
auto sub_graph_2_node_1 = std::make_unique<schema::CNodeT>();
sub_graph_2_node_1->inputIndex = {14, 15};
sub_graph_2_node_1->outputIndex = {17};
sub_graph_2_node_1->primitive = std::make_unique<schema::PrimitiveT>();
sub_graph_2_node_1->primitive->value.type = schema::PrimitiveType_Add;
auto primitive_sub_graph_2_node_1 = new schema::AddT;
sub_graph_2_node_1->primitive->value.value = primitive_sub_graph_2_node_1;
sub_graph_2_node_1->name = "body_add_2";
meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_1));
sub_graph_2->nodeIndices.push_back(10);
MS_LOG(DEBUG) << "node 9";

sub_graph_2->inputIndices = {6};
sub_graph_2->outputIndices = {17};
sub_graph_2->tensorIndices = {13, 14, 15, 6, 17};

meta_graph->subGraph.push_back(std::move(sub_graph_2));

// ------- tensor ---------
// tensor: 0 before-add input0 <main graph input>
auto tensor_0 = std::make_unique<schema::TensorT>();
tensor_0->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_0->format = schema::Format_NHWC;
tensor_0->dataType = TypeId::kNumberTypeFloat32;
tensor_0->dims = {1};
tensor_0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_0));
MS_LOG(DEBUG) << "tensor 0";

// tensor: 1 before-add input1 <const>
auto tensor_1 = std::make_unique<schema::TensorT>();
tensor_1->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_1->format = schema::Format_NHWC;
tensor_1->dataType = TypeId::kNumberTypeFloat32;
tensor_1->dims = {1};
tensor_1->data.resize(sizeof(float) * 1);
float input1_data[] = {1};
memcpy(tensor_1->data.data(), input1_data, sizeof(float) * 1);
tensor_1->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_1));
MS_LOG(DEBUG) << "tensor 1";

// tensor: 2 before-add output/partial input
auto tensor_2 = std::make_unique<schema::TensorT>();
tensor_2->nodeType = schema::NodeType::NodeType_Parameter;
tensor_2->format = schema::Format_NHWC;
tensor_2->dataType = TypeId::kNumberTypeFloat32;
tensor_2->dims = {1};
tensor_2->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_2));
MS_LOG(DEBUG) << "tensor 2";

// tensor: 3 before-add input1 <const>
auto tensor_3 = std::make_unique<schema::TensorT>();
tensor_3->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_3->format = schema::Format_NHWC;
tensor_3->dataType = TypeId::kNumberTypeFloat32;
tensor_3->dims = {1};
tensor_3->data.resize(sizeof(float) * 1);
float tensor_3_data[] = {1};
memcpy(tensor_3->data.data(), tensor_3_data, sizeof(float) * 1);
tensor_3->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_3));
MS_LOG(DEBUG) << "tensor 3";

auto tensor_4 = std::make_unique<schema::TensorT>();
tensor_4->nodeType = schema::NodeType::NodeType_Parameter;
tensor_4->format = schema::Format_NHWC;
tensor_4->dataType = TypeId::kNumberTypeFloat32;
tensor_4->dims = {1};
tensor_4->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_4));
MS_LOG(DEBUG) << "tensor 4";

// tensor :5 partial output <bool>
auto tensor_5 = std::make_unique<schema::TensorT>();
tensor_5->nodeType = schema::NodeType::NodeType_Parameter;
tensor_5->format = schema::Format_NHWC;
tensor_5->dataType = TypeId::kNumberTypeBool;
tensor_5->dims = {1};
tensor_5->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_5));
MS_LOG(DEBUG) << "tensor_4";

// tensor: 6 switch true output
auto tensor_6 = std::make_unique<schema::TensorT>();
tensor_6->nodeType = schema::NodeType::NodeType_Parameter;
tensor_6->format = schema::Format_NHWC;
tensor_6->dataType = TypeId::kNumberTypeFloat32;
tensor_6->dims = {1};
tensor_6->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_6));
MS_LOG(DEBUG) << "tensor 6";

// tensor: 5 switch False output
auto tensor_7 = std::make_unique<schema::TensorT>();
tensor_7->nodeType = schema::NodeType::NodeType_Parameter;
tensor_7->format = schema::Format_NHWC;
tensor_7->dataType = TypeId::kNumberTypeFloat32;
tensor_7->dims = {1};
tensor_7->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_7));
MS_LOG(DEBUG) << "tensor_7";

// tensor: 6 body-add input ,other input is switch true output
auto tensor_8 = std::make_unique<schema::TensorT>();
tensor_8->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_8->format = schema::Format_NHWC;
tensor_8->dataType = TypeId::kNumberTypeFloat32;
tensor_8->dims = {1};
tensor_8->data.resize(sizeof(float) * 1);
float tensor_8_data[] = {10};
memcpy(tensor_8->data.data(), tensor_8_data, sizeof(float) * 1);
tensor_8->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_8));
MS_LOG(DEBUG) << "tensor_8";

auto tensor_9 = std::make_unique<schema::TensorT>();
tensor_9->nodeType = schema::NodeType::NodeType_Parameter;
tensor_9->format = schema::Format_NHWC;
tensor_9->dataType = TypeId::kNumberTypeFloat32;
tensor_9->dims = {1};
tensor_9->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_9));
MS_LOG(DEBUG) << "tensor_9";

// tensor: 7 after-add input ,other input is switch false output
auto tensor_10 = std::make_unique<schema::TensorT>();
tensor_10->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_10->format = schema::Format_NHWC;
tensor_10->dataType = TypeId::kNumberTypeFloat32;
tensor_10->dims = {1};
tensor_10->data.resize(sizeof(float) * 1);
float tensor_10_data[] = {1};
memcpy(tensor_10->data.data(), tensor_10_data, sizeof(float) * 1);
tensor_10->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_10));
MS_LOG(DEBUG) << "tensor_10";

// tensor: 8 main graph output
auto tensor_11 = std::make_unique<schema::TensorT>();
tensor_11->nodeType = schema::NodeType::NodeType_Parameter;
tensor_11->format = schema::Format_NHWC;
tensor_11->dataType = TypeId::kNumberTypeFloat32;
tensor_11->dims = {1};
tensor_11->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_11));
MS_LOG(DEBUG) << "tensor 11";

// tensor: 9 cond-Less input, other input is tensor 2
auto tensor_12 = std::make_unique<schema::TensorT>();
tensor_12->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_12->format = schema::Format_NHWC;
tensor_12->dataType = TypeId::kNumberTypeFloat32;
tensor_12->dims = {1};
tensor_12->data.resize(sizeof(float) * 1);
float tensor_12_data[] = {10};
memcpy(tensor_12->data.data(), tensor_12_data, sizeof(float) * 1);
tensor_12->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_12));
MS_LOG(DEBUG) << "tensor_12";

auto tensor_13 = std::make_unique<schema::TensorT>();
tensor_13->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_13->format = schema::Format_NHWC;
tensor_13->dataType = TypeId::kNumberTypeFloat32;
tensor_13->dims = {1};
tensor_13->data.resize(sizeof(float) * 1);
float tensor_13_data[] = {1};
memcpy(tensor_13->data.data(), tensor_13_data, sizeof(float) * 1);
tensor_13->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_13));
MS_LOG(DEBUG) << "tensor_13";

auto tensor_14 = std::make_unique<schema::TensorT>();
tensor_14->nodeType = schema::NodeType::NodeType_Parameter;
tensor_14->format = schema::Format_NHWC;
tensor_14->dataType = TypeId::kNumberTypeFloat32;
tensor_14->dims = {1};
tensor_14->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_14));
MS_LOG(DEBUG) << "tensor 14";

auto tensor_15 = std::make_unique<schema::TensorT>();
tensor_15->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_15->format = schema::Format_NHWC;
tensor_15->dataType = TypeId::kNumberTypeFloat32;
tensor_15->dims = {1};
tensor_15->data.resize(sizeof(float) * 1);
float tensor_15_data[] = {1};
memcpy(tensor_15->data.data(), tensor_15_data, sizeof(float) * 1);
tensor_15->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_15));
MS_LOG(DEBUG) << "tensor_15";

auto tensor_16 = std::make_unique<schema::TensorT>();
tensor_16->nodeType = schema::NodeType::NodeType_Parameter;
tensor_16->format = schema::Format_NHWC;
tensor_16->dataType = TypeId::kNumberTypeFloat32;
tensor_16->dims = {1};
tensor_16->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_16));
MS_LOG(DEBUG) << "tensor_16";

auto tensor_17 = std::make_unique<schema::TensorT>();
tensor_17->nodeType = schema::NodeType::NodeType_Parameter;
tensor_17->format = schema::Format_NHWC;
tensor_17->dataType = TypeId::kNumberTypeFloat32;
tensor_17->dims = {1};
tensor_17->offset = -1;
meta_graph->allTensors.emplace_back(std::move(tensor_17));
MS_LOG(DEBUG) << "tensor_17";
// -----------------------------------------------------------------------

flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize();
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());

auto model = std::shared_ptr<lite::Model>(lite::Model::Import(content, size));
ASSERT_NE(model, nullptr);
lite::Context context;
context.thread_num_ = 2;
auto &cpu_device_ctx = context.device_list_[0];
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = false;
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
ASSERT_NE(session, nullptr);
auto ret = session->CompileGraph(model.get());
ASSERT_EQ(ret, lite::RET_OK);
model->Free();
auto inputs = session->GetInputs();
ASSERT_EQ(inputs.size(), 1);
auto input = inputs.front();
ASSERT_NE(input, nullptr);
ASSERT_EQ(input->data_type(), kNumberTypeFloat32);
ASSERT_EQ(input->shape().size(), 1);
ASSERT_EQ(input->shape().at(0), 1);
auto in_data = reinterpret_cast<float *>(input->MutableData());
ASSERT_NE(in_data, nullptr);
in_data[0] = 1;
ret = session->RunGraph();
ASSERT_EQ(ret, lite::RET_OK);
auto outputs = session->GetOutputs();
ASSERT_EQ(outputs.size(), 1);
auto output = outputs.begin()->second;
ASSERT_NE(output, nullptr);
ASSERT_EQ(output->data_type(), kNumberTypeFloat32);
ASSERT_EQ(output->shape().size(), 1);
ASSERT_EQ(output->shape().at(0), 1);
auto out_data = reinterpret_cast<float *>(output->MutableData());
ASSERT_NE(out_data, nullptr);
ASSERT_EQ(out_data[0], 19);
}
} // namespace mindspore

+ 217
- 0
mindspore/lite/test/st/sub_graph_test.cc View File

@@ -0,0 +1,217 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cmath>
#include <memory>
#include "schema/inner/model_generated.h"
#include "mindspore/lite/include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/model.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "src/runtime/parallel_executor.h"
#include "tools/common/storage.h"
#include "include/version.h"

namespace mindspore {
class SubGraphTest : public mindspore::CommonTest {
public:
SubGraphTest() {}
};

TEST_F(SubGraphTest, RecursiveSubGraphTest) {
// add0 partial1 2 3 tensor0 1 2
auto add_0 = std::make_unique<schema::CNodeT>();
add_0->inputIndex = {0, 1};
add_0->outputIndex = {2};
add_0->primitive = std::make_unique<schema::PrimitiveT>();
add_0->primitive->value.type = schema::PrimitiveType_Add;
auto add_0_prim = new schema::AddT;
add_0_prim->activationType = schema::ActivationType_NO_ACTIVATION;
add_0->primitive->value.value = add_0_prim;
add_0->name = "Add0";
auto partial_1 = std::make_unique<schema::CNodeT>();
partial_1->inputIndex = {2};
partial_1->outputIndex = {7};
partial_1->primitive = std::make_unique<schema::PrimitiveT>();
partial_1->primitive->value.type = schema::PrimitiveType_Partial;
auto partial_1_prim = new schema::PartialT;
partial_1_prim->subGraphIndex = 1;
partial_1->primitive->value.value = partial_1_prim;
partial_1->name = "Partial1";
auto partial_2 = std::make_unique<schema::CNodeT>();
partial_2->inputIndex = {2};
partial_2->outputIndex = {7};
partial_2->primitive = std::make_unique<schema::PrimitiveT>();
partial_2->primitive->value.type = schema::PrimitiveType_Partial;
auto partial_2_prim = new schema::PartialT;
partial_2_prim->subGraphIndex = 2;
partial_2->primitive->value.value = partial_2_prim;
partial_2->name = "Partial2";
auto partial_3 = std::make_unique<schema::CNodeT>();
partial_3->inputIndex = {4, 6};
partial_3->outputIndex = {7};
partial_3->primitive = std::make_unique<schema::PrimitiveT>();
partial_3->primitive->value.type = schema::PrimitiveType_Partial;
auto partial_3_prim = new schema::PartialT;
partial_3_prim->subGraphIndex = 3;
partial_3->primitive->value.value = partial_3_prim;
partial_3->name = "Partial3";
auto tensor_0 = std::make_unique<schema::TensorT>();
tensor_0->nodeType = schema::NodeType::NodeType_Parameter;
tensor_0->format = schema::Format_NHWC;
tensor_0->dataType = TypeId::kNumberTypeFloat32;
tensor_0->dims = {1, 2};
auto tensor_1 = std::make_unique<schema::TensorT>();
tensor_1->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_1->format = schema::Format_NHWC;
tensor_1->dataType = TypeId::kNumberTypeFloat32;
tensor_1->dims = {1, 2};
auto tensor_2 = std::make_unique<schema::TensorT>();
tensor_2->nodeType = schema::NodeType::NodeType_Parameter;
tensor_2->format = schema::Format_NHWC;
tensor_2->dataType = TypeId::kNumberTypeFloat32;
auto sub_graph_0 = std::make_unique<schema::SubGraphT>();
sub_graph_0->name = "main_graph";
sub_graph_0->inputIndices = {0};
sub_graph_0->outputIndices = {7};
sub_graph_0->nodeIndices = {0, 1, 2};
sub_graph_0->tensorIndices = {0, 1, 2, 7};
// add1 tensor3 4
auto add_1 = std::make_unique<schema::CNodeT>();
add_1->inputIndex = {2, 3};
add_1->outputIndex = {4};
add_1->primitive = std::make_unique<schema::PrimitiveT>();
add_1->primitive->value.type = schema::PrimitiveType_Add;
auto add_1_prim = new schema::AddT;
add_1_prim->activationType = schema::ActivationType_NO_ACTIVATION;
add_1->primitive->value.value = add_1_prim;
add_1->name = "Add1";
auto tensor_3 = std::make_unique<schema::TensorT>();
tensor_3->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_3->format = schema::Format_NHWC;
tensor_3->dataType = TypeId::kNumberTypeFloat32;
tensor_3->dims = {1, 2};
auto tensor_4 = std::make_unique<schema::TensorT>();
tensor_4->nodeType = schema::NodeType::NodeType_Parameter;
tensor_4->format = schema::Format_NHWC;
tensor_4->dataType = TypeId::kNumberTypeFloat32;
auto sub_graph_1 = std::make_unique<schema::SubGraphT>();
sub_graph_1->name = "sub_graph_1";
sub_graph_1->inputIndices = {2};
sub_graph_1->outputIndices = {7};
sub_graph_1->nodeIndices = {4, 3};
sub_graph_1->tensorIndices = {2, 3, 4, 7};
// add2 tensor5 6
auto add_2 = std::make_unique<schema::CNodeT>();
add_2->inputIndex = {2, 5};
add_2->outputIndex = {6};
add_2->primitive = std::make_unique<schema::PrimitiveT>();
add_2->primitive->value.type = schema::PrimitiveType_Add;
auto add_2_prim = new schema::AddT;
add_2_prim->activationType = schema::ActivationType_NO_ACTIVATION;
add_2->primitive->value.value = add_2_prim;
add_2->name = "Add2";
auto tensor_5 = std::make_unique<schema::TensorT>();
tensor_5->nodeType = schema::NodeType::NodeType_ValueNode;
tensor_5->format = schema::Format_NHWC;
tensor_5->dataType = TypeId::kNumberTypeFloat32;
tensor_5->dims = {1, 2};
auto tensor_6 = std::make_unique<schema::TensorT>();
tensor_6->nodeType = schema::NodeType::NodeType_Parameter;
tensor_6->format = schema::Format_NHWC;
tensor_6->dataType = TypeId::kNumberTypeFloat32;
auto sub_graph_2 = std::make_unique<schema::SubGraphT>();
sub_graph_2->name = "sub_graph_2";
sub_graph_2->inputIndices = {2};
sub_graph_2->outputIndices = {7};
sub_graph_2->nodeIndices = {5, 3};
sub_graph_2->tensorIndices = {2, 5, 6, 7};
// add3 tensor7
auto add_3 = std::make_unique<schema::CNodeT>();
add_3->inputIndex = {4, 6};
add_3->outputIndex = {7};
add_3->primitive = std::make_unique<schema::PrimitiveT>();
add_3->primitive->value.type = schema::PrimitiveType_Add;
auto add_3_prim = new schema::AddT;
add_3_prim->activationType = schema::ActivationType_NO_ACTIVATION;
add_3->primitive->value.value = add_3_prim;
add_3->name = "Add3";
auto tensor_7 = std::make_unique<schema::TensorT>();
tensor_7->nodeType = schema::NodeType::NodeType_Parameter;
tensor_7->format = schema::Format_NHWC;
tensor_7->dataType = TypeId::kNumberTypeFloat32;
auto sub_graph_3 = std::make_unique<schema::SubGraphT>();
sub_graph_3->name = "sub_graph_3";
sub_graph_3->inputIndices = {4, 6};
sub_graph_3->outputIndices = {7};
sub_graph_3->nodeIndices = {6};
sub_graph_3->tensorIndices = {4, 6, 7};

// make graph
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
meta_graph->nodes.emplace_back(std::move(add_0));
meta_graph->nodes.emplace_back(std::move(partial_1));
meta_graph->nodes.emplace_back(std::move(partial_2));
meta_graph->nodes.emplace_back(std::move(partial_3));
meta_graph->nodes.emplace_back(std::move(add_1));
meta_graph->nodes.emplace_back(std::move(add_2));
meta_graph->nodes.emplace_back(std::move(add_3));
meta_graph->allTensors.emplace_back(std::move(tensor_0));
meta_graph->allTensors.emplace_back(std::move(tensor_1));
meta_graph->allTensors.emplace_back(std::move(tensor_2));
meta_graph->allTensors.emplace_back(std::move(tensor_3));
meta_graph->allTensors.emplace_back(std::move(tensor_4));
meta_graph->allTensors.emplace_back(std::move(tensor_5));
meta_graph->allTensors.emplace_back(std::move(tensor_6));
meta_graph->allTensors.emplace_back(std::move(tensor_7));
meta_graph->subGraph.emplace_back(std::move(sub_graph_0));
meta_graph->subGraph.emplace_back(std::move(sub_graph_1));
meta_graph->subGraph.emplace_back(std::move(sub_graph_2));
meta_graph->subGraph.emplace_back(std::move(sub_graph_3));
meta_graph->version = lite::Version();
// -----------------------------------------------------------------------
lite::Storage::Save(*meta_graph,
"/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph");
// -----------------------------------------------------------------------
size_t size = 0;
char *graph_buf = lite::ReadFile(
"/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size);
ASSERT_NE(graph_buf, nullptr);

auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size));
ASSERT_NE(model, nullptr);
delete[](graph_buf);
lite::Context context;
auto &cpu_device_ctx = context.device_list_[0];
cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
context.thread_num_ = 2;
auto session = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(&context));
ASSERT_NE(session, nullptr);
auto ret = session->CompileGraph(model.get());
ASSERT_EQ(ret, lite::RET_OK);
auto inputs = session->GetInputs();
for (auto *input : inputs) {
(void)input->MutableData();
}
ret = session->RunGraph();
ASSERT_EQ(ret, lite::RET_OK);
}
} // namespace mindspore

+ 0
- 1
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -142,7 +142,6 @@ add_executable(converter_lite
${KERNEL_SRC}
${LITE_SRC}
)
add_dependencies(converter_lite tflite_fbs_src)
add_dependencies(converter_lite fbs_src)
add_dependencies(converter_lite fbs_inner_src)



+ 1
- 0
mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt View File

@@ -5,4 +5,5 @@ set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID
add_library(tflite_parser_mid OBJECT
${TFLITE_SRC_LIST}
)
add_dependencies(tflite_parser_mid tflite_fbs_src)
target_link_libraries(tflite_parser_mid mindspore::flatbuffers)

Loading…
Cancel
Save