|
|
|
@@ -660,13 +660,18 @@ def _save_mindir(net, file_name, *inputs): |
|
|
|
net_dict = net.parameters_dict() |
|
|
|
data_total = 0 |
|
|
|
save_together = True |
|
|
|
for i in net_dict.values(): |
|
|
|
data_total += sys.getsizeof(i.data.asnumpy().tobytes())/1024 |
|
|
|
|
|
|
|
model.ParseFromString(mindir_stream) |
|
|
|
for param_proto in model.graph.parameter: |
|
|
|
name = param_proto.name[param_proto.name.find(":") + 1:] |
|
|
|
if name in net_dict.keys(): |
|
|
|
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 |
|
|
|
else: |
|
|
|
raise RuntimeError('Graph parameter: {} Undefined in network.'.format(param_proto.name)) |
|
|
|
if data_total > TOTAL_SAVE: |
|
|
|
save_together = False |
|
|
|
break |
|
|
|
|
|
|
|
model.ParseFromString(mindir_stream) |
|
|
|
if save_together: |
|
|
|
for param_proto in model.graph.parameter: |
|
|
|
param_name = param_proto.name[param_proto.name.find(":")+1:] |
|
|
|
@@ -699,19 +704,19 @@ def _save_mindir(net, file_name, *inputs): |
|
|
|
index = 0 |
|
|
|
graphproto = graph_proto() |
|
|
|
data_size = 0 |
|
|
|
|
|
|
|
for name, param in net_dict.items(): |
|
|
|
byte_data = param.data.asnumpy().tobytes() |
|
|
|
data_size += sys.getsizeof(byte_data)/1024 |
|
|
|
parameter = graphproto.parameter.add() |
|
|
|
for param_proto in model.graph.parameter: |
|
|
|
if name == param_proto.name[param_proto.name.find(":") + 1:]: |
|
|
|
parameter = graphproto.parameter.add() |
|
|
|
parameter.name = param_proto.name |
|
|
|
parameter.data_type = param_proto.data_type |
|
|
|
for dim in param_proto.dims: |
|
|
|
parameter.dims.append(dim) |
|
|
|
byte_data = param.data.asnumpy().tobytes() |
|
|
|
parameter.raw_data = byte_data |
|
|
|
data_size += sys.getsizeof(byte_data) / 1024 |
|
|
|
break |
|
|
|
|
|
|
|
parameter.raw_data = byte_data |
|
|
|
if data_size > TOTAL_SAVE: |
|
|
|
data_file_name = data_path + "/" + "data_" + str(index) |
|
|
|
with open(data_file_name, "ab") as f: |
|
|
|
|