| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| RangeGPUKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| RangeGPUKernel, int) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class RangeGPUKernel : public GpuKernel { | |||||
| public: | |||||
| RangeGPUKernel() : input_size_(0), output_size_(0), start_(0.), limit_(1.), delta_(1.) {} | |||||
| ~RangeGPUKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| int size = SizeToInt(input_size_ / sizeof(T)); | |||||
| CalRange(size, start_, limit_, delta_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 1) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but Range needs 1 input."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but Range needs 1 output."; | |||||
| return false; | |||||
| } | |||||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| auto shape_size = input_shape.size(); | |||||
| input_size_ = 1; | |||||
| for (size_t i = 0; i < shape_size; i++) { | |||||
| input_size_ *= input_shape[i]; | |||||
| } | |||||
| input_size_ *= sizeof(T); | |||||
| output_size_ = input_size_; | |||||
| start_ = GetAttr<float>(kernel_node, "start"); | |||||
| limit_ = GetAttr<float>(kernel_node, "limit"); | |||||
| delta_ = GetAttr<float>(kernel_node, "delta"); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| return; | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| size_t input_size_; | |||||
| size_t output_size_; | |||||
| float start_; | |||||
| float limit_; | |||||
| float delta_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cuda_runtime.h> | |||||
| #include "range_impl.cuh" | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| __global__ void Range(const int size, const float start, const float limit, const float delta, const T *input, | |||||
| T *output) { | |||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| output[pos] = input[pos] * delta + start; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output, | |||||
| cudaStream_t cuda_stream) { | |||||
| Range<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, start, limit, delta, input, output); | |||||
| return; | |||||
| } | |||||
| template void CalRange<float>(const int size, const float start, const float limit, const float delta, | |||||
| const float *input, float *output, cudaStream_t cuda_stream); | |||||
| template void CalRange<int>(const int size, const float start, const float limit, const float delta, const int *input, | |||||
| int *output, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,23 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_ | |||||
| template <typename T> | |||||
| void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH | |||||
| @@ -13,8 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Categorical Distribution""" | """Categorical Distribution""" | ||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| import mindspore.nn as nn | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from .distribution import Distribution | from .distribution import Distribution | ||||
| from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error | from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error | ||||
| @@ -119,17 +119,19 @@ class Categorical(Distribution): | |||||
| """ | """ | ||||
| return self._probs | return self._probs | ||||
| def _sample(self, sample_shape=(1,)): | |||||
| def _sample(self, sample_shape=()): | |||||
| """ | """ | ||||
| Sampling. | Sampling. | ||||
| Args: | Args: | ||||
| sample_shape (tuple): shape of the sample. Default: (1,). | |||||
| sample_shape (tuple): shape of the sample. Default: (). | |||||
| Returns: | Returns: | ||||
| Tensor, shape is shape(probs)[:-1] + sample_shape | Tensor, shape is shape(probs)[:-1] + sample_shape | ||||
| """ | """ | ||||
| self.checktuple(sample_shape, 'shape') | self.checktuple(sample_shape, 'shape') | ||||
| if sample_shape == (): | |||||
| sample_shape = (1,) | |||||
| num_sample = 1 | num_sample = 1 | ||||
| for i in sample_shape: | for i in sample_shape: | ||||
| num_sample *= i | num_sample *= i | ||||
| @@ -184,16 +186,15 @@ class Categorical(Distribution): | |||||
| if value is not None: | if value is not None: | ||||
| check_tensor_type("value", value, [mstype.float32, bool, mstype.int32]) | check_tensor_type("value", value, [mstype.float32, bool, mstype.int32]) | ||||
| value = self.expandim(self.cast(value, mstype.float32), -1) | value = self.expandim(self.cast(value, mstype.float32), -1) | ||||
| index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32)) | |||||
| index = self.expandim(index, -1) | |||||
| logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0) | |||||
| broad_shape = self._broad_cast_shape(value, logits) | |||||
| broad_shape = self._broad_cast_shape(value, self._logits) | |||||
| broad = P.BroadcastTo(broad_shape) | broad = P.BroadcastTo(broad_shape) | ||||
| value = broad(value)[..., :1] | |||||
| index = broad(index)[..., :1] | |||||
| logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1])) | |||||
| value = self.reshape(broad(value)[..., :1], (-1, 1)) | |||||
| index = nn.Range(0., self.shape(value)[0], 1)() | |||||
| index = self.reshape(index, (-1, 1)) | |||||
| value = self.concat((index, value)) | value = self.concat((index, value)) | ||||
| value = self.cast(value, mstype.int32) | value = self.cast(value, mstype.int32) | ||||
| return self.gather(logits, value) | |||||
| return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1]) | |||||
| return None | return None | ||||
| def _entropy(self): | def _entropy(self): | ||||
| @@ -211,7 +212,7 @@ class Categorical(Distribution): | |||||
| Enumerate categories. | Enumerate categories. | ||||
| """ | """ | ||||
| num_events = self._num_events | num_events = self._num_events | ||||
| values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32) | |||||
| values = nn.Range(0., num_events, 1)() | |||||
| values = self.reshape(values, (num_events, 1)) | values = self.reshape(values, (num_events, 1)) | ||||
| if expand: | if expand: | ||||
| values = P.BroadcastTo((num_events, self._batch_shape))(values) | values = P.BroadcastTo((num_events, self._batch_shape))(values) | ||||
| @@ -450,8 +450,8 @@ class Multinomial(PrimitiveWithInfer): | |||||
| Examples: | Examples: | ||||
| >>> input = Tensor([0., 9., 4., 0.], mstype.float32) | >>> input = Tensor([0., 9., 4., 0.], mstype.float32) | ||||
| >>> multinomial = P.Multinomial(seed=10) | |||||
| >>> output = multinomial(input, 2, True) | |||||
| >>> multinomial = P.Multinomial(replacement=True, seed=10) | |||||
| >>> output = multinomial(input, 2) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||