Browse Source

!9781 fix tile op error when input is float16

From: @chujinjin
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8973830a4e
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.cc

+ 4
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/tile_cpu_kernel.cc View File

@@ -27,7 +27,10 @@ void TileCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<int64_t> multiples_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "multiples"); std::vector<int64_t> multiples_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "multiples");
(void)std::transform(multiples_me.begin(), multiples_me.end(), std::back_inserter(multiples_), (void)std::transform(multiples_me.begin(), multiples_me.end(), std::back_inserter(multiples_),
[](const int64_t &value) { return static_cast<int>(value); }); [](const int64_t &value) { return static_cast<int>(value); });
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
dtype_ = AnfAlgo ::GetPrevNodeOutputDeviceDataType(kernel_node, 0);
if (dtype_ == kTypeUnknown) {
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
}
} }


bool TileCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool TileCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,


Loading…
Cancel
Save