Browse Source

[MSLITE][DEVELOP] add input and output tensors for DelegateModel

tags/v1.4.0
yangruoqi713 4 years ago
parent
commit
93be34cacc
4 changed files with 30 additions and 8 deletions
  1. +11
    -4
      mindspore/lite/include/delegate.h
  2. +1
    -1
      mindspore/lite/src/lite_session.cc
  3. +11
    -1
      mindspore/lite/src/scheduler.cc
  4. +7
    -2
      mindspore/lite/src/scheduler.h

+ 11
- 4
mindspore/lite/include/delegate.h View File

@@ -29,9 +29,10 @@ using KernelIter = std::vector<kernel::Kernel *>::iterator;
class DelegateModel {
public:
/// \brief Constructor of MindSpore Lite DelegateModel.
DelegateModel(std::vector<kernel::Kernel *> *kernels,
const std::map<kernel::Kernel *, const schema::Primitive *> primitives)
: kernels_(kernels), primitives_(primitives) {}
DelegateModel(std::vector<kernel::Kernel *> *kernels, const std::vector<tensor::MSTensor *> &inputs,
const std::vector<tensor::MSTensor *> &outputs,
const std::map<kernel::Kernel *, const schema::Primitive *> &primitives)
: kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives) {}

/// \brief Destructor of MindSpore Lite DelegateModel.
~DelegateModel() = default;
@@ -61,9 +62,15 @@ class DelegateModel {
/// \return The next iterator after graph_kernel, point to the next kernel that is not visited.
KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel);

const std::vector<mindspore::tensor::MSTensor *> &inputs() { return this->inputs_; }

const std::vector<mindspore::tensor::MSTensor *> &outputs() { return this->outputs_; }

protected:
std::vector<kernel::Kernel *> *kernels_;
const std::map<kernel::Kernel *, const schema::Primitive *> primitives_;
const std::vector<mindspore::tensor::MSTensor *> &inputs_;
const std::vector<mindspore::tensor::MSTensor *> &outputs_;
const std::map<kernel::Kernel *, const schema::Primitive *> &primitives_;
};

typedef void (*DelegateHook)(std::shared_ptr<Delegate> delegate);


+ 1
- 1
mindspore/lite/src/lite_session.cc View File

@@ -501,7 +501,7 @@ int LiteSession::CompileGraph(Model *model) {
return ret;
}
// scheduler kernels
Scheduler scheduler(context_, model, &tensors_, is_train_session_, delegate_);
Scheduler scheduler(context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_);
scheduler.SetupSchedulerCb(std::move(sched_cb_));
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {


+ 11
- 1
mindspore/lite/src/scheduler.cc View File

@@ -152,7 +152,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
for (size_t i = 0; i < dst_kernels->size(); i++) {
kernels.push_back((*dst_kernels)[i]->kernel());
}
DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, primitives_);

std::vector<tensor::MSTensor *> input_ms_tensors;
input_ms_tensors.resize(inputs_.size());
(void)std::transform(inputs_.begin(), inputs_.end(), input_ms_tensors.begin(),
[](lite::Tensor *tensor) { return reinterpret_cast<tensor::MSTensor *>(tensor); });
std::vector<tensor::MSTensor *> output_ms_tensors;
output_ms_tensors.resize(outputs_.size());
(void)std::transform(outputs_.begin(), outputs_.end(), output_ms_tensors.begin(),
[](lite::Tensor *tensor) { return reinterpret_cast<tensor::MSTensor *>(tensor); });

DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, input_ms_tensors, output_ms_tensors, primitives_);
if (model == nullptr) {
MS_LOG(ERROR) << "New delegate model failed.";
return RET_NULL_PTR;


+ 7
- 2
mindspore/lite/src/scheduler.h View File

@@ -31,11 +31,14 @@
namespace mindspore::lite {
class Scheduler {
public:
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session,
std::shared_ptr<Delegate> delegate = nullptr)
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors,
bool is_train_session, std::shared_ptr<Delegate> delegate = nullptr)
: context_(ctx),
src_model_(src_model),
src_tensors_(src_tensors),
inputs_(input_tensors),
outputs_(output_tensors),
is_train_session_(is_train_session),
delegate_(delegate) {}
~Scheduler() = default;
@@ -113,6 +116,8 @@ class Scheduler {
const InnerContext *context_ = nullptr;
Model *src_model_ = nullptr;
std::vector<Tensor *> *src_tensors_;
const std::vector<Tensor *> &inputs_;
const std::vector<Tensor *> &outputs_;
std::vector<size_t> graph_output_node_indexes_;
std::map<int, OpParameter *> op_parameters_;
bool is_train_session_ = false;


Loading…
Cancel
Save