| @@ -256,6 +256,7 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin | |||||
| if (!print.SerializeToOstream(output)) { | if (!print.SerializeToOstream(output)) { | ||||
| MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail."; | MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail."; | ||||
| ret_end_thread = true; | ret_end_thread = true; | ||||
| break; | |||||
| } | } | ||||
| print.Clear(); | print.Clear(); | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ The context of mindspore, used to configure the current execution environment, | |||||
| including execution mode, execution backend and other feature switches. | including execution mode, execution backend and other feature switches. | ||||
| """ | """ | ||||
| import os | import os | ||||
| import time | |||||
| import threading | import threading | ||||
| from collections import namedtuple | from collections import namedtuple | ||||
| from types import FunctionType | from types import FunctionType | ||||
| @@ -55,12 +56,20 @@ def _make_directory(path): | |||||
| os.makedirs(path) | os.makedirs(path) | ||||
| real_path = path | real_path = path | ||||
| except PermissionError as e: | except PermissionError as e: | ||||
| logger.error( | |||||
| f"No write permission on the directory `{path}, error = {e}") | |||||
| logger.error(f"No write permission on the directory `{path}, error = {e}") | |||||
| raise ValueError(f"No write permission on the directory `{path}`.") | raise ValueError(f"No write permission on the directory `{path}`.") | ||||
| return real_path | return real_path | ||||
| def _get_print_file_name(file_name): | |||||
| """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds).""" | |||||
| time_second = str(int(time.time())) | |||||
| file_name = file_name + "." + time_second | |||||
| if os.path.exists(file_name): | |||||
| ValueError("This file {} already exists.".format(file_name)) | |||||
| return file_name | |||||
| class _ThreadLocalInfo(threading.local): | class _ThreadLocalInfo(threading.local): | ||||
| """ | """ | ||||
| Thread local Info used for store thread local attributes. | Thread local Info used for store thread local attributes. | ||||
| @@ -381,8 +390,20 @@ class _Context: | |||||
| return None | return None | ||||
| @print_file_path.setter | @print_file_path.setter | ||||
| def print_file_path(self, file): | |||||
| self._context_handle.set_print_file_path(file) | |||||
| def print_file_path(self, file_path): | |||||
| """Add timestamp suffix to file name. Sets print file path.""" | |||||
| print_file_path = os.path.realpath(file_path) | |||||
| if os.path.isdir(print_file_path): | |||||
| raise IOError("Print_file_path should be file path, but got {}.".format(file_path)) | |||||
| if os.path.exists(print_file_path): | |||||
| _path, _file_name = os.path.split(print_file_path) | |||||
| path = _make_directory(_path) | |||||
| file_name = _get_print_file_name(_file_name) | |||||
| full_file_name = os.path.join(path, file_name) | |||||
| else: | |||||
| full_file_name = print_file_path | |||||
| self._context_handle.set_print_file_path(full_file_name) | |||||
| def check_input_format(x): | def check_input_format(x): | ||||
| @@ -575,7 +596,8 @@ def set_context(**kwargs): | |||||
| max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU. | max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU. | ||||
| The format is "xxGB". Default: "1024GB". | The format is "xxGB". Default: "1024GB". | ||||
| print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to | print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to | ||||
| a file by default, and turn off printing to the screen. | |||||
| a file by default, and turn off printing to the screen. If the file already exists, add a timestamp | |||||
| suffix to the file. | |||||
| enable_sparse (bool): Whether to enable sparse feature. Default: False. | enable_sparse (bool): Whether to enable sparse feature. Default: False. | ||||
| Raises: | Raises: | ||||
| @@ -302,7 +302,7 @@ def _save_graph(network, file_name): | |||||
| if graph_proto: | if graph_proto: | ||||
| with open(file_name, "wb") as f: | with open(file_name, "wb") as f: | ||||
| f.write(graph_proto) | f.write(graph_proto) | ||||
| os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | |||||
| os.chmod(file_name, stat.S_IRUSR) | |||||
| def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): | def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): | ||||
| @@ -462,19 +462,18 @@ def parse_print(print_file_name): | |||||
| List, element of list is Tensor. | List, element of list is Tensor. | ||||
| Raises: | Raises: | ||||
| ValueError: Print file is incorrect. | |||||
| ValueError: The print file may be empty, please make sure enter the correct file name. | |||||
| """ | """ | ||||
| if not os.path.realpath(print_file_name): | |||||
| raise ValueError("Please input the correct print file name.") | |||||
| print_file_path = os.path.realpath(print_file_name) | |||||
| if os.path.getsize(print_file_name) == 0: | |||||
| if os.path.getsize(print_file_path) == 0: | |||||
| raise ValueError("The print file may be empty, please make sure enter the correct file name.") | raise ValueError("The print file may be empty, please make sure enter the correct file name.") | ||||
| logger.info("Execute load print process.") | logger.info("Execute load print process.") | ||||
| print_list = Print() | print_list = Print() | ||||
| try: | try: | ||||
| with open(print_file_name, "rb") as f: | |||||
| with open(print_file_path, "rb") as f: | |||||
| pb_content = f.read() | pb_content = f.read() | ||||
| print_list.ParseFromString(pb_content) | print_list.ParseFromString(pb_content) | ||||
| except BaseException as e: | except BaseException as e: | ||||
| @@ -118,6 +118,12 @@ def test_variable_memory_max_size(): | |||||
| context.set_context(variable_memory_max_size="3GB") | context.set_context(variable_memory_max_size="3GB") | ||||
| def test_print_file_path(): | |||||
| """test_print_file_path""" | |||||
| with pytest.raises(IOError): | |||||
| context.set_context(print_file_path="./") | |||||
| def test_set_context(): | def test_set_context(): | ||||
| """ test_set_context """ | """ test_set_context """ | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", | ||||
| @@ -34,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load | |||||
| _exec_save_checkpoint, export, _save_graph | _exec_save_checkpoint, export, _save_graph | ||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb") | |||||
| context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| @@ -374,10 +374,13 @@ def test_print(): | |||||
| def teardown_module(): | def teardown_module(): | ||||
| files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb'] | |||||
| files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] | |||||
| for item in files: | for item in files: | ||||
| file_name = './' + item | file_name = './' + item | ||||
| if not os.path.exists(file_name): | if not os.path.exists(file_name): | ||||
| continue | continue | ||||
| os.chmod(file_name, stat.S_IWRITE) | os.chmod(file_name, stat.S_IWRITE) | ||||
| os.remove(file_name) | os.remove(file_name) | ||||
| import shutil | |||||
| if os.path.exists('./print'): | |||||
| shutil.rmtree('./print') | |||||