Browse Source

modify export onnx limit

pull/15644/head
changzherui 4 years ago
parent
commit
18bbba607f
2 changed files with 22 additions and 3 deletions
  1. +3
    -1
      mindspore/train/callback/_checkpoint.py
  2. +19
    -2
      mindspore/train/serialization.py

+ 3
- 1
mindspore/train/callback/_checkpoint.py View File

@@ -235,7 +235,9 @@ class ModelCheckpoint(Callback):
self._last_time_for_keep = time.time()
self._last_triggered_step = 0

Validator.check_file_name_by_regular(prefix)
if not isinstance(prefix, str) or prefix.find('/') >= 0:
raise ValueError("Prefix {} for checkpoint file name invalid, "
"please check and correct it and then continue.".format(prefix))
self._prefix = prefix

if directory is not None:


+ 19
- 2
mindspore/train/serialization.py View File

@@ -53,7 +53,10 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}

_ckpt_mutex = Lock()
SLICE_SIZE = 512 * 1024 * 1024

# unit is KB
SLICE_SIZE = 512 * 1024
PROTO_LIMIT_SIZE = 1024 * 1024 * 2
TOTAL_SAVE = 1024 * 1024


@@ -126,7 +129,7 @@ def _exec_save(ckpt_file_name, data_list):
os.remove(ckpt_file_name)
with open(ckpt_file_name, "ab") as f:
for name, value in data_list.items():
data_size = value[2].nbytes
data_size = value[2].nbytes / 1024
if data_size > SLICE_SIZE:
slice_count = math.ceil(data_size / SLICE_SIZE)
param_slice_list = np.array_split(value[2], slice_count)
@@ -636,6 +639,10 @@ def _export(net, file_name, file_format, *inputs):
file_name += ".air"
_executor.export(file_name, graph_id)
elif file_format == 'ONNX':
total_size = _calculation_net_size(net)
if total_size > PROTO_LIMIT_SIZE:
raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.'
.format(total_size/1024/1024, PROTO_LIMIT_SIZE/1024/1024))
phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
@@ -1213,3 +1220,13 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
layerwise_parallel = merged_param.layerwise_parallel
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
return split_param


def _calculation_net_size(net):
"""Calculate the size of parameters in the network."""
data_total = 0
net_dict = net.parameters_dict()
for name in net_dict:
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024

return data_total

Loading…
Cancel
Save