| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -43,9 +43,30 @@ class ConcatCPUKernel : public CPUKernel { | |||||
| MS_REG_CPU_KERNEL_T( | MS_REG_CPU_KERNEL_T( | ||||
| Concat, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | Concat, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| ConcatCPUKernel, float); | ConcatCPUKernel, float); | ||||
| MS_REG_CPU_KERNEL_T(Concat, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||||
| ConcatCPUKernel, int8_t) | |||||
| MS_REG_CPU_KERNEL_T(Concat, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||||
| ConcatCPUKernel, int16_t) | |||||
| MS_REG_CPU_KERNEL_T(Concat, | MS_REG_CPU_KERNEL_T(Concat, | ||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ConcatCPUKernel, int) | ConcatCPUKernel, int) | ||||
| MS_REG_CPU_KERNEL_T(Concat, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||||
| ConcatCPUKernel, int64_t) | |||||
| MS_REG_CPU_KERNEL_T(Concat, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| ConcatCPUKernel, uint8_t) | |||||
| MS_REG_CPU_KERNEL_T(Concat, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||||
| ConcatCPUKernel, uint16_t) | |||||
| MS_REG_CPU_KERNEL_T(Concat, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| ConcatCPUKernel, uint32_t) | |||||
| MS_REG_CPU_KERNEL_T(Concat, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||||
| ConcatCPUKernel, uint64_t) | |||||
| MS_REG_CPU_KERNEL_T(Concat, | MS_REG_CPU_KERNEL_T(Concat, | ||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | ||||
| ConcatCPUKernel, bool) | ConcatCPUKernel, bool) | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -38,23 +38,47 @@ class ReshapeCPUKernel : public CPUKernel { | |||||
| size_t type_size_ = 4; | size_t type_size_ = 4; | ||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); | MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); | MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| @@ -62,14 +86,36 @@ MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOut | |||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | ||||
| ReshapeCPUKernel); | ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); | MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReshapeCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||||
| ReshapeCPUKernel); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -31,18 +31,30 @@ void TileCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (dtype_ == kTypeUnknown) { | if (dtype_ == kTypeUnknown) { | ||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| } | } | ||||
| launch_map_[kNumberTypeInt8] = &TileCPUKernel::LaunchKernel<int8_t>; | |||||
| launch_map_[kNumberTypeInt16] = &TileCPUKernel::LaunchKernel<int16_t>; | |||||
| launch_map_[kNumberTypeInt32] = &TileCPUKernel::LaunchKernel<int>; | |||||
| launch_map_[kNumberTypeInt64] = &TileCPUKernel::LaunchKernel<int64_t>; | |||||
| launch_map_[kNumberTypeUInt8] = &TileCPUKernel::LaunchKernel<uint8_t>; | |||||
| launch_map_[kNumberTypeUInt16] = &TileCPUKernel::LaunchKernel<uint16_t>; | |||||
| launch_map_[kNumberTypeUInt32] = &TileCPUKernel::LaunchKernel<uint32_t>; | |||||
| launch_map_[kNumberTypeUInt64] = &TileCPUKernel::LaunchKernel<uint64_t>; | |||||
| launch_map_[kNumberTypeFloat32] = &TileCPUKernel::LaunchKernel<float>; | |||||
| launch_map_[kNumberTypeBool] = &TileCPUKernel::LaunchKernel<bool>; | |||||
| auto iter = launch_map_.find(dtype_); | |||||
| if (iter != launch_map_.end()) { | |||||
| launch_func_ = iter->second; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for Tile kernel on CPU."; | |||||
| } | |||||
| } | } | ||||
| bool TileCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool TileCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (dtype_ == kNumberTypeInt32) { | |||||
| LaunchKernel<int>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| LaunchKernel<float>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeInt64) { | |||||
| LaunchKernel<int64_t>(inputs, outputs); | |||||
| } | |||||
| launch_func_(this, inputs, outputs); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -43,14 +43,30 @@ class TileCPUKernel : public CPUKernel { | |||||
| std::vector<size_t> y_shape_; | std::vector<size_t> y_shape_; | ||||
| std::vector<int> multiples_; | std::vector<int> multiples_; | ||||
| TypeId dtype_{kTypeUnknown}; | TypeId dtype_{kTypeUnknown}; | ||||
| using TypeKernel = | |||||
| std::function<void(TileCPUKernel *, const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs)>; | |||||
| std::unordered_map<TypeId, TypeKernel> launch_map_; | |||||
| TypeKernel launch_func_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), TileCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), TileCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel); | MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel); | MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel); | MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), TileCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), TileCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), TileCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), TileCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), TileCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel); | MS_REG_CPU_KERNEL(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -29,13 +29,36 @@ void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| if (shape_.size() != axis_.size()) { | if (shape_.size() != axis_.size()) { | ||||
| MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal."; | MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal."; | ||||
| } | } | ||||
| dtype_ = AnfAlgo ::GetPrevNodeOutputDeviceDataType(kernel_node, 0); | |||||
| if (dtype_ == kTypeUnknown) { | |||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| } | |||||
| launch_map_[kNumberTypeInt8] = &TransposeCPUFwdKernel::LaunchKernel<int8_t>; | |||||
| launch_map_[kNumberTypeInt16] = &TransposeCPUFwdKernel::LaunchKernel<int16_t>; | |||||
| launch_map_[kNumberTypeInt32] = &TransposeCPUFwdKernel::LaunchKernel<int>; | |||||
| launch_map_[kNumberTypeInt64] = &TransposeCPUFwdKernel::LaunchKernel<int64_t>; | |||||
| launch_map_[kNumberTypeUInt8] = &TransposeCPUFwdKernel::LaunchKernel<uint8_t>; | |||||
| launch_map_[kNumberTypeUInt16] = &TransposeCPUFwdKernel::LaunchKernel<uint16_t>; | |||||
| launch_map_[kNumberTypeUInt32] = &TransposeCPUFwdKernel::LaunchKernel<uint32_t>; | |||||
| launch_map_[kNumberTypeUInt64] = &TransposeCPUFwdKernel::LaunchKernel<uint64_t>; | |||||
| launch_map_[kNumberTypeFloat32] = &TransposeCPUFwdKernel::LaunchKernel<float>; | |||||
| launch_map_[kNumberTypeBool] = &TransposeCPUFwdKernel::LaunchKernel<bool>; | |||||
| auto iter = launch_map_.find(dtype_); | |||||
| if (iter != launch_map_.end()) { | |||||
| launch_func_ = iter->second; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for Transpose kernel on CPU."; | |||||
| } | |||||
| } | } | ||||
| bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| auto input = reinterpret_cast<float *>(inputs[0]->addr); | |||||
| auto output = reinterpret_cast<float *>(outputs[0]->addr); | |||||
| size_t size = IntToSize(inputs[0]->size / sizeof(float)); | |||||
| template <typename T> | |||||
| void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| const std::vector<AddressPtr> &outputs) { | |||||
| auto input = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| auto output = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| size_t size = IntToSize(inputs[0]->size / sizeof(T)); | |||||
| size_t shape_size = IntToSize(shape_.size()); | size_t shape_size = IntToSize(shape_.size()); | ||||
| if (shape_size > kMaxDim) { | if (shape_size > kMaxDim) { | ||||
| MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs."; | MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs."; | ||||
| @@ -61,7 +84,14 @@ bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs | |||||
| } | } | ||||
| output[new_position] = input[position]; | output[new_position] = input[position]; | ||||
| } | } | ||||
| } | |||||
| bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| launch_func_(this, inputs, outputs); | |||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -16,6 +16,7 @@ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ | #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ | ||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ | #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | #include "backend/kernel_compiler/cpu/cpu_kernel.h" | ||||
| @@ -32,12 +33,47 @@ class TransposeCPUFwdKernel : public CPUKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| template <typename T> | |||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||||
| private: | private: | ||||
| std::vector<size_t> shape_; | std::vector<size_t> shape_; | ||||
| std::vector<int> axis_; | std::vector<int> axis_; | ||||
| TypeId dtype_{kTypeUnknown}; | |||||
| using TypeKernel = | |||||
| std::function<void(TransposeCPUFwdKernel *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>; | |||||
| std::unordered_map<TypeId, TypeKernel> launch_map_; | |||||
| TypeKernel launch_func_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||||
| TransposeCPUFwdKernel); | |||||
| MS_REG_CPU_KERNEL(Transpose, | |||||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||||
| TransposeCPUFwdKernel); | TransposeCPUFwdKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -46,6 +46,7 @@ def axis10(nptype): | |||||
| print(output) | print(output) | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -171,6 +172,7 @@ def axis21(nptype): | |||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| print(output) | print(output) | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -287,6 +289,18 @@ def test_concat_4i_float32(): | |||||
| def test_concat_4i_int32(): | def test_concat_4i_int32(): | ||||
| concat_4i(np.int32) | concat_4i(np.int32) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_concat_4i_int8(): | |||||
| concat_4i(np.int8) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_concat_4i_uint64(): | |||||
| concat_4i(np.uint64) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -40,7 +40,8 @@ def test_squeeze_shape_float32(): | |||||
| expect = np.ones(shape=[2, 8, 3]).astype(np.float32) | expect = np.ones(shape=[2, 8, 3]).astype(np.float32) | ||||
| net = SqueezeNet() | net = SqueezeNet() | ||||
| result = net(Tensor(x)) | result = net(Tensor(x)) | ||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, | |||||
| atol=1.e-8, equal_nan=True) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -51,7 +52,8 @@ def test_squeeze_shape_int32(): | |||||
| expect = np.array([7, 11]).astype(np.int32) | expect = np.array([7, 11]).astype(np.int32) | ||||
| net = SqueezeNet() | net = SqueezeNet() | ||||
| result = net(Tensor(x)) | result = net(Tensor(x)) | ||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, | |||||
| atol=1.e-8, equal_nan=True) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -62,4 +64,31 @@ def test_squeeze_shape_bool(): | |||||
| expect = np.array([True, False]).astype(np.bool_) | expect = np.array([True, False]).astype(np.bool_) | ||||
| net = SqueezeNet() | net = SqueezeNet() | ||||
| result = net(Tensor(x)) | result = net(Tensor(x)) | ||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) | |||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, | |||||
| atol=1.e-8, equal_nan=True) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_squeeze_shape_float64(): | |||||
| x = np.random.random([1, 2, 1, 1, 8, 3, 1]).astype(np.float64) | |||||
| expect = np.squeeze(x) | |||||
| net = SqueezeNet() | |||||
| result = net(Tensor(x)) | |||||
| print(result.asnumpy()[0][0], expect[0][0]) | |||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, | |||||
| atol=1.e-8, equal_nan=True) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_squeeze_shape_uint16(): | |||||
| x = np.random.random([1, 2, 1, 1, 8, 3, 1]).astype(np.uint16) | |||||
| expect = np.squeeze(x) | |||||
| net = SqueezeNet() | |||||
| result = net(Tensor(x)) | |||||
| print(result.asnumpy()[0][0], expect[0][0]) | |||||
| assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, | |||||
| atol=1.e-8, equal_nan=True) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -43,3 +43,29 @@ def test_net(): | |||||
| print(arr_x) | print(arr_x) | ||||
| output = tile(Tensor(arr_x)) | output = tile(Tensor(arr_x)) | ||||
| print(output.asnumpy()) | print(output.asnumpy()) | ||||
| arr_x = np.array([[0], [1], [2], [3]]).astype(np.float64) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_net_float64(): | |||||
| tile = Net() | |||||
| print(arr_x) | |||||
| output = tile(Tensor(arr_x)) | |||||
| print(output.asnumpy()) | |||||
| arr_x = np.array([[0], [1], [2], [3]]).astype(np.bool_) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_net_bool(): | |||||
| tile = Net() | |||||
| print(arr_x) | |||||
| output = tile(Tensor(arr_x)) | |||||
| print(output.asnumpy()) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -40,7 +40,8 @@ class Transpose(nn.Cell): | |||||
| self.perm_3D = (1, 0, 2) | self.perm_3D = (1, 0, 2) | ||||
| self.x_4D = Parameter( | self.x_4D = Parameter( | ||||
| initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(np.float32)), [2, 3, 4, 5]), | |||||
| initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, | |||||
| 3, 4, 5).astype(np.float32)), [2, 3, 4, 5]), | |||||
| name='x_4D') | name='x_4D') | ||||
| self.perm_4D = (0, 1, 2, 3) | self.perm_4D = (0, 1, 2, 3) | ||||
| @@ -145,3 +146,247 @@ def test_transpose(): | |||||
| test_transpose() | test_transpose() | ||||
| class Transpose_int64(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Transpose_int64, self).__init__() | |||||
| self.transpose = P.Transpose() | |||||
| self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.int64)), [5, 6]), | |||||
| name='x_2D') | |||||
| self.perm_2D = (1, 0) | |||||
| self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.int64)), [2, 2, 4]), | |||||
| name='x_3D') | |||||
| self.perm_3D = (1, 0, 2) | |||||
| self.x_4D = Parameter( | |||||
| initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, | |||||
| 3, 4, 5).astype(np.int64)), [2, 3, 4, 5]), | |||||
| name='x_4D') | |||||
| self.perm_4D = (0, 1, 2, 3) | |||||
| self.x_5D = Parameter( | |||||
| initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.int64)), | |||||
| [1, 2, 3, 4, 5]), name='x_5D') | |||||
| self.perm_5D = (1, 0, 3, 4, 2) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), | |||||
| self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_transpose_int64(): | |||||
| transpose = Transpose_int64() | |||||
| output = transpose() | |||||
| expect0 = np.array([[[0, 6, 12, 18, 24], | |||||
| [1, 7, 13, 19, 25], | |||||
| [2, 8, 14, 20, 26], | |||||
| [3, 9, 15, 21, 27], | |||||
| [4, 10, 16, 22, 28], | |||||
| [5, 11, 17, 23, 29]]]).astype(np.int64) | |||||
| expect1 = np.array([[[[0, 1, 2, 3], | |||||
| [8, 9, 10, 11]], | |||||
| [[4, 5, 6, 7], | |||||
| [12, 13, 14, 15]]]]).astype(np.int64) | |||||
| expect2 = np.array([[[[[0, 1, 2, 3, 4], | |||||
| [5, 6, 7, 8, 9], | |||||
| [10, 11, 12, 13, 14], | |||||
| [15, 16, 17, 18, 19]], | |||||
| [[20, 21, 22, 23, 24], | |||||
| [25, 26, 27, 28, 29], | |||||
| [30, 31, 32, 33, 34], | |||||
| [35, 36, 37, 38, 39]], | |||||
| [[40, 41, 42, 43, 44], | |||||
| [45, 46, 47, 48, 49], | |||||
| [50, 51, 52, 53, 54], | |||||
| [55, 56, 57, 58, 59]]], | |||||
| [[[60, 61, 62, 63, 64], | |||||
| [65, 66, 67, 68, 69], | |||||
| [70, 71, 72, 73, 74], | |||||
| [75, 76, 77, 78, 79]], | |||||
| [[80, 81, 82, 83, 84], | |||||
| [85, 86, 87, 88, 89], | |||||
| [90, 91, 92, 93, 94], | |||||
| [95, 96, 97, 98, 99]], | |||||
| [[100, 101, 102, 103, 104], | |||||
| [105, 106, 107, 108, 109], | |||||
| [110, 111, 112, 113, 114], | |||||
| [115, 116, 117, 118, 119]]]]]).astype(np.int64) | |||||
| expect3 = np.array([[[[[[0, 20, 40], | |||||
| [1, 21, 41], | |||||
| [2, 22, 42], | |||||
| [3, 23, 43], | |||||
| [4, 24, 44]], | |||||
| [[5, 25, 45], | |||||
| [6, 26, 46], | |||||
| [7, 27, 47], | |||||
| [8, 28, 48], | |||||
| [9, 29, 49]], | |||||
| [[10, 30, 50], | |||||
| [11, 31, 51], | |||||
| [12, 32, 52], | |||||
| [13, 33, 53], | |||||
| [14, 34, 54]], | |||||
| [[15, 35, 55], | |||||
| [16, 36, 56], | |||||
| [17, 37, 57], | |||||
| [18, 38, 58], | |||||
| [19, 39, 59]]]], | |||||
| [[[[60, 80, 100], | |||||
| [61, 81, 101], | |||||
| [62, 82, 102], | |||||
| [63, 83, 103], | |||||
| [64, 84, 104]], | |||||
| [[65, 85, 105], | |||||
| [66, 86, 106], | |||||
| [67, 87, 107], | |||||
| [68, 88, 108], | |||||
| [69, 89, 109]], | |||||
| [[70, 90, 110], | |||||
| [71, 91, 111], | |||||
| [72, 92, 112], | |||||
| [73, 93, 113], | |||||
| [74, 94, 114]], | |||||
| [[75, 95, 115], | |||||
| [76, 96, 116], | |||||
| [77, 97, 117], | |||||
| [78, 98, 118], | |||||
| [79, 99, 119]]]]]]).astype(np.int64) | |||||
| assert (output[0].asnumpy() == expect0).all() | |||||
| assert (output[1].asnumpy() == expect1).all() | |||||
| assert (output[2].asnumpy() == expect2).all() | |||||
| assert (output[3].asnumpy() == expect3).all() | |||||
| test_transpose_int64() | |||||
| class Transpose_uint8(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Transpose_uint8, self).__init__() | |||||
| self.transpose = P.Transpose() | |||||
| self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.uint8)), [5, 6]), | |||||
| name='x_2D') | |||||
| self.perm_2D = (1, 0) | |||||
| self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.uint8)), [2, 2, 4]), | |||||
| name='x_3D') | |||||
| self.perm_3D = (1, 0, 2) | |||||
| self.x_4D = Parameter( | |||||
| initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, | |||||
| 3, 4, 5).astype(np.uint8)), [2, 3, 4, 5]), | |||||
| name='x_4D') | |||||
| self.perm_4D = (0, 1, 2, 3) | |||||
| self.x_5D = Parameter( | |||||
| initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.uint8)), | |||||
| [1, 2, 3, 4, 5]), name='x_5D') | |||||
| self.perm_5D = (1, 0, 3, 4, 2) | |||||
| @ms_function | |||||
| def construct(self): | |||||
| return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), | |||||
| self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_transpose_uint8(): | |||||
| transpose = Transpose_uint8() | |||||
| output = transpose() | |||||
| expect0 = np.array([[[0, 6, 12, 18, 24], | |||||
| [1, 7, 13, 19, 25], | |||||
| [2, 8, 14, 20, 26], | |||||
| [3, 9, 15, 21, 27], | |||||
| [4, 10, 16, 22, 28], | |||||
| [5, 11, 17, 23, 29]]]).astype(np.uint8) | |||||
| expect1 = np.array([[[[0, 1, 2, 3], | |||||
| [8, 9, 10, 11]], | |||||
| [[4, 5, 6, 7], | |||||
| [12, 13, 14, 15]]]]).astype(np.uint8) | |||||
| expect2 = np.array([[[[[0, 1, 2, 3, 4], | |||||
| [5, 6, 7, 8, 9], | |||||
| [10, 11, 12, 13, 14], | |||||
| [15, 16, 17, 18, 19]], | |||||
| [[20, 21, 22, 23, 24], | |||||
| [25, 26, 27, 28, 29], | |||||
| [30, 31, 32, 33, 34], | |||||
| [35, 36, 37, 38, 39]], | |||||
| [[40, 41, 42, 43, 44], | |||||
| [45, 46, 47, 48, 49], | |||||
| [50, 51, 52, 53, 54], | |||||
| [55, 56, 57, 58, 59]]], | |||||
| [[[60, 61, 62, 63, 64], | |||||
| [65, 66, 67, 68, 69], | |||||
| [70, 71, 72, 73, 74], | |||||
| [75, 76, 77, 78, 79]], | |||||
| [[80, 81, 82, 83, 84], | |||||
| [85, 86, 87, 88, 89], | |||||
| [90, 91, 92, 93, 94], | |||||
| [95, 96, 97, 98, 99]], | |||||
| [[100, 101, 102, 103, 104], | |||||
| [105, 106, 107, 108, 109], | |||||
| [110, 111, 112, 113, 114], | |||||
| [115, 116, 117, 118, 119]]]]]).astype(np.uint8) | |||||
| expect3 = np.array([[[[[[0, 20, 40], | |||||
| [1, 21, 41], | |||||
| [2, 22, 42], | |||||
| [3, 23, 43], | |||||
| [4, 24, 44]], | |||||
| [[5, 25, 45], | |||||
| [6, 26, 46], | |||||
| [7, 27, 47], | |||||
| [8, 28, 48], | |||||
| [9, 29, 49]], | |||||
| [[10, 30, 50], | |||||
| [11, 31, 51], | |||||
| [12, 32, 52], | |||||
| [13, 33, 53], | |||||
| [14, 34, 54]], | |||||
| [[15, 35, 55], | |||||
| [16, 36, 56], | |||||
| [17, 37, 57], | |||||
| [18, 38, 58], | |||||
| [19, 39, 59]]]], | |||||
| [[[[60, 80, 100], | |||||
| [61, 81, 101], | |||||
| [62, 82, 102], | |||||
| [63, 83, 103], | |||||
| [64, 84, 104]], | |||||
| [[65, 85, 105], | |||||
| [66, 86, 106], | |||||
| [67, 87, 107], | |||||
| [68, 88, 108], | |||||
| [69, 89, 109]], | |||||
| [[70, 90, 110], | |||||
| [71, 91, 111], | |||||
| [72, 92, 112], | |||||
| [73, 93, 113], | |||||
| [74, 94, 114]], | |||||
| [[75, 95, 115], | |||||
| [76, 96, 116], | |||||
| [77, 97, 117], | |||||
| [78, 98, 118], | |||||
| [79, 99, 119]]]]]]).astype(np.uint8) | |||||
| assert (output[0].asnumpy() == expect0).all() | |||||
| assert (output[1].asnumpy() == expect1).all() | |||||
| assert (output[2].asnumpy() == expect2).all() | |||||
| assert (output[3].asnumpy() == expect3).all() | |||||
| test_transpose_uint8() | |||||