| @@ -109,7 +109,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N | |||||
| raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.') | raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.') | ||||
| if type_mismatch or not rel_fn(arg_value, value): | if type_mismatch or not rel_fn(arg_value, value): | ||||
| rel_str = Rel.get_strs(rel).format(value) | rel_str = Rel.get_strs(rel).format(value) | ||||
| raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, ' | |||||
| raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, ' | |||||
| f'but got `{arg_value}` with type `{type(arg_value).__name__}`.') | f'but got `{arg_value}` with type `{type(arg_value).__name__}`.') | ||||
| return arg_value | return arg_value | ||||
| @@ -130,7 +130,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): | |||||
| if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): | if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): | ||||
| raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') | raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') | ||||
| return arg_value | return arg_value | ||||
| raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') | |||||
| raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`') | |||||
| def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): | def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): | ||||
| @@ -146,7 +146,8 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg | |||||
| arg_name = f'`{arg_name}`' if arg_name else '' | arg_name = f'`{arg_name}`' if arg_name else '' | ||||
| type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) | type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) | ||||
| if type_mismatch: | if type_mismatch: | ||||
| raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.') | |||||
| raise TypeError("{} {} must be `{}`, but got `{}`.".format( | |||||
| arg_name, prim_name, value_type.__name__, type(arg_value).__name__)) | |||||
| if not rel_fn(arg_value, lower_limit, upper_limit): | if not rel_fn(arg_value, lower_limit, upper_limit): | ||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | ||||
| raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format( | raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format( | ||||
| @@ -135,7 +135,7 @@ class _MindSporeFunction: | |||||
| _exec_init_graph(self.obj, init_phase) | _exec_init_graph(self.obj, init_phase) | ||||
| def compile(self, arguments_dict, method_name): | def compile(self, arguments_dict, method_name): | ||||
| """Returns pipline for the given args.""" | |||||
| """Returns pipeline for the given args.""" | |||||
| args_list = tuple(arguments_dict.values()) | args_list = tuple(arguments_dict.values()) | ||||
| arg_names = tuple(arguments_dict.keys()) | arg_names = tuple(arguments_dict.keys()) | ||||
| @@ -32,6 +32,7 @@ from ..ops.functional import cast | |||||
| from ..parallel._tensor import _load_tensor_by_layout | from ..parallel._tensor import _load_tensor_by_layout | ||||
| from ..common.tensor import Tensor | from ..common.tensor import Tensor | ||||
| class Cell(Cell_): | class Cell(Cell_): | ||||
| """ | """ | ||||
| Base class for all neural networks. | Base class for all neural networks. | ||||
| @@ -579,7 +580,7 @@ class Cell(Cell_): | |||||
| def cast_param(self, param): | def cast_param(self, param): | ||||
| """ | """ | ||||
| Cast parameter according to auto mix precison level in pynative mode. | |||||
| Cast parameter according to auto mix precision level in pynative mode. | |||||
| Args: | Args: | ||||
| param (Parameter): The parameter to cast. | param (Parameter): The parameter to cast. | ||||
| @@ -594,15 +595,13 @@ class Cell(Cell_): | |||||
| param.set_cast_dtype() | param.set_cast_dtype() | ||||
| return param | return param | ||||
| def insert_child_to_cell(self, child_name, child): | |||||
| def insert_child_to_cell(self, child_name, child_cell): | |||||
| """ | """ | ||||
| Adds a child cell to the current cell. | |||||
| Inserts a subcell with a given name to the current cell. | |||||
| Adds a child cell to the current cell with a given name. | |||||
| Args: | Args: | ||||
| child_name (str): Name of the child cell. | child_name (str): Name of the child cell. | ||||
| child (Cell): The child cell to be inserted. | |||||
| child_cell (Cell): The child cell to be inserted. | |||||
| Raises: | Raises: | ||||
| KeyError: Child Cell's name is incorrect or duplicated with the other child name. | KeyError: Child Cell's name is incorrect or duplicated with the other child name. | ||||
| @@ -612,15 +611,13 @@ class Cell(Cell_): | |||||
| raise KeyError("Child cell name is incorrect.") | raise KeyError("Child cell name is incorrect.") | ||||
| if hasattr(self, child_name) and child_name not in self._cells: | if hasattr(self, child_name) and child_name not in self._cells: | ||||
| raise KeyError("Duplicate child name '{}'.".format(child_name)) | raise KeyError("Duplicate child name '{}'.".format(child_name)) | ||||
| if not isinstance(child, Cell) and child is not None: | |||||
| if not isinstance(child_cell, Cell) and child_cell is not None: | |||||
| raise TypeError("Child cell type is incorrect.") | raise TypeError("Child cell type is incorrect.") | ||||
| self._cells[child_name] = child | |||||
| self._cells[child_name] = child_cell | |||||
| def construct(self, *inputs, **kwargs): | def construct(self, *inputs, **kwargs): | ||||
| """ | """ | ||||
| Defines the computation to be performed. | |||||
| This method must be overridden by all subclasses. | |||||
| Defines the computation to be performed. This method must be overridden by all subclasses. | |||||
| Note: | Note: | ||||
| The inputs of the top cell only allow Tensor. | The inputs of the top cell only allow Tensor. | ||||
| @@ -477,7 +477,7 @@ class Model: | |||||
| len_element = len(next_element) | len_element = len(next_element) | ||||
| next_element = _transfer_tensor_to_tuple(next_element) | next_element = _transfer_tensor_to_tuple(next_element) | ||||
| if self._loss_fn and len_element != 2: | if self._loss_fn and len_element != 2: | ||||
| raise ValueError("when loss_fn is not None, train_dataset should" | |||||
| raise ValueError("when loss_fn is not None, train_dataset should " | |||||
| "return two elements, but got {}".format(len_element)) | "return two elements, but got {}".format(len_element)) | ||||
| cb_params.cur_step_num += 1 | cb_params.cur_step_num += 1 | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ | """ | ||||
| export quantization aware training network to infer `AIR` backend. | |||||
| export network to infer `AIR` backend. | |||||
| """ | """ | ||||
| import argparse | import argparse | ||||
| @@ -27,14 +27,17 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net, | |||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| from src.lenet import LeNet5 | from src.lenet import LeNet5 | ||||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||||
| choices=['Ascend', 'GPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--ckpt_path', type=str, default="", | |||||
| help='if mode is test, must provide path where the trained ckpt file') | |||||
| args = parser.parse_args() | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||||
| choices=['Ascend', 'GPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--ckpt_path', type=str, default="", | |||||
| help='if mode is test, must provide path where the trained ckpt file') | |||||
| args = parser.parse_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | ||||
| # define fusion network | # define fusion network | ||||
| @@ -30,23 +30,21 @@ from mindspore.train import Model | |||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| set_seed(1) | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--data_path', type=str, default="./Data", | |||||
| help='path where the dataset is saved') | |||||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||||
| path where the trained ckpt file') | |||||
| args = parser.parse_args() | |||||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--data_path', type=str, default="./Data", | |||||
| help='path where the dataset is saved') | |||||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||||
| path where the trained ckpt file') | |||||
| args = parser.parse_args() | |||||
| set_seed(1) | |||||
| if __name__ == "__main__": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | ||||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), | |||||
| cfg.batch_size) | |||||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size) | |||||
| if ds_train.get_dataset_size() == 0: | if ds_train.get_dataset_size() == 0: | ||||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | ||||