Merge pull request !1897 from jinyaohui/printtags/v0.5.0-beta
| @@ -81,6 +81,7 @@ if (ENABLE_DUMP_PROTO) | |||||
| "utils/summary.proto" | "utils/summary.proto" | ||||
| "utils/lineage.proto" | "utils/lineage.proto" | ||||
| "utils/checkpoint.proto" | "utils/checkpoint.proto" | ||||
| "utils/print.proto" | |||||
| ) | ) | ||||
| ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) | ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) | ||||
| @@ -148,7 +148,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") | .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") | ||||
| .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") | .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") | ||||
| .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") | ||||
| .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size."); | |||||
| .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") | |||||
| .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print."); | |||||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | ||||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | ||||
| @@ -83,6 +83,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||||
| profiling_options_ = "training_trace"; | profiling_options_ = "training_trace"; | ||||
| check_bprop_flag_ = false; | check_bprop_flag_ = false; | ||||
| max_device_memory_ = kDefaultMaxDeviceMemory; | max_device_memory_ = kDefaultMaxDeviceMemory; | ||||
| print_file_path_ = ""; | |||||
| } | } | ||||
| std::shared_ptr<MsContext> MsContext::GetInstance() { | std::shared_ptr<MsContext> MsContext::GetInstance() { | ||||
| @@ -151,6 +151,8 @@ class MsContext { | |||||
| std::string profiling_options() const { return profiling_options_; } | std::string profiling_options() const { return profiling_options_; } | ||||
| bool check_bprop_flag() const { return check_bprop_flag_; } | bool check_bprop_flag() const { return check_bprop_flag_; } | ||||
| void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; } | void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; } | ||||
| void set_print_file_path(const std::string &file) { print_file_path_ = file; } | |||||
| const std::string &print_file_path() const { return print_file_path_; } | |||||
| float max_device_memory() const { return max_device_memory_; } | float max_device_memory() const { return max_device_memory_; } | ||||
| void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; } | void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; } | ||||
| @@ -196,6 +198,7 @@ class MsContext { | |||||
| std::string profiling_options_; | std::string profiling_options_; | ||||
| bool check_bprop_flag_; | bool check_bprop_flag_; | ||||
| float max_device_memory_; | float max_device_memory_; | ||||
| std::string print_file_path_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto2"; | |||||
| package mindspore.prntpb; | |||||
| message TensorProto { | |||||
| // The shape of the tensor. | |||||
| repeated int64 dims = 1; | |||||
| // The type of the tensor. | |||||
| required string tensor_type = 2; | |||||
| // The data of the tensor. | |||||
| required bytes tensor_content = 3; | |||||
| } | |||||
| message Print { | |||||
| message Value { | |||||
| oneof value { | |||||
| string desc = 1; | |||||
| TensorProto tensor = 2; | |||||
| } | |||||
| } | |||||
| repeated Value value = 1; | |||||
| } | |||||
| @@ -47,6 +47,18 @@ static std::map<std::string, size_t> type_size_map = { | |||||
| {"int64_t", sizeof(int64_t)}, {"uint64_t", sizeof(uint64_t)}, {"float16", sizeof(float) / 2}, | {"int64_t", sizeof(int64_t)}, {"uint64_t", sizeof(uint64_t)}, {"float16", sizeof(float) / 2}, | ||||
| {"float", sizeof(float)}, {"double", sizeof(double)}, {"bool", sizeof(bool)}}; | {"float", sizeof(float)}, {"double", sizeof(double)}, {"bool", sizeof(bool)}}; | ||||
| std::string GetParseType(const std::string &tensorType_) { | |||||
| static const std::map<std::string, std::string> print_parse_map = { | |||||
| {"int8_t", "Int8"}, {"uint8_t", "Uint8"}, {"int16_t", "Int16"}, {"uint16_t", "Uint16"}, | |||||
| {"int32_t", "Int32"}, {"uint32_t", "Uint32"}, {"int64_t", "Int64"}, {"uint64_t", "Uint64"}, | |||||
| {"float16", "Float16"}, {"float", "Float32"}, {"double", "Float64"}, {"bool", "Bool"}}; | |||||
| auto type_iter = print_parse_map.find(tensorType_); | |||||
| if (type_iter == print_parse_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "type of tensor need to print is not support " << tensorType_; | |||||
| } | |||||
| return type_iter->second; | |||||
| } | |||||
| bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *const tensor_shape, size_t *dims) { | bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *const tensor_shape, size_t *dims) { | ||||
| if (tensor_shape == nullptr) { | if (tensor_shape == nullptr) { | ||||
| return false; | return false; | ||||
| @@ -141,7 +153,7 @@ void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type, | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; | MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; | ||||
| } | } | ||||
| } // namespace mindspore | |||||
| } | |||||
| bool judgeLengthValid(const size_t str_len, const string &tensor_type) { | bool judgeLengthValid(const size_t str_len, const string &tensor_type) { | ||||
| auto type_iter = type_size_map.find(tensor_type); | auto type_iter = type_size_map.find(tensor_type); | ||||
| @@ -200,14 +212,84 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { | |||||
| return ret_end_sequence; | return ret_end_sequence; | ||||
| } | } | ||||
| void TensorPrint::operator()() { | |||||
| while (true) { | |||||
| std::vector<tdt::DataItem> bundle; | |||||
| if (tdt::TdtHostPopData("_npu_log", bundle) != 0) { | |||||
| bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::string &print_file_path, prntpb::Print print, | |||||
| std::fstream *output) { | |||||
| bool ret_end_sequence = false; | |||||
| for (auto &item : items) { | |||||
| if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) { | |||||
| ret_end_sequence = true; | |||||
| break; | break; | ||||
| } | } | ||||
| if (ConvertDataItem2Tensor(bundle)) { | |||||
| break; | |||||
| prntpb::Print_Value *value = print.add_value(); | |||||
| std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); | |||||
| MS_EXCEPTION_IF_NULL(str_data_ptr); | |||||
| if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) { | |||||
| if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { | |||||
| MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; | |||||
| } | |||||
| } | |||||
| std::vector<int> tensor_shape; | |||||
| size_t totaldims = 1; | |||||
| if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { | |||||
| MS_LOG(EXCEPTION) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; | |||||
| } | |||||
| if (item.tensorType_ == "string") { | |||||
| std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); | |||||
| value->set_desc(data); | |||||
| } else { | |||||
| auto parse_type = GetParseType(item.tensorType_); | |||||
| prntpb::TensorProto *tensor = value->mutable_tensor(); | |||||
| if (!(item.tensorShape_ == kShapeScalar) && !(item.tensorShape_ == kShapeNone)) { | |||||
| for (const auto &dim : tensor_shape) { | |||||
| tensor->add_dims(static_cast<::google::protobuf::int64>(dim)); | |||||
| } | |||||
| } | |||||
| tensor->set_tensor_type(parse_type); | |||||
| std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); | |||||
| tensor->set_tensor_content(data); | |||||
| } | |||||
| if (!print.SerializeToOstream(output)) { | |||||
| MS_LOG(EXCEPTION) << "Save print file:" << print_file_path << " fail."; | |||||
| } | |||||
| print.Clear(); | |||||
| } | |||||
| return ret_end_sequence; | |||||
| } | |||||
| void TensorPrint::operator()() { | |||||
| prntpb::Print print; | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| std::string print_file_path = ms_context->print_file_path(); | |||||
| if (print_file_path == "") { | |||||
| while (true) { | |||||
| std::vector<tdt::DataItem> bundle; | |||||
| if (tdt::TdtHostPopData("_npu_log", bundle) != 0) { | |||||
| break; | |||||
| } | |||||
| if (ConvertDataItem2Tensor(bundle)) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| std::fstream output(print_file_path, std::ios::out | std::ios::trunc | std::ios::binary); | |||||
| while (true) { | |||||
| std::vector<tdt::DataItem> bundle; | |||||
| if (tdt::TdtHostPopData("_npu_log", bundle) != 0) { | |||||
| break; | |||||
| } | |||||
| if (SaveDataItem2File(bundle, print_file_path, print, &output)) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| output.close(); | |||||
| std::string path_string = print_file_path; | |||||
| if (chmod(common::SafeCStr(path_string), S_IRUSR) == -1) { | |||||
| MS_LOG(ERROR) << "Modify file:" << print_file_path << " to r fail."; | |||||
| return; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,6 +23,8 @@ | |||||
| #include "tdt/tsd_client.h" | #include "tdt/tsd_client.h" | ||||
| #include "tdt/tdt_host_interface.h" | #include "tdt/tdt_host_interface.h" | ||||
| #include "tdt/data_common.h" | #include "tdt/data_common.h" | ||||
| #include "proto/print.pb.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| class TensorPrint { | class TensorPrint { | ||||
| @@ -346,6 +346,15 @@ class _Context: | |||||
| raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") | ||||
| self._context_handle.set_max_device_memory(max_device_memory_value) | self._context_handle.set_max_device_memory(max_device_memory_value) | ||||
| @property | |||||
| def print_file_path(self): | |||||
| return None | |||||
| @print_file_path.setter | |||||
| def print_file_path(self, file): | |||||
| self._context_handle.set_print_file_path(file) | |||||
| def check_input_format(x): | def check_input_format(x): | ||||
| import re | import re | ||||
| pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' | ||||
| @@ -479,7 +488,7 @@ def reset_auto_parallel_context(): | |||||
| save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, | save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, | ||||
| save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | ||||
| enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | ||||
| check_bprop=bool, max_device_memory=str) | |||||
| check_bprop=bool, max_device_memory=str, print_file_path=str) | |||||
| def set_context(**kwargs): | def set_context(**kwargs): | ||||
| """ | """ | ||||
| Sets context for running environment. | Sets context for running environment. | ||||
| @@ -21,6 +21,7 @@ import mindspore.nn as nn | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.train.checkpoint_pb2 import Checkpoint | from mindspore.train.checkpoint_pb2 import Checkpoint | ||||
| from mindspore.train.print_pb2 import Print | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| @@ -30,11 +31,15 @@ from mindspore._checkparam import check_input_data | |||||
| __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] | __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] | ||||
| tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype.int32, "Int64": mstype.int64, | |||||
| "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64} | |||||
| tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | |||||
| "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, | |||||
| "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64, | |||||
| "Bool": mstype.bool_} | |||||
| tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uint16": np.uint16, | |||||
| "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, | |||||
| "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} | |||||
| tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64, | |||||
| "Float16": np.float16, "Float32": np.float32, "Float64": np.float64} | |||||
| def _special_process_par(par, new_par): | def _special_process_par(par, new_par): | ||||
| """ | """ | ||||
| @@ -442,3 +447,64 @@ def export(net, *inputs, file_name, file_format='GEIR'): | |||||
| # restore network training mode | # restore network training mode | ||||
| if is_training: | if is_training: | ||||
| net.set_train(mode=True) | net.set_train(mode=True) | ||||
| def parse_print(print_file_name): | |||||
| """ | |||||
| Loads Print data from a specified file. | |||||
| Args: | |||||
| print_file_name (str): The file name of save print data. | |||||
| Returns: | |||||
| List, element of list is Tensor. | |||||
| Raises: | |||||
| ValueError: Print file is incorrect. | |||||
| """ | |||||
| if not os.path.realpath(print_file_name): | |||||
| raise ValueError("Please input the correct print file name.") | |||||
| if os.path.getsize(print_file_name) == 0: | |||||
| raise ValueError("The print file may be empty, please make sure enter the correct file name.") | |||||
| logger.info("Execute load print process.") | |||||
| print_list = Print() | |||||
| try: | |||||
| with open(print_file_name, "rb") as f: | |||||
| pb_content = f.read() | |||||
| print_list.ParseFromString(pb_content) | |||||
| except BaseException as e: | |||||
| logger.error("Failed to read the print file %s, please check the correct of the file.", print_file_name) | |||||
| raise ValueError(e.__str__()) | |||||
| tensor_list = [] | |||||
| try: | |||||
| for print_ in print_list.value: | |||||
| # String type | |||||
| if print_.HasField("desc"): | |||||
| tensor_list.append(print_.desc) | |||||
| elif print_.HasField("tensor"): | |||||
| dims = print_.tensor.dims | |||||
| data_type = print_.tensor.tensor_type | |||||
| data = print_.tensor.tensor_content | |||||
| np_type = tensor_to_np_type[data_type] | |||||
| param_data = np.fromstring(data, np_type) | |||||
| ms_type = tensor_to_ms_type[data_type] | |||||
| param_dim = [] | |||||
| for dim in dims: | |||||
| param_dim.append(dim) | |||||
| if param_dim: | |||||
| param_value = param_data.reshape(param_dim) | |||||
| tensor_list.append(Tensor(param_value, ms_type)) | |||||
| # Scale type | |||||
| else: | |||||
| tensor_list.append(Tensor(param_data, ms_type)) | |||||
| except BaseException as e: | |||||
| logger.error("Failed to load the print file %s.", print_list) | |||||
| raise RuntimeError(e.__str__()) | |||||
| return tensor_list | |||||
| @@ -16,8 +16,9 @@ | |||||
| import os | import os | ||||
| import stat | import stat | ||||
| import time | import time | ||||
| import pytest | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -33,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load | |||||
| _exec_save_checkpoint, export, _save_graph | _exec_save_checkpoint, export, _save_graph | ||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb") | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| @@ -327,8 +328,52 @@ def test_binary_export(): | |||||
| export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY") | export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY") | ||||
| class PrintNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(PrintNet, self).__init__() | |||||
| self.print = P.Print() | |||||
| def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, | |||||
| scale1, scale2): | |||||
| self.print('============tensor int8:==============', int8) | |||||
| self.print('============tensor uint8:==============', uint8) | |||||
| self.print('============tensor int16:==============', int16) | |||||
| self.print('============tensor uint16:==============', uint16) | |||||
| self.print('============tensor int32:==============', int32) | |||||
| self.print('============tensor uint32:==============', uint32) | |||||
| self.print('============tensor int64:==============', int64) | |||||
| self.print('============tensor uint64:==============', uint64) | |||||
| self.print('============tensor float16:==============', flt16) | |||||
| self.print('============tensor float32:==============', flt32) | |||||
| self.print('============tensor float64:==============', flt64) | |||||
| self.print('============tensor bool:==============', bool_) | |||||
| self.print('============tensor scale1:==============', scale1) | |||||
| self.print('============tensor scale2:==============', scale2) | |||||
| return int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2 | |||||
| def test_print(): | |||||
| print_net = PrintNet() | |||||
| int8 = Tensor(np.random.randint(100, size=(10, 10), dtype="int8")) | |||||
| uint8 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint8")) | |||||
| int16 = Tensor(np.random.randint(100, size=(10, 10), dtype="int16")) | |||||
| uint16 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint16")) | |||||
| int32 = Tensor(np.random.randint(100, size=(10, 10), dtype="int32")) | |||||
| uint32 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint32")) | |||||
| int64 = Tensor(np.random.randint(100, size=(10, 10), dtype="int64")) | |||||
| uint64 = Tensor(np.random.randint(100, size=(10, 10), dtype="uint64")) | |||||
| float16 = Tensor(np.random.rand(224, 224).astype(np.float16)) | |||||
| float32 = Tensor(np.random.rand(224, 224).astype(np.float32)) | |||||
| float64 = Tensor(np.random.rand(224, 224).astype(np.float64)) | |||||
| bool_ = Tensor(np.arange(-10, 10, 2).astype(np.bool_)) | |||||
| scale1 = Tensor(np.array(1)) | |||||
| scale2 = Tensor(np.array(0.1)) | |||||
| print_net(int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, bool_, scale1, | |||||
| scale2) | |||||
| def teardown_module(): | def teardown_module(): | ||||
| files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] | |||||
| files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb'] | |||||
| for item in files: | for item in files: | ||||
| file_name = './' + item | file_name = './' + item | ||||
| if not os.path.exists(file_name): | if not os.path.exists(file_name): | ||||