Browse Source

fix issue I4TBJO

r1.7
twc 4 years ago
parent
commit
ba2192eb40
17 changed files with 166 additions and 45 deletions
  1. +18
    -10
      docs/api/api_python/ops/mindspore.ops.matmul.rst
  2. +7
    -7
      mindspore/ccsrc/fl/server/common.h
  3. +3
    -2
      mindspore/ccsrc/fl/server/executor.cc
  4. +2
    -1
      mindspore/ccsrc/fl/server/iteration.cc
  5. +6
    -1
      mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h
  6. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc
  7. +12
    -6
      mindspore/ccsrc/fl/server/local_meta_store.cc
  8. +14
    -5
      mindspore/ccsrc/fl/server/local_meta_store.h
  9. +40
    -3
      mindspore/ccsrc/fl/server/model_store.cc
  10. +7
    -2
      mindspore/ccsrc/fl/server/model_store.h
  11. +1
    -1
      mindspore/ccsrc/fl/server/server.cc
  12. +3
    -1
      mindspore/ccsrc/pipeline/jit/init.cc
  13. +4
    -0
      mindspore/ccsrc/ps/ps_context.cc
  14. +7
    -1
      mindspore/ccsrc/ps/ps_context.h
  15. +3
    -1
      mindspore/python/mindspore/context.py
  16. +4
    -2
      mindspore/python/mindspore/parallel/_ps_context.py
  17. +34
    -1
      mindspore/python/mindspore/train/serialization.py

+ 18
- 10
docs/api/api_python/ops/mindspore.ops.matmul.rst View File

@@ -1,24 +1,32 @@
mindspore.ops.matmul
=====================

.. py:function:: mindspore.ops.matmul(x1, x2, dtype=None)
.. py:class:: mindspore.ops.MatMul(transpose_a=False, transpose_b=False)

计算两个数组的乘积
将矩阵 `a` 和矩阵 `b` 相乘

.. note::
不支持NumPy参数 `out` 、 `casting` 、 `order` 、 `subok` 、 `signature` 、 `extobj` 。在GPU上支持的数据类型为np.float16和np.float32。在CPU上支持的数据类型为np.float16和np.float32。
.. math::

(Output)_{i j}=\sum_{k=1}^{p} a_{i k} b_{k j}=a_{i 1} b_{1 j}+a_{i 2} b_{2 j}+\cdots+a_{i p} b_{p j}, p\in N

其中, :math:`i,j` 表示输出的第i行和第j列元素。

**参数:**

- **x1** (Tensor) - 输入Tensor,不支持Scalar, `x1` 的最后一维度和 `x2` 的倒数第二维度相等,且 `x1` 和 `x2` 彼此支持广播。
- **x2** (Tensor) - 输入Tensor,不支持Scalar, `x1` 的最后一维度和 `x2` 的倒数第二维度相等,且 `x1` 和 `x2` 彼此支持广播。
- **dtype** (:class:mindspore.dtype, optional) - 指定输入Tensor的数据类型,默认:None。
- **transpose_a** (bool) - 如果为True,则在相乘之前转置 `a`。默认值:False。
- **transpose_b** (bool) - 如果为True,则在相乘之前转置 `b`。默认值:False。

**输入:**

- **a** (Tensor) - 要相乘的第一个Tensor。如果 `transpose_a` 为False,则该Tensor的shape为 :math:`(N, C)` ;否则,该Tensor的shape为 :math:`(C, N)` 。
- **b** (Tensor) - 要相乘的第二个Tensor。如果 `transpose_b` 为False,则该Tensor的shape为 :math:`(C, M)` ;否则,该Tensor的shape为 :math:`(M, C)` 。

**输出:**

Tensor或Scalar,矩阵乘积的输入。当 `x1` 和 `x2` 为一维向量时,输入为Scalar。
Tensor,输出Tensor的shape为 :math:`(N, M)`

**异常:**

- **ValueError** - `x1` 的最后一维度和 `x2` 的倒数第二维度不相等,或者输入的是Scalar。
- **ValueError** - `x1` 和 `x2` 彼此不能广播。
- **TypeError** - `transpose_a` 或 `transpose_b` 不是bool。
- **ValueError** - 矩阵 `a` 的列不等于矩阵 `b` 的行。
- **ValueError** - `a` 或 `b` 的维度不等于2。

+ 7
- 7
mindspore/ccsrc/fl/server/common.h View File

