Merge pull request !7786 from JonathanY/ops_octtags/v1.1.0
| @@ -30,7 +30,8 @@ namespace kernel { | |||
| template <typename T, typename S> | |||
| class UniformSamplerGpuKernel : public GpuKernel { | |||
| public: | |||
| UniformSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0) {} | |||
| UniformSamplerGpuKernel() | |||
| : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {} | |||
| ~UniformSamplerGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -43,6 +44,16 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||
| T *sampled_candidates = GetDeviceAddress<T>(outputs, 0); | |||
| S *true_expected_count = GetDeviceAddress<S>(outputs, 1); | |||
| S *sampled_expected_count = GetDeviceAddress<S>(outputs, 2); | |||
| if (remove_accidental_hits_) { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| array_input_ = std::vector<T>(input_size_, 0); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&array_input_[0], input, input_size_ * sizeof(T), | |||
| cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync sampled_candidates failed"); | |||
| for (const auto item : array_input_) { | |||
| set_input_.insert(item); | |||
| } | |||
| } | |||
| int counter = Sampling(); | |||
| float prob = Probability(); | |||
| size_t sampled_candidates_size = num_sampled_ * sizeof(T); | |||
| @@ -72,6 +83,7 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||
| unique_ = GetAttr<bool>(kernel_node, "unique"); | |||
| range_max_ = GetAttr<int>(kernel_node, "range_max"); | |||
| int seed = GetAttr<int>(kernel_node, "seed"); | |||
| remove_accidental_hits_ = GetAttr<bool>(kernel_node, "remove_accidental_hits"); | |||
| if (seed == 0) seed = time(NULL); | |||
| generator_.seed(seed); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| @@ -80,6 +92,9 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| input_size_ = input_shape[0] * input_shape[1]; | |||
| if (num_sampled_ * num_true_ + static_cast<int>(input_size_) > range_max_ * num_true_) { | |||
| remove_accidental_hits_ = false; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -105,7 +120,8 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||
| while (picked < num_sampled_) { | |||
| tmp = distribution(generator_); | |||
| counter++; | |||
| if (set_container.find(tmp) == set_container.end()) { | |||
| if ((set_container.find(tmp) == set_container.end()) && | |||
| ((!remove_accidental_hits_) || set_input_.find(tmp) == set_input_.end())) { | |||
| set_container.insert(tmp); | |||
| sampled_candidates_.push_back(tmp); | |||
| picked++; | |||
| @@ -133,6 +149,9 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||
| bool unique_; | |||
| int range_max_; | |||
| size_t input_size_; | |||
| bool remove_accidental_hits_; | |||
| std::vector<T> array_input_; | |||
| std::set<int> set_input_; | |||
| std::default_random_engine generator_; | |||
| std::vector<int> sampled_candidates_; | |||
| std::vector<size_t> input_size_list_; | |||
| @@ -20,8 +20,9 @@ It shows how well the model works on a dataset and the optimization target which | |||
| """ | |||
| from .loss import L1Loss, MSELoss, SmoothL1Loss, \ | |||
| SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss | |||
| SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ | |||
| SampledSoftmaxLoss | |||
| __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', | |||
| 'SoftmaxCrossEntropyWithLogits', 'BCELoss', | |||
| 'CosineEmbeddingLoss'] | |||
| 'CosineEmbeddingLoss', 'SampledSoftmaxLoss'] | |||
| @@ -263,6 +263,186 @@ class SoftmaxCrossEntropyWithLogits(_Loss): | |||
| return self.get_loss(x) | |||
| class SampledSoftmaxLoss(_Loss): | |||
| r""" | |||
| Computes the sampled softmax training loss. | |||
| Args: | |||
| num_sampled (int): The number of classes to randomly sample per batch. | |||
| num_classes (int): The number of possible classes. | |||
| num_true (int): The number of target classes per training example. | |||
| sampled_values (Tuple): Tuple of (`sampled_candidates`, `true_expected_count`, | |||
| `sampled_expected_count`) returned by a `*_candidate_sampler` function. | |||
| Default to None, `log_uniform_candidate_sampler` is applied. | |||
| remove_accidental_hits (bool): Whether to remove "accidental hits" | |||
| where a sampled class equals one of the target classes. Default is True. | |||
| seed (int): Random seed for candidate sampling. Default: 0 | |||
| reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". | |||
| If "none", do not perform reduction. Default: "None". | |||
| Inputs: | |||
| - **weights** (Tensor) - Tensor of shape (C, dim). | |||
| - **bias** (Tensor) - Tensor of shape (C). The class biases. | |||
| - **labels** (Tensor) - Tensor of shape (N, num_true), type `int64`. The | |||
| target classes. | |||
| - **inputs** (Tensor) - Tensor of shape (N, dim). The forward activations of | |||
| the input network. | |||
| Outputs: | |||
| Tensor, a tensor of shape (N) with the per-example sampled softmax losses. | |||
| """ | |||
| def __init__(self, num_sampled, num_classes, num_true=1, | |||
| sampled_values=None, remove_accidental_hits=True, seed=0, | |||
| reduction='none'): | |||
| super(SampledSoftmaxLoss, self).__init__() | |||
| self.num_sampled = num_sampled | |||
| self.num_classes = num_classes | |||
| self.num_true = num_true | |||
| self.sampled_values = sampled_values | |||
| self.remove_accidental_hits = remove_accidental_hits | |||
| self.seed = seed | |||
| self.sampler = P.UniformSampler( | |||
| num_true, | |||
| num_sampled, | |||
| True, | |||
| num_classes, | |||
| seed, | |||
| remove_accidental_hits) | |||
| self.cast = P.Cast() | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| self.exp = P.Exp() | |||
| self.log = P.Log() | |||
| self.slice_op = P.Slice() | |||
| self.matmul = P.MatMul(False, True) | |||
| self.gather_v2 = P.GatherV2() | |||
| self.reduce_max_true = P.ReduceMax(True) | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.reduce_sum_true = P.ReduceSum(True) | |||
| self.concat_dim0 = P.Concat(0) | |||
| self.concat_dim1 = P.Concat(1) | |||
| self.ones_like = P.OnesLike() | |||
| self.zeros_like = P.ZerosLike() | |||
| self.mul = P.Mul() | |||
| self.expand_dims = P.ExpandDims() | |||
| def construct(self, weights, biases, labels, inputs): | |||
| logits, labels = self._compute_sampled_logits( | |||
| weights=weights, | |||
| biases=biases, | |||
| labels=labels, | |||
| inputs=inputs, | |||
| num_true=self.num_true, | |||
| sampled_values=self.sampled_values, | |||
| subtract_log_q=True) | |||
| x = self._softmax_cross_entropy(logits, labels) | |||
| return x | |||
| def _softmax_cross_entropy(self, logits, targets): | |||
| stable_exp_logits = self.exp(logits - self.reduce_max_true(logits, 1)) | |||
| pred = stable_exp_logits / self.reduce_sum_true(stable_exp_logits, 1) | |||
| return -self.reduce_sum(targets * self.log(pred + 1.0e-20), 1) | |||
| def _compute_sampled_logits(self, weights, | |||
| biases, | |||
| labels, | |||
| inputs, | |||
| num_true=1, | |||
| sampled_values=None, | |||
| subtract_log_q=True): | |||
| """Helper function for SampledSoftmaxLoss functions. | |||
| Computes sampled output training logits and labels suitable | |||
| Note: In the case where num_true > 1, we assign to each target class | |||
| the target probability 1 / num_true so that the target probabilities | |||
| sum to 1 per-example. | |||
| Args: | |||
| weights (Tensor): Tensor of shape `[num_classes, dim]`. | |||
| biases (Tensor): Tensor of shape `[num_classes]`. | |||
| labels (Tensor): Tensor of shape `[batch_size, num_true]`. The target classes. | |||
| inputs (Tensor): Tensor of shape `[batch_size, dim]`. The forward | |||
| activations of the input network. | |||
| num_true (int): The number of target classes per training example. | |||
| sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, | |||
| `sampled_expected_count`) returned by a `UniformSampler` function. | |||
| subtract_log_q: A `bool`. whether to subtract the log expected count of | |||
| the labels in the sample to get the logits of the true labels. | |||
| Default is True. | |||
| Returns: | |||
| out_logits: `Tensor` object with shape | |||
| `[batch_size, num_true + num_sampled]` | |||
| out_labels: A Tensor object with the same shape as `out_logits`. | |||
| """ | |||
| if not labels.dtype == mstype.int32: | |||
| labels = self.cast(labels, mstype.int32) | |||
| labels = self.reshape(labels, (-1, num_true)) | |||
| labels_flat = self.reshape(labels, (-1,)) | |||
| # Sample the negative labels. | |||
| # sampled shape: [num_sampled] tensor | |||
| # true_expected_count shape = [batch_size, 1] tensor | |||
| # sampled_expected_count shape = [num_sampled] tensor | |||
| if sampled_values is None: | |||
| sampled_values = self.sampler(labels) | |||
| (sampled, true_expected_count, sampled_expected_count) = sampled_values | |||
| if not sampled.dtype == mstype.int32: | |||
| sampled = self.cast(sampled, mstype.int32) | |||
| all_ids = self.concat_dim0((labels_flat, sampled)) | |||
| all_w = self.gather_v2(weights, all_ids, 0) | |||
| n_true = self.shape(labels_flat)[0] | |||
| n_sampled = self.shape(sampled)[0] | |||
| n_dim = self.shape(all_w)[1] | |||
| # true_w shape is [batch_size * num_true, dim] | |||
| true_w = self.slice_op(all_w, [0, 0], [n_true, n_dim]) | |||
| sampled_w = self.slice_op(all_w, [n_true, 0], [n_sampled, n_dim]) | |||
| sampled_logits = self.matmul(inputs, sampled_w) | |||
| all_b = self.gather_v2(biases, all_ids, 0) | |||
| true_b = self.slice_op(all_b, [0], [n_true]) | |||
| sampled_b = self.slice_op(all_b, [n_true], [n_sampled]) | |||
| # inputs shape is [batch_size, dim] | |||
| # true_w shape is [batch_size * num_true, dim] | |||
| # row_wise_dots is [batch_size, num_true, dim] | |||
| new_true_w_shape = (-1, num_true, n_dim) | |||
| row_wise_dots = self.mul(self.expand_dims(inputs, 1), | |||
| self.reshape(true_w, new_true_w_shape)) | |||
| # We want the row-wise dot plus biases which yields a | |||
| # [batch_size, num_true] tensor of true_logits. | |||
| dots_as_matrix = self.reshape(row_wise_dots, (-1, n_dim)) | |||
| true_logits = self.reshape(self.reduce_sum(dots_as_matrix, 1), (-1, num_true)) | |||
| true_b = self.reshape(true_b, (-1, num_true)) | |||
| true_logits += true_b | |||
| sampled_logits += sampled_b | |||
| if subtract_log_q: | |||
| # Subtract log of Q(l), prior probability that l appears in sampled. | |||
| true_logits -= self.log(true_expected_count) | |||
| sampled_logits -= self.log(sampled_expected_count) | |||
| # Construct output logits and labels. The true labels/logits start at col 0. | |||
| out_logits = self.concat_dim1((true_logits, sampled_logits)) | |||
| # true_logits is a float tensor, ones_like(true_logits) is a float | |||
| # tensor of ones. We then divide by num_true to ensure the per-example | |||
| # labels sum to 1.0, i.e. form a proper probability distribution. | |||
| out_labels = self.concat_dim1(( | |||
| self.ones_like(true_logits) / num_true, | |||
| self.zeros_like(sampled_logits) | |||
| )) | |||
| return out_logits, out_labels | |||
| class BCELoss(_Loss): | |||
| r""" | |||
| BCELoss creates a criterion to measure the Binary Cross Entropy between the true labels and predicted labels. | |||
| @@ -5831,6 +5831,7 @@ class UniformSampler(PrimitiveWithInfer): | |||
| unique (bool): Whether all sampled classes in a batch are unique. | |||
| range_max (int): The number of possible classes. | |||
| seed (int): Random seed, must be non-negative. Default: 0. | |||
| remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. | |||
| Inputs: | |||
| true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true). | |||
| @@ -5850,13 +5851,14 @@ class UniformSampler(PrimitiveWithInfer): | |||
| [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, num_true, num_sampled, unique, range_max, seed=0): | |||
| def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): | |||
| """Initialize UniformSampler""" | |||
| validator.check_value_type("num_true", num_true, [int], self.name) | |||
| validator.check_value_type("num_sampled", num_sampled, [int], self.name) | |||
| validator.check_value_type("unique", unique, [bool], self.name) | |||
| validator.check_value_type("range_max", range_max, [int], self.name) | |||
| validator.check_value_type("seed", seed, [int], self.name) | |||
| validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name) | |||
| validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name) | |||
| if unique: | |||
| validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name) | |||
| @@ -0,0 +1,137 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| def generate_test_data(num_classes, batch_size, sampled): | |||
| dim = 10 | |||
| weights_s = np.linspace(start=1, stop=num_classes * dim, num=num_classes * dim) | |||
| weights_s = np.reshape(weights_s, (num_classes, dim)).astype(np.float32) / 100.0 | |||
| biases_s = np.linspace(start=1, stop=num_classes, num=num_classes) | |||
| biases_s = np.reshape(biases_s, (num_classes,)).astype(np.float32) / 100.0 | |||
| hidden_acts_s = np.linspace(start=1, stop=batch_size * dim, num=batch_size * dim) | |||
| hidden_acts_s = np.reshape( | |||
| hidden_acts_s, (batch_size, dim)).astype(np.float32) / 100.0 | |||
| true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) | |||
| sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) | |||
| sampled_values = (Tensor(sampled), Tensor(true_exp), Tensor(sampled_exp)) | |||
| return weights_s, biases_s, hidden_acts_s, sampled_values | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sampled_softmax_loss_assigned_sampler(): | |||
| np.random.seed(0) | |||
| num_classes = 7 | |||
| batch_size = 3 | |||
| labels = [0, 1, 2] | |||
| (weights, biases, hidden_acts, sampled_vals) = generate_test_data( | |||
| num_classes=num_classes, | |||
| batch_size=batch_size, | |||
| sampled=[4, 0, 2, 3]) | |||
| def case_not_remove_accidental_hits(): | |||
| loss = nn.SampledSoftmaxLoss( | |||
| num_sampled=4, | |||
| num_classes=num_classes, | |||
| num_true=1, | |||
| sampled_values=sampled_vals, | |||
| remove_accidental_hits=False) | |||
| got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases), | |||
| Tensor(labels), Tensor(hidden_acts)) | |||
| exp_sampled_softmax_loss = np.array( | |||
| [1.7318448, 1.8015041, 1.7211525]).astype(np.float32) | |||
| assert np.allclose(got_sampled_softmax_loss.asnumpy(), | |||
| exp_sampled_softmax_loss) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| case_not_remove_accidental_hits() | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| case_not_remove_accidental_hits() | |||
| (weights, biases, hidden_acts, sampled_vals) = generate_test_data( | |||
| num_classes=num_classes, | |||
| batch_size=batch_size, | |||
| sampled=[4, 5, 6, 3]) | |||
| def case_remove_accidental_hits(): | |||
| loss = nn.SampledSoftmaxLoss( | |||
| num_sampled=4, | |||
| num_classes=num_classes, | |||
| num_true=1, | |||
| sampled_values=sampled_vals, | |||
| remove_accidental_hits=True) | |||
| got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases), | |||
| Tensor(labels), Tensor(hidden_acts)) | |||
| exp_sampled_softmax_loss = np.array( | |||
| [[1.85211, 2.10999, 2.20862]]).astype(np.float32) | |||
| assert np.allclose(got_sampled_softmax_loss.asnumpy(), | |||
| exp_sampled_softmax_loss) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| case_remove_accidental_hits() | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| case_remove_accidental_hits() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sampled_softmax_loss_none_sampler(): | |||
| np.random.seed(0) | |||
| num_classes = 7 | |||
| batch_size = 3 | |||
| labels = [0, 1, 2] | |||
| (weights, biases, hidden_acts, _) = generate_test_data( | |||
| num_classes=num_classes, | |||
| batch_size=batch_size, | |||
| sampled=[4, 0, 2, 3]) | |||
| def case_no_sampler(): | |||
| loss = nn.SampledSoftmaxLoss( | |||
| num_sampled=4, | |||
| num_classes=num_classes, | |||
| num_true=1, | |||
| sampled_values=None, | |||
| seed=1, | |||
| remove_accidental_hits=False) | |||
| got_sampled_softmax_loss = loss(Tensor(weights), Tensor(biases), | |||
| Tensor(labels), Tensor(hidden_acts)) | |||
| exp_sampled_softmax_loss = np.array( | |||
| [1.7345718, 1.820291, 1.7704818]).astype(np.float32) | |||
| assert np.allclose(got_sampled_softmax_loss.asnumpy(), | |||
| exp_sampled_softmax_loss) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| case_no_sampler() | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| case_no_sampler() | |||
| if __name__ == "__main__": | |||
| test_sampled_softmax_loss_assigned_sampler() | |||
| test_sampled_softmax_loss_none_sampler() | |||
| @@ -35,6 +35,25 @@ def uniform_sampler(x, num_true, num_sampled, unique, range_max): | |||
| out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) | |||
| return out1.shape, out2.shape, out3.shape | |||
| class UniformSamplerHitNet(nn.Cell): | |||
| def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): | |||
| super(UniformSamplerHitNet, self).__init__() | |||
| self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max, seed=seed, | |||
| remove_accidental_hits=remove_accidental_hits) | |||
| def construct(self, x): | |||
| return self.sampler(x) | |||
| def uniform_sampler_hit(x, num_true, num_sampled, unique, range_max, seed, | |||
| remove_accidental_hits): | |||
| uniform_sampler_net = UniformSamplerHitNet(num_true, num_sampled, unique, range_max, | |||
| seed, remove_accidental_hits) | |||
| out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) | |||
| return out1, out2, out3 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -114,3 +133,23 @@ def test_uniform_sampler_large_random(): | |||
| np.testing.assert_array_equal(ms1, expected_1) | |||
| np.testing.assert_array_equal(ms2, expected_2) | |||
| np.testing.assert_array_equal(ms3, expected_3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_uniform_sampler_unique_1_true_hit(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False) | |||
| expected_1 = np.array([0, 3, 1]) | |||
| np.testing.assert_array_equal(ms1.asnumpy(), expected_1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_uniform_sampler_unique_1_true_no_hit(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True) | |||
| expected_1 = np.array([0, 3, 2]) | |||
| np.testing.assert_array_equal(ms1.asnumpy(), expected_1) | |||