| @@ -64,6 +64,14 @@ class MS_API Model { | |||
| /// \return Status. | |||
| Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims); | |||
| /// \brief Change the size and or content of weight tensors | |||
| /// | |||
| /// \param[in] new_weights a vector of tensors with new shapes and data to use in the model | |||
| /// If data pointer is null, the data of the original tensors will be copied to the new ones | |||
| /// | |||
| /// \return Status. | |||
| Status UpdateWeights(const std::vector<MSTensor> &new_weights); | |||
| /// \brief Inference model. | |||
| /// | |||
| /// \param[in] inputs A vector where model inputs are arranged in sequence. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H | |||
| #define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H | |||
| #ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_ | |||
| #define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_ | |||
| #ifndef NOT_USE_STL | |||
| #include <unordered_map> | |||
| @@ -190,6 +190,14 @@ class MS_API LiteSession { | |||
| return mindspore::lite::RET_ERROR; | |||
| } | |||
| /// \brief Change the size and or content of weight tensors | |||
| /// | |||
| /// \param[in] new_weights a vector of tensors with new shapes and data to use in the model | |||
| /// If data pointer is null, the data of the original tensors will be copied to the new ones | |||
| /// | |||
| /// \return STATUS as an error code of operation, STATUS is defined in errorcode.h. | |||
| virtual int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) { return mindspore::lite::RET_ERROR; } | |||
| /// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction | |||
| /// | |||
| /// \return a vector of output tensors (MindSpore Lite MSTensor). | |||
| @@ -233,4 +241,4 @@ class MS_API LiteSession { | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H | |||
| #endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_ | |||
| @@ -102,6 +102,14 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std: | |||
| return impl_->Resize(inputs, dims); | |||
| } | |||
| Status Model::UpdateWeights(const std::vector<MSTensor> &new_weights) { | |||
| if (impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Model implement is null."; | |||
| return kLiteNullptr; | |||
| } | |||
| return impl_->UpdateWeights(new_weights); | |||
| } | |||
| Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, | |||
| const MSKernelCallBack &before, const MSKernelCallBack &after) { | |||
| if (impl_ == nullptr) { | |||
| @@ -559,6 +559,29 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector< | |||
| return static_cast<StatusCode>(ret); | |||
| } | |||
| Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) { | |||
| if (session_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session is null."; | |||
| return kLiteNullptr; | |||
| } | |||
| if (new_weights.empty()) { | |||
| MS_LOG(ERROR) << "New weights are empty."; | |||
| return kLiteInputParamInvalid; | |||
| } | |||
| std::vector<tensor::MSTensor *> inner_weights; | |||
| inner_weights.resize(new_weights.size()); | |||
| for (size_t i = 0; i < new_weights.size(); i++) { | |||
| auto weight = new_weights[i]; | |||
| if (weight.impl_ == nullptr || weight.impl_->lite_tensor() == nullptr) { | |||
| MS_LOG(ERROR) << "Input tensor " << weight.Name() << " is null."; | |||
| return kLiteInputTensorError; | |||
| } | |||
| inner_weights[i] = weight.impl_->lite_tensor(); | |||
| } | |||
| auto ret = session_->UpdateWeights(inner_weights); | |||
| return static_cast<StatusCode>(ret); | |||
| } | |||
| session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) { | |||
| auto session = new (std::nothrow) lite::LiteSession(); | |||
| if (session == nullptr) { | |||
| @@ -63,6 +63,7 @@ class ModelImpl { | |||
| const std::shared_ptr<Context> &model_context); | |||
| Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context); | |||
| Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims); | |||
| Status UpdateWeights(const std::vector<MSTensor> &new_weights); | |||
| Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before, | |||
| const MSKernelCallBack &after); | |||
| @@ -91,7 +91,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType ty | |||
| return nullptr; | |||
| } | |||
| if (data_len > 0 && data == nullptr) { | |||
| MS_LOG(ERROR) << "Mull data ptr of tensor."; | |||
| MS_LOG(ERROR) << "Null data ptr of tensor."; | |||
| return nullptr; | |||
| } | |||
| auto impl = Impl::CreateTensorImpl(CharToString(name), type, shape, nullptr, data_len); | |||
| @@ -28,7 +28,7 @@ using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; | |||
| namespace mindspore::kernel { | |||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; } | |||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return Prepare(); } | |||
| int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, | |||
| float *output) const { | |||
| @@ -50,8 +50,6 @@ int StridedSliceGradCPUKernel::Prepare() { | |||
| MS_LOG(ERROR) << "Not supported data type: " << input->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| FillEmptyDims(); | |||
| FillOutputDim(); | |||
| return ReSize(); | |||
| } | |||
| @@ -113,7 +111,11 @@ void StridedSliceGradCPUKernel::FillOutputDim() { | |||
| } | |||
| } | |||
| int StridedSliceGradCPUKernel::ReSize() { return RET_OK; } | |||
| int StridedSliceGradCPUKernel::ReSize() { | |||
| FillEmptyDims(); | |||
| FillOutputDim(); | |||
| return RET_OK; | |||
| } | |||
| int StridedSliceGradImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) { | |||
| CHECK_NULL_RETURN(cdata); | |||
| @@ -176,6 +176,89 @@ int TrainSession::InitCallBack() { | |||
| return RET_OK; | |||
| } | |||
| static int ReshapeWeightTensor(Tensor *orig_tensor, tensor::MSTensor *new_tensor) { | |||
| if (orig_tensor->data_type() != new_tensor->data_type()) { | |||
| MS_LOG(ERROR) << "Cannot reshape tensor of different type: " << new_tensor->tensor_name(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (orig_tensor->category() != lite::Category::CONST_TENSOR) { | |||
| MS_LOG(ERROR) << "Cannot reshape non const tensor: " << new_tensor->tensor_name(); | |||
| return RET_ERROR; | |||
| } | |||
| auto orig_size = orig_tensor->Size(); | |||
| uint8_t *new_data = reinterpret_cast<uint8_t *>(new_tensor->data()); | |||
| if (new_data == nullptr) { | |||
| // Copy original data into new_tensor | |||
| new_data = reinterpret_cast<uint8_t *>(new_tensor->MutableData()); | |||
| if (new_data == nullptr) { | |||
| MS_LOG(ERROR) << "Allocation of Data Failed" << new_tensor->tensor_name(); | |||
| return RET_ERROR; | |||
| } | |||
| if (orig_size == 0) { | |||
| MS_LOG(ERROR) << "Operation failed: Both new tensors and original one have no data"; | |||
| return RET_ERROR; | |||
| } | |||
| uint8_t *orig_data = reinterpret_cast<uint8_t *>(orig_tensor->data()); | |||
| for (unsigned int loc = 0; loc < new_tensor->Size(); loc++) { | |||
| new_data[loc] = orig_data[loc % orig_size]; | |||
| } | |||
| } | |||
| orig_tensor->FreeData(); | |||
| orig_tensor->set_data(nullptr); | |||
| orig_tensor->set_shape(new_tensor->shape()); | |||
| uint8_t *dst_data = reinterpret_cast<uint8_t *>(orig_tensor->MutableData()); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Allocation of Data Failed"; | |||
| return RET_ERROR; | |||
| } | |||
| std::copy(new_data, new_data + orig_tensor->Size(), dst_data); | |||
| return RET_OK; | |||
| } | |||
| int TrainSession::UpdateWeights(std::vector<tensor::MSTensor *> modify_tensors) { | |||
| unsigned int num_of_found_tensors = 0; | |||
| for (auto tensor : tensors_) { | |||
| for (auto modify : modify_tensors) { | |||
| if (modify == nullptr) { | |||
| MS_LOG(ERROR) << "Tensor is nullptr"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (modify->tensor_name() == tensor->tensor_name()) { | |||
| auto ret = ReshapeWeightTensor(tensor, modify); | |||
| num_of_found_tensors++; | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (num_of_found_tensors != modify_tensors.size()) { | |||
| MS_LOG(ERROR) << "Did not find all the given tensors in the model"; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = ReSizeKernels(kernels_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Resize kernels fail!"; | |||
| return ret; | |||
| } | |||
| bool is_eval = IsEval(); | |||
| ret = Train(); // This will trigger proper Allocation of static data; | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "General failure occurred during Update of Weights"; | |||
| return ret; | |||
| } | |||
| if (is_eval) { | |||
| ret = Eval(); | |||
| } | |||
| return ret; | |||
| } | |||
| int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) { | |||
| if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK; | |||
| OptAllocator allocator; | |||
| @@ -199,8 +282,12 @@ int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) | |||
| } | |||
| } | |||
| // Set Tensor data | |||
| auto size = allocator.total_size(); | |||
| if (size > tensors_data_size_) { | |||
| free(tensors_data_); | |||
| tensors_data_ = nullptr; | |||
| } | |||
| if (tensors_data_ == nullptr) { | |||
| auto size = allocator.total_size(); | |||
| auto buf = malloc(size); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "cannot allocate buffer size" << size; | |||
| @@ -209,6 +296,7 @@ int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) | |||
| StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get()); | |||
| alloc->SetContex(buf, size); | |||
| tensors_data_ = buf; | |||
| tensors_data_size_ = size; | |||
| } | |||
| for (auto kernel : train_kernels_) { | |||
| for (auto tensor : kernel->out_tensors()) { | |||
| @@ -85,6 +85,7 @@ class TrainSession : virtual public lite::LiteSession { | |||
| return lite::LiteSession::GetOutputByTensorName(tensor_name); | |||
| } | |||
| int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override; | |||
| int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) override; | |||
| std::vector<tensor::MSTensor *> GetPredictions() const override { | |||
| std::vector<tensor::MSTensor *> outputs; | |||
| @@ -166,6 +167,7 @@ class TrainSession : virtual public lite::LiteSession { | |||
| SchedCallBack sched_mix_precision_callback_; | |||
| bool train_mode_ = false; | |||
| void *tensors_data_ = nullptr; | |||
| unsigned int tensors_data_size_ = 0; | |||
| std::shared_ptr<Allocator> allocator_; | |||
| }; | |||
| @@ -229,4 +229,29 @@ TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) { | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| } | |||
| #define NUM_OF_CLASSES 10 | |||
| #define FEATURE_SIZE 10 | |||
| TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) { | |||
| Model model; | |||
| Graph graph; | |||
| auto context = std::make_shared<Context>(); | |||
| auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>(); | |||
| cpu_context->SetEnableFP16(true); | |||
| context->MutableDeviceInfo().push_back(cpu_context); | |||
| auto train_cfg = std::make_shared<TrainCfg>(); | |||
| train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true; | |||
| ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess); | |||
| ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess); | |||
| std::vector<mindspore::MSTensor> changes; | |||
| ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess); | |||
| changes.push_back( | |||
| *MSTensor::CreateTensor("fc4.weight", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0)); | |||
| ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess); | |||
| changes.clear(); | |||
| changes.push_back( | |||
| *MSTensor::CreateTensor("fc3.bias", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0)); | |||
| ASSERT_TRUE(model.UpdateWeights(changes) == kSuccess); | |||
| } | |||
| } // namespace mindspore | |||