| @@ -27,24 +27,24 @@ | |||
| #include "include/api/dual_abi_helper.h" | |||
| namespace mindspore { | |||
| using Key = struct Key { | |||
| constexpr char kDecModeAesGcm[] = "AES-GCM"; | |||
| struct MS_API Key { | |||
| const size_t max_key_len = 32; | |||
| size_t len; | |||
| unsigned char key[32]; | |||
| Key() : len(0) {} | |||
| Key(const char *dec_key, size_t key_len); | |||
| }; | |||
| class MS_API Serialization { | |||
| public: | |||
| static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph); | |||
| inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, | |||
| const Key &dec_key, const std::string &dec_mode); | |||
| inline static Status Load(const std::string &file, ModelType model_type, Graph *graph); | |||
| inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key, | |||
| const std::string &dec_mode); | |||
| const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); | |||
| inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {}, | |||
| const std::string &dec_mode = kDecModeAesGcm); | |||
| inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs, | |||
| const Key &dec_key = {}, const std::string &dec_mode = "AES-GCM"); | |||
| static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters); | |||
| const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); | |||
| static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model); | |||
| static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); | |||
| static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); | |||
| @@ -64,10 +64,6 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m | |||
| return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode)); | |||
| } | |||
| Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) { | |||
| return Load(StringToChar(file), model_type, graph); | |||
| } | |||
| Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key, | |||
| const std::string &dec_mode) { | |||
| return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode)); | |||
| @@ -35,6 +35,7 @@ enum ModelType : uint32_t { | |||
| kAIR = 1, | |||
| kOM = 2, | |||
| kONNX = 3, | |||
| kFlatBuffer = 4, | |||
| // insert new data type here | |||
| kUnknownType = 0xFFFFFFFF | |||
| }; | |||
| @@ -79,8 +79,20 @@ static Buffer ReadFile(const std::string &file) { | |||
| return buffer; | |||
| } | |||
| Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) { | |||
| return Load(model_data, data_size, model_type, graph, Key{}, StringToChar("AES-GCM")); | |||
| Key::Key(const char *dec_key, size_t key_len) { | |||
| len = 0; | |||
| if (key_len >= max_key_len) { | |||
| MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len; | |||
| return; | |||
| } | |||
| auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len); | |||
| if (sec_ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret; | |||
| return; | |||
| } | |||
| len = key_len; | |||
| } | |||
| Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, | |||
| @@ -137,7 +149,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m | |||
| } | |||
| Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) { | |||
| return Load(file, model_type, graph, Key{}, StringToChar("AES-GCM")); | |||
| return Load(file, model_type, graph, Key{}, StringToChar(kDecModeAesGcm)); | |||
| } | |||
| Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key, | |||
| @@ -256,11 +268,6 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp | |||
| return Status(kMEInvalidInput, err_msg.str()); | |||
| } | |||
| Status Serialization::LoadCheckPoint(const std::string &, std::map<std::string, Buffer> *) { | |||
| MS_LOG(ERROR) << "Unsupported feature."; | |||
| return kMEFailed; | |||
| } | |||
| Status Serialization::SetParameters(const std::map<std::string, Buffer> &, Model *) { | |||
| MS_LOG(ERROR) << "Unsupported feature."; | |||
| return kMEFailed; | |||
| @@ -24,10 +24,27 @@ | |||
| #include "include/model.h" | |||
| #include "include/ms_tensor.h" | |||
| #include "src/cxx_api/graph/graph_data.h" | |||
| #include "src/cxx_api/model/model_impl.h" | |||
| #include "src/common/log_adapter.h" | |||
| namespace mindspore { | |||
| Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) { | |||
| Key::Key(const char *dec_key, size_t key_len) { | |||
| len = 0; | |||
| if (key_len >= max_key_len) { | |||
| MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len; | |||
| return; | |||
| } | |||
| memcpy(key, dec_key, key_len); | |||
| len = key_len; | |||
| } | |||
| Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, | |||
| const Key &dec_key, const std::vector<char> &dec_mode) { | |||
| if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) { | |||
| MS_LOG(ERROR) << "Unsupported Feature."; | |||
| return kLiteError; | |||
| } | |||
| if (model_data == nullptr) { | |||
| MS_LOG(ERROR) << "model data is nullptr."; | |||
| return kLiteNullptr; | |||
| @@ -40,6 +57,7 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m | |||
| MS_LOG(ERROR) << "Unsupported IR."; | |||
| return kLiteInputParamInvalid; | |||
| } | |||
| auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size)); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "New model failed."; | |||
| @@ -54,28 +72,47 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m | |||
| return kSuccess; | |||
| } | |||
| Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, | |||
| const Key &dec_key, const std::vector<char> &dec_mode) { | |||
| MS_LOG(ERROR) << "Unsupported Feature."; | |||
| return kLiteError; | |||
| } | |||
| Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key, | |||
| const std::vector<char> &dec_mode) { | |||
| if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) { | |||
| MS_LOG(ERROR) << "Unsupported Feature."; | |||
| return kLiteError; | |||
| } | |||
| Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) { | |||
| MS_LOG(ERROR) << "Unsupported Feature."; | |||
| return kLiteError; | |||
| if (graph == nullptr) { | |||
| MS_LOG(ERROR) << "graph is nullptr."; | |||
| return kLiteNullptr; | |||
| } | |||
| if (model_type != kFlatBuffer) { | |||
| MS_LOG(ERROR) << "Unsupported IR."; | |||
| return kLiteInputParamInvalid; | |||
| } | |||
| std::string filename = file.data(); | |||
| if (filename.substr(filename.find_last_of(".") + 1) != "ms") { | |||
| filename = filename + ".ms"; | |||
| } | |||
| auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str())); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "New model failed."; | |||
| return kLiteNullptr; | |||
| } | |||
| auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model)); | |||
| if (graph_data == nullptr) { | |||
| MS_LOG(ERROR) << "New graph data failed."; | |||
| return kLiteMemoryFailed; | |||
| } | |||
| *graph = Graph(graph_data); | |||
| return kSuccess; | |||
| } | |||
| Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key, | |||
| const std::vector<char> &dec_mode) { | |||
| Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type, | |||
| std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) { | |||
| MS_LOG(ERROR) << "Unsupported Feature."; | |||
| return kLiteError; | |||
| } | |||
| Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) { | |||
| MS_LOG(ERROR) << "Unsupported feature."; | |||
| return kMEFailed; | |||
| } | |||
| Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) { | |||
| MS_LOG(ERROR) << "Unsupported feature."; | |||
| return kMEFailed; | |||
| @@ -46,11 +46,8 @@ TEST_F(TestCxxApiSerialization, test_load_file_not_exist_FAILED) { | |||
| TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_SUCCESS) { | |||
| Graph graph; | |||
| std::string key_str = "0123456789ABCDEF"; | |||
| Key key; | |||
| memcpy(key.key, key_str.c_str(), key_str.size()); | |||
| key.len = key_str.size(); | |||
| ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph, | |||
| key, "AES-GCM") == kSuccess); | |||
| Key(key_str.c_str(), key_str.size()), kDecModeAesGcm) == kSuccess); | |||
| } | |||
| TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) { | |||
| @@ -65,21 +62,16 @@ TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) { | |||
| TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_with_wrong_key_FAILED) { | |||
| Graph graph; | |||
| std::string key_str = "WRONG_KEY"; | |||
| Key key; | |||
| memcpy(key.key, key_str.c_str(), key_str.size()); | |||
| key.len = key_str.size(); | |||
| auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph, | |||
| key, "AES-GCM"); | |||
| Key(key_str.c_str(), key_str.size()), kDecModeAesGcm); | |||
| ASSERT_TRUE(status != kSuccess); | |||
| } | |||
| TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_with_wrong_key_FAILED) { | |||
| Graph graph; | |||
| std::string key_str = "WRONG_KEY"; | |||
| Key key; | |||
| memcpy(key.key, key_str.c_str(), key_str.size()); | |||
| key.len = key_str.size(); | |||
| auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, key, "AES-GCM"); | |||
| auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, | |||
| Key(key_str.c_str(), key_str.size()), kDecModeAesGcm); | |||
| ASSERT_TRUE(status != kSuccess); | |||
| } | |||