| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void SmoothL1LossCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| beta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "beta"); | |||
| CheckParam(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (const uint64_t &d : x_shape) { | |||
| tensor_size_ *= d; | |||
| } | |||
| } | |||
| bool SmoothL1LossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void SmoothL1LossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| auto predict_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto target_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto result_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| T zero = (T)0.0; | |||
| T half = (T)0.5; | |||
| T beta = (T)beta_; | |||
| for (uint64_t i = 0; i < tensor_size_; ++i) { | |||
| T diff = predict_addr[i] - target_addr[i]; | |||
| if (diff < zero) { | |||
| diff = -diff; | |||
| } | |||
| if (diff < beta) { | |||
| result_addr[i] = half * diff * diff / beta; | |||
| } else { | |||
| result_addr[i] = diff - (half * beta); | |||
| } | |||
| } | |||
| } | |||
| void SmoothL1LossCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossCPUKernel needs 2 input."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SmoothL1LossCPUKernel needs 1 output."; | |||
| } | |||
| if (beta_ == 0.0) { | |||
| MS_LOG(EXCEPTION) << "Attr beta can not be zero."; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class SmoothL1LossCPUKernel : public CPUKernel { | |||
| public: | |||
| SmoothL1LossCPUKernel() = default; | |||
| ~SmoothL1LossCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| float beta_ = 1.0; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| uint64_t tensor_size_ = 1; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| SmoothL1Loss, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SmoothL1LossCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| SmoothL1Loss, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SmoothL1LossCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void SmoothL1LossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| beta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "beta"); | |||
| CheckParam(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (const uint64_t &d : x_shape) { | |||
| tensor_size_ *= d; | |||
| } | |||
| } | |||
| bool SmoothL1LossGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void SmoothL1LossGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| auto predict_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto target_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto dloss_addr = reinterpret_cast<T *>(inputs[2]->addr); | |||
| auto result_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| T beta = (T)beta_; | |||
| for (uint64_t i = 0; i < tensor_size_; ++i) { | |||
| T diff = predict_addr[i] - target_addr[i]; | |||
| if (diff > beta) { | |||
| result_addr[i] = dloss_addr[i]; | |||
| } else if (diff < -beta) { | |||
| result_addr[i] = -dloss_addr[i]; | |||
| } else { | |||
| result_addr[i] = (diff / beta) * dloss_addr[i]; | |||
| } | |||
| } | |||
| } | |||
| void SmoothL1LossGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossGradCPUKernel needs 3 input."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SmoothL1LossGradCPUKernel needs 1 output."; | |||
| } | |||
| if (beta_ == 0.0) { | |||
| MS_LOG(EXCEPTION) << "Attr beta can not be zero."; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class SmoothL1LossGradCPUKernel : public CPUKernel { | |||
| public: | |||
| SmoothL1LossGradCPUKernel() = default; | |||
| ~SmoothL1LossGradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| float beta_ = 1.0; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| uint64_t tensor_size_ = 1; | |||
| }; | |||
| MS_REG_CPU_KERNEL(SmoothL1LossGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| SmoothL1LossGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(SmoothL1LossGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SmoothL1LossGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,103 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/tile_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void TileCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| y_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| multiples_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "multiples"); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| } | |||
| bool TileCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeInt32) { | |||
| LaunchKernel<int>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| LaunchKernel<int64_t>(inputs, outputs); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void TileRecTask(T *x, T *y, size_t dim, size_t *offset, std::vector<size_t> *pos, const std::vector<int> &multiples, | |||
| const std::vector<size_t> &cargo_x, const std::vector<size_t> &cargo_y, | |||
| const std::vector<size_t> &x_shape) { | |||
| if (dim == x_shape.size()) { | |||
| return; | |||
| } | |||
| for (size_t i = 0; i < x_shape[dim]; ++i) { | |||
| (*pos)[dim] = i; | |||
| if (dim == x_shape.size() - 1) { | |||
| size_t x_offset = 0; | |||
| for (size_t j = 0; j < (*pos).size(); ++j) { | |||
| x_offset += (*pos)[j] * cargo_x[j]; | |||
| } | |||
| memcpy(y + *offset, x + x_offset, sizeof(T)); | |||
| *offset += 1; | |||
| continue; | |||
| } | |||
| TileRecTask(x, y, dim + 1, offset, pos, multiples, cargo_x, cargo_y, x_shape); | |||
| } | |||
| for (int m = 0; m < multiples[dim] - 1; ++m) { | |||
| size_t y_offset = *offset - cargo_y[dim]; | |||
| memcpy(y + *offset, y + y_offset, cargo_y[dim] * sizeof(T)); | |||
| *offset += cargo_y[dim]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void TileCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| auto x_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto y_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t ones = multiples_.size() - x_shape_.size(); | |||
| if (ones > 0) { | |||
| for (size_t i = 0; i < ones; ++i) { | |||
| x_shape_.insert(x_shape_.begin(), 1); | |||
| } | |||
| } | |||
| int d = multiples_.size(); | |||
| std::vector<size_t> pos(d, 0); | |||
| std::vector<size_t> cargo_x(d, 1); | |||
| std::vector<size_t> cargo_y = x_shape_; | |||
| for (int i = d - 2; i >= 0; --i) { | |||
| cargo_x[i] = x_shape_[i + 1] * cargo_x[i]; | |||
| cargo_y[i] *= cargo_y[i + 1] * multiples_[i + 1]; | |||
| } | |||
| size_t offset = 0; | |||
| TileRecTask<T>(x_addr, y_addr, 0, &offset, &pos, multiples_, cargo_x, cargo_y, x_shape_); | |||
| } | |||
| void TileCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TileCPUKernel needs 1 input."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TileCPUKernel needs 1 output."; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class TileCPUKernel : public CPUKernel { | |||
| public: | |||
| TileCPUKernel() = default; | |||
| ~TileCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| std::vector<size_t> x_shape_; | |||
| std::vector<size_t> y_shape_; | |||
| std::vector<int> multiples_; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ | |||
| @@ -3959,7 +3959,7 @@ class GatherD(PrimitiveWithInfer): | |||
| - **x** (Tensor) - The source tensor. | |||
| - **dim** (int) - The axis along which to index. It must be int32. Only constant value is allowed. | |||
| - **index** (Tensor) - The indices of elements to gather. It can be one of the following data types: | |||
| int32, int64. | |||
| int32, int64. | |||
| Outputs: | |||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.composite import GradOperation | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class Net(nn.Cell): | |||
| def __init__(self, sigma=1.0): | |||
| super(Net, self).__init__() | |||
| self.SmoothL1Loss = P.SmoothL1Loss(sigma) | |||
| def construct(self, pred, gt): | |||
| return self.SmoothL1Loss(pred, gt) | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, pred, gt, dout): | |||
| return self.grad(self.network)(pred, gt, dout) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_net(): | |||
| pred = np.random.randn(2, 4).astype(np.float32) | |||
| gt = np.random.randn(2, 4).astype(np.float32) | |||
| dout = np.random.randn(2, 4).astype(np.float32) | |||
| smooth_l1_loss_grad = Grad(Net()) | |||
| output = smooth_l1_loss_grad(Tensor(pred), Tensor(gt), Tensor(dout)) | |||
| print("------------- input ---------------") | |||
| print("predict:\n", pred) | |||
| print("grount truth:\n", gt) | |||
| print("dout:\n", dout) | |||
| print("------------- output ---------------") | |||
| print("predict grad:\n", output[0].asnumpy()) | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class Net(nn.Cell): | |||
| def __init__(self, sigma=1.0): | |||
| super(Net, self).__init__() | |||
| self.SmoothL1Loss = P.SmoothL1Loss(sigma) | |||
| def construct(self, pred, gt): | |||
| return self.SmoothL1Loss(pred, gt) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_net(): | |||
| pred = np.random.randn(2, 4).astype(np.float32) | |||
| gt = np.random.randn(2, 4).astype(np.float32) | |||
| smooth_l1_loss = Net() | |||
| loss = smooth_l1_loss(Tensor(pred), Tensor(gt)) | |||
| print("------------- input ---------------") | |||
| print("predict:\n", pred) | |||
| print("grount truth:\n", gt) | |||
| print("------------- output ---------------") | |||
| print("loss:\n", loss.asnumpy()) | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.tile = P.Tile() | |||
| def construct(self, x): | |||
| return self.tile(x, (1, 4)) | |||
| arr_x = np.array([[0], [1], [2], [3]]).astype(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_net(): | |||
| tile = Net() | |||
| print(arr_x) | |||
| output = tile(Tensor(arr_x)) | |||
| print(output.asnumpy()) | |||