| @@ -256,6 +256,7 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin | |||
| if (!print.SerializeToOstream(output)) { | |||
| MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail."; | |||
| ret_end_thread = true; | |||
| break; | |||
| } | |||
| print.Clear(); | |||
| } | |||
| @@ -17,6 +17,7 @@ The context of mindspore, used to configure the current execution environment, | |||
| including execution mode, execution backend and other feature switches. | |||
| """ | |||
| import os | |||
| import time | |||
| import threading | |||
| from collections import namedtuple | |||
| from types import FunctionType | |||
| @@ -55,12 +56,20 @@ def _make_directory(path): | |||
| os.makedirs(path) | |||
| real_path = path | |||
| except PermissionError as e: | |||
| logger.error( | |||
| f"No write permission on the directory `{path}, error = {e}") | |||
| logger.error(f"No write permission on the directory `{path}, error = {e}") | |||
| raise ValueError(f"No write permission on the directory `{path}`.") | |||
| return real_path | |||
| def _get_print_file_name(file_name): | |||
| """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds).""" | |||
| time_second = str(int(time.time())) | |||
| file_name = file_name + "." + time_second | |||
| if os.path.exists(file_name): | |||
| ValueError("This file {} already exists.".format(file_name)) | |||
| return file_name | |||
| class _ThreadLocalInfo(threading.local): | |||
| """ | |||
| Thread local Info used for store thread local attributes. | |||
| @@ -381,8 +390,20 @@ class _Context: | |||
| return None | |||
| @print_file_path.setter | |||
| def print_file_path(self, file): | |||
| self._context_handle.set_print_file_path(file) | |||
| def print_file_path(self, file_path): | |||
| """Add timestamp suffix to file name. Sets print file path.""" | |||
| print_file_path = os.path.realpath(file_path) | |||
| if os.path.isdir(print_file_path): | |||
| raise IOError("Print_file_path should be file path, but got {}.".format(file_path)) | |||
| if os.path.exists(print_file_path): | |||
| _path, _file_name = os.path.split(print_file_path) | |||
| path = _make_directory(_path) | |||
| file_name = _get_print_file_name(_file_name) | |||
| full_file_name = os.path.join(path, file_name) | |||
| else: | |||
| full_file_name = print_file_path | |||
| self._context_handle.set_print_file_path(full_file_name) | |||
| def check_input_format(x): | |||
| @@ -575,7 +596,8 @@ def set_context(**kwargs): | |||
| max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU. | |||
| The format is "xxGB". Default: "1024GB". | |||
| print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to | |||
| a file by default, and turn off printing to the screen. | |||
| a file by default, and turn off printing to the screen. If the file already exists, add a timestamp | |||
| suffix to the file. | |||
| enable_sparse (bool): Whether to enable sparse feature. Default: False. | |||
| Raises: | |||
| @@ -302,7 +302,7 @@ def _save_graph(network, file_name): | |||
| if graph_proto: | |||
| with open(file_name, "wb") as f: | |||
| f.write(graph_proto) | |||
| os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | |||
| os.chmod(file_name, stat.S_IRUSR) | |||
| def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): | |||
| @@ -462,19 +462,18 @@ def parse_print(print_file_name): | |||
| List, element of list is Tensor. | |||
| Raises: | |||
| ValueError: Print file is incorrect. | |||
| ValueError: The print file may be empty, please make sure enter the correct file name. | |||
| """ | |||
| if not os.path.realpath(print_file_name): | |||
| raise ValueError("Please input the correct print file name.") | |||
| print_file_path = os.path.realpath(print_file_name) | |||
| if os.path.getsize(print_file_name) == 0: | |||
| if os.path.getsize(print_file_path) == 0: | |||
| raise ValueError("The print file may be empty, please make sure enter the correct file name.") | |||
| logger.info("Execute load print process.") | |||
| print_list = Print() | |||
| try: | |||
| with open(print_file_name, "rb") as f: | |||
| with open(print_file_path, "rb") as f: | |||
| pb_content = f.read() | |||
| print_list.ParseFromString(pb_content) | |||
| except BaseException as e: | |||
| @@ -118,6 +118,12 @@ def test_variable_memory_max_size(): | |||
| context.set_context(variable_memory_max_size="3GB") | |||
| def test_print_file_path(): | |||
| """test_print_file_path""" | |||
| with pytest.raises(IOError): | |||
| context.set_context(print_file_path="./") | |||
| def test_set_context(): | |||
| """ test_set_context """ | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", | |||
| @@ -34,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load | |||
| _exec_save_checkpoint, export, _save_graph | |||
| from ..ut_filter import non_graph_engine | |||
| context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb") | |||
| context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") | |||
| class Net(nn.Cell): | |||
| @@ -374,10 +374,13 @@ def test_print(): | |||
| def teardown_module(): | |||
| files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb'] | |||
| files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] | |||
| for item in files: | |||
| file_name = './' + item | |||
| if not os.path.exists(file_name): | |||
| continue | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| import shutil | |||
| if os.path.exists('./print'): | |||
| shutil.rmtree('./print') | |||