Browse Source

save transfer learning inference file

tags/v1.3.0
yoni 5 years ago
parent
commit
c568969a9e
12 changed files with 320 additions and 61 deletions
  1. +1
    -1
      mindspore/lite/examples/export_models/models/mobilenetv3_train_export.py
  2. +22
    -9
      mindspore/lite/examples/transfer_learning/src/net_runner.cc
  3. +2
    -1
      mindspore/lite/examples/transfer_learning/src/net_runner.h
  4. +1
    -0
      mindspore/lite/src/lite_session.h
  5. +171
    -36
      mindspore/lite/src/train/train_export.cc
  6. +25
    -10
      mindspore/lite/src/train/train_export.h
  7. +16
    -2
      mindspore/lite/src/train/train_session.cc
  8. +9
    -0
      mindspore/lite/src/train/train_session.h
  9. +9
    -0
      mindspore/lite/src/train/train_utils.cc
  10. +4
    -2
      mindspore/lite/src/train/train_utils.h
  11. +57
    -0
      mindspore/lite/src/train/transfer_session.cc
  12. +3
    -0
      mindspore/lite/src/train/transfer_session.h

+ 1
- 1
mindspore/lite/examples/export_models/models/mobilenetv3_train_export.py View File

@@ -26,7 +26,7 @@ context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs

