Merge pull request !1021 from chenweifeng/broadcast_gradtags/v0.3.0-alpha
| @@ -0,0 +1,116 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| struct MinimumGradFunc { | |||
| __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { | |||
| if (x1 < x2) { | |||
| atomicAdd(dx1, dy); | |||
| } else { | |||
| atomicAdd(dx2, dy); | |||
| } | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct MaximumGradFunc { | |||
| __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { | |||
| if (x1 > x2) { | |||
| atomicAdd(dx1, dy); | |||
| } else { | |||
| atomicAdd(dx2, dy); | |||
| } | |||
| } | |||
| }; | |||
| __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } | |||
| template <typename T, typename Func> | |||
| __device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &l1, const int &l2, const int &l3, | |||
| const int &r0, const int &r1, const int &r2, const int &r3, | |||
| const int &d0, const int &d1, const int &d2, const int &d3, | |||
| const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (d1 * d2 * d3) % d0; | |||
| int j = pos / (d2 * d3) % d1; | |||
| int k = pos / d3 % d2; | |||
| int l = pos % d3; | |||
| int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); | |||
| int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); | |||
| Func()(x1[l_index], x2[r_index], dy[pos], dx1 + l_index, dx2 + r_index); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void BroadcastGradKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, | |||
| const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, | |||
| enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, | |||
| T *dx2) { | |||
| switch (op) { | |||
| case BROADCAST_GRAD_TYPE_MINIMUM: | |||
| return BroadcastGradOperator<T, MinimumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, | |||
| dx1, dx2); | |||
| case BROADCAST_GRAD_TYPE_MAXIMUM: | |||
| return BroadcastGradOperator<T, MaximumGradFunc<T>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, | |||
| dx1, dx2); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, | |||
| cudaStream_t stream) { | |||
| int size = d0 * d1 * d2 * d3; | |||
| BroadcastGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, | |||
| x1, x2, dy, dx1, dx2); | |||
| } | |||
| template <typename T, typename Func> | |||
| __device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *x1, const T *x2, const T *dy, T *dx1, | |||
| T *dx2) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { | |||
| Func()(x1[pos], x2[pos], dy[pos], dx1 + pos, dx2 + pos); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void NoBroadcastGradKernel(const int nums, enum BroadcastGradOpType op, const T *x1, const T *x2, | |||
| const T *dy, T *dx1, T *dx2) { | |||
| switch (op) { | |||
| case BROADCAST_GRAD_TYPE_MINIMUM: | |||
| return NoBroadcastOperator<T, MinimumGradFunc<T>>(nums, x1, x2, dy, dx1, dx2); | |||
| case BROADCAST_GRAD_TYPE_MAXIMUM: | |||
| return NoBroadcastOperator<T, MaximumGradFunc<T>>(nums, x1, x2, dy, dx1, dx2); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, | |||
| T *dx2, cudaStream_t stream) { | |||
| NoBroadcastGradKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, op, x1, x2, dy, dx1, dx2); | |||
| } | |||
| template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2, | |||
| const float *dy, float *dx1, float *dx2, cudaStream_t stream); | |||
| template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1, | |||
| float *dx2, cudaStream_t stream); | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| enum BroadcastGradOpType { | |||
| BROADCAST_GRAD_TYPE_MAXIMUM = 0, | |||
| BROADCAST_GRAD_TYPE_MINIMUM = 1, | |||
| BROADCAST_GRAD_TYPE_INVALID = 0xffffffff, | |||
| }; | |||
| template <typename T> | |||
| void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, | |||
| T *dx2, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/math/broadcast_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(MinimumGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(MaximumGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGradGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,149 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ | |||
| #include <cuda_runtime_api.h> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh" | |||
| #include "kernel/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BroadcastOpGradGpuKernel : public GpuKernel { | |||
| public: | |||
| BroadcastOpGradGpuKernel() | |||
| : op_type_(BROADCAST_GRAD_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} | |||
| ~BroadcastOpGradGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| T *x1 = GetDeviceAddress<T>(inputs, 0); | |||
| T *x2 = GetDeviceAddress<T>(inputs, 1); | |||
| T *dy = GetDeviceAddress<T>(inputs, 2); | |||
| T *dx1 = GetDeviceAddress<T>(outputs, 0); | |||
| T *dx2 = GetDeviceAddress<T>(outputs, 1); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx1, 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemSet Failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx2, 0, outputs[1]->size, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemSet Failed"); | |||
| if (need_broadcast_) { | |||
| BroadcastGrad(x1_shape_[0], x1_shape_[1], x1_shape_[2], x1_shape_[3], x2_shape_[0], x2_shape_[1], x2_shape_[2], | |||
| x2_shape_[3], dy_shape_[0], dy_shape_[1], dy_shape_[2], dy_shape_[3], op_type_, x1, x2, dy, dx1, | |||
| dx2, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| NoBroadcastGrad(output_num_, op_type_, x1, x2, dy, dx1, dx2, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| GetOpType(kernel_node); | |||
| auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto shape3 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| need_broadcast_ = IsBroadcast(shape1, shape2); | |||
| if (need_broadcast_ && shape1.size() > 4) { | |||
| MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; | |||
| } | |||
| for (size_t i = 0; i < shape3.size(); i++) { | |||
| dy_shape_[i] = shape3[i]; | |||
| output_num_ *= shape3[i]; | |||
| } | |||
| int offset = shape3.size() - shape1.size(); | |||
| for (size_t i = 0; i < shape1.size(); i++) { | |||
| x1_shape_[i + offset] = shape1[i]; | |||
| input1_num_ *= shape1[i]; | |||
| } | |||
| offset = shape3.size() - shape2.size(); | |||
| for (size_t i = 0; i < shape2.size(); i++) { | |||
| x2_shape_[i + offset] = shape2[i]; | |||
| input2_num_ *= shape2[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() override { return; } | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input1_num_ * sizeof(T)); | |||
| input_size_list_.push_back(input2_num_ * sizeof(T)); | |||
| input_size_list_.push_back(output_num_ * sizeof(T)); | |||
| output_size_list_.push_back(input1_num_ * sizeof(T)); | |||
| output_size_list_.push_back(input2_num_ * sizeof(T)); | |||
| } | |||
| private: | |||
| void GetOpType(const CNodePtr &kernel_node) { | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| static std::map<std::string, BroadcastGradOpType> kBroadcastTypeMap = { | |||
| {"MaximumGrad", BROADCAST_GRAD_TYPE_MAXIMUM}, | |||
| {"MinimumGrad", BROADCAST_GRAD_TYPE_MINIMUM}, | |||
| }; | |||
| auto iter = kBroadcastTypeMap.find(kernel_name); | |||
| if (iter == kBroadcastTypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; | |||
| } else { | |||
| op_type_ = iter->second; | |||
| } | |||
| } | |||
| bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) { | |||
| if (lhs.size() != rhs.size()) { | |||
| return true; | |||
| } | |||
| for (size_t i = 0; i < lhs.size(); i++) { | |||
| if (lhs[i] != rhs[i]) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| BroadcastGradOpType op_type_; | |||
| bool need_broadcast_; | |||
| int input1_num_; | |||
| int input2_num_; | |||
| int output_num_; | |||
| int x1_shape_[4] = {1, 1, 1, 1}; | |||
| int x2_shape_[4] = {1, 1, 1, 1}; | |||
| int dy_shape_[4] = {1, 1, 1, 1}; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ | |||
| @@ -15,6 +15,7 @@ | |||
| import pytest | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.nn import Cell | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.context as context | |||
| @@ -29,11 +30,20 @@ class Net(Cell): | |||
| def construct(self, x, y): | |||
| return self.max(x, y) | |||
| class Grad(Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, x1, x2, sens): | |||
| gout = self.grad(self.network)(x1, x2, sens) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_max(): | |||
| def test_maximum(): | |||
| x = Tensor(np.array([[1, 2, 3]]).astype(np.float32)) | |||
| y = Tensor(np.array([[2]]).astype(np.float32)) | |||
| expect = [[2, 2, 3]] | |||
| @@ -53,3 +63,160 @@ def test_max(): | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU') | |||
| x1_np = np.array([[[[0.659578 ], | |||
| [0.49113268], | |||
| [0.75909054], | |||
| [0.71681815], | |||
| [0.30421826]]], | |||
| [[[0.30322495], | |||
| [0.02858258], | |||
| [0.06398096], | |||
| [0.09519596], | |||
| [0.12498625]]], | |||
| [[[0.7347768 ], | |||
| [0.166469 ], | |||
| [0.328553 ], | |||
| [0.54908437], | |||
| [0.23673844]]]]).astype(np.float32) | |||
| x2_np = np.array([[[[0.9154968, 0.29014662, 0.6492294, 0.39918253, 0.1648203, 0.00861965]], | |||
| [[0.996885, 0.24152198, 0.3601213, 0.51664376, 0.7933056, 0.84706444]], | |||
| [[0.75606346, 0.974512, 0.3939527, 0.69697475, 0.83400667, 0.6348955 ]], | |||
| [[0.68492866, 0.24609096, 0.4924665, 0.22500521, 0.38474053, 0.5586104 ]]]]).astype(np.float32) | |||
| dy_np = np.array([[[[0.42891738, 0.03434946, 0.06192983, 0.21216309, 0.37450036, 0.6619524 ], | |||
| [0.8583447, 0.5765161, 0.1468952, 0.9975385, 0.6908136, 0.4903796 ], | |||
| [0.68952006, 0.39336833, 0.9049695, 0.66886294, 0.2338471, 0.913618 ], | |||
| [0.0428149, 0.6243054, 0.8519898, 0.12088962, 0.9735885, 0.45661286], | |||
| [0.41563734, 0.41607043, 0.4754915, 0.32207987, 0.33823156, 0.47422352]], | |||
| [[0.64478457, 0.22430937, 0.7682554, 0.46082005, 0.8938723, 0.20490853], | |||
| [0.44393885, 0.08278944, 0.4734108, 0.5543551, 0.39428464, 0.44424313], | |||
| [0.12612297, 0.76566416, 0.71133816, 0.81280327, 0.20583127, 0.54058075], | |||
| [0.41341263, 0.48118508, 0.00401995, 0.37259838, 0.05435474, 0.5240658 ], | |||
| [0.4081956, 0.48718935, 0.9132831, 0.67969185, 0.0119757, 0.8328054 ]], | |||
| [[0.91695577, 0.95370644, 0.263782, 0.7477626, 0.6448147, 0.8080634 ], | |||
| [0.15576603, 0.9104615, 0.3778708, 0.6912833, 0.2092224, 0.67462957], | |||
| [0.7087075, 0.7888326, 0.4672294, 0.98221505, 0.25210258, 0.98920417], | |||
| [0.7466197, 0.22702982, 0.01991269, 0.6846591, 0.7515228, 0.5890395 ], | |||
| [0.04531088, 0.21740614, 0.8406235, 0.36480767, 0.37733936, 0.02914464]], | |||
| [[0.33069974, 0.5497569, 0.9896345, 0.4167176, 0.78057563, 0.04659131], | |||
| [0.7747768, 0.21427679, 0.29893255, 0.7706969, 0.9755185, 0.42388415], | |||
| [0.3910244, 0.39381978, 0.37065396, 0.15558061, 0.05012341, 0.15870963], | |||
| [0.17791101, 0.47219893, 0.13899496, 0.32323205, 0.3628809, 0.02580585], | |||
| [0.30274773, 0.62890774, 0.11024303, 0.6980051, 0.35346958, 0.062852 ]]], | |||
| [[[0.6925081, 0.74668753, 0.80145043, 0.06598313, 0.665123, 0.15073007], | |||
| [0.11784806, 0.6385372, 0.5228278, 0.5349848, 0.84671104, 0.8096436 ], | |||
| [0.09516156, 0.63298017, 0.52382874, 0.36734378, 0.66497755, 0.6019127 ], | |||
| [0.46438488, 0.0194377, 0.9388292, 0.7286089, 0.29178405, 0.11872514], | |||
| [0.22101837, 0.6164887, 0.6139798, 0.11711904, 0.6227745, 0.09701069]], | |||
| [[0.80480653, 0.90034056, 0.8633447, 0.97415197, 0.08309154, 0.8446033 ], | |||
| [0.9473769, 0.791024, 0.26339203, 0.01155075, 0.2673186, 0.7116369 ], | |||
| [0.9687511, 0.24281934, 0.37777108, 0.09802654, 0.2421312, 0.87095344], | |||
| [0.6311381, 0.23368953, 0.0998995, 0.4364419, 0.9187446, 0.5043872 ], | |||
| [0.35226053, 0.09357589, 0.41317305, 0.85930043, 0.16249318, 0.5478765 ]], | |||
| [[0.14338651, 0.24859418, 0.4246941, 0.73034066, 0.47172204, 0.8717199 ], | |||
| [0.05415315, 0.78556925, 0.99214983, 0.7415298, 0.673708, 0.87817156], | |||
| [0.616975, 0.42843062, 0.05179814, 0.1566958, 0.04536059, 0.70166487], | |||
| [0.15493333, 0.776598, 0.4361967, 0.40253627, 0.89210516, 0.8144414 ], | |||
| [0.04816005, 0.29696834, 0.4586605, 0.3419852, 0.5595613, 0.74093205]], | |||
| [[0.1388035, 0.9168704, 0.64287645, 0.83864623, 0.48026922, 0.78323376], | |||
| [0.12724937, 0.83034366, 0.42557436, 0.50578654, 0.25630295, 0.15349793], | |||
| [0.27256685, 0.04547984, 0.5385756, 0.39270344, 0.7661698, 0.23722854], | |||
| [0.24620503, 0.25431684, 0.71564585, 0.01161419, 0.846467, 0.7043044 ], | |||
| [0.63272387, 0.11857849, 0.3772076, 0.16758402, 0.46743023, 0.05919575]]], | |||
| [[[0.18827082, 0.8912264, 0.6841404, 0.74436826, 0.9582085, 0.1083683 ], | |||
| [0.60695344, 0.09742349, 0.25074378, 0.87940735, 0.21116392, 0.39418384], | |||
| [0.744686, 0.35679692, 0.01308284, 0.45166633, 0.68166, 0.8634658 ], | |||
| [0.7331758, 0.21113694, 0.3935488, 0.87934476, 0.70728546, 0.09309767], | |||
| [0.12128611, 0.93696386, 0.81177396, 0.85402405, 0.5827289, 0.9776509 ]], | |||
| [[0.54069614, 0.66651285, 0.10646132, 0.17342485, 0.88795924, 0.03551182], | |||
| [0.25531697, 0.87946486, 0.74267226, 0.89230734, 0.95171434, 0.94697934], | |||
| [0.3708397, 0.507355, 0.97099817, 0.4918163, 0.17212386, 0.5008048 ], | |||
| [0.62530744, 0.25210327, 0.73966664, 0.71555346, 0.82484317, 0.6094874 ], | |||
| [0.4589691, 0.1386695, 0.27448782, 0.20373994, 0.27805242, 0.23292768]], | |||
| [[0.7414099, 0.2270226, 0.90431255, 0.47035843, 0.9581062, 0.5359226 ], | |||
| [0.79603523, 0.45549425, 0.80858237, 0.7705133, 0.017761, 0.98001194], | |||
| [0.06013146, 0.99240226, 0.33515573, 0.04110833, 0.41470334, 0.7130743 ], | |||
| [0.5687417, 0.5788611, 0.00722461, 0.6603336, 0.3420471, 0.75181854], | |||
| [0.4699261, 0.51390815, 0.343182, 0.81498754, 0.8942413, 0.46532857]], | |||
| [[0.4589523, 0.5534698, 0.2825786, 0.8205943, 0.78258514, 0.43154418], | |||
| [0.27020997, 0.01667354, 0.60871965, 0.90670526, 0.3208025, 0.96995634], | |||
| [0.85337156, 0.9711295, 0.1381724, 0.53670496, 0.7347996, 0.73380876], | |||
| [0.6137464, 0.54751194, 0.9037335, 0.23134394, 0.61411524, 0.26583543], | |||
| [0.70770144, 0.01813207, 0.24718016, 0.70329237, 0.7062925, 0.14399007]]]]).astype(np.float32) | |||
| expect_dx1 = np.array([[[[ 6.6534014 ], | |||
| [ 5.649811 ], | |||
| [10.071739 ], | |||
| [ 6.6798244 ], | |||
| [ 3.0426278 ]]], | |||
| [[[ 4.2183976 ], | |||
| [ 0.8096436 ], | |||
| [ 0.6019127 ], | |||
| [ 0.11872514], | |||
| [ 0.09701069]]], | |||
| [[[ 9.573029 ], | |||
| [ 0.60534775], | |||
| [ 3.917112 ], | |||
| [ 5.9021177 ], | |||
| [ 2.263672 ]]]]).astype(np.float32) | |||
| expect_dx2 = np.array([[[[6.4205275, 2.941831 , 5.492452 , 4.3212175, 2.4262471, 0. ]], | |||
| [[7.991917 , 2.3792431, 4.9190216, 5.2013817, 6.348791 , 8.351772 ]], | |||
| [[5.518505 , 8.401285 , 4.691043 , 6.463884 , 7.504318 , 7.620938 ]], | |||
| [[5.2708025, 1.2835244, 4.1031275, 1.9843934, 4.9320035, 4.537787 ]]]]).astype(np.float32) | |||
| net = Grad(Net()) | |||
| output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np)) | |||
| assert np.allclose(output_ms[0].asnumpy(), expect_dx1) | |||
| assert np.allclose(output_ms[1].asnumpy(), expect_dx2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast_diff_dims(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU') | |||
| x1_np = np.array([[[0.275478, 0.48933202, 0.71846116], | |||
| [0.9803821, 0.57205725, 0.28511533]], | |||
| [[0.61111903, 0.9671023, 0.70624334], | |||
| [0.53730786, 0.90413177, 0.94349676]]]).astype(np.float32) | |||
| x2_np = np.array([[0.01045662, 0.82126397, 0.6365063 ], | |||
| [0.9900942, 0.6584232, 0.98537433]]).astype(np.float32) | |||
| dy_np = np.array([[[0.3897645, 0.61152864, 0.33675498], | |||
| [0.5303635, 0.84893036, 0.4959739 ]], | |||
| [[0.5391046, 0.8443047, 0.4174708 ], | |||
| [0.57513475, 0.9225578, 0.46760973]]]).astype(np.float32) | |||
| expect_dx1 = np.array([[[0.3897645 , 0. , 0.33675498], | |||
| [0. , 0. , 0. ]], | |||
| [[0.5391046 , 0.8443047 , 0.4174708 ], | |||
| [0. , 0.9225578 , 0. ]]]).astype(np.float32) | |||
| expect_dx2 = np.array([[0. , 0.61152864, 0. ], | |||
| [1.1054983 , 0.84893036, 0.96358365]]).astype(np.float32) | |||
| net = Grad(Net()) | |||
| output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np)) | |||
| assert np.allclose(output_ms[0].asnumpy(), expect_dx1) | |||
| assert np.allclose(output_ms[1].asnumpy(), expect_dx2) | |||
| @@ -0,0 +1,220 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.nn import Cell | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.context as context | |||
| import numpy as np | |||
| class MinimumNet(Cell): | |||
| def __init__(self): | |||
| super(MinimumNet, self).__init__() | |||
| self.min = P.Minimum() | |||
| def construct(self, x1, x2): | |||
| x = self.min(x1, x2) | |||
| return x | |||
| class Grad(Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, x1, x2, sens): | |||
| gout = self.grad(self.network)(x1, x2, sens) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nobroadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU') | |||
| x1_np = np.random.rand(3, 4).astype(np.float32) | |||
| x2_np = np.random.rand(3, 4).astype(np.float32) | |||
| dy_np = np.random.rand(3, 4).astype(np.float32) | |||
| net = Grad(MinimumNet()) | |||
| output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np)) | |||
| output0_np = np.where(x1_np < x2_np, dy_np, 0) | |||
| output1_np = np.where(x1_np < x2_np, 0, dy_np) | |||
| assert np.allclose(output_ms[0].asnumpy(), output0_np) | |||
| assert np.allclose(output_ms[1].asnumpy(), output1_np) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU') | |||
| x1_np = np.array([[[[0.659578 ], | |||
| [0.49113268], | |||
| [0.75909054], | |||
| [0.71681815], | |||
| [0.30421826]]], | |||
| [[[0.30322495], | |||
| [0.02858258], | |||
| [0.06398096], | |||
| [0.09519596], | |||
| [0.12498625]]], | |||
| [[[0.7347768 ], | |||
| [0.166469 ], | |||
| [0.328553 ], | |||
| [0.54908437], | |||
| [0.23673844]]]]).astype(np.float32) | |||
| x2_np = np.array([[[[0.9154968, 0.29014662, 0.6492294, 0.39918253, 0.1648203, 0.00861965]], | |||
| [[0.996885, 0.24152198, 0.3601213, 0.51664376, 0.7933056, 0.84706444]], | |||
| [[0.75606346, 0.974512, 0.3939527, 0.69697475, 0.83400667, 0.6348955 ]], | |||
| [[0.68492866, 0.24609096, 0.4924665, 0.22500521, 0.38474053, 0.5586104 ]]]]).astype(np.float32) | |||
| dy_np = np.array([[[[0.42891738, 0.03434946, 0.06192983, 0.21216309, 0.37450036, 0.6619524 ], | |||
| [0.8583447, 0.5765161, 0.1468952, 0.9975385, 0.6908136, 0.4903796 ], | |||
| [0.68952006, 0.39336833, 0.9049695, 0.66886294, 0.2338471, 0.913618 ], | |||
| [0.0428149, 0.6243054, 0.8519898, 0.12088962, 0.9735885, 0.45661286], | |||
| [0.41563734, 0.41607043, 0.4754915, 0.32207987, 0.33823156, 0.47422352]], | |||
| [[0.64478457, 0.22430937, 0.7682554, 0.46082005, 0.8938723, 0.20490853], | |||
| [0.44393885, 0.08278944, 0.4734108, 0.5543551, 0.39428464, 0.44424313], | |||
| [0.12612297, 0.76566416, 0.71133816, 0.81280327, 0.20583127, 0.54058075], | |||
| [0.41341263, 0.48118508, 0.00401995, 0.37259838, 0.05435474, 0.5240658 ], | |||
| [0.4081956, 0.48718935, 0.9132831, 0.67969185, 0.0119757, 0.8328054 ]], | |||
| [[0.91695577, 0.95370644, 0.263782, 0.7477626, 0.6448147, 0.8080634 ], | |||
| [0.15576603, 0.9104615, 0.3778708, 0.6912833, 0.2092224, 0.67462957], | |||
| [0.7087075, 0.7888326, 0.4672294, 0.98221505, 0.25210258, 0.98920417], | |||
| [0.7466197, 0.22702982, 0.01991269, 0.6846591, 0.7515228, 0.5890395 ], | |||
| [0.04531088, 0.21740614, 0.8406235, 0.36480767, 0.37733936, 0.02914464]], | |||
| [[0.33069974, 0.5497569, 0.9896345, 0.4167176, 0.78057563, 0.04659131], | |||
| [0.7747768, 0.21427679, 0.29893255, 0.7706969, 0.9755185, 0.42388415], | |||
| [0.3910244, 0.39381978, 0.37065396, 0.15558061, 0.05012341, 0.15870963], | |||
| [0.17791101, 0.47219893, 0.13899496, 0.32323205, 0.3628809, 0.02580585], | |||
| [0.30274773, 0.62890774, 0.11024303, 0.6980051, 0.35346958, 0.062852 ]]], | |||
| [[[0.6925081, 0.74668753, 0.80145043, 0.06598313, 0.665123, 0.15073007], | |||
| [0.11784806, 0.6385372, 0.5228278, 0.5349848, 0.84671104, 0.8096436 ], | |||
| [0.09516156, 0.63298017, 0.52382874, 0.36734378, 0.66497755, 0.6019127 ], | |||
| [0.46438488, 0.0194377, 0.9388292, 0.7286089, 0.29178405, 0.11872514], | |||
| [0.22101837, 0.6164887, 0.6139798, 0.11711904, 0.6227745, 0.09701069]], | |||
| [[0.80480653, 0.90034056, 0.8633447, 0.97415197, 0.08309154, 0.8446033 ], | |||
| [0.9473769, 0.791024, 0.26339203, 0.01155075, 0.2673186, 0.7116369 ], | |||
| [0.9687511, 0.24281934, 0.37777108, 0.09802654, 0.2421312, 0.87095344], | |||
| [0.6311381, 0.23368953, 0.0998995, 0.4364419, 0.9187446, 0.5043872 ], | |||
| [0.35226053, 0.09357589, 0.41317305, 0.85930043, 0.16249318, 0.5478765 ]], | |||
| [[0.14338651, 0.24859418, 0.4246941, 0.73034066, 0.47172204, 0.8717199 ], | |||
| [0.05415315, 0.78556925, 0.99214983, 0.7415298, 0.673708, 0.87817156], | |||
| [0.616975, 0.42843062, 0.05179814, 0.1566958, 0.04536059, 0.70166487], | |||
| [0.15493333, 0.776598, 0.4361967, 0.40253627, 0.89210516, 0.8144414 ], | |||
| [0.04816005, 0.29696834, 0.4586605, 0.3419852, 0.5595613, 0.74093205]], | |||
| [[0.1388035, 0.9168704, 0.64287645, 0.83864623, 0.48026922, 0.78323376], | |||
| [0.12724937, 0.83034366, 0.42557436, 0.50578654, 0.25630295, 0.15349793], | |||
| [0.27256685, 0.04547984, 0.5385756, 0.39270344, 0.7661698, 0.23722854], | |||
| [0.24620503, 0.25431684, 0.71564585, 0.01161419, 0.846467, 0.7043044 ], | |||
| [0.63272387, 0.11857849, 0.3772076, 0.16758402, 0.46743023, 0.05919575]]], | |||
| [[[0.18827082, 0.8912264, 0.6841404, 0.74436826, 0.9582085, 0.1083683 ], | |||
| [0.60695344, 0.09742349, 0.25074378, 0.87940735, 0.21116392, 0.39418384], | |||
| [0.744686, 0.35679692, 0.01308284, 0.45166633, 0.68166, 0.8634658 ], | |||
| [0.7331758, 0.21113694, 0.3935488, 0.87934476, 0.70728546, 0.09309767], | |||
| [0.12128611, 0.93696386, 0.81177396, 0.85402405, 0.5827289, 0.9776509 ]], | |||
| [[0.54069614, 0.66651285, 0.10646132, 0.17342485, 0.88795924, 0.03551182], | |||
| [0.25531697, 0.87946486, 0.74267226, 0.89230734, 0.95171434, 0.94697934], | |||
| [0.3708397, 0.507355, 0.97099817, 0.4918163, 0.17212386, 0.5008048 ], | |||
| [0.62530744, 0.25210327, 0.73966664, 0.71555346, 0.82484317, 0.6094874 ], | |||
| [0.4589691, 0.1386695, 0.27448782, 0.20373994, 0.27805242, 0.23292768]], | |||
| [[0.7414099, 0.2270226, 0.90431255, 0.47035843, 0.9581062, 0.5359226 ], | |||
| [0.79603523, 0.45549425, 0.80858237, 0.7705133, 0.017761, 0.98001194], | |||
| [0.06013146, 0.99240226, 0.33515573, 0.04110833, 0.41470334, 0.7130743 ], | |||
| [0.5687417, 0.5788611, 0.00722461, 0.6603336, 0.3420471, 0.75181854], | |||
| [0.4699261, 0.51390815, 0.343182, 0.81498754, 0.8942413, 0.46532857]], | |||
| [[0.4589523, 0.5534698, 0.2825786, 0.8205943, 0.78258514, 0.43154418], | |||
| [0.27020997, 0.01667354, 0.60871965, 0.90670526, 0.3208025, 0.96995634], | |||
| [0.85337156, 0.9711295, 0.1381724, 0.53670496, 0.7347996, 0.73380876], | |||
| [0.6137464, 0.54751194, 0.9037335, 0.23134394, 0.61411524, 0.26583543], | |||
| [0.70770144, 0.01813207, 0.24718016, 0.70329237, 0.7062925, 0.14399007]]]]).astype(np.float32) | |||
| expect_dx1 = np.array([[[[ 5.7664223], | |||
| [ 6.981018 ], | |||
| [ 2.6029902], | |||
| [ 2.7598202], | |||
| [ 6.763105 ]]], | |||
| [[[10.06558 ], | |||
| [12.077246 ], | |||
| [ 9.338394 ], | |||
| [11.52271 ], | |||
| [ 8.889048 ]]], | |||
| [[[ 3.5789769], | |||
| [13.424448 ], | |||
| [ 8.732746 ], | |||
| [ 6.9677467], | |||
| [ 9.635765 ]]]]).astype(np.float32) | |||
| expect_dx2 = np.array([[[[0. , 4.250458 , 2.5030296 , 3.623167 , 6.4171505 , 7.2115746 ]], | |||
| [[0. , 4.367449 , 2.803152 , 2.5352 , 0. , 0. ]], | |||
| [[0.7087075 , 0. , 2.040332 , 2.1372325 , 0. , 2.9222295 ]], | |||
| [[1.0278877 , 5.247942 , 2.6855955 , 5.494814 , 3.5657988 , 0.66265094]]]]).astype(np.float32) | |||
| net = Grad(MinimumNet()) | |||
| output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np)) | |||
| assert np.allclose(output_ms[0].asnumpy(), expect_dx1) | |||
| assert np.allclose(output_ms[1].asnumpy(), expect_dx2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast_diff_dims(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target='GPU') | |||
| x1_np = np.array([[[0.275478, 0.48933202, 0.71846116], | |||
| [0.9803821, 0.57205725, 0.28511533]], | |||
| [[0.61111903, 0.9671023, 0.70624334], | |||
| [0.53730786, 0.90413177, 0.94349676]]]).astype(np.float32) | |||
| x2_np = np.array([[0.01045662, 0.82126397, 0.6365063 ], | |||
| [0.9900942, 0.6584232, 0.98537433]]).astype(np.float32) | |||
| dy_np = np.array([[[0.3897645, 0.61152864, 0.33675498], | |||
| [0.5303635, 0.84893036, 0.4959739 ]], | |||
| [[0.5391046, 0.8443047, 0.4174708 ], | |||
| [0.57513475, 0.9225578, 0.46760973]]]).astype(np.float32) | |||
| expect_dx1 = np.array([[[0. , 0.61152864, 0. ], | |||
| [0.5303635 , 0.84893036, 0.4959739 ]], | |||
| [[0. , 0. , 0. ], | |||
| [0.57513475, 0. , 0.46760973]]]).astype(np.float32) | |||
| expect_dx2 = np.array([[0.92886907, 0.8443047 , 0.7542258 ], | |||
| [0. , 0.9225578 , 0. ]]).astype(np.float32) | |||
| net = Grad(MinimumNet()) | |||
| output_ms = net(Tensor(x1_np), Tensor(x2_np), Tensor(dy_np)) | |||
| assert np.allclose(output_ms[0].asnumpy(), expect_dx1) | |||
| assert np.allclose(output_ms[1].asnumpy(), expect_dx2) | |||