Browse Source

[CPU] Add TensorScatterUpdate and Print Ops

tags/v1.3.0
zhanyuan 5 years ago
parent
commit
ff02098cd4
3 changed files with 179 additions and 0 deletions
  1. +90
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/print_cpu_kernel.cc
  2. +81
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/print_cpu_kernel.h
  3. +8
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h

+ 90
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/print_cpu_kernel.cc View File

@@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/print_cpu_kernel.h"
#include <algorithm>
#include "ir/tensor.h"
#include "runtime/device/cpu/cpu_device_address.h"

using mindspore::tensor::Tensor;

namespace mindspore {
namespace kernel {
template <typename T>
void PrintCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_tensor_num; ++i) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
input_shapes_.emplace_back(input_shape);
size_t size = input_shape.size() ? 1 : 0;
for (size_t j = 0; j < input_shape.size(); ++j) {
size *= input_shape[j];
}
input_sizes_.emplace_back(size);
}
}

template <typename T>
bool PrintCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> & /*outputs*/) {
auto data_type = CheckType();
if (data_type == kTypeUnknown) {
MS_LOG(EXCEPTION) << "CPU print does not support the input type.";
}
for (size_t i = 0; i < inputs.size(); ++i) {
if (input_sizes_[i] == 0) {
auto num = reinterpret_cast<T *>(inputs[i]->addr);
std::cout << *num << std::endl;
} else {
ShapeVector shape;
std::transform(input_shapes_[i].begin(), input_shapes_[i].end(), std::back_inserter(shape),
[](const size_t &value) { return static_cast<int64_t>(value); });
Tensor tensor(data_type, shape, inputs[i]->addr, input_sizes_[i] * sizeof(T));
std::cout << tensor.ToStringNoLimit() << std::endl;
}
}
return true;
}

template <typename T>
TypeId PrintCPUKernel<T>::CheckType() {
if constexpr (std::is_same_v<T, bool>) {
return kNumberTypeBool;
} else if constexpr (std::is_same_v<T, int8_t>) {
return kNumberTypeInt8;
} else if constexpr (std::is_same_v<T, int16_t>) {
return kNumberTypeInt16;
} else if constexpr (std::is_same_v<T, int>) {
return kNumberTypeInt32;
} else if constexpr (std::is_same_v<T, int64_t>) {
return kNumberTypeInt64;
} else if constexpr (std::is_same_v<T, uint8_t>) {
return kNumberTypeUInt8;
} else if constexpr (std::is_same_v<T, uint16_t>) {
return kNumberTypeUInt16;
} else if constexpr (std::is_same_v<T, uint32_t>) {
return kNumberTypeUInt32;
} else if constexpr (std::is_same_v<T, uint64_t>) {
return kNumberTypeUInt64;
} else if constexpr (std::is_same_v<T, float16>) {
return kNumberTypeFloat16;
} else if constexpr (std::is_same_v<T, float>) {
return kNumberTypeFloat32;
}
return kTypeUnknown;
}
} // namespace kernel
} // namespace mindspore

+ 81
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/print_cpu_kernel.h View File

@@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PRINT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PRINT_CPU_KERNEL_H_
#include <memory>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
template <typename T>
class PrintCPUKernel : public CPUKernel {
public:
PrintCPUKernel() = default;
~PrintCPUKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override;

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

void LaunchKernel(const std::vector<AddressPtr> &inputs);

TypeId CheckType();

private:
std::vector<std::vector<size_t>> input_shapes_;
std::vector<size_t> input_sizes_;
};

MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, bool)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, int8_t)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, int16_t)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, int)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, int64_t)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, uint8_t)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, uint16_t)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, uint32_t)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, uint64_t)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, float16)
MS_REG_CPU_KERNEL_T(Print,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
PrintCPUKernel, float)
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PRINT_CPU_KERNEL_H_

+ 8
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h View File

@@ -64,6 +64,14 @@ MS_REG_CPU_KERNEL(ScatterNdUpdate,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterNdUpdateCPUKernel);

MS_REG_CPU_KERNEL(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterNdUpdateCPUKernel);
} // namespace kernel
} // namespace mindspore



Loading…
Cancel
Save