| @@ -116,7 +116,7 @@ class MS_API LiteSession { | |||
| /// \param[in] inputs Define the new inputs shape. | |||
| /// | |||
| /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. | |||
| virtual int Resize(const std::vector<tensor::MSTensor *> &inputs) = 0; | |||
| virtual int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>>& dims) = 0; | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -389,34 +389,51 @@ std::unordered_map<std::string, mindspore::tensor::MSTensor *> LiteSession::GetO | |||
| return this->output_tensor_map_; | |||
| } | |||
| int LiteSession::ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs) { | |||
| int LiteSession::ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, | |||
| const std::vector<std::vector<int>> &dims) { | |||
| if (inputs.size() != inputs_.size()) { | |||
| MS_LOG(ERROR) << "Inputs size " << inputs.size() << " is not equal to " << inputs_.size(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (dims.size() != inputs.size()) { | |||
| MS_LOG(ERROR) << "Input dims size " << dims.size() << " is not equal to the inputs size " << inputs.size(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (inputs[i] == nullptr) { | |||
| MS_LOG(ERROR) << "Input tensor is nullptr!"; | |||
| if (inputs[i] != inputs_[i]) { | |||
| MS_LOG(ERROR) << "Input[" << i << "] tensor is not equal to the inputs have been saved!"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| inputs_[i]->set_shape(inputs[i]->shape()); | |||
| inputs_[i]->set_shape(dims[i]); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs) { | |||
| std::vector<Tensor *> inputs_old(inputs_); | |||
| auto ret = ResizeInputs(inputs); | |||
| void LiteSession::ResetInputsShape(const std::vector<std::vector<int>> &dims) { | |||
| for (size_t i = 0; i < inputs_.size(); ++i) { | |||
| inputs_[i]->set_shape(dims[i]); | |||
| } | |||
| } | |||
| int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs, | |||
| const std::vector<std::vector<int>> &dims) { | |||
| std::vector<std::vector<int>> old_dims; | |||
| for (size_t i = 0; i < inputs_.size(); ++i) { | |||
| old_dims.push_back(inputs_[i]->shape()); | |||
| } | |||
| auto ret = ResizeInputs(inputs, dims); | |||
| if (ret != RET_OK) { | |||
| inputs_ = inputs_old; | |||
| ResetInputsShape(old_dims); | |||
| return ret; | |||
| } | |||
| Scheduler scheduler(context_); | |||
| ret = scheduler.ReSizeKernels(kernels_); | |||
| if (ret != RET_OK) { | |||
| inputs_ = inputs_old; | |||
| ResetInputsShape(old_dims); | |||
| auto resize_ret = scheduler.ReSizeKernels(kernels_); | |||
| if (resize_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret; | |||
| @@ -59,7 +59,8 @@ class LiteSession : public session::LiteSession { | |||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override; | |||
| int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs) override; | |||
| int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs, | |||
| const std::vector<std::vector<int>> &dims) override; | |||
| protected: | |||
| int ConvertTensors(const lite::Model *model); | |||
| @@ -80,7 +81,11 @@ class LiteSession : public session::LiteSession { | |||
| void InitGraphOutputTensorMap(const lite::Model *model); | |||
| int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs); | |||
| int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, | |||
| const std::vector<std::vector<int>> &dims); | |||
| private: | |||
| void ResetInputsShape(const std::vector<std::vector<int>> &dims); | |||
| protected: | |||
| Context *context_ = nullptr; | |||
| @@ -52,6 +52,7 @@ int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors | |||
| } | |||
| int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||
| bool infer_shape_interrupt = false; | |||
| for (size_t i = 0; i < kernels.size(); ++i) { | |||
| if (kernels[i] == nullptr) { | |||
| MS_LOG(ERROR) << "input kernel is nullptr!"; | |||
| @@ -64,15 +65,25 @@ int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) { | |||
| } | |||
| std::vector<Tensor *> &inputs = kernels[i]->in_tensors(); | |||
| std::vector<Tensor *> &outputs = kernels[i]->out_tensors(); | |||
| primitive->SetInferFlag(!infer_shape_interrupt); | |||
| auto ret = primitive->InferShape(inputs, outputs); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "InferShape failed, name: " << kernels[i]->name() << ", ret = " << ret; | |||
| return ret; | |||
| if (ret == RET_INFER_INVALID) { | |||
| MS_LOG(INFO) << "InferShape shouldn't be done before runtime, type:" | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())) | |||
| << "flag set to false."; | |||
| primitive->SetInferFlag(false); | |||
| infer_shape_interrupt = true; | |||
| } else if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "InferShape failed, type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())); | |||
| return RET_INFER_ERR; | |||
| } | |||
| ret = kernels[i]->ReSize(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "kernel " << kernels[i]->name() << " resize fail!ret = " << ret; | |||
| return ret; | |||
| if (!infer_shape_interrupt) { | |||
| ret = kernels[i]->ReSize(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "kernel " << kernels[i]->name() << " resize fail!ret = " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||