Browse Source

set proto env

tags/v1.1.0
changzherui 5 years ago
parent
commit
376c4aaf26
2 changed files with 13 additions and 14 deletions
  1. +13
    -0
      mindspore/_version_check.py
  2. +0
    -14
      mindspore/train/serialization.py

+ 13
- 0
mindspore/_version_check.py View File

@@ -280,4 +280,17 @@ def check_version_and_env_config():
except ImportError as e: except ImportError as e:
env_checker.check_env(e) env_checker.check_env(e)



def _set_pb_env():
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\
When the parameter is too large, it may cause memory limit error.\
This can be solved by set env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.")
elif os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "":
logger.warning("Set the env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python` to prevent memory overflow.")
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"


check_version_and_env_config() check_version_and_env_config()
_set_pb_env()

+ 0
- 14
mindspore/train/serialization.py View File

@@ -49,17 +49,6 @@ _ckpt_mutex = Lock()
SLICE_SIZE = 512 * 1024 * 1024 SLICE_SIZE = 512 * 1024 * 1024




def _set_pb_env():
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\
When the parameter is too large, it may cause memory limit error.\
This can be solved by set env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.")
else:
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.")


def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
Processes the special condition. Processes the special condition.
@@ -885,6 +874,3 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)


return merged_parameter return merged_parameter


_set_pb_env()

Loading…
Cancel
Save