| @@ -643,29 +643,33 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { | |||
| expressions_map_ = GetExpressions(); | |||
| for (auto node : func_graph->GetOrderedCnodes()) { | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph); | |||
| bool find_pattern = true; | |||
| bool change_anf_graph = false; | |||
| while (find_pattern) { | |||
| find_pattern = false; | |||
| find_pattern = DoArithmeticTrans(lg) || find_pattern; | |||
| find_pattern = DoConstantFold(lg) || find_pattern; | |||
| change_anf_graph = change_anf_graph || find_pattern; | |||
| try { | |||
| auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); | |||
| graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph); | |||
| bool find_pattern = true; | |||
| bool change_anf_graph = false; | |||
| while (find_pattern) { | |||
| find_pattern = false; | |||
| find_pattern = DoArithmeticTrans(lg) || find_pattern; | |||
| find_pattern = DoConstantFold(lg) || find_pattern; | |||
| change_anf_graph = change_anf_graph || find_pattern; | |||
| } | |||
| if (!change_anf_graph) continue; | |||
| ReorganizeEmptyGraph(lg); | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| EliminateRedundantParameters(new_funcgraph, &inputs); | |||
| auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| do_simplify = true; | |||
| } catch (const graphkernel::GKException &e) { | |||
| MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph"; | |||
| } | |||
| if (!change_anf_graph) continue; | |||
| ReorganizeEmptyGraph(lg); | |||
| AnfNodePtrList outputs; | |||
| auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); | |||
| new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| EliminateRedundantParameters(new_funcgraph, &inputs); | |||
| auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); | |||
| SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); | |||
| mng->Replace(node, new_node); | |||
| mng->AddFuncGraph(new_funcgraph); | |||
| do_simplify = true; | |||
| } | |||
| } | |||
| return do_simplify; | |||
| @@ -27,6 +27,7 @@ | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| #include "backend/optimizer/graph_kernel/model/op_node.h" | |||
| #include "backend/optimizer/graph_kernel/model/op_register.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -107,36 +108,15 @@ NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList & | |||
| NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, | |||
| const DAttrs &attrs, std::string node_name) { | |||
| auto op_ptr = Emit(op, inputs, attrs, node_name); | |||
| PrimOpPtr op_ptr = CreateOp(op, node_name); | |||
| op_ptr->SetInputs(inputs); | |||
| op_ptr->SetAttrs(attrs); | |||
| op_ptr->SetBaseInfo(baseinfo); | |||
| return op_ptr; | |||
| return graph_->Add(op_ptr); | |||
| } | |||
| PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) { | |||
| static std::map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators; | |||
| if (creators.empty()) { | |||
| creators = {{"Add", Elemwise}, | |||
| {"Sub", Elemwise}, | |||
| {"RealDiv", Elemwise}, | |||
| {"Mul", Elemwise}, | |||
| {"Log", Elemwise}, | |||
| {"Exp", Elemwise}, | |||
| {"Pow", Elemwise}, | |||
| {"Sqrt", Elemwise}, | |||
| {"Rsqrt", Elemwise}, | |||
| {"Neg", Elemwise}, | |||
| {"Reciprocal", Elemwise}, | |||
| {"Abs", Elemwise}, | |||
| {"BroadcastTo", BroadcastTo}, | |||
| {"Reshape", Reshape}, | |||
| {"ReduceSum", Reduce}, | |||
| {"ReduceMax", Reduce}, | |||
| {"ReduceMin", Reduce}, | |||
| {"Conv2D", Conv2d}}; | |||
| } | |||
| auto iter = creators.find(op); | |||
| auto creator = (iter == creators.end() ? Opaque : iter->second); | |||
| return creator(op, node_name); | |||
| return OpRegistry::Instance().NewOp(op, node_name); | |||
| } | |||
| } // namespace graphkernel | |||
| } // namespace opt | |||
| @@ -81,28 +81,6 @@ class LiteGraph::GraphBuilder { | |||
| LiteGraphPtr Get() { return graph_; } | |||
| private: | |||
| static PrimOpPtr Elemwise(const std::string &op, const std::string &name) { | |||
| return std::make_shared<ElemwiseOp>(op, name); | |||
| } | |||
| static PrimOpPtr BroadcastTo(const std::string &op, const std::string &name) { | |||
| return std::make_shared<BroadcastToOp>(op, name); | |||
| } | |||
| static PrimOpPtr Reshape(const std::string &op, const std::string &name) { | |||
| return std::make_shared<ReshapeOp>(op, name); | |||
| } | |||
| static PrimOpPtr Reduce(const std::string &op, const std::string &name) { | |||
| return std::make_shared<ReduceOp>(op, name); | |||
| } | |||
| static PrimOpPtr Opaque(const std::string &op, const std::string &name) { | |||
| return std::make_shared<OpaqueOp>(op, name); | |||
| } | |||
| static PrimOpPtr Conv2d(const std::string &op, const std::string &name) { | |||
| return std::make_shared<Conv2dOp>(op, name); | |||
| } | |||
| PrimOpPtr CreateOp(const std::string &id, const std::string &name); | |||
| std::string NewName(std::string prefix = "output_") { return prefix + std::to_string(graph_->name_id_++); } | |||
| @@ -26,6 +26,7 @@ | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <stdexcept> | |||
| #include "mindspore/core/ir/dtype/type_id.h" | |||
| #include "mindspore/core/ir/value.h" | |||
| @@ -85,6 +86,8 @@ class Node : public NodeBase { | |||
| void SetInput(size_t i, const NodePtr &new_input); | |||
| void SetInputs(const NodePtrList &inputs); | |||
| void ReplaceWith(const NodePtr &other_node); | |||
| void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; } | |||
| void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } | |||
| template <typename T> | |||
| T *As() { | |||
| @@ -146,6 +149,15 @@ class OutputNode : public Node { | |||
| void Dump(std::ostringstream &os) const override { ; } | |||
| NType NodeType() override { return NType::Output; } | |||
| }; | |||
| class GKException : public std::exception { | |||
| public: | |||
| explicit GKException(const std::string &message) : msg_(message) {} | |||
| const char *what() const noexcept override { return msg_.c_str(); } | |||
| protected: | |||
| std::string msg_; | |||
| }; | |||
| } // namespace graphkernel | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -49,7 +49,40 @@ std::vector<int64_t> GetListInt(const ValuePtr &attr_value) { | |||
| return list_int; | |||
| } | |||
| void PrimOp::Check(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| CheckShape(inputs, attrs); | |||
| CheckType(inputs, attrs); | |||
| CheckFormat(inputs, attrs); | |||
| } | |||
| // check all type to be identical | |||
| void PrimOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| TypeId tid = inputs[0]->type; | |||
| for (size_t i = 1; i < inputs.size(); i++) { | |||
| if (inputs[i]->type != tid) { | |||
| MS_LOG(EXCEPTION) << "Incompatible dtype between input " << 0 << "and" << i; | |||
| } | |||
| } | |||
| } | |||
| // check all formats are compatible, only DefaultForant is compatible with others | |||
| void PrimOp::CheckFormat(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| DFormat res = inputs[0]->format; | |||
| size_t i = 0; | |||
| for (size_t j = 1; j < inputs.size(); j++) { | |||
| if (inputs[j]->format != res) { | |||
| if (inputs[j]->format != kOpFormat_DEFAULT && res != kOpFormat_DEFAULT) { | |||
| MS_LOG(EXCEPTION) << "Incompatible format between input " << i << "and" << (j + 1); | |||
| } | |||
| if (res == kOpFormat_DEFAULT) { | |||
| res = inputs[j]->format; | |||
| i = j + 1; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| Check(inputs, attrs); | |||
| this->shape = InferShape(inputs, attrs); | |||
| this->type = InferType(inputs, attrs); | |||
| this->format = InferFormat(inputs, attrs); | |||
| @@ -164,6 +197,88 @@ NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const | |||
| return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res); | |||
| } | |||
| // default format shape to fractal_Nz format shape | |||
| DShape ToNz(const DShape &default_shape) { | |||
| if (default_shape.size() != 1 && default_shape.size() != 2) { | |||
| throw GKException("shape is too long"); | |||
| } | |||
| DShape output_shape; | |||
| if (default_shape.size() == 1 || (default_shape.size() == 2 && default_shape[0] == 1)) { | |||
| output_shape = {default_shape[default_shape.size() - 1] / 16, 1, 1, 16}; | |||
| if (default_shape[default_shape.size() - 1] % 16 != 0) { | |||
| throw GKException("should be multiplies of 16"); | |||
| } | |||
| } else if (default_shape.size() == 2 || default_shape[1] == 1) { | |||
| output_shape = {1, default_shape[0] / 16, 16, 1}; | |||
| if (default_shape[0] % 16 != 0) { | |||
| throw GKException("should be multiplies of 16"); | |||
| } | |||
| } else { | |||
| output_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16}; | |||
| if (default_shape[0] % 16 != 0 || default_shape[1] % 16 != 0) { | |||
| throw GKException("should be multiplies of 16"); | |||
| } | |||
| } | |||
| return output_shape; | |||
| } | |||
| DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) { | |||
| std::vector<std::vector<int64_t>> shapes; | |||
| for (auto &input : inputs) { | |||
| if (to_nz && input->format != kOpFormat_FRAC_NZ) { | |||
| shapes.emplace_back(ToNz(input->shape)); | |||
| } else { | |||
| shapes.emplace_back(input->shape); | |||
| } | |||
| } | |||
| auto max_dim_input = | |||
| std::max_element(shapes.begin(), shapes.end(), | |||
| [](const std::vector<int64_t> &a, const std::vector<int64_t> &b) { return a.size() < b.size(); }); | |||
| auto max_dim = max_dim_input->size(); | |||
| std::vector<std::vector<int64_t>> align_shapes; | |||
| for (auto &s : shapes) { | |||
| std::vector<int64_t> cur(max_dim - s.size(), 1); | |||
| cur.insert(cur.end(), s.begin(), s.end()); | |||
| align_shapes.emplace_back(cur); | |||
| } | |||
| std::vector<int64_t> output_shape(max_dim, 1); | |||
| for (size_t i = 0; i < max_dim; i++) { | |||
| for (auto &align_shape : align_shapes) { | |||
| if (align_shape[i] > 1) { | |||
| if (output_shape[i] == 1) { | |||
| output_shape[i] = align_shape[i]; | |||
| } | |||
| if (output_shape[i] != align_shape[i]) { | |||
| throw GKException("shape broadcast failed"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return output_shape; | |||
| } | |||
| DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) { | |||
| return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC || input->format == kOpFormat_NCHW; | |||
| })) { | |||
| return BroadcastShape(inputs, false); | |||
| } | |||
| if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) { | |||
| return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC || | |||
| input->format == kOpFormat_NCHW || input->format == kOpFormat_FRAC_NZ; | |||
| })) { | |||
| return BroadcastShape(inputs, true); | |||
| } | |||
| throw GKException("Only support default and fractal_nz"); | |||
| } | |||
| DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| auto it = std::find_if(inputs.begin(), inputs.end(), [](const NodePtr &i) { return i->format != kOpFormat_DEFAULT; }); | |||
| return it == inputs.end() ? kOpFormat_DEFAULT : (*it)->format; | |||
| } | |||
| void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| PrimOp::Infer(inputs, attrs); | |||
| auto IsBroadcast = [this](const NodePtrList &inputs) -> bool { | |||
| @@ -178,25 +293,63 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE; | |||
| } | |||
| DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| return GetListInt(attrs.find("shape")->second); | |||
| TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| CHECK_ATTR(attrs, "dst_type"); | |||
| auto dst_type = attrs.find("dst_type")->second; | |||
| if (dst_type->isa<Type>()) { | |||
| return dst_type->cast<TypePtr>()->type_id(); | |||
| } | |||
| return kernel::DtypeToTypeId(GetValue<std::string>(dst_type)); | |||
| } | |||
| void SelectOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| if (inputs[0]->type != TypeId::kNumberTypeBool) { | |||
| MS_LOG(EXCEPTION) << "Select's input[0] should be bool type"; | |||
| } | |||
| if (inputs[1]->type != inputs[2]->type) { | |||
| MS_LOG(EXCEPTION) << "Select's input[1] and input[2]'s type doesn't match"; | |||
| } | |||
| } | |||
| DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| CHECK_ATTR(attrs, "shape"); | |||
| auto new_shape = GetListInt(attrs.find("shape")->second); | |||
| auto origin_shape = inputs[0]->shape; | |||
| auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>()); | |||
| auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>()); | |||
| for (size_t i = 0; i < new_shape.size(); i++) { | |||
| if (new_shape[i] == -1) { | |||
| auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>()); | |||
| auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>()); | |||
| new_shape[i] = origin_product / new_product * (-1); | |||
| break; | |||
| return new_shape; | |||
| } | |||
| } | |||
| if (origin_product != new_product) { | |||
| MS_LOG(EXCEPTION) << "The shape product before and after reshaping should be equal"; | |||
| } | |||
| return new_shape; | |||
| } | |||
| DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| CHECK_ATTR(attrs, "shape"); | |||
| return GetListInt(attrs.find("shape")->second); | |||
| } | |||
| // check rudece axis in range [-size,size) | |||
| void ReduceOp::Check(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| PrimOp::Check(inputs, attrs); | |||
| CHECK_ATTR(attrs, "axis"); | |||
| auto axis = GetListInt(attrs.find("axis")->second); | |||
| int64_t size = static_cast<int64_t>(inputs[0]->shape.size()); | |||
| auto it = std::find_if(axis.begin(), axis.end(), [&size](const int64_t &i) { return (i >= size || i < (-size)); }); | |||
| if (it != axis.end()) { | |||
| MS_LOG(EXCEPTION) << "reduce_axis should be in range [" << (-size) << "," << size << ")" | |||
| << ",but got " << (*it); | |||
| } | |||
| } | |||
| DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| CHECK_ATTR(attrs, "axis"); | |||
| CHECK_ATTR(attrs, "keep_dims"); | |||
| auto axis = GetListInt(attrs.find("axis")->second); | |||
| auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second); | |||
| if (keepdims) { | |||
| @@ -218,6 +371,171 @@ DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| } | |||
| return new_shape; | |||
| } | |||
| void CheckNd(const std::vector<int64_t> &shape, size_t n) { | |||
| if (shape.size() != n) { | |||
| std::ostringstream info; | |||
| info << "input dimension should be " << n << ", but got " << shape.size(); | |||
| throw GKException(info.str()); | |||
| } | |||
| } | |||
| DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| auto shape0 = inputs[0]->shape; | |||
| auto shape1 = inputs[1]->shape; | |||
| CheckNd(shape0, 4); | |||
| CheckNd(shape1, 4); | |||
| if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC && | |||
| GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) { | |||
| throw GKException("check NHWC format failed"); | |||
| } | |||
| auto n = shape0[0]; | |||
| auto h = shape0[1]; | |||
| auto w = shape0[2]; | |||
| auto out_channel = shape1[0]; | |||
| CHECK_ATTR(attrs, "pad_list"); | |||
| CHECK_ATTR(attrs, "pad_mode"); | |||
| CHECK_ATTR(attrs, "kernel_size"); | |||
| CHECK_ATTR(attrs, "stride"); | |||
| CHECK_ATTR(attrs, "dilation"); | |||
| auto pad_list = GetListInt(attrs.find("pad_list")->second); | |||
| auto pad_mode = GetValue<std::string>(attrs.find("pad_mode")->second); | |||
| auto kernel_size = GetListInt(attrs.find("kernel_size")->second); | |||
| auto stride = GetListInt(attrs.find("stride")->second); | |||
| auto dilation = GetListInt(attrs.find("dilation")->second); | |||
| CheckNd(pad_list, 4); | |||
| CheckNd(kernel_size, 2); | |||
| CheckNd(stride, 4); | |||
| CheckNd(dilation, 4); | |||
| bool has_pad = false; | |||
| if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) { | |||
| has_pad = true; | |||
| } else { | |||
| if (pad_mode == "VALID" || pad_mode == "valid") { | |||
| if (std::any_of(pad_list.begin(), pad_list.end(), [](int i) { return i == 0; })) { | |||
| has_pad = true; | |||
| } | |||
| } | |||
| } | |||
| if (!has_pad) { | |||
| pad_list = {0, 0, 0, 0}; | |||
| } | |||
| auto k_h = (kernel_size[0] - 1) * dilation[2] + 1; | |||
| auto k_w = (kernel_size[1] - 1) * dilation[3] + 1; | |||
| auto out_h = (h + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1; | |||
| auto out_w = (w + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1; | |||
| std::vector<int64_t> output = {n, out_h, out_w, out_channel}; | |||
| return output; | |||
| } | |||
| TypeId Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type; | |||
| auto dst_type = attrs.find("dst_type")->second; | |||
| if (dst_type->isa<Type>()) { | |||
| return dst_type->cast<TypePtr>()->type_id(); | |||
| } | |||
| return kernel::DtypeToTypeId(GetValue<std::string>(dst_type)); | |||
| } | |||
| DShape TransposeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| CHECK_ATTR(attrs, "perm"); | |||
| auto perm = GetListInt(attrs.find("perm")->second); | |||
| auto &old_shape = inputs[0]->shape; | |||
| DShape new_shape; | |||
| if (perm.size() != old_shape.size()) { | |||
| MS_LOG(EXCEPTION) << "perm.size() != old_shape.size(). " << perm.size() << " vs " << old_shape.size(); | |||
| } | |||
| std::transform(perm.begin(), perm.end(), std::back_inserter(new_shape), | |||
| [&old_shape](int64_t p) { return old_shape[p]; }); | |||
| return new_shape; | |||
| } | |||
| DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| if (inputs[0]->shape.size() != 4) return kOpFormat_DEFAULT; | |||
| CHECK_ATTR(attrs, "perm"); | |||
| auto perm = GetListInt(attrs.find("perm")->second); | |||
| const auto &ori_format = inputs[0]->format; | |||
| if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) { | |||
| std::vector<int64_t> nchw2nhwc = {0, 2, 3, 1}; | |||
| if (perm == nchw2nhwc) return kOpFormat_NHWC; | |||
| } else if (ori_format == kOpFormat_NHWC) { | |||
| std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2}; | |||
| if (perm == nhwc2nchw) return kOpFormat_DEFAULT; | |||
| } | |||
| std::ostringstream info; | |||
| info << "Unsupported Transpose. ori_format = " << ori_format << ", perm = " << attrs.find("perm")->second->ToString(); | |||
| throw GKException(info.str()); | |||
| } | |||
| DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| std::vector<int64_t> shape0 = inputs[0]->shape; | |||
| std::vector<int64_t> shape1 = inputs[1]->shape; | |||
| if (shape0.size() != 2 || shape1.size() != 2) { | |||
| std::ostringstream info; | |||
| info << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size(); | |||
| throw GKException(info.str()); | |||
| } | |||
| auto transpose_a = GetValue<bool>(attrs.find("transpose_a")->second); | |||
| auto transpose_b = GetValue<bool>(attrs.find("transpose_b")->second); | |||
| int64_t m = transpose_a ? shape0[1] : shape0[0]; | |||
| int64_t k1 = transpose_a ? shape0[0] : shape0[1]; | |||
| int64_t k2 = transpose_b ? shape1[1] : shape1[0]; | |||
| int64_t n = transpose_b ? shape1[0] : shape1[1]; | |||
| if (k1 != k2) { | |||
| MS_LOG(EXCEPTION) << "MatMul's inputs have different k value " << k1 << " vs " << k2; | |||
| } | |||
| std::vector<int64_t> output = {m, n}; | |||
| return output; | |||
| } | |||
| TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type; | |||
| auto dst_type = attrs.find("dst_type")->second; | |||
| if (dst_type->isa<Type>()) { | |||
| return dst_type->cast<TypePtr>()->type_id(); | |||
| } | |||
| return kernel::DtypeToTypeId(GetValue<std::string>(dst_type)); | |||
| } | |||
| DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| std::vector<int64_t> shape0 = inputs[0]->shape; | |||
| size_t n = shape0.size(); | |||
| std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second); | |||
| std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second); | |||
| if (pad_before.size() != n || pad_after.size() != n) { | |||
| MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs " | |||
| << pad_after.size(); | |||
| } | |||
| std::vector<int64_t> output; | |||
| for (size_t i = 0; i < n; i++) { | |||
| output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]); | |||
| } | |||
| return output; | |||
| } | |||
| DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| std::vector<int64_t> shape0 = inputs[0]->shape; | |||
| size_t n = shape0.size(); | |||
| std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second); | |||
| if (unpad_after.size() != n) { | |||
| MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size(); | |||
| } | |||
| std::vector<int64_t> output; | |||
| for (size_t i = 0; i < n; i++) { | |||
| output.emplace_back(shape0[i] - unpad_after[i]); | |||
| } | |||
| return output; | |||
| } | |||
| void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| if (inputs[0]->type != TypeId::kNumberTypeFloat32) { | |||
| throw GKException("Complex's input[0] should be float32"); | |||
| } | |||
| if (inputs[0]->type != inputs[1]->type) { | |||
| MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch"; | |||
| } | |||
| } | |||
| } // namespace graphkernel | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -20,12 +20,23 @@ | |||
| #include <algorithm> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <functional> | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "ir/dtype/type.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace graphkernel { | |||
| #define CHECK_ATTR(attrs, attr_name) \ | |||
| do { \ | |||
| if (attrs.count(attr_name) == 0) { \ | |||
| MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \ | |||
| } \ | |||
| } while (0) | |||
| class PrimOp : public Node { | |||
| public: | |||
| enum ComputeType { | |||
| @@ -39,43 +50,109 @@ class PrimOp : public Node { | |||
| PrimOp(const std::string &op, const std::string &node_name, ComputeType compute) | |||
| : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {} | |||
| virtual void Check(const NodePtrList &inputs, const DAttrs &attrs); | |||
| virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {} | |||
| virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs); | |||
| virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs); | |||
| virtual void Infer(const NodePtrList &inputs, const DAttrs &attrs); | |||
| virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op); | |||
| virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; } | |||
| virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; } | |||
| virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; } | |||
| void Dump(std::ostringstream &os) const override; | |||
| NType NodeType() override { return NType::Primitive; } | |||
| const std::string &op() const { return op_; } | |||
| ComputeType compute_type() const { return compute_type_; } | |||
| virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op); | |||
| protected: | |||
| std::string op_; | |||
| ComputeType compute_type_; | |||
| virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; } | |||
| virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; } | |||
| virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; } | |||
| }; | |||
| using PrimOpPtr = std::shared_ptr<PrimOp>; | |||
| class ElemwiseOp : public PrimOp { | |||
| public: | |||
| ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {} | |||
| void Infer(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| // TODO(dayschan) rewrite InferShape/InferFormat | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class CastOp : public ElemwiseOp { | |||
| public: | |||
| CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {} | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class InplaceAssignOp : public ElemwiseOp { | |||
| public: | |||
| InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {} | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->shape; } | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->type; } | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->format; } | |||
| }; | |||
| class SelectOp : public ElemwiseOp { | |||
| public: | |||
| SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {} | |||
| void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[1]->type; } | |||
| }; | |||
| class CompareOp : public ElemwiseOp { | |||
| public: | |||
| CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {} | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; } | |||
| }; | |||
| class LessOp : public CompareOp { | |||
| public: | |||
| LessOp(const std::string &op, const std::string &node_name) : CompareOp("Less", node_name) {} | |||
| }; | |||
| class EqualOp : public CompareOp { | |||
| public: | |||
| EqualOp(const std::string &op, const std::string &node_name) : CompareOp("Equal", node_name) {} | |||
| }; | |||
| class LessEqualOp : public CompareOp { | |||
| public: | |||
| LessEqualOp(const std::string &op, const std::string &node_name) : CompareOp("LessEqual", node_name) {} | |||
| }; | |||
| class GreaterOp : public CompareOp { | |||
| public: | |||
| GreaterOp(const std::string &op, const std::string &node_name) : CompareOp("Greater", node_name) {} | |||
| }; | |||
| class GreaterEqualOp : public CompareOp { | |||
| public: | |||
| GreaterEqualOp(const std::string &op, const std::string &node_name) : CompareOp("GreaterEqual", node_name) {} | |||
| }; | |||
| class ReshapeOp : public PrimOp { | |||
| public: | |||
| ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {} | |||
| protected: | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { | |||
| return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT | |||
| : GetValue<std::string>(attrs.find("format")->second); | |||
| } | |||
| }; | |||
| class BroadcastToOp : public PrimOp { | |||
| public: | |||
| BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {} | |||
| protected: | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| @@ -83,8 +160,10 @@ class ReduceOp : public PrimOp { | |||
| public: | |||
| ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {} | |||
| protected: | |||
| void Check(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }; | |||
| }; | |||
| class OpaqueOp : public PrimOp { | |||
| @@ -95,6 +174,74 @@ class OpaqueOp : public PrimOp { | |||
| class Conv2dOp : public OpaqueOp { | |||
| public: | |||
| Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {} | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class TransposeOp : public OpaqueOp { | |||
| public: | |||
| TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {} | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class MatMulOp : public OpaqueOp { | |||
| public: | |||
| MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {} | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class PadAkgOp : public OpaqueOp { | |||
| public: | |||
| PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {} | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class UnPadAkgOp : public OpaqueOp { | |||
| public: | |||
| UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {} | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class CImagOp : public ElemwiseOp { | |||
| public: | |||
| CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {} | |||
| void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { | |||
| if (inputs[0]->type != TypeId::kNumberTypeComplex64) { | |||
| throw GKException("CImag's input[0] should be complex64"); | |||
| } | |||
| }; | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; } | |||
| }; | |||
| class CRealOp : public ElemwiseOp { | |||
| public: | |||
| CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {} | |||
| void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override { | |||
| if (inputs[0]->type != TypeId::kNumberTypeComplex64) { | |||
| throw GKException("CReal's input[0] should be complex64"); | |||
| } | |||
| }; | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; } | |||
| }; | |||
| class ComplexOp : public ElemwiseOp { | |||
| public: | |||
| ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {} | |||
| void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; } | |||
| }; | |||
| } // namespace graphkernel | |||
| } // namespace opt | |||
| @@ -0,0 +1,90 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_ | |||
| #include <unordered_map> | |||
| #include <functional> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace graphkernel { | |||
| #define OP_CREATOR(cls) \ | |||
| [](const std::string &op, const std::string &name) -> PrimOpPtr { return std::make_shared<cls>(op, name); } | |||
| class OpRegistry { | |||
| public: | |||
| static OpRegistry &Instance() { | |||
| static OpRegistry instance{}; | |||
| return instance; | |||
| } | |||
| void Register(const std::string &op_name, | |||
| const std::function<PrimOpPtr(const std::string &, const std::string &)> &func) { | |||
| creators.insert({op_name, func}); | |||
| } | |||
| PrimOpPtr NewOp(const std::string &op, const std::string &name) { | |||
| return creators.find(op) == creators.end() ? creators["Opaque"](op, name) : creators[op](op, name); | |||
| } | |||
| private: | |||
| OpRegistry() { | |||
| Register("Add", OP_CREATOR(ElemwiseOp)); | |||
| Register("Sub", OP_CREATOR(ElemwiseOp)); | |||
| Register("RealDiv", OP_CREATOR(ElemwiseOp)); | |||
| Register("Mul", OP_CREATOR(ElemwiseOp)); | |||
| Register("Log", OP_CREATOR(ElemwiseOp)); | |||
| Register("Exp", OP_CREATOR(ElemwiseOp)); | |||
| Register("Pow", OP_CREATOR(ElemwiseOp)); | |||
| Register("Sqrt", OP_CREATOR(ElemwiseOp)); | |||
| Register("Rsqrt", OP_CREATOR(ElemwiseOp)); | |||
| Register("Neg", OP_CREATOR(ElemwiseOp)); | |||
| Register("Reciprocal", OP_CREATOR(ElemwiseOp)); | |||
| Register("Abs", OP_CREATOR(ElemwiseOp)); | |||
| Register("BroadcastTo", OP_CREATOR(BroadcastToOp)); | |||
| Register("Reshape", OP_CREATOR(ReshapeOp)); | |||
| Register("ReduceSum", OP_CREATOR(ReduceOp)); | |||
| Register("ReduceMax", OP_CREATOR(ReduceOp)); | |||
| Register("ReduceMin", OP_CREATOR(ReduceOp)); | |||
| Register("Cast", OP_CREATOR(CastOp)); | |||
| Register("InplaceAssign", OP_CREATOR(InplaceAssignOp)); | |||
| Register("Select", OP_CREATOR(SelectOp)); | |||
| Register("Less", OP_CREATOR(LessOp)); | |||
| Register("Equal", OP_CREATOR(EqualOp)); | |||
| Register("LessEqual", OP_CREATOR(LessEqualOp)); | |||
| Register("GreaterEqual", OP_CREATOR(GreaterEqualOp)); | |||
| Register("Greater", OP_CREATOR(GreaterOp)); | |||
| Register("Transpose", OP_CREATOR(TransposeOp)); | |||
| Register("MatMul", OP_CREATOR(MatMulOp)); | |||
| Register("PadAkg", OP_CREATOR(PadAkgOp)); | |||
| Register("UnPadAkg", OP_CREATOR(UnPadAkgOp)); | |||
| Register("CReal", OP_CREATOR(CRealOp)); | |||
| Register("CImag", OP_CREATOR(CImagOp)); | |||
| Register("Complex", OP_CREATOR(ComplexOp)); | |||
| Register("Opaque", OP_CREATOR(OpaqueOp)); | |||
| } | |||
| ~OpRegistry() = default; | |||
| std::unordered_map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators; | |||
| }; | |||
| } // namespace graphkernel | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif | |||