| @@ -1,114 +1,126 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class TransposeGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| TransposeGpuFwdKernel() : shape_size_(0), input_size_(0), output_size_(0), workspace_size_(0) {} | |||
| ~TransposeGpuFwdKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0); | |||
| size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_shape failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_axis failed"); | |||
| size_t size = input_size_ / sizeof(T); | |||
| CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); | |||
| shape_size_ = input_shape.size(); | |||
| if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { | |||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION | |||
| << "-D inputs."; | |||
| } | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < shape_size_; i++) { | |||
| input_size_ *= input_shape[i]; | |||
| input_shape_.push_back(input_shape[i]); | |||
| } | |||
| input_size_ *= sizeof(T); | |||
| output_size_ = input_size_; | |||
| std::vector<int> perm; | |||
| std::vector<int64_t> perm_me = GetAttr<std::vector<int64_t>>(kernel_node, "perm"); | |||
| (void)std::transform(perm_me.begin(), perm_me.end(), std::back_inserter(perm), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| for (size_t j = 0; j < perm.size(); j++) { | |||
| input_axis_.push_back(perm[j]); | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_ = shape_size_ * sizeof(size_t); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| return; | |||
| } | |||
| private: | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> input_axis_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t shape_size_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class TransposeGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| TransposeGpuFwdKernel() { ResetResource(); } | |||
| ~TransposeGpuFwdKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0); | |||
| size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_shape failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_axis failed"); | |||
| size_t size = input_size_ / sizeof(T); | |||
| CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); | |||
| shape_size_ = input_shape.size(); | |||
| if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { | |||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION | |||
| << "-D inputs."; | |||
| } | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < shape_size_; i++) { | |||
| input_size_ *= input_shape[i]; | |||
| input_shape_.push_back(input_shape[i]); | |||
| } | |||
| input_size_ *= sizeof(T); | |||
| output_size_ = input_size_; | |||
| std::vector<int> perm; | |||
| std::vector<int64_t> perm_me = GetAttr<std::vector<int64_t>>(kernel_node, "perm"); | |||
| (void)std::transform(perm_me.begin(), perm_me.end(), std::back_inserter(perm), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| for (size_t j = 0; j < perm.size(); j++) { | |||
| input_axis_.push_back(perm[j]); | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| shape_size_ = 0; | |||
| input_size_ = 0; | |||
| output_size_ = 0; | |||
| workspace_size_ = 0; | |||
| input_shape_.clear(); | |||
| input_axis_.clear(); | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_ = shape_size_ * sizeof(size_t); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| return; | |||
| } | |||
| private: | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> input_axis_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t shape_size_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRANSPOSE_H_ | |||
| @@ -44,7 +44,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An | |||
| } | |||
| std::set<string> DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName, kReshapeOpName, | |||
| kEmbeddingLookupOpName}; | |||
| kEmbeddingLookupOpName, kTransposeOpName}; | |||
| for (auto &t : todos) { | |||
| CNodePtr cnode = t->cast<CNodePtr>(); | |||
| ConstInputToAttrInfoRegister reg; | |||
| @@ -82,6 +82,7 @@ constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin"; | |||
| constexpr auto kFlattenGradOpName = "FlattenGrad"; | |||
| constexpr auto kExpandDimsOpName = "ExpandDims"; | |||
| constexpr auto kReshapeOpName = "Reshape"; | |||
| constexpr auto kTransposeOpName = "Transpose"; | |||
| constexpr auto kSplitOpName = "Split"; | |||
| constexpr auto kSplitVOpName = "SplitV"; | |||
| constexpr auto kSparseApplyAdagradOpName = "SparseApplyAdagrad"; | |||
| @@ -14,7 +14,6 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <set> | |||
| #include <algorithm> | |||
| #include <iterator> | |||
| #include "abstract/infer_functions.h" | |||
| @@ -260,7 +259,11 @@ AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePt | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| return std::make_shared<AbstractTensor>(x->element(), x->shape()); | |||
| ShapeVector shape = x->shape()->shape(); | |||
| ShapeVector min_shape = x->shape()->min_shape(); | |||
| ShapeVector max_shape = x->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(shape, &min_shape, &max_shape); | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -270,7 +273,11 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| return std::make_shared<AbstractTensor>(x->element(), x->shape()); | |||
| ShapeVector shape = x->shape()->shape(); | |||
| ShapeVector min_shape = x->shape()->min_shape(); | |||
| ShapeVector max_shape = x->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(shape, &min_shape, &max_shape); | |||
| return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -536,43 +543,28 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr | |||
| AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto perm = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| auto input_shp = input->shape()->shape(); | |||
| auto perm_val = perm->BuildValue(); | |||
| if (perm_val->isa<AnyValue>()) { | |||
| MS_LOG(EXCEPTION) << "Perm can't be anything: " << args_spec_list[1]->ToString(); | |||
| } | |||
| auto perm_val_data = perm_val->cast<ValueTuplePtr>()->value(); | |||
| ValuePtr perm = primitive->GetAttr("perm"); | |||
| auto perm_val = perm->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(perm_val); | |||
| auto perm_val_data = perm_val->value(); | |||
| ShapeVector perm_vec; | |||
| (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| ShapeVector result_shp; | |||
| std::set<size_t> indices; | |||
| ShapeVector max_shp; | |||
| ShapeVector min_shp; | |||
| ShapeVector x_max_shp = input->shape()->max_shape(); | |||
| ShapeVector x_min_shp = input->shape()->min_shape(); | |||
| (void)CheckMinMaxShape(input_shp, &x_min_shp, &x_max_shp); | |||
| for (size_t i = 0; i < perm_vec.size(); i++) { | |||
| size_t idx = static_cast<size_t>(perm_vec[i]); | |||
| if (indices.find(idx) != indices.end()) { | |||
| MS_LOG(EXCEPTION) << "Perm values must be unique"; | |||
| } | |||
| if (idx >= perm_vec.size()) { | |||
| MS_LOG(EXCEPTION) << "One value in perm is " << idx << ", not in range [0, " << perm_vec.size() << ")"; | |||
| } | |||
| result_shp.push_back(input_shp[idx]); | |||
| indices.insert(idx); | |||
| } | |||
| ShapeVector max_shp; | |||
| ShapeVector min_shp; | |||
| if (input->shape()->max_shape().size() == input_shp.size() && | |||
| input->shape()->min_shape().size() == input_shp.size()) { | |||
| for (size_t i = 0; i < perm_vec.size(); i++) { | |||
| size_t idx = static_cast<size_t>(perm_vec[i]); | |||
| max_shp.push_back(input->shape()->max_shape()[idx]); | |||
| min_shp.push_back(input->shape()->min_shape()[idx]); | |||
| } | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp)); | |||
| max_shp.push_back(x_max_shp[idx]); | |||
| min_shp.push_back(x_min_shp[idx]); | |||
| } | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp)); | |||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp)); | |||
| } | |||
| AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -310,5 +310,10 @@ size_t TypeIdSize(const TypeId data_type) { | |||
| size_t ShapeSize(const std::vector<size_t> &shape) { | |||
| return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>()); | |||
| } | |||
| void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) { | |||
| *min_shape = (*min_shape).empty() ? shape : *min_shape; | |||
| *max_shape = (*max_shape).empty() ? shape : *max_shape; | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -56,6 +56,10 @@ size_t ShapeSize(const std::vector<size_t> &shape); | |||
| // Get broadcasted shape for binary element-wise operation | |||
| ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y); | |||
| // Check dynamic shape routine | |||
| void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_ | |||
| @@ -73,17 +73,13 @@ class _ScatterOp_Dynamic(PrimitiveWithCheck): | |||
| """ | |||
| Defines Scatter operators with dynamic shape | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('indices', dtype=sig.sig_dtype.T1), | |||
| sig.make_sig('updates', dtype=sig.sig_dtype.T) | |||
| ) | |||
| def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | |||
| if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: | |||
| raise ValueError(f"For '{prim_name}', " | |||
| f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | |||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | |||
| if np.all(np.array(x_shape) != -1): | |||
| if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]: | |||
| raise ValueError(f"For '{prim_name}', " | |||
| f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | |||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=False): | |||
| @@ -649,7 +645,7 @@ class Squeeze(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class Transpose(PrimitiveWithCheck): | |||
| class Transpose(PrimitiveWithInfer): | |||
| """ | |||
| Permutes the dimensions of the input tensor according to input permutation. | |||
| @@ -685,14 +681,36 @@ class Transpose(PrimitiveWithCheck): | |||
| """Initialize Transpose""" | |||
| self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) | |||
| def check_shape(self, x, perm): | |||
| validator.check_value_type("perm", perm, [tuple], self.name) | |||
| if len(x) != len(perm): | |||
| def __infer__(self, x, perm): | |||
| x_shape = x['shape'] | |||
| p_value = perm['value'] | |||
| x_type = x['dtype'] | |||
| validator.check_value_type("p_value", p_value, [tuple], self.name) | |||
| validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | |||
| if len(x_shape) != len(p_value): | |||
| raise ValueError('The dimension of x and perm must be equal.') | |||
| def check_dtype(self, x, perm): | |||
| validator.check_subclass("x", x, mstype.tensor, self.name) | |||
| tmp = list(p_value) | |||
| for i, dim in enumerate(p_value): | |||
| validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name) | |||
| validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name) | |||
| tmp.remove(dim) | |||
| if dim in tmp: | |||
| raise ValueError('The value of perm is wrong.') | |||
| out_shapes = [] | |||
| for i in p_value: | |||
| out_shapes.append(x_shape[i]) | |||
| out = {'shape': tuple(out_shapes), | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| if 'min_shape' in x and 'max_shape' in x: | |||
| min_vec = [] | |||
| max_vec = [] | |||
| for i in p_value: | |||
| min_vec.append(x['min_shape'][i]) | |||
| max_vec.append(x['max_shape'][i]) | |||
| out['min_shape'] = tuple(min_vec) | |||
| out['max_shape'] = tuple(max_vec) | |||
| return out | |||
| class Unique(Primitive): | |||
| """ | |||
| @@ -19,6 +19,7 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| # all cases tested against dchip | |||
| @@ -45,6 +46,44 @@ def scatter_add_use_locking_false_net(inputx, indices, updates): | |||
| net = TestScatterAddNet(lock, inputx, indices, updates) | |||
| return net() | |||
| class TestScatterAddDynamicNet(nn.Cell): | |||
| def __init__(self, inputx, indices, updates): | |||
| super(TestScatterAddDynamicNet, self).__init__() | |||
| self.scatter_add = P.ScatterAdd() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.inputx = Parameter(inputx, name="inputx") | |||
| self.indices = Parameter(indices, name="indices") | |||
| self.updates = Parameter(updates, name="updates") | |||
| def construct(self): | |||
| out = self.test_dynamic(self.inputx) | |||
| out = self.scatter_add(out, self.indices, self.updates) | |||
| return out | |||
| def scatter_add_d_net(inputx, indices, updates): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = TestScatterAddDynamicNet(inputx, indices, updates) | |||
| return net() | |||
| class TestScatterAddDynamicNet2(nn.Cell): | |||
| def __init__(self): | |||
| super(TestScatterAddDynamicNet2, self).__init__() | |||
| self.scatter_add = P.ScatterAdd() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| def construct(self, inputx, indices, updates): | |||
| out = self.test_dynamic(inputx) | |||
| out = self.scatter_add(out, indices, updates) | |||
| return out | |||
| def scatter_add_d2_net(inputx_1, indices_1, updates_1, inputx_2, | |||
| indices_2, updates_2): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = TestScatterAddDynamicNet2() | |||
| out1 = net(inputx_1, indices_1, updates_1) | |||
| out2 = net(inputx_2, indices_2, updates_2) | |||
| return (out1, out2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -196,3 +235,78 @@ def test_scatter_add_disordered_int32(): | |||
| [187., 188., 189., 190.], | |||
| [492., 496., 500., 504.]]).astype(np.int32) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_add_disordered_dynamic_int32(): | |||
| inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) | |||
| indices = Tensor(np.array([[[0, 1, 2], | |||
| [2, 1, 0]], | |||
| [[0, 0, 0], | |||
| [2, 2, 2]]]).astype(np.int32)) | |||
| updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32)) | |||
| output = scatter_add_d_net(inputx, indices, updates) | |||
| expected = np.array([[464., 468., 472., 476.], | |||
| [187., 188., 189., 190.], | |||
| [492., 496., 500., 504.]]).astype(np.int32) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_add_input_less_than_1_dynamic_float32(): | |||
| inputx = Tensor(np.array([[0.214141, 0.415151, 0.51516], | |||
| [0.876542, 0.451611, 0.55112], | |||
| [0.111244, 0.633333, 0.34444]]).astype(np.float32)) | |||
| indices = Tensor(np.array([[[1, 0, 2], | |||
| [2, 2, 0]], | |||
| [[1, 0, 1], | |||
| [2, 1, 2]]]).astype(np.int32)) | |||
| updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32)) | |||
| output = scatter_add_d_net(inputx, indices, updates) | |||
| expected = np.array([[141.21414, 144.41515, 147.51517], | |||
| [208.87654, 212.45161, 216.55112], | |||
| [257.11124, 262.63333, 267.34442]], dtype=np.float32) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_add_dynamic_two_inputs(): | |||
| inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| indices_1 = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32)) | |||
| updates_1 = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32)) | |||
| inputx_2 = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32)) | |||
| indices_2 = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32)) | |||
| updates_2 = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32)) | |||
| output_1, output_2 = scatter_add_d2_net(inputx_1, indices_1, updates_1, | |||
| inputx_2, indices_2, updates_2) | |||
| expected_1 = np.array([[6., 8., 10.], | |||
| [12., 14., 16.]]) | |||
| expected_2 = np.array([[[[1., 2., 3., 4.], | |||
| [5., 6., 7., 8.], | |||
| [9., 10., 11., 12.]], | |||
| [[13., 14., 15., 16.], | |||
| [17., 18., 19., 20.], | |||
| [21., 22., 23., 24.]]], | |||
| [[[73., 74., 75., 76.], | |||
| [77., 78., 79., 80.], | |||
| [81., 82., 83., 84.]], | |||
| [[85., 86., 87., 88.], | |||
| [89., 90., 91., 92.], | |||
| [93., 94., 95., 96.]]], | |||
| [[[25., 26., 27., 28.], | |||
| [29., 30., 31., 32.], | |||
| [33., 34., 35., 36.]], | |||
| [[37., 38., 39., 40.], | |||
| [41., 42., 43., 44.], | |||
| [45., 46., 47., 48.]]], | |||
| [[[49., 50., 51., 52.], | |||
| [53., 54., 55., 56.], | |||
| [57., 58., 59., 60.]], | |||
| [[61., 62., 63., 64.], | |||
| [65., 66., 67., 68.], | |||
| [69., 70., 71., 72.]]]]) | |||
| np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1) | |||
| np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2) | |||
| @@ -19,6 +19,7 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| # all cases tested against dchip | |||
| @@ -39,6 +40,44 @@ def scatter_update_net(inputx, indices, updates): | |||
| net = TestScatterUpdateNet(inputx, indices, updates) | |||
| return net() | |||
| class TestScatterUpdateDynamicNet(nn.Cell): | |||
| def __init__(self, inputx, indices, updates): | |||
| super(TestScatterUpdateDynamicNet, self).__init__() | |||
| self.scatter_update = P.ScatterUpdate() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.inputx = Parameter(inputx, name="inputx") | |||
| self.indices = Parameter(indices, name="indices") | |||
| self.updates = Parameter(updates, name="updates") | |||
| def construct(self): | |||
| out = self.test_dynamic(self.inputx) | |||
| out = self.scatter_update(out, self.indices, self.updates) | |||
| return out | |||
| def scatter_update_d_net(inputx, indices, updates): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = TestScatterUpdateDynamicNet(inputx, indices, updates) | |||
| return net() | |||
| class TestScatterUpdateDynamicNet2(nn.Cell): | |||
| def __init__(self): | |||
| super(TestScatterUpdateDynamicNet2, self).__init__() | |||
| self.scatter_update = P.ScatterUpdate() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| def construct(self, inputx, indices, updates): | |||
| out = self.test_dynamic(inputx) | |||
| out = self.scatter_update(out, indices, updates) | |||
| return out | |||
| def scatter_update_d2_net(inputx_1, indices_1, updates_1, inputx_2, | |||
| indices_2, updates_2): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = TestScatterUpdateDynamicNet2() | |||
| out1 = net(inputx_1, indices_1, updates_1) | |||
| out2 = net(inputx_2, indices_2, updates_2) | |||
| return (out1, out2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -237,3 +276,72 @@ def test_scatter_update_disordered_uint8(): | |||
| [63., 64., 65., 66.], | |||
| [67., 68., 69., 70.]]).astype(np.uint8) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_large_shape_dynamic_int8(): | |||
| inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.int8)) | |||
| indices = Tensor(np.array([1, 0]).astype(np.int32)) | |||
| updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.int8))) | |||
| output = scatter_update_d_net(inputx, indices, updates) | |||
| expected = np.array([[[[23., 22., 21., 20.], | |||
| [19., 18., 17., 16.], | |||
| [15., 14., 13., 12.]], | |||
| [[11., 10., 9., 8.], | |||
| [7., 6., 5., 4.], | |||
| [3., 2., 1., 0.]]], | |||
| [[[47., 46., 45., 44.], | |||
| [43., 42., 41., 40.], | |||
| [39., 38., 37., 36.]], | |||
| [[35., 34., 33., 32.], | |||
| [31., 30., 29., 28.], | |||
| [27., 26., 25., 24.]]], | |||
| [[[48., 49., 50., 51.], | |||
| [52., 53., 54., 55.], | |||
| [56., 57., 58., 59.]], | |||
| [[60., 61., 62., 63.], | |||
| [64., 65., 66., 67.], | |||
| [68., 69., 70., 71.]]], | |||
| [[[72., 73., 74., 75.], | |||
| [76., 77., 78., 79.], | |||
| [80., 81., 82., 83.]], | |||
| [[84., 85., 86., 87.], | |||
| [88., 89., 90., 91.], | |||
| [92., 93., 94., 95.]]]]).astype(np.int8) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_disordered_dynamic_int32(): | |||
| inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) | |||
| indices = Tensor(np.array([1, 2]).astype(np.int32)) | |||
| updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int32)) | |||
| output = scatter_update_d_net(inputx, indices, updates) | |||
| expected = np.array([[45., 44., 43., 42.], | |||
| [63., 64., 65., 66.], | |||
| [67., 68., 69., 70.]]).astype(np.int32) | |||
| np.testing.assert_array_almost_equal(output.asnumpy(), expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scatter_update_two_inputs(): | |||
| inputx_1 = Tensor(np.zeros((2, 3)).astype(np.float32)) | |||
| indices_1 = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| updates_1 = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32)) | |||
| inputx_2 = Tensor(np.array([[0.214141, 0.415151, 0.51516], | |||
| [0.876542, 0.451611, 0.55112], | |||
| [0.111244, 0.633333, 0.34444]]).astype(np.float32)) | |||
| indices_2 = Tensor(np.array([1, 0, 2]).astype(np.int32)) | |||
| updates_2 = Tensor(np.arange(34, 43).reshape((3, 3)).astype(np.float32)) | |||
| output_1, output_2 = scatter_update_d2_net(inputx_1, indices_1, updates_1, | |||
| inputx_2, indices_2, updates_2) | |||
| expected_1 = np.array([[0., 1., 2.], | |||
| [3., 4., 5.]]) | |||
| expected_2 = np.array([[37., 38., 39.], | |||
| [34., 35., 36.], | |||
| [40., 41., 42.]], dtype=np.float32) | |||
| np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1) | |||
| np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2) | |||
| @@ -23,28 +23,24 @@ from mindspore.common.api import ms_function | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| context.set_context(device_target='GPU') | |||
| class Transpose(nn.Cell): | |||
| def __init__(self, nptype): | |||
| super(Transpose, self).__init__() | |||
| self.transpose = P.Transpose() | |||
| self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(nptype)), [5, 6]), | |||
| name='x_2D') | |||
| self.perm_2D = (1, 0) | |||
| self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(nptype)), [2, 2, 4]), | |||
| name='x_3D') | |||
| self.perm_3D = (1, 0, 2) | |||
| self.x_4D = Parameter( | |||
| initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(nptype)), [2, 3, 4, 5]), | |||
| name='x_4D') | |||
| self.perm_4D = (0, 1, 2, 3) | |||
| self.x_5D = Parameter( | |||
| initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(nptype)), | |||
| [1, 2, 3, 4, 5]), name='x_5D') | |||
| @@ -55,11 +51,42 @@ class Transpose(nn.Cell): | |||
| return (self.transpose(self.x_2D, self.perm_2D), self.transpose(self.x_3D, self.perm_3D), | |||
| self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) | |||
| class Transpose_dynamic(nn.Cell): | |||
| def __init__(self, nptype): | |||
| super(Transpose_dynamic, self).__init__() | |||
| self.transpose = P.Transpose() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.x = Parameter( | |||
| initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(nptype)), | |||
| [1, 2, 3, 4, 5]), name='5D') | |||
| self.perm = (1, 0, 3, 4, 2) | |||
| @ms_function | |||
| def construct(self): | |||
| out = self.test_dynamic(self.x) | |||
| return self.transpose(out, self.perm) | |||
| class Transpose_dynamic2(nn.Cell): | |||
| def __init__(self, input_1, input_2, perm_1, perm_2): | |||
| super(Transpose_dynamic2, self).__init__() | |||
| self.transpose = P.Transpose() | |||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||
| self.x_1 = input_1 | |||
| self.x_2 = input_2 | |||
| self.perm_1 = perm_1 | |||
| self.perm_2 = perm_2 | |||
| @ms_function | |||
| def construct(self): | |||
| out_1 = self.test_dynamic(self.x_1) | |||
| out_1 = self.transpose(out_1, self.perm_1) | |||
| out_2 = self.test_dynamic(self.x_2) | |||
| out_2 = self.transpose(out_2, self.perm_2) | |||
| return (out_1, out_2) | |||
| def transpose1(nptype): | |||
| transpose = Transpose(nptype) | |||
| output = transpose() | |||
| expect0 = np.array([[[0, 6, 12, 18, 24], | |||
| [1, 7, 13, 19, 25], | |||
| [2, 8, 14, 20, 26], | |||
| @@ -82,7 +109,6 @@ def transpose1(nptype): | |||
| [45, 46, 47, 48, 49], | |||
| [50, 51, 52, 53, 54], | |||
| [55, 56, 57, 58, 59]]], | |||
| [[[60, 61, 62, 63, 64], | |||
| [65, 66, 67, 68, 69], | |||
| [70, 71, 72, 73, 74], | |||
| @@ -115,7 +141,6 @@ def transpose1(nptype): | |||
| [17, 37, 57], | |||
| [18, 38, 58], | |||
| [19, 39, 59]]]], | |||
| [[[[60, 80, 100], | |||
| [61, 81, 101], | |||
| [62, 82, 102], | |||
| @@ -141,6 +166,75 @@ def transpose1(nptype): | |||
| assert (output[2].asnumpy() == expect2).all() | |||
| assert (output[3].asnumpy() == expect3).all() | |||
| def transpose_d(nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| transpose = Transpose_dynamic(nptype) | |||
| output = transpose() | |||
| expect = np.array([[[[[[0, 20, 40], | |||
| [1, 21, 41], | |||
| [2, 22, 42], | |||
| [3, 23, 43], | |||
| [4, 24, 44]], | |||
| [[5, 25, 45], | |||
| [6, 26, 46], | |||
| [7, 27, 47], | |||
| [8, 28, 48], | |||
| [9, 29, 49]], | |||
| [[10, 30, 50], | |||
| [11, 31, 51], | |||
| [12, 32, 52], | |||
| [13, 33, 53], | |||
| [14, 34, 54]], | |||
| [[15, 35, 55], | |||
| [16, 36, 56], | |||
| [17, 37, 57], | |||
| [18, 38, 58], | |||
| [19, 39, 59]]]], | |||
| [[[[60, 80, 100], | |||
| [61, 81, 101], | |||
| [62, 82, 102], | |||
| [63, 83, 103], | |||
| [64, 84, 104]], | |||
| [[65, 85, 105], | |||
| [66, 86, 106], | |||
| [67, 87, 107], | |||
| [68, 88, 108], | |||
| [69, 89, 109]], | |||
| [[70, 90, 110], | |||
| [71, 91, 111], | |||
| [72, 92, 112], | |||
| [73, 93, 113], | |||
| [74, 94, 114]], | |||
| [[75, 95, 115], | |||
| [76, 96, 116], | |||
| [77, 97, 117], | |||
| [78, 98, 118], | |||
| [79, 99, 119]]]]]]).astype(nptype) | |||
| assert (output.asnumpy() == expect).all() | |||
| def transpose_d2(nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| input_1 = Parameter(Tensor(np.arange(5 * 6).reshape(5, 6).astype(nptype)), | |||
| name="input_1") | |||
| input_2 = Parameter(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(nptype)), | |||
| name="input_2") | |||
| perm_1 = (1, 0) | |||
| perm_2 = (1, 0, 2) | |||
| expect_1 = np.array([[[0, 6, 12, 18, 24], | |||
| [1, 7, 13, 19, 25], | |||
| [2, 8, 14, 20, 26], | |||
| [3, 9, 15, 21, 27], | |||
| [4, 10, 16, 22, 28], | |||
| [5, 11, 17, 23, 29]]]).astype(nptype) | |||
| expect_2 = np.array([[[[0, 1, 2, 3], | |||
| [8, 9, 10, 11]], | |||
| [[4, 5, 6, 7], | |||
| [12, 13, 14, 15]]]]).astype(nptype) | |||
| net = Transpose_dynamic2(input_1, input_2, perm_1, perm_2) | |||
| output_1, output_2 = net() | |||
| assert (output_1.asnumpy() == expect_1).all() | |||
| assert (output_2.asnumpy() == expect_2).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -158,3 +252,39 @@ def test_transpose_float16(): | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_int32(): | |||
| transpose1(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_dynamic_float32(): | |||
| transpose_d(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_dynamic_float16(): | |||
| transpose_d(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_dynamic_int32(): | |||
| transpose_d(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_dynamic_two_inputs_float32(): | |||
| transpose_d2(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_dynamic_two_inputs_float16(): | |||
| transpose_d2(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_transpose_dynamic_two_inputs_int32(): | |||
| transpose_d2(np.int32) | |||