| @@ -26,7 +26,7 @@ context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs | |||
| n = mobilenet_v3_small(num_classes=10) | |||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean') | |||
| optimizer = nn.Adam(n.trainable_params(), learning_rate=1e-2, beta1=0.5, beta2=0.7, eps=1e-2, use_locking=True, | |||
| optimizer = nn.Adam(n.trainable_params(), learning_rate=1e-3, beta1=0.5, beta2=0.7, eps=1e-2, use_locking=True, | |||
| use_nesterov=False, weight_decay=0.1, loss_scale=0.3) | |||
| net = TrainWrap(n, loss_fn, optimizer) | |||
| @@ -22,6 +22,7 @@ | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include "include/context.h" | |||
| #include "include/lite_session.h" | |||
| #include "src/utils.h" | |||
| static unsigned int seed = time(NULL); | |||
| @@ -122,14 +123,14 @@ std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dat | |||
| return labels_vec; | |||
| } | |||
| float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) const { | |||
| float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset, | |||
| mindspore::session::LiteSession *session) const { | |||
| float accuracy = 0.0; | |||
| int tests = dataset.size() / batch_size_; | |||
| session_->Eval(); | |||
| for (int i = 0; i < tests; i++) { | |||
| auto labels = FillInputData(dataset, i); | |||
| session_->RunGraph(); | |||
| session->RunGraph(); | |||
| auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_); | |||
| MS_ASSERT(outputsv != nullptr); | |||
| auto scores = reinterpret_cast<float *>(outputsv->MutableData()); | |||
| @@ -145,7 +146,6 @@ float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) c | |||
| if (labels[b] == max_idx) accuracy += 1.0; | |||
| } | |||
| } | |||
| session_->Train(); | |||
| accuracy /= static_cast<float>(batch_size_ * tests); | |||
| return accuracy; | |||
| } | |||
| @@ -192,7 +192,9 @@ int NetRunner::TrainLoop() { | |||
| std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl; | |||
| if ((i + 1) % 20 == 0) { | |||
| float acc = CalculateAccuracy(ds_.test_data()); | |||
| session_->Eval(); | |||
| float acc = CalculateAccuracy(ds_.test_data(), session_); | |||
| session_->Train(); | |||
| if (max_acc < acc) max_acc = acc; | |||
| std::cout << "accuracy on test data = " << acc << " max accuracy = " << max_acc << std::endl; | |||
| if (acc > 0.9) return 0; | |||
| @@ -207,26 +209,34 @@ int NetRunner::Main() { | |||
| InitDB(); | |||
| TrainLoop(); | |||
| float acc = CalculateAccuracy(ds_.val_data()); | |||
| session_->Eval(); | |||
| float acc = CalculateAccuracy(ds_.val_data(), session_); | |||
| std::cout << "accuracy on validation data = " << acc << std::endl; | |||
| if (cycles_ > 0 && head_model_ != nullptr) { | |||
| auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms"; | |||
| mindspore::lite::Model::Export(head_model_, trained_fn.c_str()); | |||
| } | |||
| if (!save_inference_.empty()) { | |||
| int status = session_->ExportInference(save_inference_); | |||
| if (status != mindspore::lite::RET_OK) { | |||
| std::cout << "Failed to save inference file"; | |||
| return mindspore::lite::RET_ERROR; | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| void NetRunner::Usage() { | |||
| std::cout << "Usage: net_runner -f <.ms head model file> -b <.ms backbone model file> -d <data_dir> " | |||
| << "[-c <num of training cycles>] [-v (verbose mode)] " | |||
| << "[-s <save checkpoint every X iterations>]" << std::endl; | |||
| << "[-s <save checkpoint every X iterations>]" | |||
| << "[-i <save inference file>]" << std::endl; | |||
| } | |||
| bool NetRunner::ReadArgs(int argc, char *argv[]) { | |||
| int opt; | |||
| while ((opt = getopt(argc, argv, "b:f:e:d:s:ihc:v")) != -1) { | |||
| while ((opt = getopt(argc, argv, "b:f:e:d:s:i:hc:v")) != -1) { | |||
| switch (opt) { | |||
| case 'b': | |||
| ms_backbone_file_ = std::string(optarg); | |||
| @@ -246,6 +256,9 @@ bool NetRunner::ReadArgs(int argc, char *argv[]) { | |||
| case 's': | |||
| save_checkpoint_ = atoi(optarg); | |||
| break; | |||
| case 'i': | |||
| save_inference_ = std::string(optarg); | |||
| break; | |||
| case 'h': | |||
| default: | |||
| Usage(); | |||
| @@ -38,7 +38,7 @@ class NetRunner { | |||
| int InitDB(); | |||
| int TrainLoop(); | |||
| std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, int serially = -1) const; | |||
| float CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) const; | |||
| float CalculateAccuracy(const std::vector<DataLabelTuple> &dataset, mindspore::session::LiteSession *session) const; | |||
| float GetLoss() const; | |||
| mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const; | |||
| @@ -50,6 +50,7 @@ class NetRunner { | |||
| std::string ms_backbone_file_ = ""; | |||
| std::string ms_head_file_ = ""; | |||
| std::string data_dir_ = ""; | |||
| std::string save_inference_ = ""; | |||
| size_t data_size_ = 0; | |||
| size_t batch_size_ = 0; | |||
| unsigned int cycles_ = 100; | |||
| @@ -135,6 +135,7 @@ class LiteSession : public session::LiteSession { | |||
| Model *model_ = nullptr; | |||
| std::atomic<bool> is_running_ = false; | |||
| bool is_train_session_ = false; | |||
| friend class TransferSession; | |||
| #if SUPPORT_NPU | |||
| NPUManager *npu_manager_ = nullptr; | |||
| NPUPassManager *npu_pass_manager_ = nullptr; | |||
| @@ -115,8 +115,8 @@ std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite | |||
| return tensorT; | |||
| } | |||
| mindspore::lite::Model::Node *TrainExport::FindNode(const mindspore::kernel::LiteKernel *kernel) { | |||
| auto nodes = model_->all_nodes_; | |||
| Model::Node *TrainExport::FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model) { | |||
| auto nodes = model->all_nodes_; | |||
| auto it = std::find_if(nodes.begin(), nodes.end(), | |||
| [&kernel](mindspore::lite::Model::Node *n) { return (kernel->name() == n->name_); }); | |||
| if (it == nodes.end()) { | |||
| @@ -127,14 +127,18 @@ mindspore::lite::Model::Node *TrainExport::FindNode(const mindspore::kernel::Lit | |||
| std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel::LiteKernel *kernel, | |||
| std::vector<uint32_t> inputIndex, | |||
| std::vector<uint32_t> outputIndex) { | |||
| std::vector<uint32_t> outputIndex, const Model *model) { | |||
| auto cnodeT = std::make_unique<schema::CNodeT>(); | |||
| if (cnodeT == nullptr) { | |||
| MS_LOG(ERROR) << " cannot allocate node"; | |||
| return nullptr; | |||
| } | |||
| cnodeT->inputIndex = inputIndex; | |||
| cnodeT->outputIndex = outputIndex; | |||
| cnodeT->name = kernel->name(); | |||
| cnodeT->quantType = GetNodeQuantType(kernel); | |||
| // find kernel in model | |||
| auto *node = FindNode(kernel); | |||
| auto *node = FindNode(kernel, model); | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "cannot find kernel " + kernel->name() + " in model"; | |||
| return nullptr; | |||
| @@ -144,28 +148,141 @@ std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel | |||
| return cnodeT; | |||
| } | |||
| int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kernels, | |||
| const std::vector<mindspore::lite::Tensor *> &tensors, | |||
| const std::vector<std::string> &output_names) { | |||
| std::map<size_t, size_t> remap; | |||
| int TrainExport::LoadModel(void *buf, size_t buf_size) { | |||
| flatbuffers::Verifier verify((const uint8_t *)buf, buf_size); | |||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||
| MS_LOG(ERROR) << "model flatbuffer verify fail"; | |||
| return RET_ERROR; | |||
| } | |||
| meta_graph_ = schema::GetMetaGraph(buf)->UnPack(); | |||
| meta_graph_->outputIndex.clear(); | |||
| return RET_OK; | |||
| } | |||
| std::unique_ptr<schema::TensorT> TrainExport::CreateTransformTensor(size_t id) { | |||
| auto &scTensor = meta_graph_->allTensors.at(id); | |||
| auto tensorT = std::make_unique<schema::TensorT>(); | |||
| if (tensorT == nullptr) { | |||
| MS_LOG(ERROR) << "Could not create tensor "; | |||
| return nullptr; | |||
| } | |||
| tensorT->nodeType = scTensor->nodeType; | |||
| tensorT->dataType = scTensor->dataType; | |||
| std::vector<int32_t> dims; | |||
| std::vector<int32_t> val = {0, 2, 3, 1}; | |||
| if (scTensor->dims.size() == 4) { | |||
| for (size_t i = 0; i < val.size(); i++) { | |||
| dims.push_back(scTensor->dims.at(val[i])); | |||
| } | |||
| tensorT->dims = dims; | |||
| } else { | |||
| tensorT->dims = scTensor->dims; | |||
| } | |||
| tensorT->format = schema::Format_NHWC; | |||
| tensorT->name = scTensor->name + "_post"; | |||
| tensorT->refCount = 0; | |||
| tensorT->offset = 0; | |||
| tensorT->enableHuffmanCode = false; | |||
| return tensorT; | |||
| } | |||
| std::unique_ptr<schema::TensorT> TrainExport::CreateTransformConst(size_t last_id) { | |||
| auto tensorT = std::make_unique<schema::TensorT>(); | |||
| if (tensorT == nullptr) { | |||
| MS_LOG(ERROR) << "Could not create tensor "; | |||
| return nullptr; | |||
| } | |||
| tensorT->nodeType = lite::NodeType_ValueNode; | |||
| tensorT->dataType = TypeId::kNumberTypeInt32; | |||
| tensorT->dims = {4}; | |||
| tensorT->format = schema::Format_NCHW; | |||
| tensorT->name = "const-" + std::to_string(last_id); | |||
| tensorT->refCount = 0; | |||
| tensorT->offset = 0; | |||
| tensorT->enableHuffmanCode = false; | |||
| int32_t val[] = {0, 2, 3, 1}; | |||
| uint8_t *valp = reinterpret_cast<uint8_t *>(val); | |||
| tensorT->data = std::vector<uint8_t>(valp, valp + sizeof(val)); | |||
| return tensorT; | |||
| } | |||
| std::unique_ptr<schema::CNodeT> TrainExport::CreateTransformNode(std::vector<uint32_t> inputIndex, | |||
| std::vector<uint32_t> outputIndex, size_t id) { | |||
| auto cnodeT = std::make_unique<schema::CNodeT>(); | |||
| if (cnodeT == nullptr) { | |||
| MS_LOG(ERROR) << "cannot allocate node"; | |||
| return nullptr; | |||
| } | |||
| cnodeT->inputIndex = inputIndex; | |||
| cnodeT->outputIndex = outputIndex; | |||
| cnodeT->name = "transpose-" + std::to_string(id); | |||
| cnodeT->quantType = schema::QuantType_QUANT_NONE; | |||
| cnodeT->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| cnodeT->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| return cnodeT; | |||
| } | |||
| int TrainExport::AddTransformNode() { | |||
| std::unordered_map<size_t, size_t> reconnect; | |||
| size_t last_id = meta_graph_->allTensors.size(); | |||
| size_t last_node = meta_graph_->nodes.size(); | |||
| for (auto it : connect_) { | |||
| auto tensorConst = CreateTransformConst(last_id); | |||
| if (tensorConst == nullptr) { | |||
| MS_LOG(ERROR) << "error in create tensor"; | |||
| return RET_ERROR; | |||
| } | |||
| meta_graph_->allTensors.emplace_back(std::move(tensorConst)); // last_id | |||
| auto tensorT = CreateTransformTensor(it.second); | |||
| if (tensorT == nullptr) { | |||
| MS_LOG(ERROR) << "error in create tensor"; | |||
| return RET_ERROR; | |||
| } | |||
| meta_graph_->allTensors.emplace_back(std::move(tensorT)); // last_id + 1 | |||
| std::vector<uint32_t> in_idx = {static_cast<uint32_t>(it.second), static_cast<uint32_t>(last_id)}; | |||
| std::vector<uint32_t> out_idx = {static_cast<uint32_t>(last_id + 1)}; | |||
| reconnect[it.first] = last_id + 1; | |||
| auto cnode = CreateTransformNode(in_idx, out_idx, last_node); | |||
| if (cnode == nullptr) { | |||
| MS_LOG(ERROR) << "error in node creation"; | |||
| return RET_ERROR; | |||
| } | |||
| meta_graph_->nodes.emplace_back(std::move(cnode)); | |||
| } | |||
| connect_ = reconnect; | |||
| return RET_OK; | |||
| } | |||
| int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels, | |||
| const std::vector<mindspore::lite::Tensor *> &tensors, | |||
| const std::vector<std::string> &output_names, const Model *model) { | |||
| std::vector<size_t> map_index; | |||
| std::set<size_t> out_set; | |||
| int tensor_idx = 0; | |||
| auto meta_graph = std::make_unique<schema::MetaGraphT>(); | |||
| meta_graph->fmkType = 3; | |||
| meta_graph->name = model_->name_; | |||
| meta_graph->version = model_->version_; | |||
| int offset = meta_graph_->allTensors.size(); | |||
| int tensor_idx = offset; | |||
| if (meta_graph_ == nullptr) { | |||
| int status = ExportInit(model->name_, model->version_); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } | |||
| // prepare mapping for connection | |||
| for (auto it : connect_) { | |||
| remap_[it.first + offset] = it.second; | |||
| } | |||
| for (const auto kernel : kernels) { | |||
| std::vector<uint32_t> in_idx, out_idx; | |||
| for (const auto tensor : kernel->in_tensors()) { | |||
| size_t id = TSFindTensor(tensors, tensor); | |||
| size_t id = TSFindTensor(tensors, tensor) + offset; | |||
| if (id == tensors.size()) { | |||
| MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model"; | |||
| return RET_ERROR; | |||
| } | |||
| auto it = remap.find(id); | |||
| if (it == remap.end()) { | |||
| remap[id] = tensor_idx; | |||
| auto it = remap_.find(id); | |||
| if (it == remap_.end()) { | |||
| remap_[id] = tensor_idx; | |||
| in_idx.push_back(tensor_idx); | |||
| map_index.push_back(id); | |||
| tensor_idx++; | |||
| @@ -174,14 +291,14 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern | |||
| } | |||
| } | |||
| for (const auto tensor : kernel->out_tensors()) { | |||
| size_t id = TSFindTensor(tensors, tensor); | |||
| size_t id = TSFindTensor(tensors, tensor) + offset; | |||
| if (id == tensors.size()) { | |||
| MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model"; | |||
| return RET_ERROR; | |||
| } | |||
| auto it = remap.find(id); | |||
| if (it == remap.end()) { | |||
| remap[id] = tensor_idx; | |||
| auto it = remap_.find(id); | |||
| if (it == remap_.end()) { | |||
| remap_[id] = tensor_idx; | |||
| map_index.push_back(id); | |||
| out_idx.push_back(tensor_idx); | |||
| out_set.insert(tensor_idx); | |||
| @@ -191,33 +308,51 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern | |||
| out_set.insert(it->second); | |||
| } | |||
| } | |||
| auto cnode = CreateCNode(kernel, in_idx, out_idx); | |||
| meta_graph->nodes.emplace_back(std::move(cnode)); | |||
| auto cnode = CreateCNode(kernel, in_idx, out_idx, model); | |||
| if (cnode == nullptr) { | |||
| MS_LOG(ERROR) << "failed to create cnode"; | |||
| return RET_ERROR; | |||
| } | |||
| meta_graph_->nodes.emplace_back(std::move(cnode)); | |||
| } | |||
| for (auto id : map_index) { | |||
| mindspore::lite::Tensor *tensor = tensors.at(id); | |||
| schema::Tensor *scTensor = model_->all_tensors_.at(id); | |||
| size_t pid = id - offset; | |||
| mindspore::lite::Tensor *tensor = tensors.at(pid); | |||
| schema::Tensor *scTensor = model->all_tensors_.at(pid); | |||
| auto tensorT = CreateTensor(tensor, scTensor); | |||
| // find a tensor which is not an output | |||
| if (out_set.find(remap[id]) == out_set.end()) { | |||
| if (tensorT == nullptr) { | |||
| MS_LOG(ERROR) << "error in tensor creation"; | |||
| return RET_ERROR; | |||
| } | |||
| if (out_set.find(remap_[id]) == out_set.end()) { | |||
| if ((tensorT->nodeType == NodeType_ValueNode) && (tensorT->data.size() == 0)) { | |||
| meta_graph->inputIndex.push_back(remap[id]); | |||
| meta_graph_->inputIndex.push_back(remap_[id]); | |||
| } | |||
| } | |||
| // find output tensor | |||
| if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) { | |||
| meta_graph->outputIndex.push_back(remap[id]); | |||
| meta_graph_->outputIndex.push_back(remap_[id]); | |||
| } | |||
| meta_graph->allTensors.emplace_back(std::move(tensorT)); | |||
| meta_graph_->allTensors.emplace_back(std::move(tensorT)); | |||
| } | |||
| auto graph = meta_graph.release(); | |||
| int err = Storage::Save(*graph, file_name_); | |||
| if (err != RET_OK) { | |||
| MS_LOG(ERROR) << "failed to save flatbuffer file " << file_name_; | |||
| return RET_OK; | |||
| } | |||
| int TrainExport::ExportInit(const std::string model_name, std::string version) { | |||
| meta_graph_ = new (std::nothrow) schema::MetaGraphT(); | |||
| if (meta_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "cannot allocate meta_graph"; | |||
| return RET_ERROR; | |||
| } | |||
| delete graph; | |||
| return err; | |||
| meta_graph_->fmkType = 3; | |||
| meta_graph_->name = model_name; | |||
| meta_graph_->version = version; | |||
| return RET_OK; | |||
| } | |||
| int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); } | |||
| TrainExport::~TrainExport() { delete meta_graph_; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/lite_model.h" | |||
| @@ -34,23 +36,36 @@ namespace lite { | |||
| class TrainExport { | |||
| public: | |||
| TrainExport(const std::string file_name, const mindspore::lite::Model *model) | |||
| : model_(model), file_name_(file_name) {} | |||
| virtual ~TrainExport() {} | |||
| int Export(const std::vector<mindspore::kernel::LiteKernel *> &kernels, | |||
| const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names); | |||
| explicit TrainExport(const std::string file_name) : file_name_(file_name) {} | |||
| virtual ~TrainExport(); | |||
| int ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels, | |||
| const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names, | |||
| const Model *model); | |||
| int ExportInit(const std::string model_name, std::string version); | |||
| int SaveToFile(); | |||
| void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } | |||
| int LoadModel(void *buf, size_t buf_size); | |||
| int AddTransformNode(); | |||
| protected: | |||
| virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor); | |||
| private: | |||
| const Model *model_; | |||
| std::string file_name_; | |||
| mindspore::lite::Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel); | |||
| std::unique_ptr<schema::TensorT> CreateTensor(const mindspore::lite::Tensor *tensor, schema::Tensor *scTensor); | |||
| schema::MetaGraphT *meta_graph_ = nullptr; | |||
| std::vector<size_t> out_idx_; | |||
| std::map<size_t, size_t> remap_; | |||
| std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id) | |||
| Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model); | |||
| std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor); | |||
| std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::LiteKernel *kernel, | |||
| std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex); | |||
| std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex, | |||
| const Model *model); | |||
| std::unique_ptr<schema::CNodeT> CreateTransformNode(std::vector<uint32_t> inputIndex, | |||
| std::vector<uint32_t> outputIndex, size_t id); | |||
| std::unique_ptr<schema::TensorT> CreateTransformTensor(size_t id); | |||
| std::unique_ptr<schema::TensorT> CreateTransformConst(size_t last_id); | |||
| int AddTransform(); | |||
| bool NeedQuantization(const mindspore::lite::Tensor *tensor); | |||
| virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor); | |||
| mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::LiteKernel *kernel); | |||
| @@ -457,8 +457,22 @@ int TrainSession::SetLossName(std::string loss_name) { | |||
| int TrainSession::ExportInference(std::string file_name) { | |||
| bool orig_train_state = IsTrain(); | |||
| Eval(); | |||
| TrainExport texport(file_name, model_); | |||
| int status = texport.Export(inference_kernels_, tensors_, GetOutputTensorNames()); | |||
| TrainExport texport(file_name); | |||
| int status = texport.ExportInit(model_->name_, model_->version_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "cannot init export"; | |||
| return status; | |||
| } | |||
| status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "cannot export Network"; | |||
| return status; | |||
| } | |||
| status = texport.SaveToFile(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "failed to save to " << file_name; | |||
| return status; | |||
| } | |||
| if (orig_train_state) Train(); | |||
| return status; | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <tuple> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <map> | |||
| #include "include/train/train_session.h" | |||
| #include "src/lite_session.h" | |||
| @@ -125,6 +126,14 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: | |||
| void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector<kernel::LiteKernel *> *req_kernels); | |||
| int AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum); | |||
| int OptimizerStep(); | |||
| int ExecKernels(const KernelCallBack &before, const KernelCallBack &after, | |||
| std::vector<kernel::LiteKernel *> run_kernel); | |||
| int MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after, | |||
| std::vector<kernel::LiteKernel *> run_kernel); | |||
| int CopyTensor(Tensor *tensor, TypeId dst_data_type); | |||
| void RestoreTensorData(); | |||
| void FreeRestoreTensors(); | |||
| std::map<Tensor *, Tensor *> restored_origin_tensors_; | |||
| int virtual_batch_idx_ = 0; | |||
| int virtual_batch_multiplier_ = 0; | |||
| }; | |||
| @@ -33,6 +33,15 @@ size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor | |||
| return where.size(); | |||
| } | |||
| size_t TSFindTensorByName(const std::vector<lite::Tensor *> &where, const std::string &searchParameter) { | |||
| for (size_t i = 0; i < where.size(); i++) { | |||
| if (where[i]->tensor_name() == searchParameter) { | |||
| return i; | |||
| } | |||
| } | |||
| return where.size(); | |||
| } | |||
| kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter) { | |||
| auto it = std::find_if(where.begin(), where.end(), | |||
| [&searchParameter](const kernel::LiteKernel *k) { return (k->name() == searchParameter); }); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <string> | |||
| #include "include/ms_tensor.h" | |||
| #include "src/tensor.h" | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -27,10 +28,11 @@ class LiteKernel; | |||
| } | |||
| namespace lite { | |||
| kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter); | |||
| size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter); | |||
| size_t TSFindTensorByName(const std::vector<lite::Tensor *> &where, const std::string &searchParameter); | |||
| kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter); | |||
| size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter); | |||
| float CalculateSparseClassification(tensor::MSTensor *input, tensor::MSTensor *output); | |||
| float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *output); | |||
| @@ -34,6 +34,8 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/convolution.h" | |||
| #include "nnacl/fp32/pack_fp32.h" | |||
| #include "src/train/train_export.h" | |||
| #include "src/train/train_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -41,6 +43,7 @@ namespace lite { | |||
| TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, lite::Context *context) | |||
| : is_valid_(false) { | |||
| lite_model_ = reinterpret_cast<char *>(malloc(size_backbone)); | |||
| size_backbone_ = size_backbone; | |||
| if (lite_model_ != nullptr) { | |||
| std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_); | |||
| backbone_session_ = | |||
| @@ -154,6 +157,60 @@ int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack | |||
| return ret; | |||
| } | |||
| std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() { | |||
| std::unordered_map<size_t, size_t> map; | |||
| for (auto &backbone_head_pair : backbone_head_map_) { | |||
| auto input = backbone_head_pair.first; | |||
| auto output = backbone_head_pair.second; | |||
| auto in_id = TSFindTensorByName(tensors_, input->tensor_name()); | |||
| if (in_id == tensors_.size()) { | |||
| MS_LOG(ERROR) << "cannot find input tensor " << input->tensor_name(); | |||
| map.clear(); | |||
| return map; | |||
| } | |||
| auto out_id = TSFindTensorByName(backbone_session_->tensors_, output->tensor_name()); | |||
| if (out_id == backbone_session_->tensors_.size()) { | |||
| MS_LOG(ERROR) << "cannot find input tensor " << output->tensor_name(); | |||
| map.clear(); | |||
| return map; | |||
| } | |||
| map[in_id] = out_id; | |||
| } | |||
| return map; | |||
| } | |||
| int TransferSession::ExportInference(std::string file_name) { | |||
| bool orig_train_state = IsTrain(); | |||
| Eval(); | |||
| TrainExport texport(file_name); | |||
| int status = texport.LoadModel(lite_model_, size_backbone_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "cannot init export"; | |||
| return status; | |||
| } | |||
| auto connect_map = ConnectionMap(); | |||
| texport.set_connect(connect_map); | |||
| if (nchw2nhwc_) { | |||
| status = texport.AddTransformNode(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "cannot add transform node"; | |||
| return status; | |||
| } | |||
| } | |||
| status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "cannot serialize head"; | |||
| return status; | |||
| } | |||
| status = texport.SaveToFile(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "failed to save to " << file_name; | |||
| return status; | |||
| } | |||
| if (orig_train_state) Train(); | |||
| return status; | |||
| } | |||
| } // namespace lite | |||
| session::TrainSession *session::TrainSession::CreateTransferSession(const char *model_buf_backbone, | |||
| @@ -61,6 +61,7 @@ class TransferSession : public lite::TrainSession { | |||
| mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override; | |||
| int CompileTransferGraph(); | |||
| int ExportInference(std::string file_name) override; | |||
| protected: | |||
| lite::LiteSession *backbone_session_ = nullptr; | |||
| @@ -71,7 +72,9 @@ class TransferSession : public lite::TrainSession { | |||
| private: | |||
| bool CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask, size_t mask_len); | |||
| std::unordered_map<size_t, size_t> ConnectionMap(); | |||
| bool nchw2nhwc_ = false; | |||
| size_t size_backbone_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||