| @@ -47,9 +47,12 @@ constexpr auto kEditDistance = "EditDistance"; | |||||
| constexpr auto kGatherD = "GatherD"; | constexpr auto kGatherD = "GatherD"; | ||||
| constexpr auto kIdentity = "Identity"; | constexpr auto kIdentity = "Identity"; | ||||
| constexpr auto kUpdateCache = "UpdateCache"; | constexpr auto kUpdateCache = "UpdateCache"; | ||||
| constexpr auto kCacheSwapTable = "CacheSwapTable"; | |||||
| constexpr auto kSubAndFilter = "SubAndFilter"; | |||||
| constexpr auto kPadAndShift = "PadAndShift"; | |||||
| constexpr auto kCustRunApi = "RunCpuKernel"; | constexpr auto kCustRunApi = "RunCpuKernel"; | ||||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | ||||
| const std::set<std::string> kCacheKernelOps{kUpdateCache}; | |||||
| const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift}; | |||||
| struct AicpuParamHead { | struct AicpuParamHead { | ||||
| uint32_t length; // Total length: include cunstom message | uint32_t length; // Total length: include cunstom message | ||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.h" | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void DynamicAssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| node_ = kernel_node; | |||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| input_x_dtype_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); | |||||
| } | |||||
| bool DynamicAssignCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (input_x_dtype_ == kNumberTypeInt32) { | |||||
| LaunchKernel<int>(inputs, outputs); | |||||
| } else if (input_x_dtype_ == kNumberTypeInt64) { | |||||
| LaunchKernel<int64_t>(inputs, outputs); | |||||
| } else if (input_x_dtype_ == kNumberTypeFloat32) { | |||||
| LaunchKernel<float>(inputs, outputs); | |||||
| } else if (input_x_dtype_ == kNumberTypeFloat64) { | |||||
| LaunchKernel<double>(inputs, outputs); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Dtype of indices only support float32, float64, int32, int64"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| void DynamicAssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | |||||
| auto input_y_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | |||||
| batch_size_ = 1; | |||||
| for (size_t i = 0; i < input_x_shape.size(); ++i) { | |||||
| batch_size_ *= input_x_shape[i]; | |||||
| } | |||||
| if (input_x_shape.size() != input_y_shape.size()) MS_LOG(EXCEPTION) << "x y must be same shape"; | |||||
| for (size_t i = 0; i < input_x_shape.size(); ++i) { | |||||
| if (input_x_shape[i] != input_y_shape[i]) { | |||||
| MS_LOG(EXCEPTION) << "x y must be same shape"; | |||||
| } | |||||
| } | |||||
| T *input_x = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| T *input_y = reinterpret_cast<T *>(inputs[1]->addr); | |||||
| auto max_size = inputs[0]->size; | |||||
| size_t total_size = input_x_dtype_size_ * batch_size_; | |||||
| if (total_size > max_size) { | |||||
| MS_LOG(EXCEPTION) << "Memcpy size must <= max_size, but got memcpy size is : " << total_size | |||||
| << ", max size is : " << max_size; | |||||
| } | |||||
| int ret = memcpy_s(input_x, total_size, input_y, total_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "Memcpy_s error, errorno" << ret; | |||||
| } | |||||
| auto node_with_idx = AnfAlgo::GetPrevNodeOutput(node_, 0); | |||||
| auto node = node_with_idx.first; | |||||
| if (node->isa<Parameter>()) { | |||||
| auto node_ptr = node->cast<ParameterPtr>(); | |||||
| auto value = node_ptr->default_param(); | |||||
| auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>(); | |||||
| ShapeVector shape_tmp; | |||||
| (void)std::transform(input_x_shape.begin(), input_x_shape.end(), std::back_inserter(shape_tmp), SizeToLong); | |||||
| tensor->set_shape(shape_tmp); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Input x must be a Parameter."; | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_ASSIGN_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_ASSIGN_CPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class DynamicAssignCPUKernel : public CPUKernel { | |||||
| public: | |||||
| DynamicAssignCPUKernel() = default; | |||||
| ~DynamicAssignCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| template <typename T> | |||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); | |||||
| private: | |||||
| size_t batch_size_{1}; | |||||
| TypeId input_x_dtype_{kTypeUnknown}; | |||||
| size_t input_x_dtype_size_ = 4; | |||||
| CNodePtr node_ = nullptr; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL( | |||||
| DynamicAssign, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicAssignCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| DynamicAssign, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||||
| DynamicAssignCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| DynamicAssign, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| DynamicAssignCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| DynamicAssign, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| DynamicAssignCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_ASSIGN_CPU_KERNEL_H_ | |||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.h" | |||||
| #include <string> | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| void PadAndShiftCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| node_ = kernel_node; | |||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| type_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); | |||||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| batch_size_ = 1; | |||||
| for (size_t i = 0; i < indices_shape.size(); ++i) { | |||||
| batch_size_ *= indices_shape[i]; | |||||
| } | |||||
| MS_LOG(INFO) << "PadAndShift batch_size:" << batch_size_; | |||||
| auto cum_sum_arr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| if (cum_sum_arr_shape.size() != 1) { | |||||
| MS_LOG(ERROR) << "The shape of cum_sum_arr must be 1."; | |||||
| } | |||||
| cum_sum_size_ = cum_sum_arr_shape[0]; | |||||
| } | |||||
| bool PadAndShiftCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (input_x_dtype_ == kNumberTypeInt32) { | |||||
| LaunchKernel<int>(inputs, outputs); | |||||
| } else if (input_x_dtype_ == kNumberTypeInt64) { | |||||
| LaunchKernel<int64_t>(inputs, outputs); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Dtype of input_x only support int32, int64"; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| void PadAndShiftCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| T *input_x = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| T *cum_sum_arr = reinterpret_cast<T *>(inputs[1]->addr); | |||||
| T shift_idx = *reinterpret_cast<T *>(inputs[2]->addr); | |||||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| if (shift_idx >= static_cast<T>(cum_sum_size_)) { | |||||
| MS_LOG(EXCEPTION) << "Shift index must small than cumsum size."; | |||||
| } | |||||
| size_t output_size = cum_sum_arr[cum_sum_size_ - 1]; | |||||
| T shift_size = cum_sum_arr[shift_idx]; | |||||
| T valid_size = cum_sum_arr[shift_idx + 1] - shift_size; | |||||
| int ret = memset_s(output, outputs[0]->size, -1, type_size_ * output_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memset_s error, errorno" << ret; | |||||
| } | |||||
| ret = memcpy_s(output + shift_size, valid_size * type_size_, input_x, valid_size * type_size_); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||||
| } | |||||
| std::vector<size_t> out_shape; | |||||
| out_shape.emplace_back(output_size); | |||||
| std::vector<TypeId> dtypes; | |||||
| auto output_nums = AnfAlgo::GetOutputTensorNum(node_); | |||||
| for (size_t i = 0; i < output_nums; i++) { | |||||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | |||||
| } | |||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape}, node_.get()); | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PAD_AND_SHIFT_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PAD_AND_SHIFT_CPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class PadAndShiftCPUKernel : public CPUKernel { | |||||
| public: | |||||
| PadAndShiftCPUKernel() = default; | |||||
| ~PadAndShiftCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| template <typename T> | |||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); | |||||
| private: | |||||
| size_t batch_size_{1}; | |||||
| size_t cum_sum_size_{1}; | |||||
| size_t type_size_{4}; | |||||
| TypeId input_x_dtype_{kTypeUnknown}; | |||||
| CNodePtr node_ = nullptr; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL(PadAndShift, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| PadAndShiftCPUKernel); | |||||
| MS_REG_CPU_KERNEL(PadAndShift, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddOutputAttr(kNumberTypeInt64), | |||||
| PadAndShiftCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PAD_AND_SHIFT_CPU_KERNEL_H_ | |||||
| @@ -263,6 +263,8 @@ constexpr auto kDropoutOpName = "Dropout"; | |||||
| constexpr auto kDropoutGradOpName = "DropoutGrad"; | constexpr auto kDropoutGradOpName = "DropoutGrad"; | ||||
| constexpr auto kDropoutGenMaskOpName = "DropoutGenMask"; | constexpr auto kDropoutGenMaskOpName = "DropoutGenMask"; | ||||
| constexpr auto kDropoutDoMaskOpName = "DropoutDoMask"; | constexpr auto kDropoutDoMaskOpName = "DropoutDoMask"; | ||||
| constexpr auto kSubAndFilterOpName = "SubAndFilter"; | |||||
| constexpr auto kPadAndShiftOpName = "PadAndShift"; | |||||
| constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits"; | constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits"; | ||||
| constexpr auto kOneHotOpName = "OneHot"; | constexpr auto kOneHotOpName = "OneHot"; | ||||
| constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; | constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; | ||||
| @@ -482,7 +484,8 @@ const std::set<std::string> kHWSpecialFormatSet = { | |||||
| const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | ||||
| const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName}; | |||||
| const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, | |||||
| kPadAndShiftOpName}; | |||||
| static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { | ||||
| try { | try { | ||||
| @@ -223,6 +223,10 @@ AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitiveP | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -193,6 +193,27 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| return std::make_shared<AbstractTuple>(elements); | return std::make_shared<AbstractTuple>(elements); | ||||
| } | } | ||||
| AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // inputs: a 1-d Tensor | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | |||||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| auto shape = input->shape(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| if (shape->shape().size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; | |||||
| } | |||||
| ShapeVector ids_shape = {Shape::SHP_ANY}; | |||||
| ShapeVector min_shape = {1}; | |||||
| ShapeVector max_shape = shape->max_shape(); | |||||
| if (max_shape.empty()) { | |||||
| max_shape = shape->shape(); | |||||
| } | |||||
| return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); | |||||
| } | |||||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // inputs: a 1-d Tensor | // inputs: a 1-d Tensor | ||||
| @@ -612,6 +633,29 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape)); | return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape)); | ||||
| } | } | ||||
| AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: a tensor | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||||
| MS_LOG(INFO) << "InferImplDynamicAssign " << args_spec_list[0]; | |||||
| auto type = args_spec_list[0]->BuildType(); | |||||
| if (type->type_id() == kObjectTypeRefKey) { | |||||
| return args_spec_list[1]->Broaden(); | |||||
| } else { | |||||
| auto x = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0); | |||||
| auto y = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(x); | |||||
| MS_EXCEPTION_IF_NULL(y); | |||||
| auto y_shape = y->shape(); | |||||
| MS_EXCEPTION_IF_NULL(y_shape); | |||||
| if (!y_shape->max_shape().empty()) { | |||||
| x->set_shape(y->shape()); | |||||
| } | |||||
| return args_spec_list[0]; | |||||
| } | |||||
| } | |||||
| AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| @@ -67,9 +67,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimSubAndFilter, {InferImplSubAndFilter, true}}, | {prim::kPrimSubAndFilter, {InferImplSubAndFilter, true}}, | ||||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | ||||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, | {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, | ||||
| {prim::kPrimDynamicAssign, {InferImplDynamicAssign, true}}, | |||||
| {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, | {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, | ||||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | ||||
| {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, | {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, | ||||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, | |||||
| {prim::kPrimDiv, {InferImplDiv, true}}, | {prim::kPrimDiv, {InferImplDiv, true}}, | ||||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | ||||
| {prim::kPrimShape, {InferImplShape, false}}, | {prim::kPrimShape, {InferImplShape, false}}, | ||||
| @@ -103,6 +103,8 @@ inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCac | |||||
| inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | ||||
| inline const PrimitivePtr kPrimComputeAccidentalHits = std::make_shared<Primitive>("ComputeAccidentalHits"); | inline const PrimitivePtr kPrimComputeAccidentalHits = std::make_shared<Primitive>("ComputeAccidentalHits"); | ||||
| inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | ||||
| inline const PrimitivePtr kPrimDynamicAssign = std::make_shared<Primitive>("DynamicAssign"); | |||||
| inline const PrimitivePtr kPrimPadAndShift = std::make_shared<Primitive>("PadAndShift"); | |||||
| inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice"); | inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice"); | ||||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | ||||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | ||||
| @@ -24,6 +24,8 @@ from .scatter import _scatter_aicpu | |||||
| from .identity import _identity_aicpu | from .identity import _identity_aicpu | ||||
| from .edit_distance import _edit_distance_aicpu | from .edit_distance import _edit_distance_aicpu | ||||
| from .unique_with_pad import _unique_with_pad_aicpu | from .unique_with_pad import _unique_with_pad_aicpu | ||||
| from .sub_and_filter import _sub_and_filter_aicpu | |||||
| from .pad_and_shift import _pad_and_shift_aicpu | |||||
| from .dropout_genmask import _dropout_genmask_aicpu | from .dropout_genmask import _dropout_genmask_aicpu | ||||
| from .get_next import _get_next_aicpu | from .get_next import _get_next_aicpu | ||||
| from .print_tensor import _print_aicpu | from .print_tensor import _print_aicpu | ||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """PadAndShift op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| pad_and_shift_op_info = AiCPURegOp("PadAndShift") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .input(1, "cum_sum_arr", "required") \ | |||||
| .input(2, "shift_idx", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(pad_and_shift_op_info) | |||||
| def _pad_and_shift_aicpu(): | |||||
| """PadAndShift AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SubAndFilter op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| sub_and_filter_op_info = AiCPURegOp("SubAndFilter") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .input(1, "max_num", "required") \ | |||||
| .input(2, "offset", "required") \ | |||||
| .output(0, "filter_res", "required") \ | |||||
| .output(1, "filter_idx", "required") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, | |||||
| DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, | |||||
| DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sub_and_filter_op_info) | |||||
| def _sub_and_filter_aicpu(): | |||||
| """SubAndFilter AiCPU register""" | |||||
| return | |||||
| @@ -92,8 +92,8 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg | |||||
| CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle, | CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle, | ||||
| ProdForceSeA) | ProdForceSeA) | ||||
| from .sparse_ops import SparseToDense | from .sparse_ops import SparseToDense | ||||
| from ._cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter, | |||||
| MapUniform) | |||||
| from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter, | |||||
| MapUniform, DynamicAssign, PadAndShift) | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Unique', | 'Unique', | ||||
| @@ -93,7 +93,7 @@ class SubAndFilter(PrimitiveWithCheck): | |||||
| outputs=['sub_res', 'sub_idx']) | outputs=['sub_res', 'sub_idx']) | ||||
| def check_shape(self, input_x_shape, max_num_shape, offset_shape): | def check_shape(self, input_x_shape, max_num_shape, offset_shape): | ||||
| return (-1, -1) | |||||
| return ((-1,), (-1,)) | |||||
| def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): | def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): | ||||
| validator.check_tensor_dtype_valid( | validator.check_tensor_dtype_valid( | ||||
| @@ -358,3 +358,77 @@ class MapCacheIdx(PrimitiveWithCheck): | |||||
| else: | else: | ||||
| out['min_shape'] = (0, 0, 0, 0) | out['min_shape'] = (0, 0, 0, 0) | ||||
| return out | return out | ||||
| class DynamicAssign(PrimitiveWithCheck): | |||||
| """ | |||||
| Assigns `Parameter` with a value, the `value` can have a dynamic shape. | |||||
| Inputs: | |||||
| - **variable** (Parameter) - The `Parameter`. | |||||
| - **value** (Tensor) - The value to be assigned. | |||||
| Outputs: | |||||
| Tensor, has the same type as original `variable`. | |||||
| Supported Platforms: | |||||
| `CPU` | |||||
| """ | |||||
| __mindspore_signature__ = ( | |||||
| sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('value', dtype=sig.sig_dtype.T) | |||||
| ) | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) | |||||
| def check_dtype(self, variable, value): | |||||
| if variable != mstype.type_refkey: | |||||
| validator.check_tensor_dtype_valid( | |||||
| "variable", variable, mstype.number_type, self.name) | |||||
| validator.check_scalar_or_tensor_types_same( | |||||
| {"value": value}, mstype.number_type, self.name) | |||||
| class PadAndShift(PrimitiveWithCheck): | |||||
| """ | |||||
| Pad a tensor with -1, and shift with a length. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input Tensor, which will be copyed | |||||
| to `output`. | |||||
| - **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is | |||||
| the pad length of output tensor, cum_sum_arr[shift_idx] is | |||||
| the start to shift, and cum_sum_arr[shift_idx+1] is the end. | |||||
| - **shift_idx** (Int) - The idx of cum_sum_arr. | |||||
| if use python, PadAndShift is: | |||||
| output = [-1] * cum_sum_arr[-1] | |||||
| start = cum_sum_arr[shift_idx] | |||||
| end = cum_sum_arr[shift_idx + 1] | |||||
| output[start:end] = input_x[:(end-start)] | |||||
| Outputs: | |||||
| Tensor, has the same type as original `variable`. | |||||
| Supported Platforms: | |||||
| `CPU` | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32) | |||||
| >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32) | |||||
| >>> shift_idx = 1 | |||||
| >>> pad_and_shift = ops.PadAndShift() | |||||
| >>> output = pad_and_shift(input_x, cum_sum_arr, shift_idx) | |||||
| >>> print(output) | |||||
| [-1, -1, -1, 9, 13] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names( | |||||
| inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output']) | |||||
| def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape): | |||||
| return input_x_shape | |||||
| def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype): | |||||
| return input_x_dtype | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, Parameter | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.unique = P.Unique() | |||||
| self.dynamic_assign = P.DynamicAssign() | |||||
| self.param = Parameter( | |||||
| Tensor(np.zeros((5,), np.int32)), name="assign_x") | |||||
| def construct(self, y): | |||||
| y, _ = self.unique(y) | |||||
| return self.dynamic_assign(self.param, y) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dynamic_assign(): | |||||
| y = Tensor(np.array([2, 2, 3, 3, 4]), mstype.int32) | |||||
| dynamic_assign = Net() | |||||
| _ = dynamic_assign(y) | |||||
| expect1 = np.array([2, 3, 4]) | |||||
| param_np = dynamic_assign.param.data.asnumpy() | |||||
| assert (param_np == expect1).all() | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.pad_and_shift = P.PadAndShift() | |||||
| self.shift_idx = 1 | |||||
| def construct(self, x, y): | |||||
| return self.pad_and_shift(x, y, self.shift_idx) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pad_and_shift_cpu(): | |||||
| x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32) | |||||
| y = Tensor(np.array([0, 3, 5]), mstype.int32) | |||||
| net = Net() | |||||
| output = net(x, y) | |||||
| expect = np.array([-1, -1, -1, 9, 13]) | |||||
| assert (output.asnumpy() == expect).all() | |||||