Since akg supports both Ascend and Gpu, but their supported type and format are different, so we use two directory "ascend" and "gpu" to store their registers respectively, and use an attribute "processor" to distinguish them. Main changes: 1) Add two op register class "AkgAscendRegOp" and "AkgGpuRegOp", inherited from the original AkgRegOp. 2) Rewrite akg ascend op registers with new interface, move them into directory "ascend". 3) Rename the imply_type from "AutoDiff" to "AKG". 4) Modify function FindOp, check the processor when imply_type is "AKG". 5) Modify function CheckRepetition, remove the judgement for impl_path, check processor instead. TODO: Remove op registers in akg root path.tags/v0.6.0-beta
| @@ -103,6 +103,7 @@ class OpInfo { | |||||
| partial_flag_ = opinfo.partial_flag_; | partial_flag_ = opinfo.partial_flag_; | ||||
| dynamic_format_ = opinfo.dynamic_format_; | dynamic_format_ = opinfo.dynamic_format_; | ||||
| op_pattern_ = opinfo.op_pattern(); | op_pattern_ = opinfo.op_pattern(); | ||||
| processor_ = opinfo.processor_; | |||||
| for (const auto &attr : opinfo.attrs_ptr()) { | for (const auto &attr : opinfo.attrs_ptr()) { | ||||
| attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr)); | attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr)); | ||||
| } | } | ||||
| @@ -121,6 +122,7 @@ class OpInfo { | |||||
| std::string fusion_type() const { return fusion_type_; } | std::string fusion_type() const { return fusion_type_; } | ||||
| std::string kernel_name() const { return kernel_name_; } | std::string kernel_name() const { return kernel_name_; } | ||||
| OpPattern op_pattern() const { return op_pattern_; } | OpPattern op_pattern() const { return op_pattern_; } | ||||
| std::string processor() const { return processor_; } | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | ||||
| @@ -136,6 +138,7 @@ class OpInfo { | |||||
| void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } | void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } | ||||
| void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | ||||
| void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } | void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } | ||||
| void set_processor(const std::string &processor) { processor_ = processor; } | |||||
| void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | ||||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | ||||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | ||||
| @@ -144,6 +147,10 @@ class OpInfo { | |||||
| void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | ||||
| void ClearInputs() { (void)inputs_ptr_.clear(); } | void ClearInputs() { (void)inputs_ptr_.clear(); } | ||||
| void ClearOutputs() { (void)outputs_ptr_.clear(); } | void ClearOutputs() { (void)outputs_ptr_.clear(); } | ||||
| bool equals_to(const std::shared_ptr<OpInfo> &other_info) const { | |||||
| return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && | |||||
| this->processor_ == other_info->processor_; | |||||
| } | |||||
| private: | private: | ||||
| std::string op_name_; | std::string op_name_; | ||||
| @@ -157,6 +164,7 @@ class OpInfo { | |||||
| bool partial_flag_ = false; | bool partial_flag_ = false; | ||||
| bool dynamic_format_ = false; | bool dynamic_format_ = false; | ||||
| OpPattern op_pattern_ = kCommonPattern; | OpPattern op_pattern_ = kCommonPattern; | ||||
| std::string processor_; | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; | ||||
| @@ -45,9 +45,10 @@ constexpr auto kAttr = "attr"; | |||||
| constexpr auto kIputs = "inputs"; | constexpr auto kIputs = "inputs"; | ||||
| constexpr auto kOutputs = "outputs"; | constexpr auto kOutputs = "outputs"; | ||||
| constexpr auto kAiCPU = "AiCPU"; | constexpr auto kAiCPU = "AiCPU"; | ||||
| constexpr auto kAiCore = "AiCore"; | |||||
| constexpr auto kCUDA = "CUDA"; | |||||
| constexpr auto kTbe = "TBE"; | constexpr auto kTbe = "TBE"; | ||||
| constexpr auto kAkg = "akg"; | |||||
| constexpr auto kAutodiff = "AutoDiff"; | |||||
| constexpr auto kAkg = "AKG"; | |||||
| constexpr auto kName = "name"; | constexpr auto kName = "name"; | ||||
| constexpr auto kParamType = "param_type"; | constexpr auto kParamType = "param_type"; | ||||
| constexpr auto kDtype = "dtype"; | constexpr auto kDtype = "dtype"; | ||||
| @@ -58,6 +59,7 @@ constexpr auto kIndex = "index"; | |||||
| constexpr auto kFormat = "format"; | constexpr auto kFormat = "format"; | ||||
| constexpr auto kNeedCompile = "need_compile"; | constexpr auto kNeedCompile = "need_compile"; | ||||
| constexpr auto kShape = "shape"; | constexpr auto kShape = "shape"; | ||||
| constexpr auto kProcessor = "processor"; | |||||
| std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_; | std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_; | ||||
| static std::string ImplTypeToStr(OpImplyType impl_type) { | static std::string ImplTypeToStr(OpImplyType impl_type) { | ||||
| @@ -81,7 +83,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) | |||||
| if (imply_type_string == kTbe) { | if (imply_type_string == kTbe) { | ||||
| OpImplyType imply_type = kTBE; | OpImplyType imply_type = kTBE; | ||||
| ret = DecodeOpInfo(op_json, imply_type, impl_path); | ret = DecodeOpInfo(op_json, imply_type, impl_path); | ||||
| } else if (imply_type_string == kAutodiff) { | |||||
| } else if (imply_type_string == kAkg) { | |||||
| OpImplyType imply_type = kAKG; | OpImplyType imply_type = kAKG; | ||||
| ret = DecodeOpInfo(op_json, imply_type, impl_path); | ret = DecodeOpInfo(op_json, imply_type, impl_path); | ||||
| } else if (imply_type_string == kAiCPU) { | } else if (imply_type_string == kAiCPU) { | ||||
| @@ -125,6 +127,11 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p | |||||
| } | } | ||||
| } | } | ||||
| void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | |||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| op_info->set_processor(obj.at(kProcessor)); | |||||
| } | |||||
| bool OpLib::RegOpFromLocalInfo() { | bool OpLib::RegOpFromLocalInfo() { | ||||
| MS_LOG(INFO) << "Start"; | MS_LOG(INFO) << "Start"; | ||||
| static bool has_load = false; | static bool has_load = false; | ||||
| @@ -179,6 +186,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI | |||||
| op_info->set_fusion_type(obj.at(kFusionType)); | op_info->set_fusion_type(obj.at(kFusionType)); | ||||
| if (imply_type == kTBE) { | if (imply_type == kTBE) { | ||||
| DecodeTBESpecificInfo(obj, op_info); | DecodeTBESpecificInfo(obj, op_info); | ||||
| } else if (imply_type == kAKG) { | |||||
| DecodeAKGSpecificInfo(obj, op_info); | |||||
| } | } | ||||
| auto attrs = obj.at(kAttr); | auto attrs = obj.at(kAttr); | ||||
| for (const auto &attr : attrs) { | for (const auto &attr : attrs) { | ||||
| @@ -330,7 +339,12 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im | |||||
| for (const auto &op_info : op_info_) { | for (const auto &op_info : op_info_) { | ||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { | if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { | ||||
| return op_info; | |||||
| auto akg_processor_match = [&]() { | |||||
| return is_gpu ? op_info->processor() == kCUDA : op_info->processor() == kAiCore; | |||||
| }; | |||||
| if (imply_type != kAKG || akg_processor_match()) { | |||||
| return op_info; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) | MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) | ||||
| @@ -363,19 +377,14 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) { | |||||
| } | } | ||||
| bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) { | bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) { | ||||
| bool has_register = false; | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| for (const auto &exist_op_info : op_info_) { | for (const auto &exist_op_info : op_info_) { | ||||
| MS_EXCEPTION_IF_NULL(exist_op_info); | MS_EXCEPTION_IF_NULL(exist_op_info); | ||||
| if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && | |||||
| exist_op_info->impl_path() == op_info->impl_path()) { | |||||
| MS_LOG(INFO) << "Op has already exist, please use other name, op name: " << op_info->op_name() | |||||
| << " op type: " << ImplTypeToStr(op_info->imply_type()); | |||||
| has_register = true; | |||||
| break; | |||||
| if (exist_op_info->equals_to(op_info)) { | |||||
| return true; | |||||
| } | } | ||||
| } | } | ||||
| return has_register; | |||||
| return false; | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,7 @@ class OpLib { | |||||
| static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io, | static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io, | ||||
| size_t index); | size_t index); | ||||
| static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info); | static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info); | ||||
| static void DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info); | |||||
| static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, | static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, | ||||
| const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format); | const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format); | ||||
| static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info); | static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info); | ||||
| @@ -32,7 +32,7 @@ Note: | |||||
| from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | ||||
| from .op_info_register import op_info_register, AkgRegOp, AiCPURegOp, TBERegOp, DataType | |||||
| from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType | |||||
| from .primitive import constexpr | from .primitive import constexpr | ||||
| from .._c_expression import signature_rw, signature_kind | from .._c_expression import signature_rw, signature_kind | ||||
| @@ -42,6 +42,6 @@ __primitive__ = [ | |||||
| ] | ] | ||||
| __all__ = ["get_vm_impl_fn", "vm_impl_registry", | __all__ = ["get_vm_impl_fn", "vm_impl_registry", | ||||
| "op_info_register", "AkgRegOp", "AiCPURegOp", "TBERegOp", "DataType", | |||||
| "op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType", | |||||
| "constexpr"] | "constexpr"] | ||||
| __all__.extend(__primitive__) | __all__.extend(__primitive__) | ||||
| @@ -17,7 +17,7 @@ | |||||
| import platform | import platform | ||||
| from .aicpu import * | from .aicpu import * | ||||
| if "Windows" not in platform.system(): | if "Windows" not in platform.system(): | ||||
| from .akg.gpu import * | |||||
| from .akg import * | |||||
| from .tbe import * | from .tbe import * | ||||
| __all__ = [] | __all__ = [] | ||||
| @@ -13,77 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """autodiff ops""" | |||||
| from .abs import _abs_akg | |||||
| from .add_n import _add_n_akg | |||||
| from .add import _add_akg | |||||
| from .apply_momentum import _apply_momentum_akg | |||||
| from .assign import _assign_akg | |||||
| from .inplace_assign import _inplace_assign_akg | |||||
| from .assign_add import _assign_add_akg | |||||
| from .bias_add_grad import _bias_add_grad_akg | |||||
| from .bias_add import _bias_add_akg | |||||
| from .cast import _cast_akg | |||||
| from .clear_zero import _clear_zero_akg | |||||
| from .conv_bn1 import _conv_bn1_akg | |||||
| from .conv2d_backprop_filter import _conv2d_backprop_filter_akg | |||||
| from .conv2d_backprop_input import _conv2d_backprop_input_akg | |||||
| from .conv2d import _conv2d_akg | |||||
| from .div import _div_akg | |||||
| from .equal_count import _equal_count_akg | |||||
| from .exp import _exp_akg | |||||
| from .five2four import _five2four_akg | |||||
| from .four2five import _four2five_akg | |||||
| from .fused_batch_norm_grad import _fused_batch_norm_grad_akg | |||||
| from .fused_batch_norm_infer import _fused_batch_norm_infer_akg | |||||
| from .fused_batch_norm import _fused_batch_norm_akg | |||||
| from .fused_bn1_grad import _bn1_grad_akg | |||||
| from .fused_bn1 import _fused_bn1_akg | |||||
| from .fused_bn2_grad import _bn2_grad_akg | |||||
| from .fused_bn2 import _fused_bn2_akg | |||||
| from .fused_bn3_grad import _bn3_grad_akg | |||||
| from .fused_bn3 import _fused_bn3_akg | |||||
| from .gather_v2 import _gather_v2_akg | |||||
| from .less import _less_akg | |||||
| from .log import _log_akg | |||||
| from .matmul import _matmul_akg | |||||
| from .batchmatmul import _batchmatmul_akg | |||||
| from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg | |||||
| from .max_pool_with_argmax import _max_pool_with_argmax_akg | |||||
| from .max import _max_akg | |||||
| from .maximum import _maximum_akg | |||||
| from .mean_grad import _mean_grad_akg | |||||
| from .mean import _mean_akg | |||||
| from .minimum import _minimum_akg | |||||
| from .mul import _mul_akg | |||||
| from .neg import _neg_akg | |||||
| from .one_hot import _one_hot_akg | |||||
| from .pow import _power_akg | |||||
| from .real_div import _real_div_akg | |||||
| from .reciprocal import _reciprocal_akg | |||||
| from .reduce_max import _reduce_max_akg | |||||
| from .reduce_mean import _reduce_mean_akg | |||||
| from .reduce_sum import _reduce_sum_akg | |||||
| from .relu_grad import _relu_grad_akg | |||||
| from .relu import _relu_akg | |||||
| from .reshape import _reshape_akg | |||||
| from .round import _round_akg | |||||
| from .rsqrt import _rsqrt_akg | |||||
| from .select import _select_akg | |||||
| from .softmax import _softmax_akg | |||||
| from .sparse_softmax_cross_entropy_with_logits import _sparse_softmax_cross_entropy_with_logits_akg | |||||
| from .sqrt import _sqrt_akg | |||||
| from .strided_slice import _strided_slice_akg | |||||
| from .sub import _sub_akg | |||||
| from .sum import _sum_akg | |||||
| from .tile import _tile_akg | |||||
| from .zeros_like import _zeros_like_akg | |||||
| from .argmax import _argmax_akg | |||||
| from .floordiv import _floor_div_akg | |||||
| from .equal import _equal_akg | |||||
| from .greater_equal import _greater_equal_akg | |||||
| from .less_equal import _less_equal_akg | |||||
| from .expand_dims import _expand_dims_akg | |||||
| from .greater import _greater_akg | |||||
| from .equiv_format import _equiv_format_akg | |||||
| """akg ops""" | |||||
| from . import ascend | |||||
| from . import gpu | from . import gpu | ||||
| @@ -0,0 +1,30 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """__init__""" | |||||
| from .add import _add_akg | |||||
| from .batchmatmul import _batchmatmul_akg | |||||
| from .cast import _cast_akg | |||||
| from .expand_dims import _expand_dims_akg | |||||
| from .greater import _greater_akg | |||||
| from .inplace_assign import _inplace_assign_akg | |||||
| from .maximum import _maximum_akg | |||||
| from .minimum import _minimum_akg | |||||
| from .mul import _mul_akg | |||||
| from .real_div import _real_div_akg | |||||
| from .rsqrt import _rsqrt_akg | |||||
| from .select import _select_akg | |||||
| from .sqrt import _sqrt_akg | |||||
| from .sub import _sub_akg | |||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """TensorAdd op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("TensorAdd") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||||
| .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ | |||||
| .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ | |||||
| .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ | |||||
| .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ | |||||
| .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ | |||||
| .dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _add_akg(): | |||||
| """TensorAdd Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BatchMatMul op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("BatchMatMul") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x1") \ | |||||
| .input(1, "x2") \ | |||||
| .output(0, "output") \ | |||||
| .attr("transpose_a", "optional", "bool") \ | |||||
| .attr("transpose_b", "optional", "bool") \ | |||||
| .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _batchmatmul_akg(): | |||||
| """BatchMatMul AKG register""" | |||||
| return | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Cast op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Cast") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x") \ | |||||
| .output(0, "output") \ | |||||
| .attr("dst_type", "required", "str") \ | |||||
| .dtype_format(DT.F16_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.F16_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.BOOL_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.BOOL_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.BOOL_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.BOOL_5HD, DT.I32_5HD) \ | |||||
| .dtype_format(DT.BOOL_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.F16_FracNZ, DT.F32_FracNZ) \ | |||||
| .dtype_format(DT.F32_FracNZ, DT.F16_FracNZ) \ | |||||
| .dtype_format(DT.BOOL_FracNZ, DT.I32_FracNZ) \ | |||||
| .dtype_format(DT.BOOL_FracNZ, DT.F32_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _cast_akg(): | |||||
| """Cast Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ExpandDims op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("ExpandDims") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x") \ | |||||
| .output(0, "y") \ | |||||
| .attr("axis", "required", "int") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _expand_dims_akg(): | |||||
| """ExpandDims Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,34 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Greater op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Greater") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.BOOL_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.BOOL_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.BOOL_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.BOOL_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _greater_akg(): | |||||
| """Greater Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """InplaceAssign op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("InplaceAssign") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .input(2, "z") \ | |||||
| .output(0, "output") \ | |||||
| .attr("fake_output", "optional", "bool") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||||
| .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ | |||||
| .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ | |||||
| .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _inplace_assign_akg(): | |||||
| """InplaceAssign Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Maximum op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Maximum") \ | |||||
| .fusion_type("COMMREDUCE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _maximum_akg(): | |||||
| """Maximum Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Minimum op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Minimum") \ | |||||
| .fusion_type("COMMREDUCE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||||
| .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ | |||||
| .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ | |||||
| .dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _minimum_akg(): | |||||
| """Minimum Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Mul op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Mul") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .output(0, "output") \ | |||||
| .attr("x_shape", "required", "listInt") \ | |||||
| .attr("y_shape", "required", "listInt") \ | |||||
| .attr("data_format", "required", "listStr") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ | |||||
| .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ | |||||
| .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ | |||||
| .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _mul_akg(): | |||||
| """Mul Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """RealDiv op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("RealDiv") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ | |||||
| .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _real_div_akg(): | |||||
| """RealDiv Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Rsqrt op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Rsqrt") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _rsqrt_akg(): | |||||
| """Rsqrt Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Select op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Select") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "condition") \ | |||||
| .input(1, "x") \ | |||||
| .input(2, "y") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.BOOL_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.BOOL_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.BOOL_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.BOOL_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.BOOL_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.BOOL_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _select_akg(): | |||||
| """Select Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Sqrt op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Sqrt") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _sqrt_akg(): | |||||
| """Sqrt Akg register""" | |||||
| return | |||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Sub op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT | |||||
| op_info = AkgAscendRegOp("Sub") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .input(0, "x") \ | |||||
| .input(1, "y") \ | |||||
| .output(0, "output") \ | |||||
| .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ | |||||
| .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ | |||||
| .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ | |||||
| .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ | |||||
| .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ | |||||
| .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ | |||||
| .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ | |||||
| .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ | |||||
| .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ | |||||
| .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ | |||||
| .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ | |||||
| .dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(op_info) | |||||
| def _sub_akg(): | |||||
| """Sub Akg register""" | |||||
| return | |||||
| @@ -13,15 +13,16 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Cast op""" | """Cast op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| cast_op_info = AkgRegOp("Cast") \ | |||||
| cast_op_info = AkgGpuRegOp("Cast") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| .attr("dst_type", "required", "str") \ | .attr("dst_type", "required", "str") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | .dtype_format(DataType.F16_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | .dtype_format(DataType.F32_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.F32_Default) \ | .dtype_format(DataType.I32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Equal op""" | """Equal op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| equal_op_info = AkgRegOp("Equal") \ | |||||
| equal_op_info = AkgGpuRegOp("Equal") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """GreaterEqual op""" | """GreaterEqual op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| greater_equal_op_info = AkgRegOp("GreaterEqual") \ | |||||
| greater_equal_op_info = AkgGpuRegOp("GreaterEqual") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """HSigmoid op""" | """HSigmoid op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| hsigmoid_op_info = AkgRegOp("HSigmoid") \ | |||||
| hsigmoid_op_info = AkgGpuRegOp("HSigmoid") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """HSigmoidGrad op""" | """HSigmoidGrad op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| hsigmoidgrad_op_info = AkgRegOp("HSigmoidGrad") \ | |||||
| hsigmoidgrad_op_info = AkgGpuRegOp("HSigmoidGrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "y_grad") \ | .input(0, "y_grad") \ | ||||
| .input(1, "x") \ | .input(1, "x") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """HSwish op""" | """HSwish op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| hswish_op_info = AkgRegOp("HSwish") \ | |||||
| hswish_op_info = AkgGpuRegOp("HSwish") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """HSwishGrad op""" | """HSwishGrad op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| hswish_grad_op_info = AkgRegOp("HSwishGrad") \ | |||||
| hswish_grad_op_info = AkgGpuRegOp("HSwishGrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "y_grad") \ | .input(0, "y_grad") \ | ||||
| .input(1, "x") \ | .input(1, "x") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """LessEqual op""" | """LessEqual op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| lessequal_op_info = AkgRegOp("LessEqual") \ | |||||
| lessequal_op_info = AkgGpuRegOp("LessEqual") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """LogicalAnd op""" | """LogicalAnd op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| logicaland_op_info = AkgRegOp("LogicalAnd") \ | |||||
| logicaland_op_info = AkgGpuRegOp("LogicalAnd") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -23,6 +23,7 @@ logicaland_op_info = AkgRegOp("LogicalAnd") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(logicaland_op_info) | @op_info_register(logicaland_op_info) | ||||
| def _logical_and_akg(): | def _logical_and_akg(): | ||||
| """LogicalAnd register""" | """LogicalAnd register""" | ||||
| @@ -13,15 +13,16 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """LogicalNot op""" | """LogicalNot op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| logical_not_op_info = AkgRegOp("LogicalNot") \ | |||||
| logical_not_op_info = AkgGpuRegOp("LogicalNot") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(logical_not_op_info) | @op_info_register(logical_not_op_info) | ||||
| def _logical_not_akg(): | def _logical_not_akg(): | ||||
| """LogicalNot AutoDiff register""" | """LogicalNot AutoDiff register""" | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """LogicalOr op""" | """LogicalOr op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| logicalor_op_info = AkgRegOp("LogicalOr") \ | |||||
| logicalor_op_info = AkgGpuRegOp("LogicalOr") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -23,6 +23,7 @@ logicalor_op_info = AkgRegOp("LogicalOr") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(logicalor_op_info) | @op_info_register(logicalor_op_info) | ||||
| def _logical_or_akg(): | def _logical_or_akg(): | ||||
| """LogicalOr register""" | """LogicalOr register""" | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """SimpleMean op""" | """SimpleMean op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| mean_op_info = AkgRegOp("SimpleMean") \ | |||||
| mean_op_info = AkgGpuRegOp("SimpleMean") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """SimpleMeanGrad op""" | """SimpleMeanGrad op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| mean_grad_op_info = AkgRegOp("SimpleMeanGrad") \ | |||||
| mean_grad_op_info = AkgGpuRegOp("SimpleMeanGrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "HEAD") \ | .input(0, "HEAD") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Mul op""" | """Mul op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| mul_op_info = AkgRegOp("Mul") \ | |||||
| mul_op_info = AkgGpuRegOp("Mul") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """NotEqual op""" | """NotEqual op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| notequal_op_info = AkgRegOp("NotEqual") \ | |||||
| notequal_op_info = AkgGpuRegOp("NotEqual") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """ReLU6 op""" | """ReLU6 op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| relu_op_info = AkgRegOp("ReLU6") \ | |||||
| relu_op_info = AkgGpuRegOp("ReLU6") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """ReLU6Grad op""" | """ReLU6Grad op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| relu_grad_op_info = AkgRegOp("ReLU6Grad") \ | |||||
| relu_grad_op_info = AkgGpuRegOp("ReLU6Grad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "y_grad") \ | .input(0, "y_grad") \ | ||||
| .input(1, "x") \ | .input(1, "x") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Squeeze op""" | """Squeeze op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| squeeze_op_info = AkgRegOp("Squeeze") \ | |||||
| squeeze_op_info = AkgGpuRegOp("Squeeze") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """SqueezeGrad op""" | """SqueezeGrad op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| squeeze_grad_op_info = AkgRegOp("SqueezeGrad") \ | |||||
| squeeze_grad_op_info = AkgGpuRegOp("SqueezeGrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "y_grad") \ | .input(0, "y_grad") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Sub op""" | """Sub op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| sub_op_info = AkgRegOp("Sub") \ | |||||
| sub_op_info = AkgGpuRegOp("Sub") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .input(1, "y") \ | .input(1, "y") \ | ||||
| @@ -25,6 +25,7 @@ sub_op_info = AkgRegOp("Sub") \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(sub_op_info) | @op_info_register(sub_op_info) | ||||
| def _sub_akg(): | def _sub_akg(): | ||||
| """Sub AutoDiff register""" | """Sub AutoDiff register""" | ||||
| @@ -13,9 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Tile op""" | """Tile op""" | ||||
| from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType | |||||
| from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType | |||||
| tile_op_info = AkgRegOp("Tile") \ | |||||
| tile_op_info = AkgGpuRegOp("Tile") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .input(0, "x") \ | .input(0, "x") \ | ||||
| .output(0, "output") \ | .output(0, "output") \ | ||||
| @@ -215,10 +215,10 @@ class RegOp: | |||||
| class AkgRegOp(RegOp): | class AkgRegOp(RegOp): | ||||
| """Class for Akg op info register.""" | """Class for Akg op info register.""" | ||||
| def __init__(self, op_name): | |||||
| def __init__(self, op_name, processor): | |||||
| super(AkgRegOp, self).__init__(op_name) | super(AkgRegOp, self).__init__(op_name) | ||||
| self.imply_type = "AutoDiff" | |||||
| self.processor = "cuda" | |||||
| self.imply_type = "AKG" | |||||
| self.processor = processor | |||||
| def input(self, index=None, name=None, **kwargs): | def input(self, index=None, name=None, **kwargs): | ||||
| """ | """ | ||||
| @@ -270,6 +270,16 @@ class AkgRegOp(RegOp): | |||||
| return self | return self | ||||
| class AkgGpuRegOp(AkgRegOp): | |||||
| def __init__(self, op_name): | |||||
| super(AkgGpuRegOp, self).__init__(op_name, "CUDA") | |||||
| class AkgAscendRegOp(AkgRegOp): | |||||
| def __init__(self, op_name): | |||||
| super(AkgAscendRegOp, self).__init__(op_name, "AiCore") | |||||
| class AiCPURegOp(RegOp): | class AiCPURegOp(RegOp): | ||||
| """Class for AiCPU op info register""" | """Class for AiCPU op info register""" | ||||