| @@ -109,7 +109,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N | |||
| raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.') | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, ' | |||
| raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, ' | |||
| f'but got `{arg_value}` with type `{type(arg_value).__name__}`.') | |||
| return arg_value | |||
| @@ -130,7 +130,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): | |||
| if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): | |||
| raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') | |||
| return arg_value | |||
| raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') | |||
| raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`') | |||
| def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): | |||
| @@ -146,7 +146,8 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg | |||
| arg_name = f'`{arg_name}`' if arg_name else '' | |||
| type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) | |||
| if type_mismatch: | |||
| raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.') | |||
| raise TypeError("{} {} must be `{}`, but got `{}`.".format( | |||
| arg_name, prim_name, value_type.__name__, type(arg_value).__name__)) | |||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format( | |||
| @@ -135,7 +135,7 @@ class _MindSporeFunction: | |||
| _exec_init_graph(self.obj, init_phase) | |||
| def compile(self, arguments_dict, method_name): | |||
| """Returns pipline for the given args.""" | |||
| """Returns pipeline for the given args.""" | |||
| args_list = tuple(arguments_dict.values()) | |||
| arg_names = tuple(arguments_dict.keys()) | |||
| @@ -32,6 +32,7 @@ from ..ops.functional import cast | |||
| from ..parallel._tensor import _load_tensor_by_layout | |||
| from ..common.tensor import Tensor | |||
| class Cell(Cell_): | |||
| """ | |||
| Base class for all neural networks. | |||
| @@ -579,7 +580,7 @@ class Cell(Cell_): | |||
| def cast_param(self, param): | |||
| """ | |||
| Cast parameter according to auto mix precison level in pynative mode. | |||
| Cast parameter according to auto mix precision level in pynative mode. | |||
| Args: | |||
| param (Parameter): The parameter to cast. | |||
| @@ -594,15 +595,13 @@ class Cell(Cell_): | |||
| param.set_cast_dtype() | |||
| return param | |||
| def insert_child_to_cell(self, child_name, child): | |||
| def insert_child_to_cell(self, child_name, child_cell): | |||
| """ | |||
| Adds a child cell to the current cell. | |||
| Inserts a subcell with a given name to the current cell. | |||
| Adds a child cell to the current cell with a given name. | |||
| Args: | |||
| child_name (str): Name of the child cell. | |||
| child (Cell): The child cell to be inserted. | |||
| child_cell (Cell): The child cell to be inserted. | |||
| Raises: | |||
| KeyError: Child Cell's name is incorrect or duplicated with the other child name. | |||
| @@ -612,15 +611,13 @@ class Cell(Cell_): | |||
| raise KeyError("Child cell name is incorrect.") | |||
| if hasattr(self, child_name) and child_name not in self._cells: | |||
| raise KeyError("Duplicate child name '{}'.".format(child_name)) | |||
| if not isinstance(child, Cell) and child is not None: | |||
| if not isinstance(child_cell, Cell) and child_cell is not None: | |||
| raise TypeError("Child cell type is incorrect.") | |||
| self._cells[child_name] = child | |||
| self._cells[child_name] = child_cell | |||
| def construct(self, *inputs, **kwargs): | |||
| """ | |||
| Defines the computation to be performed. | |||
| This method must be overridden by all subclasses. | |||
| Defines the computation to be performed. This method must be overridden by all subclasses. | |||
| Note: | |||
| The inputs of the top cell only allow Tensor. | |||
| @@ -477,7 +477,7 @@ class Model: | |||
| len_element = len(next_element) | |||
| next_element = _transfer_tensor_to_tuple(next_element) | |||
| if self._loss_fn and len_element != 2: | |||
| raise ValueError("when loss_fn is not None, train_dataset should" | |||
| raise ValueError("when loss_fn is not None, train_dataset should " | |||
| "return two elements, but got {}".format(len_element)) | |||
| cb_params.cur_step_num += 1 | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| export quantization aware training network to infer `AIR` backend. | |||
| export network to infer `AIR` backend. | |||
| """ | |||
| import argparse | |||
| @@ -27,14 +27,17 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net, | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet import LeNet5 | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| # define fusion network | |||
| @@ -30,23 +30,21 @@ from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.common import set_seed | |||
| set_seed(1) | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||
| path where the trained ckpt file') | |||
| args = parser.parse_args() | |||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ | |||
| path where the trained ckpt file') | |||
| args = parser.parse_args() | |||
| set_seed(1) | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), | |||
| cfg.batch_size) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size) | |||
| if ds_train.get_dataset_size() == 0: | |||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | |||