From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosmantags/v1.3.0
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/sort_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Sort, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| SortGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Sort, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| SortGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,227 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SORT_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SORT_GPU_KERNEL_H_ | |||
| #include <algorithm> | |||
| #include <cstdint> | |||
| #include <limits> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SortGpuKernel : public GpuKernel { | |||
| public: | |||
| SortGpuKernel() { ResetResource(); } | |||
| ~SortGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input_device = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_device = GetDeviceAddress<T>(outputs, 0); | |||
| int32_t *indices_device = GetDeviceAddress<int32_t>(outputs, 1); | |||
| T *temp_output_device = GetDeviceAddress<T>(workspace, 0); | |||
| int32_t *temp_indices_device = GetDeviceAddress<int32_t>(workspace, 1); | |||
| size_t *input_shape_device = GetDeviceAddress<size_t>(workspace, 2); | |||
| size_t *perm_device = GetDeviceAddress<size_t>(workspace, 3); | |||
| size_t *transposed_shape_device = GetDeviceAddress<size_t>(workspace, 4); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(input_shape_device, &input_shape_[0], workspace_size_list_[2], | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync for input_shape_ failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(perm_device, &perm_[0], workspace_size_list_[3], cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync for perm_ failed"); | |||
| // Sort is implemented using a combination of Neg, Transpose, and TopK. It's | |||
| // Not safe to treat Transpose and TopK as inplace operators, so we alternate | |||
| // between using temp_output_device and output_device for intermediate calculations, | |||
| // this way only a constant number of allocations is needed instead of needing to | |||
| // allocate once for each intermediate calculation. | |||
| T *intermediate_input_device = input_device; | |||
| T *intermediate_output_device = output_device; | |||
| // if sort in descending order, negate input and negate back after sorting | |||
| if (!descending_) { | |||
| Negative(intermediate_input_device, intermediate_output_device, input_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| intermediate_input_device = output_device; | |||
| intermediate_output_device = temp_output_device; | |||
| } | |||
| // transpose so that desired dimension to sort along becomes the last one | |||
| CalTranspose(input_size_, intermediate_input_device, input_shape_device, perm_device, input_rank_, | |||
| intermediate_output_device, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| intermediate_input_device = intermediate_output_device; | |||
| intermediate_output_device = intermediate_input_device == output_device ? temp_output_device : output_device; | |||
| // topk sorts the input along the last dimension | |||
| FastTopK(outer_size_, inner_size_, intermediate_input_device, static_cast<int32_t>(input_shape_[axis_]), | |||
| intermediate_output_device, temp_indices_device, topk_init_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| std::swap(intermediate_input_device, intermediate_output_device); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(transposed_shape_device, &transposed_shape_[0], workspace_size_list_[4], | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync for transposed_shape_ failed"); | |||
| // transpose the sorted output back to the original input shape | |||
| CalTranspose(input_size_, intermediate_input_device, transposed_shape_device, perm_device, input_rank_, | |||
| intermediate_output_device, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| // transpose the indices back to the original input shape | |||
| CalTranspose(input_size_, temp_indices_device, transposed_shape_device, perm_device, input_rank_, indices_device, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| // negate back the sorted values if we negated prior to sorting | |||
| if (!descending_) { | |||
| std::swap(intermediate_input_device, intermediate_output_device); | |||
| Negative(intermediate_input_device, intermediate_output_device, input_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_count != 1) { | |||
| MS_LOG(ERROR) << input_count << " inputs were provided, but SortGpuKernel expects 2."; | |||
| return false; | |||
| } | |||
| size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_count != 2) { | |||
| MS_LOG(ERROR) << "Number of outputs is " << output_count << ", but should be 2 for SortGpuKernel."; | |||
| return false; | |||
| } | |||
| input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| input_rank_ = input_shape_.size(); | |||
| if (input_rank_ > TRANSPOSE_MAX_DIMENSION) { | |||
| MS_LOG(ERROR) << "Sort cannot support input that has more than " << TRANSPOSE_MAX_DIMENSION << " dimensions."; | |||
| } | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < input_rank_; i++) { | |||
| input_size_ *= input_shape_[i]; | |||
| } | |||
| descending_ = GetAttr<bool>(kernel_node, "descending"); | |||
| axis_ = GetAttr<int64_t>(kernel_node, "axis"); | |||
| if (axis_ < 0) { | |||
| axis_ += input_rank_; | |||
| } | |||
| perm_.resize(input_rank_); | |||
| std::iota(perm_.begin(), perm_.end(), 0); | |||
| std::swap(perm_[input_rank_ - 1], perm_[axis_]); | |||
| transposed_shape_ = input_shape_; | |||
| std::swap(transposed_shape_[input_rank_ - 1], transposed_shape_[axis_]); | |||
| inner_size_ = input_shape_[axis_]; | |||
| outer_size_ = input_size_ / inner_size_; | |||
| if (std::is_same<T, half>::value) { | |||
| // min value representable by float16, std::numeric_limits doesn't support half | |||
| topk_init_ = static_cast<half>(-65504.); | |||
| } else { | |||
| topk_init_ = std::numeric_limits<T>::lowest(); | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| input_size_ = 0; | |||
| axis_ = 0; | |||
| descending_ = false; | |||
| input_shape_.clear(); | |||
| input_rank_ = 0; | |||
| transposed_shape_.clear(); | |||
| perm_.clear(); | |||
| outer_size_ = 0; | |||
| inner_size_ = 0; | |||
| topk_init_ = static_cast<T>(0.); | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| size_t input_bytes = input_size_ * sizeof(T); | |||
| size_t indices_bytes = input_size_ * sizeof(int32_t); | |||
| input_size_list_.push_back(input_bytes); | |||
| // outputs: sorted values, indices | |||
| output_size_list_.push_back(input_bytes); | |||
| output_size_list_.push_back(indices_bytes); | |||
| // workspace: temp output, temp indices, input shape, perm, transposed_shape | |||
| workspace_size_list_.push_back(input_bytes); | |||
| workspace_size_list_.push_back(indices_bytes); | |||
| workspace_size_list_.push_back(input_rank_ * sizeof(size_t)); | |||
| workspace_size_list_.push_back(input_rank_ * sizeof(size_t)); | |||
| workspace_size_list_.push_back(input_rank_ * sizeof(size_t)); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| int64_t axis_; | |||
| bool descending_; | |||
| std::vector<size_t> input_shape_; | |||
| size_t input_rank_; | |||
| // for transpose | |||
| std::vector<size_t> transposed_shape_; | |||
| std::vector<size_t> perm_; | |||
| // for topk | |||
| size_t outer_size_; | |||
| size_t inner_size_; | |||
| T topk_init_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SORT_GPU_KERNEL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -25,5 +25,13 @@ MS_REG_GPU_KERNEL_TWO(TopK, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| TopKGpuKernel, float, int) | |||
| } | |||
| MS_REG_GPU_KERNEL_TWO(TopK, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| TopKGpuKernel, half, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -221,5 +221,7 @@ void FastTopK(const int outer_size, const int inner_size, const T *input, S k_cu | |||
| } | |||
| } | |||
| template void FastTopK(const int outer_size, const int inner_size, const half *input, int k_cut, half *output, | |||
| int *output_index, const half init_K, cudaStream_t stream); | |||
| template void FastTopK(const int outer_size, const int inner_size, const float *input, int k_cut, float *output, | |||
| int *output_index, const float init_K, cudaStream_t stream); | |||
| @@ -292,6 +292,8 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| @@ -1132,5 +1132,29 @@ AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const Primit | |||
| AbstractBasePtrList result = {index, value}; | |||
| return std::make_shared<AbstractTuple>(result); | |||
| } | |||
| AbstractBasePtr InferImplSort(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| TypePtrList supported_types = {kFloat16, kFloat32}; | |||
| (void)CheckTensorDType(input, supported_types, "input for Sort should be %s"); | |||
| ValuePtr axis_ptr = primitive->GetAttr("axis"); | |||
| int64_t axis = GetValue<int64_t>(axis_ptr); | |||
| int64_t input_rank = input->shape()->shape().size(); | |||
| if (!(axis >= -input_rank && axis < input_rank)) { | |||
| MS_LOG(EXCEPTION) << "axis is not in the valid range [" << -input_rank << ", " << input_rank << ")."; | |||
| } | |||
| auto sorted_values = std::make_shared<AbstractTensor>(input->element(), input->shape()); | |||
| TypePtr idx_type = kInt32; | |||
| auto indices = std::make_shared<AbstractTensor>(idx_type, input->shape()); | |||
| AbstractBasePtrList result = {sorted_values, indices}; | |||
| return std::make_shared<AbstractTuple>(result); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -94,6 +94,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, nullptr, true}}, | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}}, | |||
| {prim::kPrimSort, {InferImplSort, nullptr, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}}, | |||
| @@ -189,6 +189,7 @@ inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("Re | |||
| inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank"); | |||
| inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear"); | |||
| inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeGrad"); | |||
| inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam"); | |||
| @@ -0,0 +1,160 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class SortNet(nn.Cell): | |||
| def __init__(self, axis, descending): | |||
| super(SortNet, self).__init__() | |||
| self.sort = P.Sort(axis, descending) | |||
| def construct(self, x): | |||
| return self.sort(x) | |||
| def sort_1d(descending, nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_numpy = np.array([1, -2, 3, 4]).astype(np.float16) | |||
| x = Tensor(x_numpy) | |||
| sort_net = SortNet(0, descending) | |||
| output, indices = sort_net(x) | |||
| expected_output = np.sort(x_numpy, 0) | |||
| expected_indices = np.array([1, 0, 2, 3]) | |||
| if descending: | |||
| expected_output = expected_output[::-1] | |||
| expected_indices = expected_indices[::-1] | |||
| np.testing.assert_array_equal(output.asnumpy(), expected_output) | |||
| np.testing.assert_array_equal(indices.asnumpy(), expected_indices) | |||
| def sort_3d(descending, nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x_numpy = np.array([[[1, 2, 3, 4], | |||
| [8, 7, 2, 0], | |||
| [9, 4, 1, 8]], | |||
| [[5, 4, 1, 8], | |||
| [2, 9, 0, 7], | |||
| [6, 1, 7, 4]]]).astype(nptype) | |||
| x = Tensor(x_numpy) | |||
| axis = -1 | |||
| sort_net = SortNet(axis, descending) | |||
| output, indices = sort_net(x) | |||
| expected_output = np.sort(x_numpy, axis) | |||
| expected_indices = np.array([[[0, 1, 2, 3], | |||
| [3, 2, 1, 0], | |||
| [2, 1, 3, 0]], | |||
| [[2, 1, 0, 3], | |||
| [2, 0, 3, 1], | |||
| [1, 3, 0, 2]]]) | |||
| if descending: | |||
| expected_output = expected_output[:, :, ::-1] | |||
| expected_indices = expected_indices[:, :, ::-1] | |||
| np.testing.assert_array_equal(output.asnumpy(), expected_output) | |||
| np.testing.assert_array_equal(indices.asnumpy(), expected_indices) | |||
| axis = 1 | |||
| sort_net = SortNet(axis, descending) | |||
| output, indices = sort_net(x) | |||
| expected_output = np.sort(x_numpy, axis) | |||
| expected_indices = np.array([[[0, 0, 2, 1], | |||
| [1, 2, 1, 0], | |||
| [2, 1, 0, 2]], | |||
| [[1, 2, 1, 2], | |||
| [0, 0, 0, 1], | |||
| [2, 1, 2, 0]]]) | |||
| if descending: | |||
| expected_output = expected_output[:, ::-1, :] | |||
| expected_indices = expected_indices[:, ::-1, :] | |||
| np.testing.assert_array_equal(output.asnumpy(), expected_output) | |||
| np.testing.assert_array_equal(indices.asnumpy(), expected_indices) | |||
| axis = -3 | |||
| sort_net = SortNet(axis, descending) | |||
| output, indices = sort_net(x) | |||
| expected_output = np.sort(x_numpy, axis) | |||
| expected_indices = np.array([[[0, 0, 1, 0], | |||
| [1, 0, 1, 0], | |||
| [1, 1, 0, 1]], | |||
| [[1, 1, 0, 1], | |||
| [0, 1, 0, 1], | |||
| [0, 0, 1, 0]]]) | |||
| if descending: | |||
| expected_output = expected_output[::-1, :, :] | |||
| expected_indices = expected_indices[::-1, :, :] | |||
| np.testing.assert_array_equal(output.asnumpy(), expected_output) | |||
| np.testing.assert_array_equal(indices.asnumpy(), expected_indices) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort1d_float16(): | |||
| sort_1d(False, np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort1d_descending_float16(): | |||
| sort_1d(True, np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort1d_float32(): | |||
| sort_1d(False, np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort1d_descending_float32(): | |||
| sort_1d(True, np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort3d_float16(): | |||
| sort_3d(False, np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort3d_descending_float16(): | |||
| sort_3d(True, np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort3d_float32(): | |||
| sort_3d(False, np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sort3d_descending_float32(): | |||
| sort_3d(True, np.float32) | |||