| @@ -47,6 +47,16 @@ _ckpt_mutex = Lock() | |||||
| SLICE_SIZE = 512 * 1024 * 1024 | SLICE_SIZE = 512 * 1024 * 1024 | ||||
| def _set_pb_env(): | |||||
| """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" | |||||
| if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": | |||||
| logger.warning("Current env variable `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp`,\ | |||||
| When the parameter is too large, it may cause memory limit error.") | |||||
| else: | |||||
| os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |||||
| logger.debug("Set the `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python`.") | |||||
| def _special_process_par(par, new_par): | def _special_process_par(par, new_par): | ||||
| """ | """ | ||||
| Processes the special condition. | Processes the special condition. | ||||
| @@ -785,3 +795,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) | merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) | ||||
| return merged_parameter | return merged_parameter | ||||
| _set_pb_env() | |||||