| @@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).Add | |||
| SoftmaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SoftmaxGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SoftmaxGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SoftmaxGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -117,9 +117,16 @@ class SoftmaxGpuKernel : public GpuKernel { | |||
| if (shape_size_ != 2) { | |||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax only supports 2-D inputs."; | |||
| } | |||
| auto axis = GetAttr<std::vector<int>>(kernel_node, "axis"); | |||
| InitSizeByAxis(input_shape, axis[0]); | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == "LogSoftmax") { | |||
| algo_ = CUDNN_SOFTMAX_LOG; | |||
| auto axis = GetAttr<int>(kernel_node, "axis"); | |||
| InitSizeByAxis(input_shape, axis); | |||
| } else { | |||
| algo_ = CUDNN_SOFTMAX_ACCURATE; | |||
| auto axis = GetAttr<std::vector<int>>(kernel_node, "axis"); | |||
| InitSizeByAxis(input_shape, axis[0]); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), | |||
| SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/nn/softmax_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| LogSoftmaxGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SoftmaxGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| LogSoftmaxGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SoftmaxGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,219 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/kernel_constants.h" | |||
| #include "kernel/gpu/cuda_impl/transpose_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SoftmaxGradGpuKernel : public GpuKernel { | |||
| public: | |||
| SoftmaxGradGpuKernel() | |||
| : cudnn_handle_(nullptr), | |||
| y_desc_(nullptr), | |||
| algo_(CUDNN_SOFTMAX_ACCURATE), | |||
| mode_(CUDNN_SOFTMAX_MODE_INSTANCE), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| is_null_input_(false), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| workspace_size_(0), | |||
| axis_(0), | |||
| shape_size_(0), | |||
| batch_size_(0), | |||
| channel_size_(0), | |||
| height_(0), | |||
| width_(0) {} | |||
| ~SoftmaxGradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| T *y_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *dy_addr = GetDeviceAddress<T>(inputs, 1); | |||
| T *dx_addr = GetDeviceAddress<T>(outputs, 0); | |||
| T *transpose_y_addr = GetDeviceAddress<T>(workspace, 0); | |||
| T *transpose_dy_addr = GetDeviceAddress<T>(workspace, 1); | |||
| T *transpose_dx_addr = GetDeviceAddress<T>(workspace, 2); | |||
| int *input_shape = GetDeviceAddress<int>(workspace, 3); | |||
| int *transpose_shape = GetDeviceAddress<int>(workspace, 4); | |||
| int *transpose_axis = GetDeviceAddress<int>(workspace, 5); | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if (axis_ == 1) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr, y_desc_, | |||
| dy_addr, &beta, y_desc_, dx_addr), | |||
| "cudnnSoftmaxBackward failed"); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_shape failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_shape failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_axis failed"); | |||
| int size = SizeToInt(input_size_ / sizeof(T)); | |||
| CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr, | |||
| y_desc_, transpose_dy_addr, &beta, y_desc_, transpose_dx_addr), | |||
| "cudnnSoftmaxBackward failed"); | |||
| CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax grad needs 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "SoftmaxGradGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| shape_size_ = SizeToInt(input_shape.size()); | |||
| if (shape_size_ != 2) { | |||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs."; | |||
| } | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == "LogSoftmaxGrad") { | |||
| algo_ = CUDNN_SOFTMAX_LOG; | |||
| auto axis = GetAttr<int>(kernel_node, "axis"); | |||
| InitSizeByAxis(input_shape, axis); | |||
| } else { | |||
| algo_ = CUDNN_SOFTMAX_ACCURATE; | |||
| auto axis = GetAttr<std::vector<int>>(kernel_node, "axis"); | |||
| InitSizeByAxis(input_shape, axis[0]); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), | |||
| SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), | |||
| "set input_descriptor failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() override { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "create input_descriptor failed"); | |||
| } | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(input_size_); | |||
| workspace_size_list_.push_back(input_size_); | |||
| workspace_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| return; | |||
| } | |||
| private: | |||
| void DestroyResource() noexcept { | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "destroy output_descriptor failed"); | |||
| } | |||
| void InitSizeByAxis(const std::vector<size_t> input_shape, const int axis) { | |||
| axis_ = axis; | |||
| if (axis_ < 0) { | |||
| axis_ += shape_size_; | |||
| } | |||
| if (axis_ == 1) { | |||
| batch_size_ = input_shape[0]; | |||
| channel_size_ = input_shape[1]; | |||
| } else if (axis_ == 0) { | |||
| batch_size_ = input_shape[1]; | |||
| channel_size_ = input_shape[0]; | |||
| input_shape_.push_back(input_shape[0]); | |||
| input_shape_.push_back(input_shape[1]); | |||
| transpose_shape_.push_back(input_shape[1]); | |||
| transpose_shape_.push_back(input_shape[0]); | |||
| transpose_axis_.push_back(1); | |||
| transpose_axis_.push_back(0); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; | |||
| } | |||
| height_ = 1; | |||
| width_ = 1; | |||
| input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; | |||
| output_size_ = input_size_; | |||
| workspace_size_ = IntToSize(shape_size_) * sizeof(int); | |||
| } | |||
| cudnnHandle_t cudnn_handle_; | |||
| cudnnTensorDescriptor_t y_desc_; | |||
| cudnnSoftmaxAlgorithm_t algo_; | |||
| cudnnSoftmaxMode_t mode_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| bool is_null_input_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| std::vector<int> input_shape_; | |||
| std::vector<int> transpose_shape_; | |||
| std::vector<int> transpose_axis_; | |||
| int axis_; | |||
| int shape_size_; | |||
| size_t batch_size_; | |||
| size_t channel_size_; | |||
| size_t height_; | |||
| size_t width_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,109 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_logsoftmax(): | |||
| x = np.array([[-0.08082921, -0.13706027, -0.4711177, -0.05606057], | |||
| [-0.46082982, 1.1761844, -1.016654, -1.743829 ], | |||
| [-1.5062045, 0.6910976, 0.4839723, 1.1502692 ]]).astype(np.float32) | |||
| expect = np.array([[-1.2939762, -1.3502073, -1.6842647, -1.2692076 ], | |||
| [-1.9445671, -0.3075528, -2.5003912, -3.2275662 ], | |||
| [-3.452001, -1.2546989, -1.4618242, -0.79552734]]).astype(np.float32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| LogSoftmax = P.LogSoftmax() | |||
| output = LogSoftmax(Tensor(x)) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| class LogSoftmax(nn.Cell): | |||
| def __init__(self, axis=-1): | |||
| super(LogSoftmax, self).__init__() | |||
| self.logsoftmax = P.LogSoftmax(axis) | |||
| def construct(self, x): | |||
| return self.logsoftmax(x) | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, input_data, sens): | |||
| gout = self.grad(self.network)(input_data, sens) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_logsoftmaxgrad(): | |||
| x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655, -0.7725506, 1.4481013 ], | |||
| [ 1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024, -0.27965206, -0.702805 ], | |||
| [ 0.19450496, 0.87596166, 0.6467245, -1.044987, 0.5248943, -2.6166635, 1.6719198, 0.06600758, -0.4099178, 1.1861311 ], | |||
| [ 1.1305193, -1.97308, 2.1047623, -1.5105937, 0.93052036, 1.2467804, 0.5310002, 0.7084912, -1.3681422, -0.9686862 ], | |||
| [ 1.871408, 0.14219497, -0.41050452, -0.749807, 1.4900619, -1.8172716, -0.73839617, 0.17565694, -0.4553867, -1.5423119 ]]).astype(np.float32) | |||
| dy = np.array([[ 1.516363, -0.15196544, 0.598733, 0.64357865, 0.16265012, -1.3521105, 0.22621834, 0.7168259, -0.6709239, 0.79757756], | |||
| [-0.32457778, 1.2831115, 1.1211495, -0.02665559, 1.9170904, -1.3397789, 1.4124829, -1.4298155, 0.758519, -0.25322974], | |||
| [-0.24226122, -1.2555921, 0.6492511, -0.34847677, 0.19916506, 0.628554, -0.19658111, 0.44939864, -0.11677749, -1.2131723 ], | |||
| [ 0.24267715, 0.28106326, 1.1075432, -0.29006946, 0.31335673, 0.8833154, 0.13152207, 1.5482179, 0.29770762, -0.16246222], | |||
| [ 0.02145994, 0.80424, -0.95061, 1.5875458, -0.00308682, 0.17964548, 0.49912593, 0.46977136, 0.2151897, 0.30908248]]).astype(np.float32) | |||
| expect = np.array([[ 1.4219905 , -0.39837134, 0.5452743 , -0.09062839, -0.02375537, -1.5890603 , 0.10658137, 0.6185817 , -0.7411523 , 0.15054005], | |||
| [-0.94926417, 0.13830578, 0.7609547 , -0.31733334, 1.8485254 , -1.4657221 , 1.2625053 , -1.523396 , 0.601499 , -0.35607445], | |||
| [-0.14447737, -1.0622973 , 0.80294746, -0.32016528, 0.33523226, 0.63443416, 0.23186903, 0.53539133, -0.0633494 , -0.9495847 ], | |||
| [-0.36894822, 0.253609 , -0.5127511 , -0.33366728, -0.18740037, 0.19628316, -0.20430653, 1.1471655 , 0.24743511, -0.23741922], | |||
| [-1.2582518 , 0.57718843, -1.0812542 , 1.4944922 , -0.8770549 , 0.1476463 , 0.40500447, 0.23499368, 0.09027944, 0.26695627]]).astype(np.float32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = LogSoftmax() | |||
| dx = Grad(net)(Tensor(x), Tensor(dy)) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_logsoftmaxgrad1(): | |||
| x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655, -0.7725506, 1.4481013 ], | |||
| [ 1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024, -0.27965206, -0.702805 ], | |||
| [ 0.19450496, 0.87596166, 0.6467245, -1.044987, 0.5248943, -2.6166635, 1.6719198, 0.06600758, -0.4099178, 1.1861311 ], | |||
| [ 1.1305193, -1.97308, 2.1047623, -1.5105937, 0.93052036, 1.2467804, 0.5310002, 0.7084912, -1.3681422, -0.9686862 ], | |||
| [ 1.871408, 0.14219497, -0.41050452, -0.749807, 1.4900619, -1.8172716, -0.73839617, 0.17565694, -0.4553867, -1.5423119 ]]).astype(np.float32) | |||
| dy = np.array([[ 1.516363, -0.15196544, 0.598733, 0.64357865, 0.16265012, -1.3521105, 0.22621834, 0.7168259, -0.6709239, 0.79757756], | |||
| [-0.32457778, 1.2831115, 1.1211495, -0.02665559, 1.9170904, -1.3397789, 1.4124829, -1.4298155, 0.758519, -0.25322974], | |||
| [-0.24226122, -1.2555921, 0.6492511, -0.34847677, 0.19916506, 0.628554, -0.19658111, 0.44939864, -0.11677749, -1.2131723 ], | |||
| [ 0.24267715, 0.28106326, 1.1075432, -0.29006946, 0.31335673, 0.8833154, 0.13152207, 1.5482179, 0.29770762, -0.16246222], | |||
| [ 0.02145994, 0.80424, -0.95061, 1.5875458, -0.00308682, 0.17964548, 0.49912593, 0.46977136, 0.2151897, 0.30908248]]).astype(np.float32) | |||
| expect = np.array([[ 1.464194 , -0.29578894, 0.5296974 , -0.39600563, -0.1479242 , -1.0869746 , 0.04521982, 0.5064515 , -0.7515615 , 1.0554069 ], | |||
| [-0.5774203 , 0.793861 , 0.7805745 , -0.32800734, 1.8334473 , -1.236596 , 1.2463496 , -1.5765365 , 0.6265108 , -0.22322391], | |||
| [-0.34437084, -1.4687154 , 0.27432096, -0.42420125, -0.22908019, 0.640983 , -1.4210342 , 0.10155854, -0.23266247, -1.0147638 ], | |||
| [-0.01768187, 0.26872346, -0.5037259 , -0.3376058 , -0.3291146 , 1.4752979 , -0.25972134, 0.8869053 , 0.25325722, -0.13946185], | |||
| [-0.5247209 , 0.70192003, -1.0808672 , 1.4858199 , -1.1273282 , 0.20728993, 0.38918605, 0.08162117, 0.10445589, 0.3220427 ]],).astype(np.float32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = LogSoftmax(0) | |||
| dx = Grad(net)(Tensor(x), Tensor(dy)) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||