Browse Source

!10295 [MS_LITE] fix graph input

From: @YeFeng_24
Reviewed-by: @hangangqiang
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f2b25d4139
5 changed files with 17 additions and 2 deletions
  1. +5
    -0
      mindspore/lite/src/executor.cc
  2. +1
    -1
      mindspore/lite/src/lite_kernel.cc
  3. +6
    -0
      mindspore/lite/src/lite_session.cc
  4. +5
    -0
      mindspore/lite/src/tensor.h
  5. +0
    -1
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc

+ 5
- 0
mindspore/lite/src/executor.cc View File

@@ -49,6 +49,11 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_
MS_LOG(ERROR) << "CheckInputs failed"; MS_LOG(ERROR) << "CheckInputs failed";
return ret; return ret;
} }
MS_ASSERT(std::all_of(kernels.begin(), kernels.end(), [](kernel::LiteKernel *kernel) {
return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().end(), [](Tensor *in_tensor) {
return in_tensor->IsConst() || in_tensor->IsGraphInput() || in_tensor->ref_count() == 0;
});
}));
std::queue<kernel::LiteKernel *> kernel_queue; std::queue<kernel::LiteKernel *> kernel_queue;
for (auto kernel : kernels) { for (auto kernel : kernels) {
if (kernel->IsReady(kernel->in_tensors())) { if (kernel->IsReady(kernel->in_tensors())) {


+ 1
- 1
mindspore/lite/src/lite_kernel.cc View File

@@ -231,7 +231,7 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputNodes(const std::
auto all_input_tensors = kernel->in_tensors(); auto all_input_tensors = kernel->in_tensors();
// remove all const tensor from input tensors // remove all const tensor from input tensors
for (auto iter = all_input_tensors.begin(); iter != all_input_tensors.end();) { for (auto iter = all_input_tensors.begin(); iter != all_input_tensors.end();) {
if ((*iter)->IsConst() || (*iter)->IsGraphInput()) {
if ((*iter)->IsConst()) {
iter = all_input_tensors.erase(iter); iter = all_input_tensors.erase(iter);
} else { } else {
iter++; iter++;


+ 6
- 0
mindspore/lite/src/lite_session.cc View File

@@ -179,6 +179,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
if (IsContain(model_input_indices, i)) { if (IsContain(model_input_indices, i)) {
dst_tensor->set_category(Tensor::GRAPH_INPUT); dst_tensor->set_category(Tensor::GRAPH_INPUT);
} }
if (src_tensor->name() != nullptr) {
dst_tensor->set_tensor_name(src_tensor->name()->str());
}
this->tensors_.emplace_back(dst_tensor); this->tensors_.emplace_back(dst_tensor);
} }
return RET_OK; return RET_OK;
@@ -306,6 +309,9 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
return; return;
} }
this->output_tensor_map_.insert(std::make_pair(std::to_string(graph_out_index), out_tensor)); this->output_tensor_map_.insert(std::make_pair(std::to_string(graph_out_index), out_tensor));
if (!out_tensor->tensor_name().empty()) {
this->output_tensor_map_.insert(std::make_pair(out_tensor->tensor_name(), out_tensor));
}
} }
} }




+ 5
- 0
mindspore/lite/src/tensor.h View File

@@ -65,6 +65,10 @@ class Tensor : public mindspore::tensor::MSTensor {


virtual bool operator==(const Tensor &tensor); virtual bool operator==(const Tensor &tensor);


void set_tensor_name(std::string name) { tensor_name_ = name; }

std::string tensor_name() const { return tensor_name_; }

TypeId data_type() const override { return data_type_; } TypeId data_type() const override { return data_type_; }


void set_data_type(TypeId data_type) { data_type_ = data_type; } void set_data_type(TypeId data_type) { data_type_ = data_type; }
@@ -162,6 +166,7 @@ class Tensor : public mindspore::tensor::MSTensor {
} }


protected: protected:
std::string tensor_name_;
void *data_ = nullptr; void *data_ = nullptr;
void *device_data_ = nullptr; void *device_data_ = nullptr;
TypeId data_type_; TypeId data_type_;


+ 0
- 1
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc View File

@@ -185,7 +185,6 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input-" + std::to_string(i)); parameter->set_name("graph_input-" + std::to_string(i));
nodes_.insert(std::pair(layer.top(0), parameter)); nodes_.insert(std::pair(layer.top(0), parameter));
return RET_OK;
} }
} }




Loading…
Cancel
Save