| @@ -1,24 +1,32 @@ | |||
| mindspore.ops.matmul | |||
| ===================== | |||
| .. py:function:: mindspore.ops.matmul(x1, x2, dtype=None) | |||
| .. py:class:: mindspore.ops.MatMul(transpose_a=False, transpose_b=False) | |||
| 计算两个数组的乘积。 | |||
| 将矩阵 `a` 和矩阵 `b` 相乘。 | |||
| .. note:: | |||
| 不支持NumPy参数 `out` 、 `casting` 、 `order` 、 `subok` 、 `signature` 、 `extobj` 。在GPU上支持的数据类型为np.float16和np.float32。在CPU上支持的数据类型为np.float16和np.float32。 | |||
| .. math:: | |||
| (Output)_{i j}=\sum_{k=1}^{p} a_{i k} b_{k j}=a_{i 1} b_{1 j}+a_{i 2} b_{2 j}+\cdots+a_{i p} b_{p j}, p\in N | |||
| 其中, :math:`i,j` 表示输出的第i行和第j列元素。 | |||
| **参数:** | |||
| - **x1** (Tensor) - 输入Tensor,不支持Scalar, `x1` 的最后一维度和 `x2` 的倒数第二维度相等,且 `x1` 和 `x2` 彼此支持广播。 | |||
| - **x2** (Tensor) - 输入Tensor,不支持Scalar, `x1` 的最后一维度和 `x2` 的倒数第二维度相等,且 `x1` 和 `x2` 彼此支持广播。 | |||
| - **dtype** (:class:mindspore.dtype, optional) - 指定输入Tensor的数据类型,默认:None。 | |||
| - **transpose_a** (bool) - 如果为True,则在相乘之前转置 `a`。默认值:False。 | |||
| - **transpose_b** (bool) - 如果为True,则在相乘之前转置 `b`。默认值:False。 | |||
| **输入:** | |||
| - **a** (Tensor) - 要相乘的第一个Tensor。如果 `transpose_a` 为False,则该Tensor的shape为 :math:`(N, C)` ;否则,该Tensor的shape为 :math:`(C, N)` 。 | |||
| - **b** (Tensor) - 要相乘的第二个Tensor。如果 `transpose_b` 为False,则该Tensor的shape为 :math:`(C, M)` ;否则,该Tensor的shape为 :math:`(M, C)` 。 | |||
| **输出:** | |||
| Tensor或Scalar,矩阵乘积的输入。当 `x1` 和 `x2` 为一维向量时,输入为Scalar。 | |||
| Tensor,输出Tensor的shape为 :math:`(N, M)` 。 | |||
| **异常:** | |||
| - **ValueError** - `x1` 的最后一维度和 `x2` 的倒数第二维度不相等,或者输入的是Scalar。 | |||
| - **ValueError** - `x1` 和 `x2` 彼此不能广播。 | |||
| - **TypeError** - `transpose_a` 或 `transpose_b` 不是bool。 | |||
| - **ValueError** - 矩阵 `a` 的列不等于矩阵 `b` 的行。 | |||
| - **ValueError** - `a` 或 `b` 的维度不等于2。 | |||
| @@ -242,24 +242,24 @@ constexpr auto kCtxGetKeysClientList = "get_keys_client_list"; | |||
| constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size"; | |||
| constexpr auto kCtxCipherPrimer = "cipher_primer"; | |||
| constexpr auto kCurrentIteration = "current_iteration"; | |||
| const char PYTHON_MOD_SERIALIZE_MODULE[] = "mindspore.train.serialization"; | |||
| const char PYTHON_MOD_SAFE_WEIGHT[] = "_save_weight"; | |||
| // This macro the current timestamp in milliseconds. | |||
| #define CURRENT_TIME_MILLI \ | |||
| std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()) | |||
| // This method returns the size in bytes of the given TypeId. | |||
| inline size_t GetTypeIdByte(const TypeId &type) { | |||
| inline std::string GetTypeIdByte(const TypeId &type) { | |||
| switch (type) { | |||
| case kNumberTypeFloat16: | |||
| return kNumberTypeFloat16Type; | |||
| case kNumberTypeUInt32: | |||
| return "Float16"; | |||
| case kNumberTypeFloat32: | |||
| return kNumberTypeFloat32Type; | |||
| case kNumberTypeUInt64: | |||
| return kNumberTypeUInt64Type; | |||
| return "Float32"; | |||
| case kNumberTypeFloat64: | |||
| return "Float64"; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "TypeId " << type << " not supported."; | |||
| return 0; | |||
| } | |||
| } | |||
| @@ -32,7 +32,8 @@ void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_cou | |||
| } | |||
| aggregation_count_ = aggregation_count; | |||
| // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers. | |||
| // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and | |||
| // optimizers. | |||
| bool ret = InitParamAggregator(func_graph); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed."; | |||
| @@ -274,7 +275,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success."; | |||
| MS_LOG(INFO) << "Initializing parameter aggregator for param_name " << param_name << " success."; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -565,7 +565,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) { | |||
| feature_map[weight_fullname] = weight_size; | |||
| } | |||
| if (LocalMetaStore::GetInstance().verifyFeatureMap(feature_map)) { | |||
| if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) { | |||
| ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||
| iteration_result_ = IterationResult::kSuccess; | |||
| MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished."; | |||
| @@ -867,6 +867,7 @@ void Iteration::InitGlobalIterTimer(const TimeOutCb &timeout_cb) { | |||
| global_iteration_time_window_ = ps::PSContext::instance()->global_iteration_time_window(); | |||
| global_iter_timer_ = std::make_shared<IterationTimer>(); | |||
| MS_LOG(INFO) << "Global iteration time window is: " << global_iteration_time_window_; | |||
| // Set the timeout callback for the timer. | |||
| global_iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void { | |||
| std::string reason = "Global Iteration " + std::to_string(iteration_num_) + | |||
| @@ -65,6 +65,11 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(T), std::multiplies<size_t>()); | |||
| size_t new_weight_size = weight_size; | |||
| Feature feature; | |||
| feature.weight_shape = weight_shape; | |||
| feature.weight_size = weight_size; | |||
| feature.weight_type = GetTypeIdByte(kNumberTypeFloat32); | |||
| input_size_list_.push_back(weight_size); | |||
| input_size_list_.push_back(sizeof(size_t)); | |||
| input_size_list_.push_back(new_weight_size); | |||
| @@ -76,7 +81,7 @@ class FedAvgKernel : public AggregationKernelMod { | |||
| MS_EXCEPTION_IF_NULL(weight_node); | |||
| name_ = cnode_name + "." + weight_node->fullname_with_scope(); | |||
| LocalMetaStore::GetInstance().put_feature_map(weight_node->fullname_with_scope(), weight_size); | |||
| LocalMetaStore::GetInstance().put_aggregation_feature_map(weight_node->fullname_with_scope(), feature); | |||
| MS_LOG(INFO) << "Aggregate Weight full name is " << weight_node->fullname_with_scope() << ", weight byte size is " | |||
| << weight_size; | |||
| GenerateReuseKernelNodeInfo(); | |||
| @@ -210,7 +210,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel | |||
| feature_map[weight_full_name] = weight_size; | |||
| } | |||
| if (!LocalMetaStore::GetInstance().verifyFeatureMap(feature_map)) { | |||
| if (!LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) { | |||
| auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp); | |||
| std::string reason = "Verify model feature map failed, retry later at time: " + std::to_string(next_req_time); | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time)); | |||
| @@ -41,23 +41,29 @@ const size_t LocalMetaStore::curr_iter_num() { | |||
| return curr_iter_num_; | |||
| } | |||
| const void LocalMetaStore::put_feature_map(const std::string &name, const size_t &size) { feature_map_[name] = size; } | |||
| const void LocalMetaStore::put_aggregation_feature_map(const std::string &name, const Feature &feature) { | |||
| if (aggregation_feature_map_.count(name) > 0) { | |||
| MS_LOG(WARNING) << "Put feature " << name << " failed."; | |||
| return; | |||
| } | |||
| aggregation_feature_map_[name] = feature; | |||
| } | |||
| std::unordered_map<std::string, size_t> &LocalMetaStore::feature_map() { return feature_map_; } | |||
| std::unordered_map<std::string, Feature> &LocalMetaStore::aggregation_feature_map() { return aggregation_feature_map_; } | |||
| bool LocalMetaStore::verifyFeatureMap(const std::unordered_map<std::string, size_t> &model) { | |||
| bool LocalMetaStore::verifyAggregationFeatureMap(const std::unordered_map<std::string, size_t> &model) { | |||
| // feature map size in Hybrid training is not equal with upload model size | |||
| if (model.size() > feature_map_.size()) { | |||
| if (model.size() > aggregation_feature_map_.size()) { | |||
| return false; | |||
| } | |||
| for (const auto &weight : model) { | |||
| std::string weight_name = weight.first; | |||
| size_t weight_size = weight.second; | |||
| if (feature_map_.count(weight_name) == 0) { | |||
| if (aggregation_feature_map_.count(weight_name) == 0) { | |||
| return false; | |||
| } | |||
| if (weight_size != feature_map_[weight_name]) { | |||
| if (weight_size != aggregation_feature_map_[weight_name].weight_size) { | |||
| return false; | |||
| } | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <any> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "fl/server/common.h" | |||
| @@ -29,6 +30,14 @@ namespace server { | |||
| // LocalMetaStore class is used for metadata storage of this server process. | |||
| // For example, the current iteration number, time windows for round kernels, etc. | |||
| // LocalMetaStore is threadsafe. | |||
| struct Feature { | |||
| std::vector<size_t> weight_shape; | |||
| std::string weight_type; | |||
| size_t weight_size; | |||
| std::vector<float> weight_data; | |||
| }; | |||
| class LocalMetaStore { | |||
| public: | |||
| static LocalMetaStore &GetInstance() { | |||
| @@ -70,11 +79,11 @@ class LocalMetaStore { | |||
| void set_curr_iter_num(size_t num); | |||
| const size_t curr_iter_num(); | |||
| const void put_feature_map(const std::string &name, const size_t &size); | |||
| const void put_aggregation_feature_map(const std::string &name, const Feature &feature); | |||
| std::unordered_map<std::string, size_t> &feature_map(); | |||
| std::unordered_map<std::string, Feature> &aggregation_feature_map(); | |||
| bool verifyFeatureMap(const std::unordered_map<std::string, size_t> &model); | |||
| bool verifyAggregationFeatureMap(const std::unordered_map<std::string, size_t> &model); | |||
| private: | |||
| LocalMetaStore() : key_to_meta_({}), curr_iter_num_(0) {} | |||
| @@ -88,8 +97,8 @@ class LocalMetaStore { | |||
| std::mutex mtx_; | |||
| size_t curr_iter_num_{0}; | |||
| // feature_map_ stores model meta data with weight name and size. | |||
| std::unordered_map<std::string, size_t> feature_map_; | |||
| // aggregation_feature_map_ stores model meta data with weight name and size which will be Aggregated. | |||
| std::unordered_map<std::string, Feature> aggregation_feature_map_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace fl | |||
| @@ -19,16 +19,17 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "fl/server/executor.h" | |||
| #include "pipeline/jit/parse/parse.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| namespace server { | |||
| void ModelStore::Initialize(uint32_t max_count) { | |||
| void ModelStore::Initialize(uint32_t rank_id, uint32_t max_count) { | |||
| if (!Executor::GetInstance().initialized()) { | |||
| MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage."; | |||
| return; | |||
| } | |||
| rank_id_ = rank_id; | |||
| max_model_count_ = max_count; | |||
| initial_model_ = AssignNewModelMemory(); | |||
| iteration_to_model_[kInitIterationNum] = initial_model_; | |||
| @@ -84,6 +85,7 @@ void ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin | |||
| } | |||
| iteration_to_model_[iteration] = memory_register; | |||
| OnIterationUpdate(); | |||
| SaveCheckpoint(iteration, new_model); | |||
| return; | |||
| } | |||
| @@ -241,6 +243,41 @@ void ModelStore::OnIterationUpdate() { | |||
| << ", total add and sub reference count: " << total_add_reference_count << ", " | |||
| << total_sub_reference_count; | |||
| } | |||
| void ModelStore::SaveCheckpoint(size_t iteration, const std::map<std::string, AddressPtr> &model) { | |||
| if (rank_id_ != kLeaderServerRank) { | |||
| MS_LOG(INFO) << "Only leader server will save the weight."; | |||
| return; | |||
| } | |||
| std::unordered_map<std::string, Feature> &aggregation_feature_map = | |||
| LocalMetaStore::GetInstance().aggregation_feature_map(); | |||
| namespace python_adapter = mindspore::python_adapter; | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_SERIALIZE_MODULE); | |||
| py::dict dict_data = py::dict(); | |||
| for (const auto &weight : model) { | |||
| std::string weight_fullname = weight.first; | |||
| float *weight_data = reinterpret_cast<float *>(weight.second->addr); | |||
| size_t weight_data_size = weight.second->size / sizeof(float); | |||
| Feature aggregation_feature = aggregation_feature_map[weight_fullname]; | |||
| std::vector<float> weight_data_vec(weight_data, weight_data + weight_data_size); | |||
| py::list data_list; | |||
| data_list.append(aggregation_feature.weight_type); | |||
| data_list.append(aggregation_feature.weight_shape); | |||
| data_list.append(weight_data_vec); | |||
| data_list.append(weight_data_size); | |||
| dict_data[py::str(weight_fullname)] = data_list; | |||
| } | |||
| std::string checkpoint_dir = ps::PSContext::instance()->checkpoint_dir(); | |||
| std::string fl_name = ps::PSContext::instance()->fl_name(); | |||
| python_adapter::CallPyModFn(mod, PYTHON_MOD_SAFE_WEIGHT, py::str(checkpoint_dir), py::str(fl_name), | |||
| py::str(std::to_string(iteration)), dict_data); | |||
| } | |||
| } // namespace server | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| } // namespace mindspore | |||
| @@ -21,9 +21,11 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "fl/server/common.h" | |||
| #include "fl/server/memory_register.h" | |||
| #include "fl/server/executor.h" | |||
| #include "fl/server/local_meta_store.h" | |||
| namespace mindspore { | |||
| namespace fl { | |||
| @@ -44,7 +46,7 @@ class ModelStore { | |||
| } | |||
| // Initialize ModelStore with max count of models need to be stored. | |||
| void Initialize(uint32_t max_count = 3); | |||
| void Initialize(uint32_t rank_id, uint32_t max_count = 3); | |||
| // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already | |||
| // max_model_count_, the earliest model will be replaced. | |||
| @@ -75,6 +77,8 @@ class ModelStore { | |||
| ModelStore(const ModelStore &) = delete; | |||
| ModelStore &operator=(const ModelStore &) = delete; | |||
| void SaveCheckpoint(size_t iteration, const std::map<std::string, AddressPtr> &model); | |||
| // To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ * | |||
| // model_size_. | |||
| std::shared_ptr<MemoryRegister> AssignNewModelMemory(); | |||
| @@ -91,6 +95,7 @@ class ModelStore { | |||
| // The number of all models stored is max_model_count_. | |||
| std::mutex model_mtx_; | |||
| std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_; | |||
| uint32_t rank_id_; | |||
| struct HttpResponseModelCache { | |||
| std::string round_name; // startFlJob, getModel | |||
| @@ -108,4 +113,4 @@ class ModelStore { | |||
| } // namespace server | |||
| } // namespace fl | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_ | |||
| #endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_ | |||
| @@ -406,7 +406,7 @@ void Server::InitExecutor() { | |||
| // so the required_cnt of these kernels must be the same as executor_threshold_. | |||
| MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_; | |||
| Executor::GetInstance().Initialize(func_graph, executor_threshold_); | |||
| ModelStore::GetInstance().Initialize(); | |||
| ModelStore::GetInstance().Initialize(server_node_->rank_id()); | |||
| // init weight memory to 0 after get model | |||
| Executor::GetInstance().ResetAggregationStatus(); | |||
| return; | |||
| @@ -504,7 +504,9 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.") | |||
| .def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window, | |||
| "Set global iteration time window.") | |||
| .def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window."); | |||
| .def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.") | |||
| .def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.") | |||
| .def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory."); | |||
| (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); | |||
| (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); | |||
| (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted"); | |||
| @@ -549,5 +549,9 @@ void PSContext::set_global_iteration_time_window(const uint64_t &global_iteratio | |||
| } | |||
| uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; } | |||
| std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; } | |||
| void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -230,6 +230,9 @@ class PSContext { | |||
| void set_global_iteration_time_window(const uint64_t &global_iteration_time_window); | |||
| uint64_t global_iteration_time_window() const; | |||
| std::string checkpoint_dir() const; | |||
| void set_checkpoint_dir(const std::string &checkpoint_dir); | |||
| private: | |||
| PSContext() | |||
| : ps_enabled_(false), | |||
| @@ -282,7 +285,8 @@ class PSContext { | |||
| client_password_(""), | |||
| server_password_(""), | |||
| http_url_prefix_(""), | |||
| global_iteration_time_window_(21600000) {} | |||
| global_iteration_time_window_(3600000), | |||
| checkpoint_dir_("") {} | |||
| bool ps_enabled_; | |||
| bool is_worker_; | |||
| bool is_pserver_; | |||
| @@ -415,6 +419,8 @@ class PSContext { | |||
| // The time window of startFLJob round in millisecond. | |||
| uint64_t global_iteration_time_window_; | |||
| // directory of server checkpoint | |||
| std::string checkpoint_dir_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -1123,7 +1123,9 @@ def set_fl_context(**kwargs): | |||
| http_url_prefix (string): The http url prefix for http server. | |||
| Default: "". | |||
| global_iteration_time_window (unsigned long): The global iteration time window for one iteration | |||
| with rounds(ms). Default: 21600000. | |||
| with rounds(ms). Default: 3600000. | |||
| checkpoint_dir (string): The Server model checkpoint directory. If no checkpoint dir is set, | |||
| the startup script directory is used by default. Default: "". | |||
| Raises: | |||
| ValueError: If input key is not the attribute in federated learning mode context. | |||
| @@ -77,7 +77,8 @@ _set_ps_context_func_map = { | |||
| "sign_eps": ps_context().set_sign_eps, | |||
| "sign_thr_ratio": ps_context().set_sign_thr_ratio, | |||
| "sign_global_lr": ps_context().set_sign_global_lr, | |||
| "sign_dim_out": ps_context().set_sign_dim_out | |||
| "sign_dim_out": ps_context().set_sign_dim_out, | |||
| "checkpoint_dir": ps_context().set_checkpoint_dir, | |||
| } | |||
| _get_ps_context_func_map = { | |||
| @@ -124,7 +125,8 @@ _get_ps_context_func_map = { | |||
| "sign_eps": ps_context().sign_eps, | |||
| "sign_thr_ratio": ps_context().sign_thr_ratio, | |||
| "sign_global_lr": ps_context().sign_global_lr, | |||
| "sign_dim_out": ps_context().sign_dim_out | |||
| "sign_dim_out": ps_context().sign_dim_out, | |||
| "checkpoint_dir": ps_context().checkpoint_dir | |||
| } | |||
| _check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", | |||
| @@ -155,6 +155,39 @@ def _type_convert(param, new_param, strict_load): | |||
| return False | |||
| def _save_weight(checkpoint_dir, model_name, iteration, params): | |||
| """Save model weight into checkpoint.""" | |||
| logger.debug(f"Checkpoint dir is: '{checkpoint_dir}'") | |||
| exist_ckpt_file_list = [] | |||
| if os.path.exists(checkpoint_dir): | |||
| for exist_ckpt_name in os.listdir(checkpoint_dir): | |||
| file_prefix = model_name + "_iteration_" | |||
| if exist_ckpt_name.startswith(file_prefix): | |||
| exist_ckpt_file_list.append(exist_ckpt_name) | |||
| param_dict = OrderedDict() | |||
| for key in params.keys(): | |||
| value = params[key] | |||
| weight_type = value[0] | |||
| weight_shape = value[1] | |||
| weight_data = value[2] | |||
| weight_size = value[3] | |||
| weight_np = np.array(weight_data, dtype=weight_type.lower()) | |||
| logger.debug(f"weight_type: '{weight_type}', weight_shape: '{weight_shape}', weight_size: " | |||
| f"'{weight_size}', weight_np.nbytes: '{weight_np.nbytes}'") | |||
| param_dict[key] = [weight_shape, weight_type, weight_np] | |||
| ckpt_file_save_name = model_name + "_iteration_" + iteration + ".ckpt" | |||
| ckpt_file_save_path = os.path.join(checkpoint_dir, ckpt_file_save_name) | |||
| _exec_save(ckpt_file_save_path, param_dict) | |||
| for exist_ckpt_name in exist_ckpt_file_list: | |||
| os.remove(os.path.join(checkpoint_dir, exist_ckpt_name)) | |||
| logger.info(f"Save weight to checkpoint file path '{ckpt_file_save_path}' success.") | |||
| else: | |||
| logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.") | |||
| def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): | |||
| """Execute the process of saving checkpoint into file.""" | |||
| try: | |||
| @@ -1003,7 +1036,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs): | |||
| if 'dataset' in kwargs.keys() and kwargs.get('dataset') is not None: | |||
| check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset) | |||
| dataset = kwargs['dataset'] | |||
| dataset = kwargs.get('dataset') | |||
| _save_dataset_to_mindir(model, dataset) | |||
| save_together = _save_together(net_dict, model) | |||