|
|
|
@@ -53,7 +53,10 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin |
|
|
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} |
|
|
|
|
|
|
|
_ckpt_mutex = Lock() |
|
|
|
SLICE_SIZE = 512 * 1024 * 1024 |
|
|
|
|
|
|
|
# unit is KB |
|
|
|
SLICE_SIZE = 512 * 1024 |
|
|
|
PROTO_LIMIT_SIZE = 1024 * 1024 * 2 |
|
|
|
TOTAL_SAVE = 1024 * 1024 |
|
|
|
|
|
|
|
|
|
|
|
@@ -126,7 +129,7 @@ def _exec_save(ckpt_file_name, data_list): |
|
|
|
os.remove(ckpt_file_name) |
|
|
|
with open(ckpt_file_name, "ab") as f: |
|
|
|
for name, value in data_list.items(): |
|
|
|
data_size = value[2].nbytes |
|
|
|
data_size = value[2].nbytes / 1024 |
|
|
|
if data_size > SLICE_SIZE: |
|
|
|
slice_count = math.ceil(data_size / SLICE_SIZE) |
|
|
|
param_slice_list = np.array_split(value[2], slice_count) |
|
|
|
@@ -636,6 +639,10 @@ def _export(net, file_name, file_format, *inputs): |
|
|
|
file_name += ".air" |
|
|
|
_executor.export(file_name, graph_id) |
|
|
|
elif file_format == 'ONNX': |
|
|
|
total_size = _calculation_net_size(net) |
|
|
|
if total_size > PROTO_LIMIT_SIZE: |
|
|
|
raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.' |
|
|
|
.format(total_size/1024/1024, PROTO_LIMIT_SIZE/1024/1024)) |
|
|
|
phase_name = 'export.onnx' |
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) |
|
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id) |
|
|
|
@@ -1213,3 +1220,13 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): |
|
|
|
layerwise_parallel = merged_param.layerwise_parallel |
|
|
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) |
|
|
|
return split_param |
|
|
|
|
|
|
|
|
|
|
|
def _calculation_net_size(net): |
|
|
|
"""Calculate the size of parameters in the network.""" |
|
|
|
data_total = 0 |
|
|
|
net_dict = net.parameters_dict() |
|
|
|
for name in net_dict: |
|
|
|
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 |
|
|
|
|
|
|
|
return data_total |