diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index da8528b659..bb79c75a6d 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -235,7 +235,9 @@ class ModelCheckpoint(Callback): self._last_time_for_keep = time.time() self._last_triggered_step = 0 - Validator.check_file_name_by_regular(prefix) + if not isinstance(prefix, str) or prefix.find('/') >= 0: + raise ValueError("Prefix {} for checkpoint file name invalid, " + "please check and correct it and then continue.".format(prefix)) self._prefix = prefix if directory is not None: diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 4b8025d2fe..46039952ee 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -53,7 +53,10 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} _ckpt_mutex = Lock() -SLICE_SIZE = 512 * 1024 * 1024 + +# unit is KB +SLICE_SIZE = 512 * 1024 +PROTO_LIMIT_SIZE = 1024 * 1024 * 2 TOTAL_SAVE = 1024 * 1024 @@ -126,7 +129,7 @@ def _exec_save(ckpt_file_name, data_list): os.remove(ckpt_file_name) with open(ckpt_file_name, "ab") as f: for name, value in data_list.items(): - data_size = value[2].nbytes + data_size = value[2].nbytes / 1024 if data_size > SLICE_SIZE: slice_count = math.ceil(data_size / SLICE_SIZE) param_slice_list = np.array_split(value[2], slice_count) @@ -636,6 +639,10 @@ def _export(net, file_name, file_format, *inputs): file_name += ".air" _executor.export(file_name, graph_id) elif file_format == 'ONNX': + total_size = _calculation_net_size(net) + if total_size > PROTO_LIMIT_SIZE: + raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.' + .format(total_size/1024/1024, PROTO_LIMIT_SIZE/1024/1024)) phase_name = 'export.onnx' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(net, graph_id) @@ -1213,3 +1220,13 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): layerwise_parallel = merged_param.layerwise_parallel split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) return split_param + + +def _calculation_net_size(net): + """Calculate the size of parameters in the network.""" + data_total = 0 + net_dict = net.parameters_dict() + for name in net_dict: + data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 + + return data_total