n = mobilenet_v3_small(num_classes=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean')
optimizer = nn.Adam(n.trainable_params(), learning_rate=1e-2, beta1=0.5, beta2=0.7, eps=1e-2, use_locking=True,
optimizer = nn.Adam(n.trainable_params(), learning_rate=1e-3, beta1=0.5, beta2=0.7, eps=1e-2, use_locking=True,
use_nesterov=False, weight_decay=0.1, loss_scale=0.3)
net = TrainWrap(n, loss_fn, optimizer)



+ 22
- 9
mindspore/lite/examples/transfer_learning/src/net_runner.cc View File

@@ -22,6 +22,7 @@
#include <fstream>
#include <iostream>
#include "include/context.h"
#include "include/lite_session.h"
#include "src/utils.h"

static unsigned int seed = time(NULL);
@@ -122,14 +123,14 @@ std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dat
return labels_vec;
}

float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) const {
float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset,
mindspore::session::LiteSession *session) const {
float accuracy = 0.0;
int tests = dataset.size() / batch_size_;

session_->Eval();
for (int i = 0; i < tests; i++) {
auto labels = FillInputData(dataset, i);
session_->RunGraph();
session->RunGraph();
auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_);
MS_ASSERT(outputsv != nullptr);
auto scores = reinterpret_cast<float *>(outputsv->MutableData());
@@ -145,7 +146,6 @@ float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) c
if (labels[b] == max_idx) accuracy += 1.0;
}
}
session_->Train();
accuracy /= static_cast<float>(batch_size_ * tests);
return accuracy;
}
@@ -192,7 +192,9 @@ int NetRunner::TrainLoop() {

std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;
if ((i + 1) % 20 == 0) {
float acc = CalculateAccuracy(ds_.test_data());
session_->Eval();
float acc = CalculateAccuracy(ds_.test_data(), session_);
session_->Train();
if (max_acc < acc) max_acc = acc;
std::cout << "accuracy on test data = " << acc << " max accuracy = " << max_acc << std::endl;
if (acc > 0.9) return 0;
@@ -207,26 +209,34 @@ int NetRunner::Main() {
InitDB();

TrainLoop();
float acc = CalculateAccuracy(ds_.val_data());
session_->Eval();
float acc = CalculateAccuracy(ds_.val_data(), session_);
std::cout << "accuracy on validation data = " << acc << std::endl;

if (cycles_ > 0 && head_model_ != nullptr) {
auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms";
mindspore::lite::Model::Export(head_model_, trained_fn.c_str());
}
if (!save_inference_.empty()) {
int status = session_->ExportInference(save_inference_);
if (status != mindspore::lite::RET_OK) {
std::cout << "Failed to save inference file";
return mindspore::lite::RET_ERROR;
}
}
return 0;
}

void NetRunner::Usage() {
std::cout << "Usage: net_runner -f <.ms head model file> -b <.ms backbone model file> -d <data_dir> "
<< "[-c <num of training cycles>] [-v (verbose mode)] "
<< "[-s <save checkpoint every X iterations>]" << std::endl;
<< "[-s <save checkpoint every X iterations>]"
<< "[-i <save inference file>]" << std::endl;
}

bool NetRunner::ReadArgs(int argc, char *argv[]) {
int opt;
while ((opt = getopt(argc, argv, "b:f:e:d:s:ihc:v")) != -1) {
while ((opt = getopt(argc, argv, "b:f:e:d:s:i:hc:v")) != -1) {
switch (opt) {
case 'b':
ms_backbone_file_ = std::string(optarg);
@@ -246,6 +256,9 @@ bool NetRunner::ReadArgs(int argc, char *argv[]) {
case 's':
save_checkpoint_ = atoi(optarg);
break;
case 'i':
save_inference_ = std::string(optarg);
break;
case 'h':
default:
Usage();


+ 2
- 1
mindspore/lite/examples/transfer_learning/src/net_runner.h View File

@@ -38,7 +38,7 @@ class NetRunner {
int InitDB();
int TrainLoop();
std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, int serially = -1) const;
float CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) const;
float CalculateAccuracy(const std::vector<DataLabelTuple> &dataset, mindspore::session::LiteSession *session) const;
float GetLoss() const;
mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const;

@@ -50,6 +50,7 @@ class NetRunner {
std::string ms_backbone_file_ = "";
std::string ms_head_file_ = "";
std::string data_dir_ = "";
std::string save_inference_ = "";
size_t data_size_ = 0;
size_t batch_size_ = 0;
unsigned int cycles_ = 100;


+ 1
- 0
mindspore/lite/src/lite_session.h View File

@@ -135,6 +135,7 @@ class LiteSession : public session::LiteSession {
Model *model_ = nullptr;
std::atomic<bool> is_running_ = false;
bool is_train_session_ = false;
friend class TransferSession;
#if SUPPORT_NPU
NPUManager *npu_manager_ = nullptr;
NPUPassManager *npu_pass_manager_ = nullptr;


+ 171
- 36
mindspore/lite/src/train/train_export.cc View File

@@ -115,8 +115,8 @@ std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite
return tensorT;
}

mindspore::lite::Model::Node *TrainExport::FindNode(const mindspore::kernel::LiteKernel *kernel) {
auto nodes = model_->all_nodes_;
Model::Node *TrainExport::FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model) {
auto nodes = model->all_nodes_;
auto it = std::find_if(nodes.begin(), nodes.end(),
[&kernel](mindspore::lite::Model::Node *n) { return (kernel->name() == n->name_); });
if (it == nodes.end()) {
@@ -127,14 +127,18 @@ mindspore::lite::Model::Node *TrainExport::FindNode(const mindspore::kernel::Lit

std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel::LiteKernel *kernel,
std::vector<uint32_t> inputIndex,
std::vector<uint32_t> outputIndex) {
std::vector<uint32_t> outputIndex, const Model *model) {
auto cnodeT = std::make_unique<schema::CNodeT>();
if (cnodeT == nullptr) {
MS_LOG(ERROR) << " cannot allocate node";
return nullptr;
}
cnodeT->inputIndex = inputIndex;
cnodeT->outputIndex = outputIndex;
cnodeT->name = kernel->name();
cnodeT->quantType = GetNodeQuantType(kernel);
// find kernel in model
auto *node = FindNode(kernel);
auto *node = FindNode(kernel, model);
if (node == nullptr) {
MS_LOG(ERROR) << "cannot find kernel " + kernel->name() + " in model";
return nullptr;
@@ -144,28 +148,141 @@ std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel
return cnodeT;
}

int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
const std::vector<mindspore::lite::Tensor *> &tensors,
const std::vector<std::string> &output_names) {
std::map<size_t, size_t> remap;
int TrainExport::LoadModel(void *buf, size_t buf_size) {
flatbuffers::Verifier verify((const uint8_t *)buf, buf_size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "model flatbuffer verify fail";
return RET_ERROR;
}
meta_graph_ = schema::GetMetaGraph(buf)->UnPack();
meta_graph_->outputIndex.clear();
return RET_OK;
}

std::unique_ptr<schema::TensorT> TrainExport::CreateTransformTensor(size_t id) {
auto &scTensor = meta_graph_->allTensors.at(id);
auto tensorT = std::make_unique<schema::TensorT>();
if (tensorT == nullptr) {
MS_LOG(ERROR) << "Could not create tensor ";
return nullptr;
}
tensorT->nodeType = scTensor->nodeType;
tensorT->dataType = scTensor->dataType;
std::vector<int32_t> dims;
std::vector<int32_t> val = {0, 2, 3, 1};
if (scTensor->dims.size() == 4) {
for (size_t i = 0; i < val.size(); i++) {
dims.push_back(scTensor->dims.at(val[i]));
}
tensorT->dims = dims;
} else {
tensorT->dims = scTensor->dims;
}
tensorT->format = schema::Format_NHWC;
tensorT->name = scTensor->name + "_post";
tensorT->refCount = 0;
tensorT->offset = 0;
tensorT->enableHuffmanCode = false;
return tensorT;
}

std::unique_ptr<schema::TensorT> TrainExport::CreateTransformConst(size_t last_id) {
auto tensorT = std::make_unique<schema::TensorT>();
if (tensorT == nullptr) {
MS_LOG(ERROR) << "Could not create tensor ";
return nullptr;
}
tensorT->nodeType = lite::NodeType_ValueNode;
tensorT->dataType = TypeId::kNumberTypeInt32;
tensorT->dims = {4};
tensorT->format = schema::Format_NCHW;
tensorT->name = "const-" + std::to_string(last_id);
tensorT->refCount = 0;
tensorT->offset = 0;
tensorT->enableHuffmanCode = false;
int32_t val[] = {0, 2, 3, 1};
uint8_t *valp = reinterpret_cast<uint8_t *>(val);
tensorT->data = std::vector<uint8_t>(valp, valp + sizeof(val));
return tensorT;
}

std::unique_ptr<schema::CNodeT> TrainExport::CreateTransformNode(std::vector<uint32_t> inputIndex,
std::vector<uint32_t> outputIndex, size_t id) {
auto cnodeT = std::make_unique<schema::CNodeT>();
if (cnodeT == nullptr) {
MS_LOG(ERROR) << "cannot allocate node";
return nullptr;
}
cnodeT->inputIndex = inputIndex;
cnodeT->outputIndex = outputIndex;
cnodeT->name = "transpose-" + std::to_string(id);
cnodeT->quantType = schema::QuantType_QUANT_NONE;
cnodeT->primitive = std::make_unique<schema::PrimitiveT>();
cnodeT->primitive->value.type = schema::PrimitiveType_Transpose;
return cnodeT;
}

int TrainExport::AddTransformNode() {
std::unordered_map<size_t, size_t> reconnect;
size_t last_id = meta_graph_->allTensors.size();
size_t last_node = meta_graph_->nodes.size();
for (auto it : connect_) {
auto tensorConst = CreateTransformConst(last_id);
if (tensorConst == nullptr) {
MS_LOG(ERROR) << "error in create tensor";
return RET_ERROR;
}
meta_graph_->allTensors.emplace_back(std::move(tensorConst)); // last_id
auto tensorT = CreateTransformTensor(it.second);
if (tensorT == nullptr) {
MS_LOG(ERROR) << "error in create tensor";
return RET_ERROR;
}
meta_graph_->allTensors.emplace_back(std::move(tensorT)); // last_id + 1
std::vector<uint32_t> in_idx = {static_cast<uint32_t>(it.second), static_cast<uint32_t>(last_id)};
std::vector<uint32_t> out_idx = {static_cast<uint32_t>(last_id + 1)};
reconnect[it.first] = last_id + 1;
auto cnode = CreateTransformNode(in_idx, out_idx, last_node);
if (cnode == nullptr) {
MS_LOG(ERROR) << "error in node creation";
return RET_ERROR;
}
meta_graph_->nodes.emplace_back(std::move(cnode));
}
connect_ = reconnect;
return RET_OK;
}

int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
const std::vector<mindspore::lite::Tensor *> &tensors,
const std::vector<std::string> &output_names, const Model *model) {
std::vector<size_t> map_index;
std::set<size_t> out_set;
int tensor_idx = 0;
auto meta_graph = std::make_unique<schema::MetaGraphT>();
meta_graph->fmkType = 3;
meta_graph->name = model_->name_;
meta_graph->version = model_->version_;
int offset = meta_graph_->allTensors.size();
int tensor_idx = offset;

if (meta_graph_ == nullptr) {
int status = ExportInit(model->name_, model->version_);
if (status != RET_OK) {
return status;
}
}
// prepare mapping for connection
for (auto it : connect_) {
remap_[it.first + offset] = it.second;
}

for (const auto kernel : kernels) {
std::vector<uint32_t> in_idx, out_idx;
for (const auto tensor : kernel->in_tensors()) {
size_t id = TSFindTensor(tensors, tensor);
size_t id = TSFindTensor(tensors, tensor) + offset;
if (id == tensors.size()) {
MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
return RET_ERROR;
}
auto it = remap.find(id);
if (it == remap.end()) {
remap[id] = tensor_idx;
auto it = remap_.find(id);
if (it == remap_.end()) {
remap_[id] = tensor_idx;
in_idx.push_back(tensor_idx);
map_index.push_back(id);
tensor_idx++;
@@ -174,14 +291,14 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern
}
}
for (const auto tensor : kernel->out_tensors()) {
size_t id = TSFindTensor(tensors, tensor);
size_t id = TSFindTensor(tensors, tensor) + offset;
if (id == tensors.size()) {
MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
return RET_ERROR;
}
auto it = remap.find(id);
if (it == remap.end()) {
remap[id] = tensor_idx;
auto it = remap_.find(id);
if (it == remap_.end()) {
remap_[id] = tensor_idx;
map_index.push_back(id);
out_idx.push_back(tensor_idx);
out_set.insert(tensor_idx);
@@ -191,33 +308,51 @@ int TrainExport::Export(const std::vector<mindspore::kernel::LiteKernel *> &kern
out_set.insert(it->second);
}
}
auto cnode = CreateCNode(kernel, in_idx, out_idx);
meta_graph->nodes.emplace_back(std::move(cnode));
auto cnode = CreateCNode(kernel, in_idx, out_idx, model);
if (cnode == nullptr) {
MS_LOG(ERROR) << "failed to create cnode";
return RET_ERROR;
}
meta_graph_->nodes.emplace_back(std::move(cnode));
}
for (auto id : map_index) {
mindspore::lite::Tensor *tensor = tensors.at(id);
schema::Tensor *scTensor = model_->all_tensors_.at(id);
size_t pid = id - offset;
mindspore::lite::Tensor *tensor = tensors.at(pid);
schema::Tensor *scTensor = model->all_tensors_.at(pid);
auto tensorT = CreateTensor(tensor, scTensor);
// find a tensor which is not an output
if (out_set.find(remap[id]) == out_set.end()) {
if (tensorT == nullptr) {
MS_LOG(ERROR) << "error in tensor creation";
return RET_ERROR;
}
if (out_set.find(remap_[id]) == out_set.end()) {
if ((tensorT->nodeType == NodeType_ValueNode) && (tensorT->data.size() == 0)) {
meta_graph->inputIndex.push_back(remap[id]);
meta_graph_->inputIndex.push_back(remap_[id]);
}
}
// find output tensor
if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) {
meta_graph->outputIndex.push_back(remap[id]);
meta_graph_->outputIndex.push_back(remap_[id]);
}
meta_graph->allTensors.emplace_back(std::move(tensorT));
meta_graph_->allTensors.emplace_back(std::move(tensorT));
}
auto graph = meta_graph.release();
int err = Storage::Save(*graph, file_name_);
if (err != RET_OK) {
MS_LOG(ERROR) << "failed to save flatbuffer file " << file_name_;
return RET_OK;
}

int TrainExport::ExportInit(const std::string model_name, std::string version) {
meta_graph_ = new (std::nothrow) schema::MetaGraphT();
if (meta_graph_ == nullptr) {
MS_LOG(ERROR) << "cannot allocate meta_graph";
return RET_ERROR;
}
delete graph;
return err;
meta_graph_->fmkType = 3;
meta_graph_->name = model_name;
meta_graph_->version = version;
return RET_OK;
}

int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }

TrainExport::~TrainExport() { delete meta_graph_; }

} // namespace lite
} // namespace mindspore

+ 25
- 10
mindspore/lite/src/train/train_export.h View File

@@ -18,6 +18,8 @@
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <unordered_map>
#include "schema/inner/model_generated.h"
#include "src/lite_kernel.h"
#include "src/lite_model.h"
@@ -34,23 +36,36 @@ namespace lite {

class TrainExport {
public:
TrainExport(const std::string file_name, const mindspore::lite::Model *model)
: model_(model), file_name_(file_name) {}
virtual ~TrainExport() {}
int Export(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names);
explicit TrainExport(const std::string file_name) : file_name_(file_name) {}
virtual ~TrainExport();
int ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names,
const Model *model);
int ExportInit(const std::string model_name, std::string version);
int SaveToFile();
void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
int LoadModel(void *buf, size_t buf_size);
int AddTransformNode();

protected:
virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);

private:
const Model *model_;
std::string file_name_;
mindspore::lite::Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel);
std::unique_ptr<schema::TensorT> CreateTensor(const mindspore::lite::Tensor *tensor, schema::Tensor *scTensor);
schema::MetaGraphT *meta_graph_ = nullptr;
std::vector<size_t> out_idx_;
std::map<size_t, size_t> remap_;
std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id)
Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model);
std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor);
std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::LiteKernel *kernel,
std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex);

std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex,
const Model *model);
std::unique_ptr<schema::CNodeT> CreateTransformNode(std::vector<uint32_t> inputIndex,
std::vector<uint32_t> outputIndex, size_t id);
std::unique_ptr<schema::TensorT> CreateTransformTensor(size_t id);
std::unique_ptr<schema::TensorT> CreateTransformConst(size_t last_id);
int AddTransform();
bool NeedQuantization(const mindspore::lite::Tensor *tensor);
virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor);
mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::LiteKernel *kernel);


+ 16
- 2
mindspore/lite/src/train/train_session.cc View File

@@ -457,8 +457,22 @@ int TrainSession::SetLossName(std::string loss_name) {
int TrainSession::ExportInference(std::string file_name) {
bool orig_train_state = IsTrain();
Eval();
TrainExport texport(file_name, model_);
int status = texport.Export(inference_kernels_, tensors_, GetOutputTensorNames());
TrainExport texport(file_name);
int status = texport.ExportInit(model_->name_, model_->version_);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot init export";
return status;
}
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot export Network";
return status;
}
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << file_name;
return status;
}
if (orig_train_state) Train();
return status;
}


