| @@ -290,8 +290,8 @@ class SampledSoftmaxLoss(_Loss): | |||||
| num_classes (int): The number of possible classes. | num_classes (int): The number of possible classes. | ||||
| num_true (int): The number of target classes per training example. | num_true (int): The number of target classes per training example. | ||||
| sampled_values (Tuple): Tuple of (`sampled_candidates`, `true_expected_count`, | sampled_values (Tuple): Tuple of (`sampled_candidates`, `true_expected_count`, | ||||
| `sampled_expected_count`) returned by a `*_candidate_sampler` function. | |||||
| Default to None, `log_uniform_candidate_sampler` is applied. | |||||
| `sampled_expected_count`) returned by a `*CandidateSampler` function. | |||||
| Default to None, `UniformCandidateSampler` is applied. | |||||
| remove_accidental_hits (bool): Whether to remove "accidental hits" | remove_accidental_hits (bool): Whether to remove "accidental hits" | ||||
| where a sampled class equals one of the target classes. Default is True. | where a sampled class equals one of the target classes. Default is True. | ||||
| seed (int): Random seed for candidate sampling. Default: 0 | seed (int): Random seed for candidate sampling. Default: 0 | ||||
| @@ -301,7 +301,7 @@ class SampledSoftmaxLoss(_Loss): | |||||
| Inputs: | Inputs: | ||||
| - **weights** (Tensor) - Tensor of shape (C, dim). | - **weights** (Tensor) - Tensor of shape (C, dim). | ||||
| - **bias** (Tensor) - Tensor of shape (C). The class biases. | - **bias** (Tensor) - Tensor of shape (C). The class biases. | ||||
| - **labels** (Tensor) - Tensor of shape (N, num_true), type `int64`. The | |||||
| - **labels** (Tensor) - Tensor of shape (N, num_true), type `int64, int32`. The | |||||
| target classes. | target classes. | ||||
| - **inputs** (Tensor) - Tensor of shape (N, dim). The forward activations of | - **inputs** (Tensor) - Tensor of shape (N, dim). The forward activations of | ||||
| the input network. | the input network. | ||||
| @@ -414,7 +414,7 @@ class SampledSoftmaxLoss(_Loss): | |||||
| activations of the input network. | activations of the input network. | ||||
| num_true (int): The number of target classes per training example. | num_true (int): The number of target classes per training example. | ||||
| sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, | sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, | ||||
| `sampled_expected_count`) returned by a `UniformSampler` function. | |||||
| `sampled_expected_count`) returned by a `UniformCandidateSampler` function. | |||||
| subtract_log_q: A `bool`. whether to subtract the log expected count of | subtract_log_q: A `bool`. whether to subtract the log expected count of | ||||
| the labels in the sample to get the logits of the true labels. | the labels in the sample to get the logits of the true labels. | ||||
| Default is True. | Default is True. | ||||
| @@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add | |||||
| from .multitype_ops.ones_like_impl import ones_like | from .multitype_ops.ones_like_impl import ones_like | ||||
| from .multitype_ops.zeros_like_impl import zeros_like | from .multitype_ops.zeros_like_impl import zeros_like | ||||
| from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial | from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial | ||||
| from .math_ops import count_nonzero, TensorDot | |||||
| from .math_ops import count_nonzero, tensor_dot | |||||
| from .array_ops import repeat_elements | from .array_ops import repeat_elements | ||||
| @@ -52,5 +52,5 @@ __all__ = [ | |||||
| 'clip_by_value', | 'clip_by_value', | ||||
| 'clip_by_global_norm', | 'clip_by_global_norm', | ||||
| 'count_nonzero', | 'count_nonzero', | ||||
| 'TensorDot', | |||||
| 'tensor_dot', | |||||
| 'repeat_elements'] | 'repeat_elements'] | ||||
| @@ -12,7 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """math Operations.""" | |||||
| """array Operations.""" | |||||
| from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils | from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| @@ -69,7 +69,7 @@ def repeat_elements(x, rep, axis=0): | |||||
| Examples: | Examples: | ||||
| >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) | >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) | ||||
| >>> output = C.RepeatElements(x, rep = 2, axis = 0) | |||||
| >>> output = C.repeat_elements(x, rep = 2, axis = 0) | |||||
| >>> print(output) | >>> print(output) | ||||
| [[0, 1, 2], | [[0, 1, 2], | ||||
| [0, 1, 2], | [0, 1, 2], | ||||
| @@ -75,7 +75,7 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): | |||||
| return nonzero_num | return nonzero_num | ||||
| # TensorDot | |||||
| # tensor dot | |||||
| @constexpr | @constexpr | ||||
| def _int_to_tuple_conv(axes): | def _int_to_tuple_conv(axes): | ||||
| """ | """ | ||||
| @@ -92,7 +92,7 @@ def _check_axes(axes): | |||||
| """ | """ | ||||
| Check for validity and type of axes passed to function. | Check for validity and type of axes passed to function. | ||||
| """ | """ | ||||
| validator.check_value_type('axes', axes, [int, tuple, list], "TensorDot") | |||||
| validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot") | |||||
| if not isinstance(axes, int): | if not isinstance(axes, int): | ||||
| axes = list(axes) # to avoid immutability issues | axes = list(axes) # to avoid immutability issues | ||||
| if len(axes) != 2: | if len(axes) != 2: | ||||
| @@ -156,7 +156,7 @@ def _calc_new_shape(shape, axes, position=0): | |||||
| return new_shape, transpose_perm, free_dims | return new_shape, transpose_perm, free_dims | ||||
| def TensorDot(x1, x2, axes): | |||||
| def tensor_dot(x1, x2, axes): | |||||
| """ | """ | ||||
| Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. | Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. | ||||
| @@ -171,8 +171,8 @@ def TensorDot(x1, x2, axes): | |||||
| axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` | axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` | ||||
| Inputs: | Inputs: | ||||
| - **x1** (Tensor) - First tensor in TensorDot op with datatype float16 or float32 | |||||
| - **x2** (Tensor) - Second tensor in TensorDot op with datatype float16 or float32 | |||||
| - **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32 | |||||
| - **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32 | |||||
| - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or | - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or | ||||
| tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, | tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, | ||||
| automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. | automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. | ||||
| @@ -184,7 +184,7 @@ def TensorDot(x1, x2, axes): | |||||
| Examples: | Examples: | ||||
| >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) | >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) | ||||
| >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) | >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) | ||||
| >>> output = C.TensorDot(input_x1, input_x2, ((0,1),(1,2))) | |||||
| >>> output = C.tensor_dot(input_x1, input_x2, ((0,1),(1,2))) | |||||
| >>> print(output) | >>> print(output) | ||||
| [[2,2,2], | [[2,2,2], | ||||
| [2,2,2], | [2,2,2], | ||||
| @@ -206,7 +206,7 @@ def TensorDot(x1, x2, axes): | |||||
| x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) | x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) | ||||
| x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) | x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) | ||||
| output_shape = x1_ret + x2_ret # combine free axes from both inputs | output_shape = x1_ret + x2_ret # combine free axes from both inputs | ||||
| # run TensorDot op | |||||
| # run tensor_dot op | |||||
| x1_transposed = transpose_op(x1, x1_transpose_fwd) | x1_transposed = transpose_op(x1, x1_transpose_fwd) | ||||
| x2_transposed = transpose_op(x2, x2_transpose_fwd) | x2_transposed = transpose_op(x2, x2_transpose_fwd) | ||||
| x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd) | x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd) | ||||
| @@ -723,8 +723,10 @@ class Unique(Primitive): | |||||
| - **x** (Tensor) - The input tensor. | - **x** (Tensor) - The input tensor. | ||||
| Outputs: | Outputs: | ||||
| Tuple, containing Tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor | |||||
| containing indices of elements in the input coressponding to the output tensor. | |||||
| Tuple, containing Tensor objects `(y, idx)., `y` is a tensor with the | |||||
| same type as `x`, and contains the unique elements in `x`, sorted in | |||||
| ascending order. `idx` is a tensor containing indices of elements in | |||||
| the input corresponding to the output tensor. | |||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` ``CPU`` | ``Ascend`` ``GPU`` ``CPU`` | ||||
| @@ -734,6 +736,23 @@ class Unique(Primitive): | |||||
| >>> output = ops.Unique()(x) | >>> output = ops.Unique()(x) | ||||
| >>> print(output) | >>> print(output) | ||||
| (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1])) | (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1])) | ||||
| >>> | |||||
| >>> # note that for GPU, this operator must be wrapped inside a model, and executed in graph mode. | |||||
| >>> class UniqueNet(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(UniqueNet, self).__init__() | |||||
| >>> self.unique_op = P.Unique() | |||||
| >>> | |||||
| >>> def construct(self, x): | |||||
| >>> output, indices = self.unique_op(x) | |||||
| >>> return output, indices | |||||
| >>> | |||||
| >>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32) | |||||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| >>> net = UniqueNet() | |||||
| >>> output = net(x) | |||||
| >>> print(output) | |||||
| (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1])) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -29,7 +29,7 @@ class NetTensorDot(nn.Cell): | |||||
| self.axes = axes | self.axes = axes | ||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return C.TensorDot(x, y, self.axes) | |||||
| return C.tensor_dot(x, y, self.axes) | |||||
| class GradNetwork(nn.Cell): | class GradNetwork(nn.Cell): | ||||