@@ -242,24 +242,24 @@ constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
constexpr auto kCtxCipherPrimer = "cipher_primer";
constexpr auto kCurrentIteration = "current_iteration";
const char PYTHON_MOD_SERIALIZE_MODULE[] = "mindspore.train.serialization";
const char PYTHON_MOD_SAFE_WEIGHT[] = "_save_weight";

// This macro the current timestamp in milliseconds.
#define CURRENT_TIME_MILLI \
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())

// This method returns the size in bytes of the given TypeId.
inline size_t GetTypeIdByte(const TypeId &type) {
inline std::string GetTypeIdByte(const TypeId &type) {
switch (type) {
case kNumberTypeFloat16:
return kNumberTypeFloat16Type;
case kNumberTypeUInt32:
return "Float16";
case kNumberTypeFloat32:
return kNumberTypeFloat32Type;
case kNumberTypeUInt64:
return kNumberTypeUInt64Type;
return "Float32";
case kNumberTypeFloat64:
return "Float64";
default:
MS_LOG(EXCEPTION) << "TypeId " << type << " not supported.";
return 0;
}
}



+ 3
- 2
mindspore/ccsrc/fl/server/executor.cc View File

@@ -32,7 +32,8 @@ void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_cou
}
aggregation_count_ = aggregation_count;

// Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers.
// Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and
// optimizers.
bool ret = InitParamAggregator(func_graph);
if (!ret) {
MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed.";
@@ -274,7 +275,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed.";
return false;
}
MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
MS_LOG(INFO) << "Initializing parameter aggregator for param_name " << param_name << " success.";
}
return true;
}


+ 2
- 1
mindspore/ccsrc/fl/server/iteration.cc View File

@@ -565,7 +565,7 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
feature_map[weight_fullname] = weight_size;
}

