|
|
|
@@ -47,6 +47,16 @@ _ckpt_mutex = Lock() |
|
|
|
SLICE_SIZE = 512 * 1024 * 1024 |
|
|
|
|
|
|
|
|
|
|
|
def _set_pb_env(): |
|
|
|
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" |
|
|
|
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": |
|
|
|
logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\ |
|
|
|
When the parameter is too large, it may cause memory limit error.") |
|
|
|
else: |
|
|
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" |
|
|
|
logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.") |
|
|
|
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par): |
|
|
|
""" |
|
|
|
Processes the special condition. |
|
|
|
@@ -785,3 +795,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): |
|
|
|
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) |
|
|
|
|
|
|
|
return merged_parameter |
|
|
|
|
|
|
|
|
|
|
|
_set_pb_env() |