diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc new file mode 100644 index 0000000000..33f898273a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void SmoothL1LossCPUKernel::InitKernel(const CNodePtr &kernel_node) { + beta_ = AnfAlgo::GetNodeAttr(kernel_node, "beta"); + CheckParam(kernel_node); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + std::vector x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (const uint64_t &d : x_shape) { + tensor_size_ *= d; + } +} + +bool SmoothL1LossCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void SmoothL1LossCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto predict_addr = reinterpret_cast(inputs[0]->addr); + auto target_addr = reinterpret_cast(inputs[1]->addr); + auto result_addr = reinterpret_cast(outputs[0]->addr); + T zero = (T)0.0; + T half = (T)0.5; + T beta = (T)beta_; + for (uint64_t i = 0; i < tensor_size_; ++i) { + T diff = predict_addr[i] - target_addr[i]; + if (diff < zero) { + diff = -diff; + } + if (diff < beta) { + result_addr[i] = half * diff * diff / beta; + } else { + result_addr[i] = diff - (half * beta); + } + } +} + +void SmoothL1LossCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossCPUKernel needs 2 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SmoothL1LossCPUKernel needs 1 output."; + } + if (beta_ == 0.0) { + MS_LOG(EXCEPTION) << "Attr beta can not be zero."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h new file mode 100644 index 0000000000..321322a3de --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SmoothL1LossCPUKernel : public CPUKernel { + public: + SmoothL1LossCPUKernel() = default; + ~SmoothL1LossCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + float beta_ = 1.0; + TypeId dtype_{kTypeUnknown}; + uint64_t tensor_size_ = 1; +}; + +MS_REG_CPU_KERNEL( + SmoothL1Loss, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SmoothL1LossCPUKernel); + +MS_REG_CPU_KERNEL( + SmoothL1Loss, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SmoothL1LossCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.cc new file mode 100644 index 0000000000..a33acfc41c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void SmoothL1LossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + beta_ = AnfAlgo::GetNodeAttr(kernel_node, "beta"); + CheckParam(kernel_node); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + std::vector x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (const uint64_t &d : x_shape) { + tensor_size_ *= d; + } +} + +bool SmoothL1LossGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void SmoothL1LossGradCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto predict_addr = reinterpret_cast(inputs[0]->addr); + auto target_addr = reinterpret_cast(inputs[1]->addr); + auto dloss_addr = reinterpret_cast(inputs[2]->addr); + auto result_addr = reinterpret_cast(outputs[0]->addr); + T beta = (T)beta_; + for (uint64_t i = 0; i < tensor_size_; ++i) { + T diff = predict_addr[i] - target_addr[i]; + if (diff > beta) { + result_addr[i] = dloss_addr[i]; + } else if (diff < -beta) { + result_addr[i] = -dloss_addr[i]; + } else { + result_addr[i] = (diff / beta) * dloss_addr[i]; + } + } +} + +void SmoothL1LossGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SmoothL1LossGradCPUKernel needs 3 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SmoothL1LossGradCPUKernel needs 1 output."; + } + if (beta_ == 0.0) { + MS_LOG(EXCEPTION) << "Attr beta can not be zero."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h new file mode 100644 index 0000000000..a703e33b6e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_grad_cpu_kernel.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SmoothL1LossGradCPUKernel : public CPUKernel { + public: + SmoothL1LossGradCPUKernel() = default; + ~SmoothL1LossGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + float beta_ = 1.0; + TypeId dtype_{kTypeUnknown}; + uint64_t tensor_size_ = 1; +}; + +MS_REG_CPU_KERNEL(SmoothL1LossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SmoothL1LossGradCPUKernel); + +MS_REG_CPU_KERNEL(SmoothL1LossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SmoothL1LossGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SMOOTH_L1_LOSS_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.cc new file mode 100644 index 0000000000..8d0f7f6bc3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/tile_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void TileCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + y_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + multiples_ = AnfAlgo::GetNodeAttr>(kernel_node, "multiples"); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); +} + +bool TileCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt64) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void TileRecTask(T *x, T *y, size_t dim, size_t *offset, std::vector *pos, const std::vector &multiples, + const std::vector &cargo_x, const std::vector &cargo_y, + const std::vector &x_shape) { + if (dim == x_shape.size()) { + return; + } + for (size_t i = 0; i < x_shape[dim]; ++i) { + (*pos)[dim] = i; + if (dim == x_shape.size() - 1) { + size_t x_offset = 0; + for (size_t j = 0; j < (*pos).size(); ++j) { + x_offset += (*pos)[j] * cargo_x[j]; + } + memcpy(y + *offset, x + x_offset, sizeof(T)); + *offset += 1; + continue; + } + TileRecTask(x, y, dim + 1, offset, pos, multiples, cargo_x, cargo_y, x_shape); + } + for (int m = 0; m < multiples[dim] - 1; ++m) { + size_t y_offset = *offset - cargo_y[dim]; + memcpy(y + *offset, y + y_offset, cargo_y[dim] * sizeof(T)); + *offset += cargo_y[dim]; + } +} + +template +void TileCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto x_addr = reinterpret_cast(inputs[0]->addr); + auto y_addr = reinterpret_cast(outputs[0]->addr); + size_t ones = multiples_.size() - x_shape_.size(); + if (ones > 0) { + for (size_t i = 0; i < ones; ++i) { + x_shape_.insert(x_shape_.begin(), 1); + } + } + int d = multiples_.size(); + std::vector pos(d, 0); + std::vector cargo_x(d, 1); + std::vector cargo_y = x_shape_; + for (int i = d - 2; i >= 0; --i) { + cargo_x[i] = x_shape_[i + 1] * cargo_x[i]; + cargo_y[i] *= cargo_y[i + 1] * multiples_[i + 1]; + } + size_t offset = 0; + TileRecTask(x_addr, y_addr, 0, &offset, &pos, multiples_, cargo_x, cargo_y, x_shape_); +} + +void TileCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TileCPUKernel needs 1 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TileCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.h new file mode 100644 index 0000000000..240953f30b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class TileCPUKernel : public CPUKernel { + public: + TileCPUKernel() = default; + ~TileCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + std::vector x_shape_; + std::vector y_shape_; + std::vector multiples_; + TypeId dtype_{kTypeUnknown}; +}; + +MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel); + +MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel); + +MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel); + +MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 24c3784afa..43405c23f6 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -3959,7 +3959,7 @@ class GatherD(PrimitiveWithInfer): - **x** (Tensor) - The source tensor. - **dim** (int) - The axis along which to index. It must be int32. Only constant value is allowed. - **index** (Tensor) - The indices of elements to gather. It can be one of the following data types: - int32, int64. + int32, int64. Outputs: Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. diff --git a/tests/st/ops/cpu/test_smooth_l1_loss_grad_op.py b/tests/st/ops/cpu/test_smooth_l1_loss_grad_op.py new file mode 100644 index 0000000000..f72dbe4590 --- /dev/null +++ b/tests/st/ops/cpu/test_smooth_l1_loss_grad_op.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.composite import GradOperation + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(nn.Cell): + def __init__(self, sigma=1.0): + super(Net, self).__init__() + self.SmoothL1Loss = P.SmoothL1Loss(sigma) + + def construct(self, pred, gt): + return self.SmoothL1Loss(pred, gt) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, pred, gt, dout): + return self.grad(self.network)(pred, gt, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_net(): + pred = np.random.randn(2, 4).astype(np.float32) + gt = np.random.randn(2, 4).astype(np.float32) + dout = np.random.randn(2, 4).astype(np.float32) + smooth_l1_loss_grad = Grad(Net()) + output = smooth_l1_loss_grad(Tensor(pred), Tensor(gt), Tensor(dout)) + print("------------- input ---------------") + print("predict:\n", pred) + print("grount truth:\n", gt) + print("dout:\n", dout) + print("------------- output ---------------") + print("predict grad:\n", output[0].asnumpy()) diff --git a/tests/st/ops/cpu/test_smooth_l1_loss_op.py b/tests/st/ops/cpu/test_smooth_l1_loss_op.py new file mode 100644 index 0000000000..f0fe298ff7 --- /dev/null +++ b/tests/st/ops/cpu/test_smooth_l1_loss_op.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(nn.Cell): + def __init__(self, sigma=1.0): + super(Net, self).__init__() + self.SmoothL1Loss = P.SmoothL1Loss(sigma) + + def construct(self, pred, gt): + return self.SmoothL1Loss(pred, gt) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_net(): + pred = np.random.randn(2, 4).astype(np.float32) + gt = np.random.randn(2, 4).astype(np.float32) + smooth_l1_loss = Net() + loss = smooth_l1_loss(Tensor(pred), Tensor(gt)) + print("------------- input ---------------") + print("predict:\n", pred) + print("grount truth:\n", gt) + print("------------- output ---------------") + print("loss:\n", loss.asnumpy()) diff --git a/tests/st/ops/cpu/test_tile_op.py b/tests/st/ops/cpu/test_tile_op.py new file mode 100644 index 0000000000..6ea52f10ce --- /dev/null +++ b/tests/st/ops/cpu/test_tile_op.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.tile = P.Tile() + + def construct(self, x): + return self.tile(x, (1, 4)) + + +arr_x = np.array([[0], [1], [2], [3]]).astype(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_net(): + tile = Net() + print(arr_x) + output = tile(Tensor(arr_x)) + print(output.asnumpy())