+ 9
- 0
mindspore/lite/src/train/train_session.h View File

@@ -20,6 +20,7 @@
#include <tuple>
#include <unordered_map>
#include <memory>
#include <map>
#include "include/train/train_session.h"
#include "src/lite_session.h"

@@ -125,6 +126,14 @@ class TrainSession : virtual public session::TrainSession, virtual public lite::
void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector<kernel::LiteKernel *> *req_kernels);
int AdminSetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum);
int OptimizerStep();
int ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
std::vector<kernel::LiteKernel *> run_kernel);
int MixPrecisionExecKernels(const KernelCallBack &before, const KernelCallBack &after,
std::vector<kernel::LiteKernel *> run_kernel);
int CopyTensor(Tensor *tensor, TypeId dst_data_type);
void RestoreTensorData();
void FreeRestoreTensors();
std::map<Tensor *, Tensor *> restored_origin_tensors_;
int virtual_batch_idx_ = 0;
int virtual_batch_multiplier_ = 0;
};


+ 9
- 0
mindspore/lite/src/train/train_utils.cc View File

@@ -33,6 +33,15 @@ size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor
return where.size();
}

size_t TSFindTensorByName(const std::vector<lite::Tensor *> &where, const std::string &searchParameter) {
for (size_t i = 0; i < where.size(); i++) {
if (where[i]->tensor_name() == searchParameter) {
return i;
}
}
return where.size();
}

kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter) {
auto it = std::find_if(where.begin(), where.end(),
[&searchParameter](const kernel::LiteKernel *k) { return (k->name() == searchParameter); });


+ 4
- 2
mindspore/lite/src/train/train_utils.h View File

@@ -20,6 +20,7 @@
#include <string>
#include "include/ms_tensor.h"
#include "src/tensor.h"
#include "src/lite_kernel.h"

namespace mindspore {
namespace kernel {
@@ -27,10 +28,11 @@ class LiteKernel;
}

namespace lite {

kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter);
size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter);

size_t TSFindTensorByName(const std::vector<lite::Tensor *> &where, const std::string &searchParameter);
kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter);
size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter);
float CalculateSparseClassification(tensor::MSTensor *input, tensor::MSTensor *output);
float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *output);



+ 57
- 0
mindspore/lite/src/train/transfer_session.cc View File

@@ -34,6 +34,8 @@
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/convolution.h"
#include "nnacl/fp32/pack_fp32.h"
#include "src/train/train_export.h"
#include "src/train/train_utils.h"

namespace mindspore {
namespace lite {
@@ -41,6 +43,7 @@ namespace lite {
TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, lite::Context *context)
: is_valid_(false) {
lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
size_backbone_ = size_backbone;
if (lite_model_ != nullptr) {
std::copy(model_buf_backbone, model_buf_backbone + size_backbone, lite_model_);
backbone_session_ =
@@ -154,6 +157,60 @@ int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack
return ret;
}

std::unordered_map<size_t, size_t> TransferSession::ConnectionMap() {
std::unordered_map<size_t, size_t> map;
for (auto &backbone_head_pair : backbone_head_map_) {
auto input = backbone_head_pair.first;
auto output = backbone_head_pair.second;
auto in_id = TSFindTensorByName(tensors_, input->tensor_name());
if (in_id == tensors_.size()) {
MS_LOG(ERROR) << "cannot find input tensor " << input->tensor_name();
map.clear();
return map;
}
auto out_id = TSFindTensorByName(backbone_session_->tensors_, output->tensor_name());
if (out_id == backbone_session_->tensors_.size()) {
MS_LOG(ERROR) << "cannot find input tensor " << output->tensor_name();
map.clear();
return map;
}
map[in_id] = out_id;
}
return map;
}

int TransferSession::ExportInference(std::string file_name) {
bool orig_train_state = IsTrain();
Eval();
TrainExport texport(file_name);
int status = texport.LoadModel(lite_model_, size_backbone_);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot init export";
return status;
}
auto connect_map = ConnectionMap();
texport.set_connect(connect_map);
if (nchw2nhwc_) {
status = texport.AddTransformNode();
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot add transform node";
return status;
}
}
status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_);
if (status != RET_OK) {
MS_LOG(ERROR) << "cannot serialize head";
return status;
}
status = texport.SaveToFile();
if (status != RET_OK) {
MS_LOG(ERROR) << "failed to save to " << file_name;
return status;
}
if (orig_train_state) Train();
return status;
}

} // namespace lite

session::TrainSession *session::TrainSession::CreateTransferSession(const char *model_buf_backbone,


+ 3
- 0
mindspore/lite/src/train/transfer_session.h View File

@@ -61,6 +61,7 @@ class TransferSession : public lite::TrainSession {
mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override;

int CompileTransferGraph();
int ExportInference(std::string file_name) override;

protected:
lite::LiteSession *backbone_session_ = nullptr;
@@ -71,7 +72,9 @@ class TransferSession : public lite::TrainSession {

private:
bool CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask, size_t mask_len);
std::unordered_map<size_t, size_t> ConnectionMap();
bool nchw2nhwc_ = false;
size_t size_backbone_;
};
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save