|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
#include "include/api/serialization.h" |
|
|
|
#include <fstream> |
|
|
|
#include <sstream> |
|
|
|
#include "cxx_api/graph/graph_data.h" |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
#include "mindspore/core/load_mindir/load_model.h" |
|
|
|
@@ -69,62 +70,48 @@ static Buffer ReadFile(const std::string &file) { |
|
|
|
} |
|
|
|
|
|
|
|
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) { |
|
|
|
if (graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Output args graph is nullptr."; |
|
|
|
return kMEInvalidInput; |
|
|
|
} |
|
|
|
|
|
|
|
if (model_type == kMindIR) { |
|
|
|
FuncGraphPtr anf_graph = nullptr; |
|
|
|
try { |
|
|
|
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size); |
|
|
|
} catch (const std::exception &) { |
|
|
|
if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) { |
|
|
|
MS_LOG(ERROR) << "Load model failed. The model_data may be encrypted, please pass in correct key."; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Load model failed."; |
|
|
|
} |
|
|
|
return kMEInvalidInput; |
|
|
|
} |
|
|
|
|
|
|
|
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); |
|
|
|
return kSuccess; |
|
|
|
} else if (model_type == kOM) { |
|
|
|
*graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM)); |
|
|
|
return kSuccess; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported ModelType " << model_type; |
|
|
|
return kMEInvalidInput; |
|
|
|
return Load(model_data, data_size, model_type, graph, Key{}, StringToChar("AES-GCM")); |
|
|
|
} |
|
|
|
|
|
|
|
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, |
|
|
|
const Key &dec_key, const std::vector<char> &dec_mode) { |
|
|
|
std::stringstream err_msg; |
|
|
|
if (graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Output args graph is nullptr."; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Output args graph is nullptr."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
|
|
|
|
if (model_type == kMindIR) { |
|
|
|
FuncGraphPtr anf_graph = nullptr; |
|
|
|
try { |
|
|
|
if (dec_key.len > dec_key.max_key_len) { |
|
|
|
MS_LOG(ERROR) << "The key length exceeds maximum length: 32."; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} else if (dec_key.len == 0) { |
|
|
|
if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) { |
|
|
|
err_msg << "Load model failed. The model_data may be encrypted, please pass in correct key."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} else { |
|
|
|
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size); |
|
|
|
} |
|
|
|
} else { |
|
|
|
size_t plain_data_size; |
|
|
|
std::string dec_mode_str(dec_mode.begin(), dec_mode.end()); |
|
|
|
auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast<const unsigned char *>(model_data), |
|
|
|
data_size, dec_key.key, dec_key.len, dec_mode_str); |
|
|
|
data_size, dec_key.key, dec_key.len, CharToString(dec_mode)); |
|
|
|
if (plain_data == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode."; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Load model failed. Please check the valid of dec_key and dec_mode."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(plain_data.get()), plain_data_size); |
|
|
|
} |
|
|
|
} catch (const std::exception &) { |
|
|
|
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode."; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Load model failed. Please check the valid of dec_key and dec_mode."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
|
|
|
|
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); |
|
|
|
@@ -134,78 +121,112 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m |
|
|
|
return kSuccess; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported ModelType " << model_type; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Unsupported ModelType " << model_type; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
|
|
|
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) { |
|
|
|
return Load(file, model_type, graph, Key{}, StringToChar("AES-GCM")); |
|
|
|
} |
|
|
|
|
|
|
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key, |
|
|
|
const std::vector<char> &dec_mode) { |
|
|
|
std::stringstream err_msg; |
|
|
|
if (graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Output args graph is nullptr."; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Output args graph is nullptr."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
|
|
|
|
std::string file_path = CharToString(file); |
|
|
|
if (model_type == kMindIR) { |
|
|
|
FuncGraphPtr anf_graph = LoadMindIR(file_path); |
|
|
|
FuncGraphPtr anf_graph; |
|
|
|
if (dec_key.len > dec_key.max_key_len) { |
|
|
|
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} else if (dec_key.len == 0 && IsCipherFile(file_path)) { |
|
|
|
err_msg << "Load model failed. The file may be encrypted, please pass in correct key."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} else { |
|
|
|
anf_graph = LoadMindIR(file_path, false, nullptr, dec_key.len, CharToString(dec_mode)); |
|
|
|
} |
|
|
|
if (anf_graph == nullptr) { |
|
|
|
if (IsCipherFile(file_path)) { |
|
|
|
MS_LOG(ERROR) << "Load model failed. The file may be encrypted, please pass in correct key."; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Load model failed."; |
|
|
|
} |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Load model failed. Please check the valid of dec_key and dec_mode"; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); |
|
|
|
return kSuccess; |
|
|
|
} else if (model_type == kOM) { |
|
|
|
Buffer data = ReadFile(file_path); |
|
|
|
if (data.Data() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Read file " << file_path << " failed."; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Read file " << file_path << " failed."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
*graph = Graph(std::make_shared<Graph::GraphData>(data, kOM)); |
|
|
|
return kSuccess; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported ModelType " << model_type; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Unsupported ModelType " << model_type; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
|
|
|
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key, |
|
|
|
const std::vector<char> &dec_mode) { |
|
|
|
if (graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Output args graph is nullptr."; |
|
|
|
return kMEInvalidInput; |
|
|
|
Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type, |
|
|
|
std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) { |
|
|
|
std::stringstream err_msg; |
|
|
|
if (graphs == nullptr) { |
|
|
|
err_msg << "Output args graph is nullptr."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
|
|
|
|
std::string file_path = CharToString(file); |
|
|
|
if (files.size() == 1) { |
|
|
|
std::vector<Graph> result(files.size()); |
|
|
|
auto ret = Load(files[0], model_type, &result[0], dec_key, dec_mode); |
|
|
|
*graphs = std::move(result); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::string> files_path = VectorCharToString(files); |
|
|
|
if (model_type == kMindIR) { |
|
|
|
FuncGraphPtr anf_graph; |
|
|
|
if (dec_key.len > dec_key.max_key_len) { |
|
|
|
MS_LOG(ERROR) << "The key length exceeds maximum length: 32."; |
|
|
|
return kMEInvalidInput; |
|
|
|
} else { |
|
|
|
std::string dec_mode_str(dec_mode.begin(), dec_mode.end()); |
|
|
|
anf_graph = LoadMindIR(file_path, false, dec_key.key, dec_key.len, dec_mode_str); |
|
|
|
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
if (anf_graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode"; |
|
|
|
return kMEInvalidInput; |
|
|
|
auto anf_graphs = |
|
|
|
LoadMindIRs(files_path, false, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode)); |
|
|
|
if (anf_graphs.size() != files_path.size()) { |
|
|
|
err_msg << "Load model failed, " << files_path.size() << " files got " << anf_graphs.size() << " graphs."; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); |
|
|
|
return kSuccess; |
|
|
|
} else if (model_type == kOM) { |
|
|
|
Buffer data = ReadFile(file_path); |
|
|
|
if (data.Data() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Read file " << file_path << " failed."; |
|
|
|
return kMEInvalidInput; |
|
|
|
std::vector<Graph> results; |
|
|
|
for (size_t i = 0; i < anf_graphs.size(); ++i) { |
|
|
|
if (anf_graphs[i] == nullptr) { |
|
|
|
if (dec_key.len == 0 && IsCipherFile(files_path[i])) { |
|
|
|
err_msg << "Load model failed. The file " << files_path[i] << " be encrypted, please pass in correct key."; |
|
|
|
} else { |
|
|
|
err_msg << "Load model " << files_path[i] << " failed."; |
|
|
|
} |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
results.emplace_back(std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR)); |
|
|
|
} |
|
|
|
*graph = Graph(std::make_shared<Graph::GraphData>(data, kOM)); |
|
|
|
|
|
|
|
*graphs = std::move(results); |
|
|
|
return kSuccess; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported ModelType " << model_type; |
|
|
|
return kMEInvalidInput; |
|
|
|
err_msg << "Unsupported ModelType " << model_type; |
|
|
|
MS_LOG(ERROR) << err_msg.str(); |
|
|
|
return Status(kMEInvalidInput, err_msg.str()); |
|
|
|
} |
|
|
|
|
|
|
|
Status Serialization::LoadCheckPoint(const std::string &, std::map<std::string, Buffer> *) { |
|
|
|
|