Merge pull request !6185 from gziyan/fix_api_commentstags/v1.0.0
| @@ -32,33 +32,32 @@ DEFAULT_BACKEND = Backend("hccl") | |||
| def _get_group(group): | |||
| """Get the global world group if the group is default world comm group.""" | |||
| """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" | |||
| if group == DEFAULT_WORLD_COMM_GROUP: | |||
| return GlobalComm.WORLD_COMM_GROUP | |||
| return group | |||
| class GlobalComm: | |||
| """Global communication info.""" | |||
| """World communication information.""" | |||
| BACKEND = DEFAULT_BACKEND | |||
| WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP | |||
| def init(backend_name=None): | |||
| """ | |||
| Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used. | |||
| Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service. | |||
| Note: | |||
| The full name of hccl is Huawei Collective Communication Library. | |||
| The full name of nccl is NVIDIA Collective Communication Library. | |||
| The full name of HCCL is Huawei Collective Communication Library. | |||
| The full name of NCCL is NVIDIA Collective Communication Library. | |||
| Args: | |||
| backend_name (str): Backend. | |||
| Raises: | |||
| TypeError: If backen_name is not a string. | |||
| RuntimeError: If device target is invalid. | |||
| RuntimeError: If backend is invalid or distributed init fails. | |||
| TypeError: If `backend_name` is not a string. | |||
| RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. | |||
| """ | |||
| if _is_role_pserver() or _is_role_sched(): | |||
| return | |||
| @@ -88,17 +87,17 @@ def init(backend_name=None): | |||
| def release(): | |||
| """ | |||
| Release distributed resource. e.g., hccl/nccl. | |||
| Release distributed resource. e.g. HCCL/NCCL. | |||
| Raises: | |||
| RuntimeError: If distributed resource release fails. | |||
| RuntimeError: If failed to release distributed resource. | |||
| """ | |||
| finalize_hccl() | |||
| def get_rank(group=GlobalComm.WORLD_COMM_GROUP): | |||
| """ | |||
| Gets rank ID for current device in specified collective communication group. | |||
| Get the rank ID for the current device in the specified collective communication group. | |||
| Args: | |||
| group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP. | |||
| @@ -109,7 +108,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP): | |||
| Raises: | |||
| TypeError: If group is not a string. | |||
| ValueError: If backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. | |||
| """ | |||
| return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) | |||
| @@ -130,14 +129,14 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): | |||
| Raises: | |||
| TypeError: If group is not a string. | |||
| ValueError: If backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. | |||
| """ | |||
| return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) | |||
| def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): | |||
| """ | |||
| Gets rank size of the specified collective communication group. | |||
| Get the rank size of the specified collective communication group. | |||
| Args: | |||
| group (str): ProcessGroup, the process group to work on. | |||
| @@ -148,7 +147,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): | |||
| Raises: | |||
| TypeError: If group is not a string. | |||
| ValueError: If backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. | |||
| """ | |||
| return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) | |||
| @@ -164,22 +163,23 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): | |||
| group (str): ProcessGroup, the process group to work on. | |||
| Returns: | |||
| int, the local rank size where the calling process is being within the group. | |||
| int, the local rank size where the calling process is within the group. | |||
| Raises: | |||
| TypeError: If group is not a string. | |||
| ValueError: If backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. | |||
| """ | |||
| return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) | |||
| def get_world_rank_from_group_rank(group, group_rank_id): | |||
| """ | |||
| Gets the rank ID in world communication group corresponding to the rank ID in specified user communication group. | |||
| Gets the rank ID in the world communication group corresponding to | |||
| the rank ID in the specified user communication group. | |||
| Note: | |||
| Nccl is not supported. | |||
| NCCL is not supported. | |||
| The parameter group should not be "hccl_world_group". | |||
| Args: | |||
| @@ -190,52 +190,53 @@ def get_world_rank_from_group_rank(group, group_rank_id): | |||
| int, the rank ID in world communication group. | |||
| Raises: | |||
| TypeError: If group_rank_id is not a int or group is not a string. | |||
| TypeError: If `group_rank_id` is not an integer or the group is not a string. | |||
| ValueError: If group is 'hccl_world_group' or backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. | |||
| """ | |||
| return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND) | |||
| def get_group_rank_from_world_rank(world_rank_id, group): | |||
| """ | |||
| Gets the rank ID in specified user communication group corresponding to the rank ID in world communication group. | |||
| Get the rank ID in the specified user communication group corresponding to | |||
| the rank ID in the world communication group. | |||
| Note: | |||
| Nccl is not supported. | |||
| NCCL is not supported. | |||
| The parameter group should not be "hccl_world_group". | |||
| Args: | |||
| world_rank_id (int): A rank ID in world communication group. | |||
| world_rank_id (int): A rank ID in the world communication group. | |||
| group (str): The user communication group. | |||
| Returns: | |||
| int, the rank ID in user communication group. | |||
| int, the rank ID in the user communication group. | |||
| Raises: | |||
| TypeError: If world_rank_id is not a int or group is not a string. | |||
| TypeError: If world_rank_id is not an integer or the group is not a string. | |||
| ValueError: If group is 'hccl_world_group' or backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. | |||
| """ | |||
| return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND) | |||
| def create_group(group, rank_ids): | |||
| """ | |||
| Creates user collective communication group. | |||
| Create a user collective communication group. | |||
| Note: | |||
| Nccl is not supported. | |||
| NCCL is not supported. | |||
| The size of rank_ids should be larger than 1. | |||
| Rank_ids should not have duplicate data. | |||
| Args: | |||
| group (str): ProcessGroup, the process group to create. | |||
| rank_ids (list): List of device ID. | |||
| rank_ids (list): A list of device IDs. | |||
| Raises: | |||
| TypeError: If group is not a string or rank_ids is not a list. | |||
| ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. | |||
| TypeError: If group is not a string or `rank_ids` is not a list. | |||
| ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| Examples: | |||
| >>> group = "0-1" | |||
| @@ -247,7 +248,7 @@ def create_group(group, rank_ids): | |||
| def destroy_group(group): | |||
| """ | |||
| Destroys user collective communication group. | |||
| Destroy the user collective communication group. | |||
| Note: | |||
| Nccl is not supported. | |||
| @@ -259,6 +260,6 @@ def destroy_group(group): | |||
| Raises: | |||
| TypeError: If group is not a string. | |||
| ValueError: If group is "hccl_world_group" or backend is invalid. | |||
| RuntimeError: If hccl/nccl is not available or nccl not supports. | |||
| RuntimeError: If HCCL/NCCL is not available or NCCL is not supported. | |||
| """ | |||
| _destroy_group_helper(group, backend=GlobalComm.BACKEND) | |||
| @@ -336,6 +336,8 @@ def set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| Auto parallel context should be configured before the initialization of your network. | |||
| Note: | |||
| Attribute name is required for setting attributes. | |||
| If a program has tasks with different parallel modes, then before setting new parallel mode for the | |||
| @@ -344,12 +346,25 @@ def set_auto_parallel_context(**kwargs): | |||
| Setting or changing parallel modes must be called before any creating Initializer, otherwise, | |||
| RuntimeError may be raised when compiling the network. | |||
| Some configurations are parallel mode specific, see the below table for details: | |||
| =========================== =========================== ================= | |||
| Common AUTO_PARALLEL DATA_PRALLEL | |||
| =========================== =========================== ================= | |||
| device_num gradient_fp32_sync enable_parallel_optimizer | |||
| global_rank loss_repeated_mean | |||
| gradients_mean auto_parallel_search_mode | |||
| parallel_mode strategy_ckpt_load_file | |||
| all_reduce_fusion_config strategy_ckpt_save_file | |||
| full_batch | |||
| =========================== =========================== ================= | |||
| Args: | |||
| device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. | |||
| global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. | |||
| gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. | |||
| "stand_alone" does not support `gradients_mean`. Default: False. | |||
| gradient_fp32_sync (bool): Gradients allreduce by fp32, even though gradients is fp16 if this flag is True.. | |||
| gradients_mean (bool): Whether to perform mean operator after allreduce of gradients. | |||
| "stand_alone" do not support gradients_mean. Default: False. | |||
| gradient_fp32_sync (bool): Run allreduce of gradients in fp32. | |||
| "stand_alone", "data_parallel" and "hybrid_parallel" do not support | |||
| gradient_fp32_sync. Default: True. | |||
| parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", | |||
| @@ -364,8 +379,8 @@ def set_auto_parallel_context(**kwargs): | |||
| - semi_auto_parallel: Achieves data parallelism and model parallelism by | |||
| setting parallel strategies. | |||
| - auto_parallel: Achieves parallelism automatically. | |||
| auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming" | |||
| - auto_parallel: Achieving parallelism automatically. | |||
| auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming" | |||
| and "dynamic_programming". Default: "dynamic_programming". | |||
| - recursive_programming: Recursive programming search mode. | |||
| @@ -376,9 +391,11 @@ def set_auto_parallel_context(**kwargs): | |||
| broadcast. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in | |||
| data parallel training in the benefit of time and memory saving. | |||
| full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter | |||
| should be set with True. Default: False. | |||
| enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for | |||
| data parallel training in the benefit of time and memory saving. For now, | |||
| `Lamb` and `AdamWeightDecay` are supported in data parallel mode. | |||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | |||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. | |||
| @@ -479,7 +496,7 @@ def set_context(**kwargs): | |||
| Some configurations are device specific, see the bellow table for details: | |||
| =========================== =========================== ================= | |||
| Common(CPU/GPU/Asecend) Ascend GPU | |||
| Common(CPU/GPU/Ascend) Ascend GPU | |||
| =========================== =========================== ================= | |||
| check_bprop enable_auto_mixed_precision max_device_memory | |||
| device_id enable_dump | |||
| @@ -33,7 +33,7 @@ __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor | |||
| def add_flags(fn=None, **flags): | |||
| """ | |||
| An decorator to add flag for a function. | |||
| A decorator that adds a flag to the function. | |||
| Note: | |||
| Only supports bool value. | |||
| @@ -43,7 +43,7 @@ def add_flags(fn=None, **flags): | |||
| flags (dict): Flags use kwargs. Default: None. | |||
| Returns: | |||
| Function, the fn added flags. | |||
| Function, the function with added flags. | |||
| Examples: | |||
| >>> add_flags(net, predit=True) | |||
| @@ -63,9 +63,9 @@ def add_flags(fn=None, **flags): | |||
| def core(fn=None, **flags): | |||
| """ | |||
| A decorator to add flag to a function. | |||
| A decorator that adds a flag to the function. | |||
| By default, the function is marked core=True using this decorator to | |||
| By default, the function is marked as True, enabling to use this decorator to | |||
| set flag to a graph. | |||
| Args: | |||
| @@ -91,11 +91,12 @@ def core(fn=None, **flags): | |||
| class GradOperation(GradOperation_): | |||
| """ | |||
| An higher-order function which is used to generate the gradient function for the input function. | |||
| A higher-order function which is used to generate the gradient function for the input function. | |||
| The gradient function generated by `GradOperation` higher-order function can be customized by construction args. | |||
| The gradient function generated by `GradOperation` higher-order function can be customized by | |||
| construction arguments. | |||
| Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`, | |||
| Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`, | |||
| see `Net` in Examples. | |||
| To generate a gradient function that returns gradients with respect to the first input | |||
| @@ -126,7 +127,7 @@ class GradOperation(GradOperation_): | |||
| 1. Construct a `GradOperation` higher-order function with `get_by_list=True`: | |||
| `grad_op = GradOperation(get_by_list=True)`. | |||
| 2. Construct a `ParameterTuple` that will be passed along input function when constructing | |||
| 2. Construct a `ParameterTuple` that will be passed to the input function when constructing | |||
| `GradOperation` higher-order function, it will be used as a parameter filter that determine | |||
| which gradient to return: `params = ParameterTuple(net.trainable_params())`. | |||
| @@ -151,20 +152,20 @@ class GradOperation(GradOperation_): | |||
| 4. Call the gradient function with input function's inputs | |||
| to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`. | |||
| We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and | |||
| passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be | |||
| with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples). | |||
| We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and | |||
| passing an extra sensitivity input to the gradient function, the sensitivity input should has the | |||
| same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples). | |||
| 1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`: | |||
| `grad_op = GradOperation(get_all=True, sens_param=True)`. | |||
| 2. Define grad_wrt_output as sens_param which works as the gradient with respect to output: | |||
| 2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output: | |||
| `grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`. | |||
| 3. Call it with input function as argument to get the gradient function: | |||
| `gradient_function = grad_op(net)`. | |||
| 4. Call the gradient function with input function's inputs and sens_param to | |||
| 4. Call the gradient function with input function's inputs and `sens_param` to | |||
| get the gradients with respect to all inputs: | |||
| `gradient_function(x, y, grad_wrt_output)`. | |||
| @@ -175,8 +176,9 @@ class GradOperation(GradOperation_): | |||
| If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables | |||
| at the same time in the form of ((gradients with respect to inputs), | |||
| (gradients with respect to parameters)). Default: False. | |||
| sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False, | |||
| a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. | |||
| sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. | |||
| If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. | |||
| Default: False. | |||
| Returns: | |||
| The higher-order function which takes a function as argument and returns gradient function for it. | |||
| @@ -349,9 +351,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||
| """ | |||
| Generate overloaded functions. | |||
| MultitypeFuncGraph is a class used to generate overloaded functions with different type as inputs. | |||
| MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs. | |||
| Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator | |||
| for the function to be registed. And the object can be called with different type of inputs, | |||
| for the function to be registed. And the object can be called with different types of inputs, | |||
| and work with `HyperMap` and `Map`. | |||
| Args: | |||
| @@ -360,7 +362,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||
| and all inputs will pass by value, set `read_value` to True. Default: False. | |||
| Raises: | |||
| ValueError: Cannot find matching functions for the given args. | |||
| ValueError: If failed to find find a matching function for the given arguments. | |||
| Examples: | |||
| >>> # `add` is a metagraph object which will add two objects according to | |||
| @@ -431,7 +433,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||
| class HyperMap(HyperMap_): | |||
| """ | |||
| Hypermap will apply the set operation on input sequences. | |||
| Hypermap will apply the set operation to input sequences. | |||
| Apply the operations to every elements of the sequence or nested sequence. Different | |||
| from `Map`, the `HyperMap` supports to apply on nested structure. | |||
| @@ -441,11 +443,10 @@ class HyperMap(HyperMap_): | |||
| the operations should be put in the first input of the instance. | |||
| Inputs: | |||
| - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, | |||
| and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence | |||
| `(args[0][i], args[1][i])` will be the input of the operation. | |||
| - **args** (Tuple[sequence]) - If `ops` is `None`, all the inputs should be sequences with the same length. | |||
| And each row of the sequences will be the inputs of the operation. | |||
| If `ops` is not `None`, the first input is the operation, and the other is inputs. | |||
| If `ops` is not `None`, the first input is the operation, and the others are inputs. | |||
| Outputs: | |||
| Sequence or nested sequence, the sequence of output after applying the function. | |||
| @@ -48,14 +48,15 @@ def normal(shape, mean, stddev, seed=0): | |||
| Args: | |||
| shape (tuple): The shape of random tensor to be generated. | |||
| mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak. | |||
| With float32 data type. | |||
| with float32 data type. | |||
| stddev (Tensor): The deviation σ distribution parameter. It should be greater than 0. | |||
| With float32 data type. | |||
| seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. | |||
| Must be non-negative. Default: 0. | |||
| with float32 data type. | |||
| seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers. | |||
| must be non-negative. Default: 0. | |||
| Returns: | |||
| Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean and stddev. | |||
| Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes | |||
| of `mean` and `stddev`. | |||
| The dtype is float32. | |||
| Examples: | |||
| @@ -123,20 +124,21 @@ def uniform(shape, minval, maxval, seed=0, dtype=mstype.float32): | |||
| Args: | |||
| shape (tuple): The shape of random tensor to be generated. | |||
| minval (Tensor): The a distribution parameter. | |||
| It defines the minimum possibly generated value. With int32 or float32 data type. | |||
| minval (Tensor): The distribution parameter `a`. | |||
| It defines the minimum possible generated value, with int32 or float32 data type. | |||
| If dtype is int32, only one number is allowed. | |||
| maxval (Tensor): The b distribution parameter. | |||
| It defines the maximum possibly generated value. With int32 or float32 data type. | |||
| maxval (Tensor): The distribution parameter `b`. | |||
| It defines the maximum possible generated value, with int32 or float32 data type. | |||
| If dtype is int32, only one number is allowed. | |||
| seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. | |||
| Must be non-negative. Default: 0. | |||
| seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers, | |||
| must be non-negative. Default: 0. | |||
| dtype (mindspore.dtype): type of the Uniform distribution. If it is int32, it generates numbers from discrete | |||
| uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only | |||
| supports these two data types. Default: mstype.float32. | |||
| Returns: | |||
| Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of minval and maxval. | |||
| Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes | |||
| of `minval` and `maxval`. | |||
| The dtype is designated as the input `dtype`. | |||
| Examples: | |||
| @@ -175,13 +177,14 @@ def gamma(shape, alpha, beta, seed=0): | |||
| Args: | |||
| shape (tuple): The shape of random tensor to be generated. | |||
| alpha (Tensor): The alpha α distribution parameter. It should be greater than 0. With float32 data type. | |||
| beta (Tensor): The beta β distribution parameter. It should be greater than 0. With float32 data type. | |||
| seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. | |||
| Must be non-negative. Default: 0. | |||
| alpha (Tensor): The alpha α distribution parameter. It should be greater than 0 with float32 data type. | |||
| beta (Tensor): The beta β distribution parameter. It should be greater than 0 with float32 data type. | |||
| seed (int): Seed is used as entropy source for the random number engines to generate | |||
| pseudo-random numbers, must be non-negative. Default: 0. | |||
| Returns: | |||
| Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta. | |||
| Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes | |||
| of `alpha` and `beta`. | |||
| The dtype is float32. | |||
| Examples: | |||
| @@ -203,12 +206,12 @@ def poisson(shape, mean, seed=0): | |||
| Args: | |||
| shape (tuple): The shape of random tensor to be generated. | |||
| mean (Tensor): The mean μ distribution parameter. It should be greater than 0. With float32 data type. | |||
| seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. | |||
| Must be non-negative. Default: 0. | |||
| mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type. | |||
| seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers | |||
| and must be non-negative. Default: 0. | |||
| Returns: | |||
| Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of mean. | |||
| Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`. | |||
| The dtype is float32. | |||
| Examples: | |||
| @@ -226,21 +229,23 @@ def poisson(shape, mean, seed=0): | |||
| def multinomial(inputs, num_sample, replacement=True, seed=0): | |||
| r""" | |||
| Returns a tensor sampled from the multinomial probability distribution located in the corresponding | |||
| row of tensor input. | |||
| row of the input tensor. | |||
| Note: | |||
| The rows of input do not need to sum to one (in which case we use the values as weights), | |||
| but must be non-negative, finite and have a non-zero sum. | |||
| Args: | |||
| inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type. | |||
| num_sample (int): number of samples to draw. | |||
| replacement (bool, optional): whether to draw with replacement or not, default True. | |||
| seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers. | |||
| Must be non-negative. Default: 0. | |||
| inputs (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with | |||
| float32 data type. | |||
| num_sample (int): Number of samples to draw. | |||
| replacement (bool, optional): Whether to draw with replacement or not, default True. | |||
| seed (int, optional): Seed is used as entropy source for the random number engines to generate | |||
| pseudo-random numbers, | |||
| must be non-negative. Default: 0. | |||
| Outputs: | |||
| Tensor. have the same rows with input, each row has num_samples sampled indices. | |||
| Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`. | |||
| The dtype is float32. | |||
| Examples: | |||
| @@ -197,6 +197,9 @@ class _AutoParallelContext: | |||
| parameter_broadcast (bool): Parameter broadcast or not. | |||
| """ | |||
| self.check_context_handle() | |||
| if parameter_broadcast is True and context.get_context("enable_ge") is False: | |||
| raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to" | |||
| " use mindspore.common.set_seed() to share parameters among devices.") | |||
| self._context_handle.set_parameter_broadcast(parameter_broadcast) | |||
| def get_parameter_broadcast(self): | |||
| @@ -58,7 +58,7 @@ if __name__ == '__main__': | |||
| cfg.group_size = get_group_size() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, | |||
| parameter_broadcast=True, gradients_mean=True) | |||
| gradients_mean=True) | |||
| else: | |||
| cfg.rank = 0 | |||
| cfg.group_size = 1 | |||
| @@ -61,7 +61,7 @@ if __name__ == '__main__': | |||
| cfg.group_size = get_group_size() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, | |||
| parameter_broadcast=True, gradients_mean=True) | |||
| gradients_mean=True) | |||
| else: | |||
| cfg.rank = 0 | |||
| cfg.group_size = 1 | |||
| @@ -136,8 +136,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): | |||
| os.environ['RANK_SIZE'] = str(device_num) | |||
| if enable_hccl: | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True, parameter_broadcast=True, | |||
| all_reduce_fusion_config=[107, 160]) | |||
| gradients_mean=True, all_reduce_fusion_config=[107, 160]) | |||
| init() | |||
| # network | |||
| @@ -239,8 +238,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): | |||
| os.environ['RANK_SIZE'] = str(device_num) | |||
| if enable_hccl: | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True, parameter_broadcast=True, | |||
| all_reduce_fusion_config=[107]) | |||
| gradients_mean=True, all_reduce_fusion_config=[107]) | |||
| init() | |||
| # network | |||