From: @yuanwei66 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -0,0 +1,110 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <random> | |||||
| #include <thread> | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| #include "backend/kernel_compiler/cpu/random_cpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void StandardNormal(float *output, std::normal_distribution<float> distribution, | |||||
| std::default_random_engine random_generator, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| output[i] = distribution(random_generator); | |||||
| } | |||||
| } | |||||
| void LaunchStandardNormal(int seed, int seed2, const std::vector<AddressPtr> &outputs) { | |||||
| unsigned int RNG_seed; | |||||
| std::random_device rd; | |||||
| if (seed2 != 0) { | |||||
| RNG_seed = IntToUint(seed2); | |||||
| } else if (seed != 0) { | |||||
| RNG_seed = IntToUint(seed); | |||||
| } else { | |||||
| RNG_seed = rd(); | |||||
| } | |||||
| auto output = reinterpret_cast<float *>(outputs[0]->addr); | |||||
| // multithreading | |||||
| size_t lens = outputs[0]->size / sizeof(float); | |||||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||||
| if (thread_num < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||||
| return; | |||||
| } | |||||
| std::vector<std::thread> threads; | |||||
| threads.reserve(thread_num); | |||||
| size_t start = 0; | |||||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||||
| if (once_compute_size < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||||
| return; | |||||
| } | |||||
| std::normal_distribution<float> distribution; | |||||
| while (start < lens) { | |||||
| // avoid different threads using the same seed to generate the same random number | |||||
| std::default_random_engine random_generator(++RNG_seed); | |||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||||
| threads.emplace_back(std::thread(StandardNormal, output, distribution, random_generator, start, end)); | |||||
| start += once_compute_size; | |||||
| } | |||||
| for (size_t i = 0; i < threads.size(); ++i) { | |||||
| threads[i].join(); | |||||
| } | |||||
| } | |||||
| void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto iter = kRandomOpTypeMap.find(kernel_name); | |||||
| if (iter == kRandomOpTypeMap.end()) { | |||||
| MS_LOG(EXCEPTION) << "Random operation " << kernel_name << " is not supported."; | |||||
| } else { | |||||
| random_op_type_ = iter->second; | |||||
| } | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if ((random_op_type_ == RANDOM_OP_NORMAL) && input_num != 1) { | |||||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but random op needs 1 input."; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but random op needs 1 output."; | |||||
| } | |||||
| seed_ = LongToInt(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed"))); | |||||
| seed2_ = LongToInt(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2"))); | |||||
| } | |||||
| bool RandomCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| switch (random_op_type_) { | |||||
| case RANDOM_OP_NORMAL: { | |||||
| LaunchStandardNormal(seed_, seed2_, outputs); | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANDOM_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANDOM_CPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <map> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_INT, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; | |||||
| const std::map<std::string, RandomOptype> kRandomOpTypeMap = { | |||||
| {"StandardNormal", RANDOM_OP_NORMAL}, {"UniformInt", RANDOM_OP_UNIFORM_INT}, {"UniformReal", RANDOM_OP_UNIFORM_REAL}}; | |||||
| class RandomCPUKernel : public CPUKernel { | |||||
| public: | |||||
| RandomCPUKernel() = default; | |||||
| ~RandomCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| private: | |||||
| RandomOptype random_op_type_{RANDOM_OP_INVALID_TYPE}; | |||||
| int seed_{0}; | |||||
| int seed2_{0}; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||||
| RandomCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANDOM_CPU_KERNEL_H_ | |||||
| @@ -98,6 +98,13 @@ inline uint64_t LongToUlong(int64_t u) { | |||||
| return static_cast<uint64_t>(u); | return static_cast<uint64_t>(u); | ||||
| } | } | ||||
| inline int32_t LongToInt(int64_t u) { | |||||
| if (u > static_cast<int64_t>((std::numeric_limits<int32_t>::max)())) { | |||||
| MS_LOG(EXCEPTION) << "The size_t value(" << u << ") exceeds the maximum value of int."; | |||||
| } | |||||
| return static_cast<int32_t>(u); | |||||
| } | |||||
| inline int64_t UlongToLong(uint64_t u) { | inline int64_t UlongToLong(uint64_t u) { | ||||
| if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) { | if (u > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) { | ||||
| MS_LOG(EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t."; | MS_LOG(EXCEPTION) << "The uint64_t value(" << u << ") exceeds the maximum value of int64_t."; | ||||
| @@ -35,7 +35,7 @@ class StandardNormal(PrimitiveWithInfer): | |||||
| Tensor. The shape is the same as the input `shape`. The dtype is float32. | Tensor. The shape is the same as the input `shape`. The dtype is float32. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> shape = (4, 16) | >>> shape = (4, 16) | ||||
| @@ -0,0 +1,85 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from scipy.stats import kstest | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, shape, seed=0, seed2=0): | |||||
| super(Net, self).__init__() | |||||
| self.shape = shape | |||||
| self.seed = seed | |||||
| self.seed2 = seed2 | |||||
| self.stdnormal = P.StandardNormal(seed, seed2) | |||||
| def construct(self): | |||||
| return self.stdnormal(self.shape) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_net(): | |||||
| seed = 10 | |||||
| seed2 = 10 | |||||
| shape = (5, 6, 8) | |||||
| net = Net(shape, seed, seed2) | |||||
| output = net() | |||||
| assert output.shape == (5, 6, 8) | |||||
| outnumpyflatten_1 = output.asnumpy().flatten() | |||||
| _, p_value = kstest(outnumpyflatten_1, "norm") | |||||
| # p-value is greater than the significance level, cannot reject the hypothesis that the data come from | |||||
| # the standard norm distribution. | |||||
| assert p_value >= 0.05 | |||||
| seed = 0 | |||||
| seed2 = 10 | |||||
| shape = (5, 6, 8) | |||||
| net = Net(shape, seed, seed2) | |||||
| output = net() | |||||
| assert output.shape == (5, 6, 8) | |||||
| outnumpyflatten_2 = output.asnumpy().flatten() | |||||
| _, p_value = kstest(outnumpyflatten_2, "norm") | |||||
| assert p_value >= 0.05 | |||||
| # same seed should generate same random number | |||||
| assert (outnumpyflatten_1 == outnumpyflatten_2).all() | |||||
| seed = 0 | |||||
| seed2 = 0 | |||||
| shape = (130, 120, 141) | |||||
| net = Net(shape, seed, seed2) | |||||
| output = net() | |||||
| assert output.shape == (130, 120, 141) | |||||
| outnumpyflatten_1 = output.asnumpy().flatten() | |||||
| _, p_value = kstest(outnumpyflatten_1, "norm") | |||||
| assert p_value >= 0.05 | |||||
| seed = 0 | |||||
| seed2 = 0 | |||||
| shape = (130, 120, 141) | |||||
| net = Net(shape, seed, seed2) | |||||
| output = net() | |||||
| assert output.shape == (130, 120, 141) | |||||
| outnumpyflatten_2 = output.asnumpy().flatten() | |||||
| _, p_value = kstest(outnumpyflatten_2, "norm") | |||||
| assert p_value >= 0.05 | |||||
| # different seed(seed = 0) should generate different random number | |||||
| assert ~(outnumpyflatten_1 == outnumpyflatten_2).all() | |||||