| @@ -553,8 +553,8 @@ class AdamOffload(Optimizer): | |||||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | ||||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | ||||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | ||||
| >>> {'params': no_conv_params, 'lr': 0.01}, | |||||
| >>> {'order_params': net.trainable_params()}] | |||||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||||
| ... {'order_params': net.trainable_params()}] | |||||
| >>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0) | >>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0) | ||||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | ||||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0. | >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0. | ||||
| @@ -265,7 +265,7 @@ class DistributedGradReducer(Cell): | |||||
| >>> | >>> | ||||
| >>> device_id = int(os.environ["DEVICE_ID"]) | >>> device_id = int(os.environ["DEVICE_ID"]) | ||||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, | >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, | ||||
| >>> device_id=int(device_id)) | |||||
| ... device_id=int(device_id)) | |||||
| >>> init() | >>> init() | ||||
| >>> context.reset_auto_parallel_context() | >>> context.reset_auto_parallel_context() | ||||
| >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) | >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) | ||||
| @@ -384,9 +384,11 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||||
| >>> @add.register("Tensor", "Tensor") | >>> @add.register("Tensor", "Tensor") | ||||
| ... def add_tensor(x, y): | ... def add_tensor(x, y): | ||||
| ... return tensor_add(x, y) | ... return tensor_add(x, y) | ||||
| >>> add(1, 2) | |||||
| >>> ourput = add(1, 2) | |||||
| >>> print(output) | |||||
| 3 | 3 | ||||
| >>> add(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) | |||||
| >>> output = add(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) | |||||
| >>> print(output) | |||||
| Tensor(shape=[], dtype=Float32, 3) | Tensor(shape=[], dtype=Float32, 3) | ||||
| """ | """ | ||||
| @@ -470,11 +472,13 @@ class HyperMap(HyperMap_): | |||||
| ... return F.square(x) | ... return F.square(x) | ||||
| >>> | >>> | ||||
| >>> common_map = HyperMap() | >>> common_map = HyperMap() | ||||
| >>> common_map(square, nest_tensor_list) | |||||
| >>> output = common_map(square, nest_tensor_list) | |||||
| >>> print(output) | |||||
| ((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), | ((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), | ||||
| (Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) | (Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) | ||||
| >>> square_map = HyperMap(square) | >>> square_map = HyperMap(square) | ||||
| >>> square_map(nest_tensor_list) | |||||
| >>> output = square_map(nest_tensor_list) | |||||
| >>> print(output) | |||||
| ((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), | ((Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4)), | ||||
| (Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) | (Tensor(shape=[], dtype=Float32, 9), Tensor(shape=[], dtype=Float32, 16)) | ||||
| """ | """ | ||||
| @@ -531,10 +535,12 @@ class Map(Map_): | |||||
| ... return F.square(x) | ... return F.square(x) | ||||
| >>> | >>> | ||||
| >>> common_map = Map() | >>> common_map = Map() | ||||
| >>> common_map(square, tensor_list) | |||||
| >>> output = common_map(square, tensor_list) | |||||
| >>> print(output) | |||||
| (Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) | (Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) | ||||
| >>> square_map = Map(square) | >>> square_map = Map(square) | ||||
| >>> square_map(tensor_list) | |||||
| >>> output = square_map(tensor_list) | |||||
| >>> print(output) | |||||
| (Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) | (Tensor(shape=[], dtype=Float32, 1), Tensor(shape=[], dtype=Float32, 4), Tensor(shape=[], dtype=Float32, 9)) | ||||
| """ | """ | ||||
| @@ -35,9 +35,9 @@ class Primitive(Primitive_): | |||||
| >>> # or work with prim_attr_register: | >>> # or work with prim_attr_register: | ||||
| >>> # init a Primitive class with attr1 and attr2 | >>> # init a Primitive class with attr1 and attr2 | ||||
| >>> class Add(Primitive): | >>> class Add(Primitive): | ||||
| >>> @prim_attr_register | |||||
| >>> def __init__(self, attr1, attr2): | |||||
| >>> # check attr1 and attr2 or do some initializations | |||||
| ... @prim_attr_register | |||||
| ... def __init__(self, attr1, attr2): | |||||
| ... # check attr1 and attr2 or do some initializations | |||||
| >>> # init a Primitive obj with attr1=1 and attr2=2 | >>> # init a Primitive obj with attr1=1 and attr2=2 | ||||
| >>> add = Add(attr1=1, attr2=2) | >>> add = Add(attr1=1, attr2=2) | ||||
| """ | """ | ||||