Merge pull request !934 from chenweifeng/broadcasttags/v0.3.0-alpha
| @@ -0,0 +1,138 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/cuda_impl/broadcast_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| struct GreaterFunc { | |||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } | |||
| }; | |||
| template <typename T, typename S> | |||
| struct LessFunc { | |||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } | |||
| }; | |||
| template <typename T, typename S> | |||
| struct MinimumFunc { | |||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } | |||
| }; | |||
| template <typename T, typename S> | |||
| struct MaximumFunc { | |||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } | |||
| }; | |||
| template <typename T, typename S> | |||
| struct PowerFunc { | |||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } | |||
| }; | |||
| __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } | |||
| template <typename T, typename S, typename Func> | |||
| __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, | |||
| const int &r0, const int &r1, const int &r2, const int &r3, | |||
| const int &d0, const int &d1, const int &d2, const int &d3, | |||
| const T *input0, const T *input1, S *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (d1 * d2 * d3) % d0; | |||
| int j = pos / (d2 * d3) % d1; | |||
| int k = pos / d3 % d2; | |||
| int l = pos % d3; | |||
| int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); | |||
| int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); | |||
| output[pos] = Func()(input0[l_index], input1[r_index]); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, | |||
| const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, | |||
| enum BroadcastOpType op, const T *input0, const T *input1, S *output) { | |||
| switch (op) { | |||
| case BROADCAST_TYPE_GREATER: | |||
| return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| case BROADCAST_TYPE_LESS: | |||
| return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| case BROADCAST_TYPE_MINIMUM: | |||
| return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| case BROADCAST_TYPE_MAXIMUM: | |||
| return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| case BROADCAST_TYPE_POWER: | |||
| return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, | |||
| const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, | |||
| const T *input0, const T *input1, S *output, cudaStream_t stream) { | |||
| int size = d0 * d1 * d2 * d3; | |||
| BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, | |||
| input0, input1, output); | |||
| } | |||
| template <typename T, typename S, typename Func> | |||
| __device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { | |||
| output[pos] = Func()(input0[pos], input1[pos]); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1, | |||
| S *output) { | |||
| switch (op) { | |||
| case BROADCAST_TYPE_GREATER: | |||
| return NoBroadcastOperator<T, S, GreaterFunc<T, bool>>(nums, input0, input1, output); | |||
| case BROADCAST_TYPE_LESS: | |||
| return NoBroadcastOperator<T, S, LessFunc<T, bool>>(nums, input0, input1, output); | |||
| case BROADCAST_TYPE_MINIMUM: | |||
| return NoBroadcastOperator<T, S, MinimumFunc<T, S>>(nums, input0, input1, output); | |||
| case BROADCAST_TYPE_MAXIMUM: | |||
| return NoBroadcastOperator<T, S, MaximumFunc<T, S>>(nums, input0, input1, output); | |||
| case BROADCAST_TYPE_POWER: | |||
| return NoBroadcastOperator<T, S, PowerFunc<T, S>>(nums, input0, input1, output); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output, | |||
| cudaStream_t stream) { | |||
| NoBroadcastKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, op, input0, input1, output); | |||
| } | |||
| template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastOpType op, const float *input0, const float *input1, bool *output, | |||
| cudaStream_t stream); | |||
| template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastOpType op, const float *input0, const float *input1, float *output, | |||
| cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, | |||
| bool *output, cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, | |||
| float *output, cudaStream_t stream); | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| enum BroadcastOpType { | |||
| BROADCAST_TYPE_GREATER = 0, | |||
| BROADCAST_TYPE_LESS = 1, | |||
| BROADCAST_TYPE_MAXIMUM = 2, | |||
| BROADCAST_TYPE_MINIMUM = 3, | |||
| BROADCAST_TYPE_POWER = 4, | |||
| BROADCAST_TYPE_INVALID = 0xffffffff, | |||
| }; | |||
| template <typename T, typename S> | |||
| void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, | |||
| const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, | |||
| const T *input0, const T *input1, S *output, cudaStream_t stream); | |||
| template <typename T, typename S> | |||
| void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, | |||
| cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ | |||
| @@ -38,13 +38,5 @@ MS_REG_GPU_KERNEL_ONE( | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| BinaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Maximum, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BinaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Maximum, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| BinaryOpGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -27,16 +27,9 @@ | |||
| #include "kernel/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| enum BinaryOpType { | |||
| BINARY_OP_ADD = 0, | |||
| BINARY_OP_SUB, | |||
| BINARY_OP_MUL, | |||
| BINARY_OP_DIV, | |||
| BINARY_OP_MAX, | |||
| BINARY_OP_INVALID_TYPE = 255 | |||
| }; | |||
| enum BinaryOpType { BINARY_OP_ADD = 0, BINARY_OP_SUB, BINARY_OP_MUL, BINARY_OP_DIV, BINARY_OP_INVALID_TYPE = 255 }; | |||
| static const std::map<std::string, BinaryOpType> kBinaryOpTypeMap = { | |||
| {"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}, {"Maximum", BINARY_OP_MAX}}; | |||
| {"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}}; | |||
| template <typename T> | |||
| class BinaryOpGpuKernel : public GpuKernel { | |||
| public: | |||
| @@ -88,10 +81,6 @@ class BinaryOpGpuKernel : public GpuKernel { | |||
| inputB_addr = workspace_addr; | |||
| break; | |||
| } | |||
| case BINARY_OP_MAX: { | |||
| inputB_addr = input_addr2; | |||
| break; | |||
| } | |||
| default: { | |||
| MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; | |||
| } | |||
| @@ -209,10 +198,6 @@ class BinaryOpGpuKernel : public GpuKernel { | |||
| tensor_op_ = CUDNN_OP_TENSOR_ADD; | |||
| break; | |||
| } | |||
| case BINARY_OP_MAX: { | |||
| tensor_op_ = CUDNN_OP_TENSOR_MAX; | |||
| break; | |||
| } | |||
| default: { | |||
| MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/math/broadcast_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Greater, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| BroadcastOpGpuKernel, float, bool) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| BroadcastOpGpuKernel, float, bool) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Maximum, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Minimum, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGpuKernel, float, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,132 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ | |||
| #include <cuda_runtime_api.h> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/broadcast_impl.cuh" | |||
| #include "kernel/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class BroadcastOpGpuKernel : public GpuKernel { | |||
| public: | |||
| BroadcastOpGpuKernel() | |||
| : op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} | |||
| ~BroadcastOpGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| T *lhs = GetDeviceAddress<T>(inputs, 0); | |||
| T *rhs = GetDeviceAddress<T>(inputs, 1); | |||
| S *output = GetDeviceAddress<S>(outputs, 0); | |||
| if (need_broadcast_) { | |||
| Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], | |||
| rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, | |||
| rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| GetOpType(kernel_node); | |||
| auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| need_broadcast_ = IsBroadcast(shape1, shape2); | |||
| if (need_broadcast_ && shape1.size() > 4) { | |||
| MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; | |||
| } | |||
| for (size_t i = 0; i < shape1.size(); i++) { | |||
| lhs_shape_[i] = shape1[i]; | |||
| rhs_shape_[i] = shape2[i]; | |||
| output_shape_[i] = shape3[i]; | |||
| input1_num_ *= shape1[i]; | |||
| input2_num_ *= shape2[i]; | |||
| output_num_ *= shape3[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() override { return; } | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input1_num_ * sizeof(T)); | |||
| input_size_list_.push_back(input2_num_ * sizeof(T)); | |||
| output_size_list_.push_back(output_num_ * sizeof(S)); | |||
| } | |||
| private: | |||
| void GetOpType(const CNodePtr &kernel_node) { | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = { | |||
| {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, | |||
| {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, | |||
| }; | |||
| auto iter = kBroadcastTypeMap.find(kernel_name); | |||
| if (iter == kBroadcastTypeMap.end()) { | |||
| MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; | |||
| } else { | |||
| op_type_ = iter->second; | |||
| } | |||
| } | |||
| bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) { | |||
| for (size_t i = 0; i < lhs.size(); i++) { | |||
| if (lhs[i] != rhs[i]) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| BroadcastOpType op_type_; | |||
| bool need_broadcast_; | |||
| int input1_num_; | |||
| int input2_num_; | |||
| int output_num_; | |||
| int lhs_shape_[4] = {1, 1, 1, 1}; | |||
| int rhs_shape_[4] = {1, 1, 1, 1}; | |||
| int output_shape_[4] = {1, 1, 1, 1}; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,81 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn import Cell | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.context as context | |||
| import numpy as np | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nobroadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| x1_np = np.random.rand(10, 20).astype(np.float32) | |||
| x2_np = np.random.rand(10, 20).astype(np.float32) | |||
| output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = np.minimum(x1_np, x2_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = np.maximum(x1_np, x2_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = x1_np > x2_np | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = x1_np < x2_np | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = np.power(x1_np, x2_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| x1_np = np.random.rand(3, 1, 5, 1).astype(np.float32) | |||
| x2_np = np.random.rand(1, 4, 1, 6).astype(np.float32) | |||
| output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = np.minimum(x1_np, x2_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = np.maximum(x1_np, x2_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = x1_np > x2_np | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = x1_np < x2_np | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) | |||
| output_np = np.power(x1_np, x2_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||