/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ #define MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ #include #include #include #include #include "ir/dtype.h" #include "kernel/kernel.h" namespace mindspore { namespace kernel { enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU }; enum OpIOType { kInput = 0, kOutput }; class OpAttr { public: OpAttr() = default; ~OpAttr() = default; std::string name() const { return name_; } std::string param_type() const { return param_type_; } std::string type() const { return type_; } std::string value() const { return value_; } std::string default_value() const { return default_value_; } void set_name(const std::string &name) { name_ = name; } void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } void set_type(const std::string &type) { type_ = type; } void set_value(const std::string &value) { value_ = value; } void set_default_value(const std::string &default_value) { default_value_ = default_value; } private: std::string name_; std::string param_type_; std::string type_; std::string value_; std::string default_value_; }; class OpIOInfo { public: OpIOInfo() = default; ~OpIOInfo() = default; int index() const { return index_; } std::string name() const { return name_; } bool need_compile() const { return need_compile_; } std::string param_type() const { return param_type_; } std::string reshape_type() const { return reshape_type_; } std::string shape() const { return shape_; } std::vector dtypes() const { return dtypes_; } std::vector formats() const { return formats_; } void set_index(const int index) { index_ = index; } void set_name(const std::string &name) { name_ = name; } void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } void set_shape(const std::string &shape) { shape_ = shape; } void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } void set_formats(const std::vector &formats) { formats_ = formats; } private: int index_ = 0; std::string name_; bool need_compile_ = false; std::string param_type_; std::string reshape_type_; std::string shape_; std::vector dtypes_; std::vector formats_; }; class OpInfo { public: OpInfo() = default; OpInfo(const OpInfo &opinfo) { op_name_ = opinfo.op_name(); imply_type_ = opinfo.imply_type(); impl_path_ = opinfo.impl_path(); fusion_type_ = opinfo.fusion_type(); async_flag_ = opinfo.async_flag_; binfile_name_ = opinfo.binfile_name_; compute_cost_ = opinfo.compute_cost_; kernel_name_ = opinfo.kernel_name(); partial_flag_ = opinfo.partial_flag_; dynamic_format_ = opinfo.dynamic_format_; op_pattern_ = opinfo.op_pattern(); for (auto attr : opinfo.attrs_ptr()) { attrs_ptr_.push_back(std::make_shared(*attr)); } for (auto input : opinfo.inputs_ptr()) { inputs_ptr_.push_back(std::make_shared(*input)); } for (auto output : opinfo.outputs_ptr()) { outputs_ptr_.push_back(std::make_shared(*output)); } ref_infos_ = opinfo.ref_infos(); } ~OpInfo() = default; std::string op_name() const { return op_name_; } OpImplyType imply_type() const { return imply_type_; } std::string impl_path() const { return impl_path_; } std::string fusion_type() const { return fusion_type_; } std::string kernel_name() const { return kernel_name_; } OpPattern op_pattern() const { return op_pattern_; } std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } const std::unordered_map &ref_infos() const { return ref_infos_; } void set_op_name(const std::string &op_name) { op_name_ = op_name; } void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } bool is_ref() const { return !ref_infos_.empty(); } bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } void ClearInputs() { (void)inputs_ptr_.clear(); } void ClearOutputs() { (void)outputs_ptr_.clear(); } private: std::string op_name_; OpImplyType imply_type_ = kTBE; std::string impl_path_; std::string fusion_type_; bool async_flag_ = false; std::string binfile_name_; int compute_cost_ = 0; std::string kernel_name_; bool partial_flag_ = false; bool dynamic_format_ = false; OpPattern op_pattern_ = kCommonPattern; std::vector> attrs_ptr_; std::vector> inputs_ptr_; std::vector> outputs_ptr_; std::unordered_map ref_infos_; }; } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_