Browse Source

add protobuf env

tags/v1.1.0
changzherui 5 years ago
parent
commit
e57c631103
1 changed files with 13 additions and 0 deletions
  1. +13
    -0
      mindspore/train/serialization.py

+ 13
- 0
mindspore/train/serialization.py View File

@@ -47,6 +47,16 @@ _ckpt_mutex = Lock()
SLICE_SIZE = 512 * 1024 * 1024


def _set_pb_env():
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\
When the parameter is too large, it may cause memory limit error.")
else:
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.")


def _special_process_par(par, new_par):
"""
Processes the special condition.
@@ -785,3 +795,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)

return merged_parameter


_set_pb_env()

Loading…
Cancel
Save