| @@ -1013,7 +1013,9 @@ bool SetMindIRGraphAction(const ResourcePtr &res) { | |||
| }); | |||
| if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) { | |||
| MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before." | |||
| << " Please check the args is same with export.\n" | |||
| << "Please check the args is same with export.\n" | |||
| << "The export input argument size : " << func_args.size() << "\n" | |||
| << "The load input argument size : " << broaded_args.size() << "\n" | |||
| << "Export input args info:" << abstract::ArgsToString(func_args) << "\n" | |||
| << "The input args info:" << abstract::ArgsToString(broaded_args); | |||
| } | |||
| @@ -69,11 +69,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction { | |||
| std::string ToString() const override; | |||
| AbstractFunctionPtr GetUnique() override { | |||
| MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; | |||
| AbstractFunctionPtr result; | |||
| return result; | |||
| } | |||
| AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } | |||
| /// \brief Check whether the input AbstractFunction is in AbstractFuncUnion. | |||
| /// | |||
| @@ -90,11 +86,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction { | |||
| std::size_t hash() const override; | |||
| AbstractFunctionPtr Copy() const override { | |||
| MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; | |||
| AbstractFunctionPtr result; | |||
| return result; | |||
| } | |||
| AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } | |||
| private: | |||
| AbstractFuncAtomPtrList func_list_; | |||
| @@ -64,8 +64,6 @@ AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -189,8 +187,7 @@ AbstractBasePtr InferImplScatterElements(const AnalysisEnginePtr &, const Primit | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -233,8 +230,6 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -56,16 +56,16 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard | |||
| class RegisterStandardPrimitiveEvalHelper { | |||
| public: | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl, | |||
| const InferValueImpl &infer_value_impl, const bool is_wight_list = true) { | |||
| const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; | |||
| const InferValueImpl &infer_value_impl, const bool is_white_list = true) { | |||
| const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_white_list}; | |||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||
| } | |||
| ~RegisterStandardPrimitiveEvalHelper() = default; | |||
| }; | |||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \ | |||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_white_list) \ | |||
| static auto helper_##name = \ | |||
| abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \ | |||
| abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \ | |||
| std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \ | |||
| auto out = std::make_shared<name>(); \ | |||
| return out; \ | |||
| @@ -22,10 +22,11 @@ import os | |||
| import shutil | |||
| import stat | |||
| import sys | |||
| import time | |||
| from collections import defaultdict | |||
| import threading | |||
| from threading import Thread, Lock | |||
| import time | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| from mindspore.train.checkpoint_pb2 import Checkpoint | |||
| @@ -67,6 +68,7 @@ _ckpt_mutex = Lock() | |||
| SLICE_SIZE = 512 * 1024 | |||
| PROTO_LIMIT_SIZE = 1024 * 1024 * 2 | |||
| TOTAL_SAVE = 1024 * 1024 | |||
| PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024 | |||
| def _special_process_par(par, new_par): | |||
| @@ -807,6 +809,108 @@ def _export(net, file_name, file_format, *inputs, **kwargs): | |||
| net.set_train(mode=True) | |||
| def _generate_front_info_for_param_data_file(is_encrypt, kwargs): | |||
| front_info = bytes() | |||
| check_code = sys.byteorder == "little" | |||
| front_info += check_code.to_bytes(1, byteorder=sys.byteorder) | |||
| front_info += bytes(63) | |||
| if is_encrypt(): | |||
| front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'], | |||
| len(kwargs['enc_key']), kwargs['enc_mode']) | |||
| return front_info | |||
| def _change_file(ori_data_file_name, dirname, external_local): | |||
| # The parameter has been not written in the file | |||
| if os.path.getsize(ori_data_file_name) == 64: | |||
| raise RuntimeError("The parameter size is exceed 1T,cannot export to the file") | |||
| data_file_name = os.path.join(dirname, external_local) | |||
| if os.path.exists(data_file_name): | |||
| os.chmod(data_file_name, stat.S_IWUSR) | |||
| return data_file_name | |||
| def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs): | |||
| ''' | |||
| The function to save parameter data | |||
| ''' | |||
| logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") | |||
| # save parameter | |||
| file_prefix = file_name.split("/")[-1] | |||
| if file_prefix.endswith(".mindir"): | |||
| file_prefix = file_prefix[:-7] | |||
| current_path = os.path.abspath(file_name) | |||
| dirname = os.path.dirname(current_path) | |||
| data_path = os.path.join(dirname, file_prefix + "_variables") | |||
| if os.path.exists(data_path): | |||
| shutil.rmtree(data_path) | |||
| os.makedirs(data_path, exist_ok=True) | |||
| os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) | |||
| # Reserves 4096 bytes as spare information such as check data | |||
| offset = 64 | |||
| index = 0 | |||
| parameter_size = (offset / 1024) | |||
| external_local = os.path.join(file_prefix + "_variables", "data_" + str(index)) | |||
| data_file_name = os.path.join(dirname, external_local) | |||
| if os.path.exists(data_file_name): | |||
| os.chmod(data_file_name, stat.S_IWUSR) | |||
| f = open(data_file_name, "wb") | |||
| f.write(bytes(offset)) | |||
| try: | |||
| for param_proto in model.graph.parameter: | |||
| name = param_proto.name[param_proto.name.find(":") + 1:] | |||
| param = net_dict[name] | |||
| raw_data = param.data.asnumpy().tobytes() | |||
| data_length = len(raw_data) | |||
| append_size = 0 | |||
| if data_length % 64 != 0: | |||
| append_size = 64 - (data_length % 64) | |||
| parameter_size += ((append_size + data_length) / 1024) | |||
| if parameter_size > PARAMETER_SPLIT_SIZE: | |||
| front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs) | |||
| f.seek(0, 0) | |||
| f.write(front_info) | |||
| f.close() | |||
| os.chmod(data_file_name, stat.S_IRUSR) | |||
| offset = 64 | |||
| index += 1 | |||
| parameter_size = (offset + append_size + data_length) / 1024 | |||
| external_local = os.path.join(file_prefix + "_variables", "data_" + str(index)) | |||
| data_file_name = _change_file(data_file_name, dirname, external_local) | |||
| f = open(data_file_name, "wb") | |||
| f.write(bytes(offset)) | |||
| param_proto.external_data.location = external_local | |||
| param_proto.external_data.length = data_length | |||
| param_proto.external_data.offset = offset | |||
| write_data = raw_data + bytes(append_size) | |||
| offset += (data_length + append_size) | |||
| if is_encrypt(): | |||
| write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'], | |||
| len(kwargs['enc_key']), kwargs['enc_mode']) | |||
| f.write(write_data) | |||
| # save graph | |||
| graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir") | |||
| if os.path.exists(graph_file_name): | |||
| os.chmod(graph_file_name, stat.S_IWUSR) | |||
| with open(graph_file_name, 'wb') as model_file: | |||
| os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR) | |||
| model_string = model.SerializeToString() | |||
| if is_encrypt(): | |||
| model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], | |||
| len(kwargs['enc_key']), | |||
| kwargs['enc_mode']) | |||
| model_file.write(model_string) | |||
| os.chmod(graph_file_name, stat.S_IRUSR) | |||
| front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs) | |||
| f.seek(0, 0) | |||
| f.write(front_info) | |||
| finally: | |||
| f.close() | |||
| os.chmod(data_file_name, stat.S_IRUSR) | |||
| def _save_mindir(net, file_name, *inputs, **kwargs): | |||
| """Save MindIR format file.""" | |||
| model = mindir_model() | |||
| @@ -829,67 +933,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs): | |||
| if save_together: | |||
| _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs) | |||
| else: | |||
| logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") | |||
| # save parameter | |||
| file_prefix = file_name.split("/")[-1] | |||
| if file_prefix.endswith(".mindir"): | |||
| file_prefix = file_prefix[:-7] | |||
| current_path = os.path.abspath(file_name) | |||
| dirname = os.path.dirname(current_path) | |||
| data_path = os.path.join(dirname, file_prefix + "_variables") | |||
| if os.path.exists(data_path): | |||
| shutil.rmtree(data_path) | |||
| os.makedirs(data_path, exist_ok=True) | |||
| os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) | |||
| # Reserves 4096 bytes as spare information such as check data | |||
| offset = 64 | |||
| data_file_name = os.path.join(data_path, "veriables.data") | |||
| if os.path.exists(data_file_name): | |||
| os.chmod(data_file_name, stat.S_IWUSR) | |||
| with open(data_file_name, "wb") as f: | |||
| f.write(bytes(offset)) | |||
| for name, param in net_dict.items(): | |||
| for param_proto in model.graph.parameter: | |||
| if name == param_proto.name[param_proto.name.find(":") + 1:]: | |||
| data_file = os.path.join(file_prefix + "_variables", "veriables.data") | |||
| param_proto.external_data.location = data_file | |||
| raw_data = param.data.asnumpy().tobytes() | |||
| data_length = len(raw_data) | |||
| param_proto.external_data.length = data_length | |||
| param_proto.external_data.offset = offset | |||
| write_data = raw_data | |||
| offset += data_length | |||
| if data_length % 64 != 0: | |||
| append_size = 64 - (data_length % 64) | |||
| write_data += (bytes(append_size)) | |||
| offset += append_size | |||
| if is_encrypt(): | |||
| write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'], | |||
| len(kwargs['enc_key']), kwargs['enc_mode']) | |||
| f.write(write_data) | |||
| # save graph | |||
| graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir") | |||
| if os.path.exists(graph_file_name): | |||
| os.chmod(graph_file_name, stat.S_IWUSR) | |||
| with open(graph_file_name, 'wb') as model_file: | |||
| os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR) | |||
| model_string = model.SerializeToString() | |||
| if is_encrypt(): | |||
| model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], | |||
| len(kwargs['enc_key']), | |||
| kwargs['enc_mode']) | |||
| model_file.write(model_string) | |||
| os.chmod(graph_file_name, stat.S_IRUSR) | |||
| front_info = bytearray() | |||
| check_code = sys.byteorder == "little" | |||
| front_info += check_code.to_bytes(1, byteorder=sys.byteorder) | |||
| f.seek(0, 0) | |||
| if is_encrypt(): | |||
| front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'], | |||
| len(kwargs['enc_key']), kwargs['enc_mode']) | |||
| f.write(front_info) | |||
| _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs) | |||
| def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs): | |||
| @@ -1201,6 +1245,7 @@ def ckpt_restore_group_info(group_info_file_name): | |||
| restore_rank_list = [rank for rank in restore_list.dim] | |||
| return restore_rank_list | |||
| def build_searched_strategy(strategy_filename): | |||
| """ | |||
| Build strategy of every parameter in network. Used in the case of distributed inference. | |||
| @@ -1480,7 +1525,6 @@ def async_ckpt_thread_status(): | |||
| def _check_predict_strategy(predict_strategy): | |||
| """Check predict strategy.""" | |||
| def _check_int_list(arg): | |||
| if not isinstance(arg, list): | |||
| return False | |||
| @@ -0,0 +1,146 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test mindir export larger than 1G """ | |||
| import os | |||
| import sys | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| from mindspore import Parameter | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.serialization import export, load | |||
| def get_front_info(): | |||
| correct_data = bytes() | |||
| check_code = sys.byteorder == "little" | |||
| correct_data += check_code.to_bytes(1, byteorder=sys.byteorder) | |||
| correct_data += bytes(63) | |||
| return correct_data | |||
| def get_correct_data(parameter): | |||
| correct_data = bytes() | |||
| data = parameter.data.asnumpy().tobytes() | |||
| data_size = len(data) | |||
| if data_size % 64 != 0: | |||
| data += bytes((64 - data_size % 64)) | |||
| correct_data += data | |||
| return correct_data | |||
| def get_data(mindir_name): | |||
| data_path = mindir_name + "_variables" | |||
| data = bytes() | |||
| for dirpath, _, filenames in os.walk(data_path): | |||
| for filename in filenames: | |||
| with open(os.path.join(dirpath, filename), "rb") as f: | |||
| data += f.readline() | |||
| return data | |||
| def test_mindir_export_split(): | |||
| """ | |||
| Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) | |||
| Description: MindIR Export model is exceed TOTAL_SAVE should be split save as model file and data file | |||
| Expectation: No exception. | |||
| """ | |||
| ms.train.serialization.TOTAL_SAVE = 0 | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.addn = ops.AddN() | |||
| self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w") | |||
| self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z") | |||
| def construct(self, x): | |||
| return self.addn((x, self.y, self.z)) | |||
| x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32)) | |||
| add_net = Net() | |||
| export(add_net, x, file_name="mindir_export_split", file_format="MINDIR") | |||
| graph = load("mindir_export_split_graph.mindir") | |||
| assert graph is not None | |||
| correct_data = get_front_info() | |||
| correct_data += get_correct_data(add_net.y) | |||
| correct_data += get_correct_data(add_net.z) | |||
| export_data = get_data("mindir_export_split") | |||
| assert export_data == correct_data | |||
| assert oct(os.stat(os.path.join("mindir_export_split_variables", "data_0")).st_mode)[-3:] == "400" | |||
| assert oct(os.stat("mindir_export_split_graph.mindir").st_mode)[-3:] == "400" | |||
| def test_mindir_export_larger_error(): | |||
| """ | |||
| Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) | |||
| Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file | |||
| and data file if the model has a parameter which exceed PARAMETER_SPLIT_SIZE(1T but mocked as 0) | |||
| the exception should be reported. | |||
| Expectation: Parameter is exceed PARAMETER_SPLIT_SIZE | |||
| """ | |||
| ms.train.serialization.TOTAL_SAVE = 0 | |||
| ms.train.serialization.PARAMETER_SPLIT_SIZE = 0 | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.add = ops.Add() | |||
| self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w") | |||
| def construct(self, x): | |||
| return self.add(x, self.y) | |||
| x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32)) | |||
| add = Net() | |||
| with pytest.raises(RuntimeError) as e: | |||
| export(add, x, file_name="net", file_format="MINDIR") | |||
| assert e.message == "The parameter size is exceed 1T,cannot export to the file" | |||
| def test_mindir_export_larger_parameter_exceed_1t_mock(): | |||
| """ | |||
| Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) | |||
| Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file | |||
| and data file if the parameter data file exceed PARAMETER_SPLIT_SIZE(1T but mocked as 129Bytes) limit, | |||
| it will be split to another file named data_0,data_1,data_2... | |||
| Expectation: No exception. | |||
| """ | |||
| ms.train.serialization.TOTAL_SAVE = 0 | |||
| ms.train.serialization.PARAMETER_SPLIT_SIZE = 129 / 1024 | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.addn = ops.AddN() | |||
| self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w") | |||
| self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z") | |||
| def construct(self, x): | |||
| return self.addn((x, self.y, self.z)) | |||
| x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32)) | |||
| add_net = Net() | |||
| export(add_net, x, file_name="larger_parameter_exceed_1T_mock", file_format="MINDIR") | |||
| graph = load("larger_parameter_exceed_1T_mock_graph.mindir") | |||
| assert graph is not None | |||
| correct_data = get_front_info() | |||
| correct_data += get_correct_data(add_net.y) | |||
| correct_data += get_front_info() | |||
| correct_data += get_correct_data(add_net.z) | |||
| export_data = get_data("larger_parameter_exceed_1T_mock") | |||
| assert export_data == correct_data | |||