if (LocalMetaStore::GetInstance().verifyFeatureMap(feature_map)) {
if (LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) {
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
iteration_result_ = IterationResult::kSuccess;
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
@@ -867,6 +867,7 @@ void Iteration::InitGlobalIterTimer(const TimeOutCb &timeout_cb) {
global_iteration_time_window_ = ps::PSContext::instance()->global_iteration_time_window();
global_iter_timer_ = std::make_shared<IterationTimer>();

MS_LOG(INFO) << "Global iteration time window is: " << global_iteration_time_window_;
// Set the timeout callback for the timer.
global_iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void {
std::string reason = "Global Iteration " + std::to_string(iteration_num_) +


+ 6
- 1
mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h View File

@@ -65,6 +65,11 @@ class FedAvgKernel : public AggregationKernelMod {
std::accumulate(weight_shape.begin(), weight_shape.end(), sizeof(T), std::multiplies<size_t>());
size_t new_weight_size = weight_size;

Feature feature;
feature.weight_shape = weight_shape;
feature.weight_size = weight_size;
feature.weight_type = GetTypeIdByte(kNumberTypeFloat32);

input_size_list_.push_back(weight_size);
input_size_list_.push_back(sizeof(size_t));
input_size_list_.push_back(new_weight_size);
@@ -76,7 +81,7 @@ class FedAvgKernel : public AggregationKernelMod {
MS_EXCEPTION_IF_NULL(weight_node);
name_ = cnode_name + "." + weight_node->fullname_with_scope();

LocalMetaStore::GetInstance().put_feature_map(weight_node->fullname_with_scope(), weight_size);
LocalMetaStore::GetInstance().put_aggregation_feature_map(weight_node->fullname_with_scope(), feature);
MS_LOG(INFO) << "Aggregate Weight full name is " << weight_node->fullname_with_scope() << ", weight byte size is "
<< weight_size;
GenerateReuseKernelNodeInfo();


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc View File

@@ -210,7 +210,7 @@ ResultCode UpdateModelKernel::VerifyUpdateModel(const schema::RequestUpdateModel
feature_map[weight_full_name] = weight_size;
}

if (!LocalMetaStore::GetInstance().verifyFeatureMap(feature_map)) {
if (!LocalMetaStore::GetInstance().verifyAggregationFeatureMap(feature_map)) {
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
std::string reason = "Verify model feature map failed, retry later at time: " + std::to_string(next_req_time);
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time));


+ 12
- 6
mindspore/ccsrc/fl/server/local_meta_store.cc View File

@@ -41,23 +41,29 @@ const size_t LocalMetaStore::curr_iter_num() {
return curr_iter_num_;
}

const void LocalMetaStore::put_feature_map(const std::string &name, const size_t &size) { feature_map_[name] = size; }
const void LocalMetaStore::put_aggregation_feature_map(const std::string &name, const Feature &feature) {
if (aggregation_feature_map_.count(name) > 0) {
MS_LOG(WARNING) << "Put feature " << name << " failed.";
return;
}
aggregation_feature_map_[name] = feature;
}

std::unordered_map<std::string, size_t> &LocalMetaStore::feature_map() { return feature_map_; }
std::unordered_map<std::string, Feature> &LocalMetaStore::aggregation_feature_map() { return aggregation_feature_map_; }

bool LocalMetaStore::verifyFeatureMap(const std::unordered_map<std::string, size_t> &model) {
bool LocalMetaStore::verifyAggregationFeatureMap(const std::unordered_map<std::string, size_t> &model) {
// feature map size in Hybrid training is not equal with upload model size
if (model.size() > feature_map_.size()) {
if (model.size() > aggregation_feature_map_.size()) {
return false;
}

for (const auto &weight : model) {
std::string weight_name = weight.first;
size_t weight_size = weight.second;
if (feature_map_.count(weight_name) == 0) {
if (aggregation_feature_map_.count(weight_name) == 0) {
return false;
}
if (weight_size != feature_map_[weight_name]) {
if (weight_size != aggregation_feature_map_[weight_name].weight_size) {
return false;
}
}


+ 14
- 5
mindspore/ccsrc/fl/server/local_meta_store.h View File

@@ -20,6 +20,7 @@
#include <any>
#include <mutex>
#include <string>
#include <vector>
#include <unordered_map>
#include "fl/server/common.h"

@@ -29,6 +30,14 @@ namespace server {
// LocalMetaStore class is used for metadata storage of this server process.
// For example, the current iteration number, time windows for round kernels, etc.
// LocalMetaStore is threadsafe.

struct Feature {
std::vector<size_t> weight_shape;
std::string weight_type;
size_t weight_size;
std::vector<float> weight_data;
};

class LocalMetaStore {
public:
static LocalMetaStore &GetInstance() {
@@ -70,11 +79,11 @@ class LocalMetaStore {
void set_curr_iter_num(size_t num);
const size_t curr_iter_num();

const void put_feature_map(const std::string &name, const size_t &size);
const void put_aggregation_feature_map(const std::string &name, const Feature &feature);

std::unordered_map<std::string, size_t> &feature_map();
std::unordered_map<std::string, Feature> &aggregation_feature_map();

bool verifyFeatureMap(const std::unordered_map<std::string, size_t> &model);
bool verifyAggregationFeatureMap(const std::unordered_map<std::string, size_t> &model);

private:
LocalMetaStore() : key_to_meta_({}), curr_iter_num_(0) {}
@@ -88,8 +97,8 @@ class LocalMetaStore {
std::mutex mtx_;
size_t curr_iter_num_{0};

// feature_map_ stores model meta data with weight name and size.
std::unordered_map<std::string, size_t> feature_map_;
// aggregation_feature_map_ stores model meta data with weight name and size which will be Aggregated.
std::unordered_map<std::string, Feature> aggregation_feature_map_;
};
} // namespace server
} // namespace fl


+ 40
- 3
mindspore/ccsrc/fl/server/model_store.cc View File

@@ -19,16 +19,17 @@
#include <string>
#include <memory>
#include "fl/server/executor.h"
#include "pipeline/jit/parse/parse.h"

namespace mindspore {
namespace fl {
namespace server {
void ModelStore::Initialize(uint32_t max_count) {
void ModelStore::Initialize(uint32_t rank_id, uint32_t max_count) {
if (!Executor::GetInstance().initialized()) {
MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage.";
return;
}
rank_id_ = rank_id;
max_model_count_ = max_count;
initial_model_ = AssignNewModelMemory();
iteration_to_model_[kInitIterationNum] = initial_model_;
@@ -84,6 +85,7 @@ void ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
}
iteration_to_model_[iteration] = memory_register;
OnIterationUpdate();
SaveCheckpoint(iteration, new_model);
return;
}

@@ -241,6 +243,41 @@ void ModelStore::OnIterationUpdate() {
<< ", total add and sub reference count: " << total_add_reference_count << ", "
<< total_sub_reference_count;
}

void ModelStore::SaveCheckpoint(size_t iteration, const std::map<std::string, AddressPtr> &model) {
if (rank_id_ != kLeaderServerRank) {
MS_LOG(INFO) << "Only leader server will save the weight.";
return;
}
std::unordered_map<std::string, Feature> &aggregation_feature_map =
LocalMetaStore::GetInstance().aggregation_feature_map();

namespace python_adapter = mindspore::python_adapter;
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_SERIALIZE_MODULE);

py::dict dict_data = py::dict();
for (const auto &weight : model) {
std::string weight_fullname = weight.first;
float *weight_data = reinterpret_cast<float *>(weight.second->addr);
size_t weight_data_size = weight.second->size / sizeof(float);
Feature aggregation_feature = aggregation_feature_map[weight_fullname];

std::vector<float> weight_data_vec(weight_data, weight_data + weight_data_size);

py::list data_list;
data_list.append(aggregation_feature.weight_type);
data_list.append(aggregation_feature.weight_shape);
data_list.append(weight_data_vec);
data_list.append(weight_data_size);
dict_data[py::str(weight_fullname)] = data_list;
}

std::string checkpoint_dir = ps::PSContext::instance()->checkpoint_dir();
std::string fl_name = ps::PSContext::instance()->fl_name();

python_adapter::CallPyModFn(mod, PYTHON_MOD_SAFE_WEIGHT, py::str(checkpoint_dir), py::str(fl_name),
py::str(std::to_string(iteration)), dict_data);
}
} // namespace server
} // namespace fl
} // namespace mindspore
} // namespace mindspore

+ 7
- 2
mindspore/ccsrc/fl/server/model_store.h View File

@@ -21,9 +21,11 @@
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "fl/server/common.h"
#include "fl/server/memory_register.h"
#include "fl/server/executor.h"
#include "fl/server/local_meta_store.h"

namespace mindspore {
namespace fl {
@@ -44,7 +46,7 @@ class ModelStore {
}

// Initialize ModelStore with max count of models need to be stored.
void Initialize(uint32_t max_count = 3);
void Initialize(uint32_t rank_id, uint32_t max_count = 3);

// Store the model of the given iteration. The model is acquired from Executor. If the current model count is already
// max_model_count_, the earliest model will be replaced.
@@ -75,6 +77,8 @@ class ModelStore {
ModelStore(const ModelStore &) = delete;
ModelStore &operator=(const ModelStore &) = delete;

void SaveCheckpoint(size_t iteration, const std::map<std::string, AddressPtr> &model);

// To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ *
// model_size_.
std::shared_ptr<MemoryRegister> AssignNewModelMemory();
@@ -91,6 +95,7 @@ class ModelStore {
// The number of all models stored is max_model_count_.
std::mutex model_mtx_;
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
uint32_t rank_id_;

struct HttpResponseModelCache {
std::string round_name; // startFlJob, getModel
@@ -108,4 +113,4 @@ class ModelStore {
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_

+ 1
- 1
mindspore/ccsrc/fl/server/server.cc View File

@@ -406,7 +406,7 @@ void Server::InitExecutor() {
// so the required_cnt of these kernels must be the same as executor_threshold_.
MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
Executor::GetInstance().Initialize(func_graph, executor_threshold_);
ModelStore::GetInstance().Initialize();
ModelStore::GetInstance().Initialize(server_node_->rank_id());
// init weight memory to 0 after get model
Executor::GetInstance().ResetAggregationStatus();
return;


+ 3
- 1
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -504,7 +504,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.")
.def("set_global_iteration_time_window", &PSContext::set_global_iteration_time_window,
"Set global iteration time window.")
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.");
.def("global_iteration_time_window", &PSContext::global_iteration_time_window, "Get global iteration time window.")
.def("set_checkpoint_dir", &PSContext::set_checkpoint_dir, "Set server checkpoint directory.")
.def("checkpoint_dir", &PSContext::checkpoint_dir, "Server checkpoint directory.");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");


+ 4
- 0
mindspore/ccsrc/ps/ps_context.cc View File

@@ -549,5 +549,9 @@ void PSContext::set_global_iteration_time_window(const uint64_t &global_iteratio
}

uint64_t PSContext::global_iteration_time_window() const { return global_iteration_time_window_; }

std::string PSContext::checkpoint_dir() const { return checkpoint_dir_; }

void PSContext::set_checkpoint_dir(const std::string &checkpoint_dir) { checkpoint_dir_ = checkpoint_dir; }
} // namespace ps
} // namespace mindspore

+ 7
- 1
mindspore/ccsrc/ps/ps_context.h View File

@@ -230,6 +230,9 @@ class PSContext {
void set_global_iteration_time_window(const uint64_t &global_iteration_time_window);
uint64_t global_iteration_time_window() const;

std::string checkpoint_dir() const;
void set_checkpoint_dir(const std::string &checkpoint_dir);

private:
PSContext()
: ps_enabled_(false),
@@ -282,7 +285,8 @@ class PSContext {
client_password_(""),
server_password_(""),
http_url_prefix_(""),
global_iteration_time_window_(21600000) {}
global_iteration_time_window_(3600000),
checkpoint_dir_("") {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@@ -415,6 +419,8 @@ class PSContext {

// The time window of startFLJob round in millisecond.
uint64_t global_iteration_time_window_;
// directory of server checkpoint
std::string checkpoint_dir_;
};
} // namespace ps
} // namespace mindspore


+ 3
- 1
mindspore/python/mindspore/context.py View File

@@ -1123,7 +1123,9 @@ def set_fl_context(**kwargs):
http_url_prefix (string): The http url prefix for http server.
Default: "".
global_iteration_time_window (unsigned long): The global iteration time window for one iteration
with rounds(ms). Default: 21600000.
with rounds(ms). Default: 3600000.
checkpoint_dir (string): The Server model checkpoint directory. If no checkpoint dir is set,
the startup script directory is used by default. Default: "".
Raises:
ValueError: If input key is not the attribute in federated learning mode context.



+ 4
- 2
mindspore/python/mindspore/parallel/_ps_context.py View File

@@ -77,7 +77,8 @@ _set_ps_context_func_map = {
"sign_eps": ps_context().set_sign_eps,
"sign_thr_ratio": ps_context().set_sign_thr_ratio,
"sign_global_lr": ps_context().set_sign_global_lr,
"sign_dim_out": ps_context().set_sign_dim_out
"sign_dim_out": ps_context().set_sign_dim_out,
"checkpoint_dir": ps_context().set_checkpoint_dir,
}

_get_ps_context_func_map = {
@@ -124,7 +125,8 @@ _get_ps_context_func_map = {
"sign_eps": ps_context().sign_eps,
"sign_thr_ratio": ps_context().sign_thr_ratio,
"sign_global_lr": ps_context().sign_global_lr,
"sign_dim_out": ps_context().sign_dim_out
"sign_dim_out": ps_context().sign_dim_out,
"checkpoint_dir": ps_context().checkpoint_dir
}

_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",


+ 34
- 1
mindspore/python/mindspore/train/serialization.py View File

@@ -155,6 +155,39 @@ def _type_convert(param, new_param, strict_load):
return False


def _save_weight(checkpoint_dir, model_name, iteration, params):
"""Save model weight into checkpoint."""
logger.debug(f"Checkpoint dir is: '{checkpoint_dir}'")
exist_ckpt_file_list = []
if os.path.exists(checkpoint_dir):
for exist_ckpt_name in os.listdir(checkpoint_dir):
file_prefix = model_name + "_iteration_"
if exist_ckpt_name.startswith(file_prefix):
exist_ckpt_file_list.append(exist_ckpt_name)

param_dict = OrderedDict()
for key in params.keys():
value = params[key]
weight_type = value[0]
weight_shape = value[1]
weight_data = value[2]
weight_size = value[3]
weight_np = np.array(weight_data, dtype=weight_type.lower())
logger.debug(f"weight_type: '{weight_type}', weight_shape: '{weight_shape}', weight_size: "
f"'{weight_size}', weight_np.nbytes: '{weight_np.nbytes}'")

param_dict[key] = [weight_shape, weight_type, weight_np]
ckpt_file_save_name = model_name + "_iteration_" + iteration + ".ckpt"
ckpt_file_save_path = os.path.join(checkpoint_dir, ckpt_file_save_name)

_exec_save(ckpt_file_save_path, param_dict)

for exist_ckpt_name in exist_ckpt_file_list:
os.remove(os.path.join(checkpoint_dir, exist_ckpt_name))
logger.info(f"Save weight to checkpoint file path '{ckpt_file_save_path}' success.")
else:
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")

def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
"""Execute the process of saving checkpoint into file."""
try:
@@ -1003,7 +1036,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):

if 'dataset' in kwargs.keys() and kwargs.get('dataset') is not None:
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
dataset = kwargs['dataset']
dataset = kwargs.get('dataset')
_save_dataset_to_mindir(model, dataset)

save_together = _save_together(net_dict, model)


Loading…
Cancel
Save