Browse Source

add more infer

tags/v1.5.0-rc1
Yang Jiao 4 years ago
parent
commit
74c96bf4ee
7 changed files with 612 additions and 83 deletions
  1. +26
    -22
      mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc
  2. +6
    -26
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc
  3. +0
    -22
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h
  4. +12
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h
  5. +323
    -5
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc
  6. +155
    -8
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h
  7. +90
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_register.h

+ 26
- 22
mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc View File

@@ -643,29 +643,33 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
expressions_map_ = GetExpressions();
for (auto node : func_graph->GetOrderedCnodes()) {
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph);
bool find_pattern = true;
bool change_anf_graph = false;
while (find_pattern) {
find_pattern = false;
find_pattern = DoArithmeticTrans(lg) || find_pattern;
find_pattern = DoConstantFold(lg) || find_pattern;
change_anf_graph = change_anf_graph || find_pattern;
try {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
graphkernel::LiteGraphPtr lg = AnfGraph2LiteGraph(sub_graph);
bool find_pattern = true;
bool change_anf_graph = false;
while (find_pattern) {
find_pattern = false;
find_pattern = DoArithmeticTrans(lg) || find_pattern;
find_pattern = DoConstantFold(lg) || find_pattern;
change_anf_graph = change_anf_graph || find_pattern;
}
if (!change_anf_graph) continue;
ReorganizeEmptyGraph(lg);
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
EliminateRedundantParameters(new_funcgraph, &inputs);
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
do_simplify = true;
} catch (const graphkernel::GKException &e) {
MS_LOG(WARNING) << e.what() << ", so we undo airthmetic simplify for this graph";
}
if (!change_anf_graph) continue;
ReorganizeEmptyGraph(lg);
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
EliminateRedundantParameters(new_funcgraph, &inputs);
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
do_simplify = true;
}
}
return do_simplify;


+ 6
- 26
mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc View File

@@ -27,6 +27,7 @@
#include "backend/optimizer/graph_kernel/model/node.h"
#include "backend/optimizer/graph_kernel/model/op_node.h"
#include "backend/optimizer/graph_kernel/model/op_register.h"
namespace mindspore {
namespace opt {
@@ -107,36 +108,15 @@ NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &
NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
const DAttrs &attrs, std::string node_name) {
auto op_ptr = Emit(op, inputs, attrs, node_name);
PrimOpPtr op_ptr = CreateOp(op, node_name);
op_ptr->SetInputs(inputs);
op_ptr->SetAttrs(attrs);
op_ptr->SetBaseInfo(baseinfo);
return op_ptr;
return graph_->Add(op_ptr);
}
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
static std::map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
if (creators.empty()) {
creators = {{"Add", Elemwise},
{"Sub", Elemwise},
{"RealDiv", Elemwise},
{"Mul", Elemwise},
{"Log", Elemwise},
{"Exp", Elemwise},
{"Pow", Elemwise},
{"Sqrt", Elemwise},
{"Rsqrt", Elemwise},
{"Neg", Elemwise},
{"Reciprocal", Elemwise},
{"Abs", Elemwise},
{"BroadcastTo", BroadcastTo},
{"Reshape", Reshape},
{"ReduceSum", Reduce},
{"ReduceMax", Reduce},
{"ReduceMin", Reduce},
{"Conv2D", Conv2d}};
}
auto iter = creators.find(op);
auto creator = (iter == creators.end() ? Opaque : iter->second);
return creator(op, node_name);
return OpRegistry::Instance().NewOp(op, node_name);
}
} // namespace graphkernel
} // namespace opt


+ 0
- 22
mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h View File

@@ -81,28 +81,6 @@ class LiteGraph::GraphBuilder {
LiteGraphPtr Get() { return graph_; }
private:
static PrimOpPtr Elemwise(const std::string &op, const std::string &name) {
return std::make_shared<ElemwiseOp>(op, name);
}
static PrimOpPtr BroadcastTo(const std::string &op, const std::string &name) {
return std::make_shared<BroadcastToOp>(op, name);
}
static PrimOpPtr Reshape(const std::string &op, const std::string &name) {
return std::make_shared<ReshapeOp>(op, name);
}
static PrimOpPtr Reduce(const std::string &op, const std::string &name) {
return std::make_shared<ReduceOp>(op, name);
}
static PrimOpPtr Opaque(const std::string &op, const std::string &name) {
return std::make_shared<OpaqueOp>(op, name);
}
static PrimOpPtr Conv2d(const std::string &op, const std::string &name) {
return std::make_shared<Conv2dOp>(op, name);
}
PrimOpPtr CreateOp(const std::string &id, const std::string &name);
std::string NewName(std::string prefix = "output_") { return prefix + std::to_string(graph_->name_id_++); }


+ 12
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h View File

@@ -26,6 +26,7 @@
#include <iostream>
#include <utility>
#include <string>
#include <stdexcept>
#include "mindspore/core/ir/dtype/type_id.h"
#include "mindspore/core/ir/value.h"
@@ -85,6 +86,8 @@ class Node : public NodeBase {
void SetInput(size_t i, const NodePtr &new_input);
void SetInputs(const NodePtrList &inputs);
void ReplaceWith(const NodePtr &other_node);
void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; }
void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
template <typename T>
T *As() {
@@ -146,6 +149,15 @@ class OutputNode : public Node {
void Dump(std::ostringstream &os) const override { ; }
NType NodeType() override { return NType::Output; }
};
class GKException : public std::exception {
public:
explicit GKException(const std::string &message) : msg_(message) {}
const char *what() const noexcept override { return msg_.c_str(); }
protected:
std::string msg_;
};
} // namespace graphkernel
} // namespace opt
} // namespace mindspore


+ 323
- 5
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc View File

@@ -49,7 +49,40 @@ std::vector<int64_t> GetListInt(const ValuePtr &attr_value) {
return list_int;
}
void PrimOp::Check(const NodePtrList &inputs, const DAttrs &attrs) {
CheckShape(inputs, attrs);
CheckType(inputs, attrs);
CheckFormat(inputs, attrs);
}
// check all type to be identical
void PrimOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
TypeId tid = inputs[0]->type;
for (size_t i = 1; i < inputs.size(); i++) {
if (inputs[i]->type != tid) {
MS_LOG(EXCEPTION) << "Incompatible dtype between input " << 0 << "and" << i;
}
}
}
// check all formats are compatible, only DefaultForant is compatible with others
void PrimOp::CheckFormat(const NodePtrList &inputs, const DAttrs &attrs) {
DFormat res = inputs[0]->format;
size_t i = 0;
for (size_t j = 1; j < inputs.size(); j++) {
if (inputs[j]->format != res) {
if (inputs[j]->format != kOpFormat_DEFAULT && res != kOpFormat_DEFAULT) {
MS_LOG(EXCEPTION) << "Incompatible format between input " << i << "and" << (j + 1);
}
if (res == kOpFormat_DEFAULT) {
res = inputs[j]->format;
i = j + 1;
}
}
}
}
void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
Check(inputs, attrs);
this->shape = InferShape(inputs, attrs);
this->type = InferType(inputs, attrs);
this->format = InferFormat(inputs, attrs);
@@ -164,6 +197,88 @@ NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const
return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
}
// default format shape to fractal_Nz format shape
DShape ToNz(const DShape &default_shape) {
if (default_shape.size() != 1 && default_shape.size() != 2) {
throw GKException("shape is too long");
}
DShape output_shape;
if (default_shape.size() == 1 || (default_shape.size() == 2 && default_shape[0] == 1)) {
output_shape = {default_shape[default_shape.size() - 1] / 16, 1, 1, 16};
if (default_shape[default_shape.size() - 1] % 16 != 0) {
throw GKException("should be multiplies of 16");
}
} else if (default_shape.size() == 2 || default_shape[1] == 1) {
output_shape = {1, default_shape[0] / 16, 16, 1};
if (default_shape[0] % 16 != 0) {
throw GKException("should be multiplies of 16");
}
} else {
output_shape = {default_shape[1] / 16, default_shape[0] / 16, 16, 16};
if (default_shape[0] % 16 != 0 || default_shape[1] % 16 != 0) {
throw GKException("should be multiplies of 16");
}
}
return output_shape;
}
DShape BroadcastShape(const NodePtrList &inputs, bool to_nz = false) {
std::vector<std::vector<int64_t>> shapes;
for (auto &input : inputs) {
if (to_nz && input->format != kOpFormat_FRAC_NZ) {
shapes.emplace_back(ToNz(input->shape));
} else {
shapes.emplace_back(input->shape);
}
}
auto max_dim_input =
std::max_element(shapes.begin(), shapes.end(),
[](const std::vector<int64_t> &a, const std::vector<int64_t> &b) { return a.size() < b.size(); });
auto max_dim = max_dim_input->size();
std::vector<std::vector<int64_t>> align_shapes;
for (auto &s : shapes) {
std::vector<int64_t> cur(max_dim - s.size(), 1);
cur.insert(cur.end(), s.begin(), s.end());
align_shapes.emplace_back(cur);
}
std::vector<int64_t> output_shape(max_dim, 1);
for (size_t i = 0; i < max_dim; i++) {
for (auto &align_shape : align_shapes) {
if (align_shape[i] > 1) {
if (output_shape[i] == 1) {
output_shape[i] = align_shape[i];
}
if (output_shape[i] != align_shape[i]) {
throw GKException("shape broadcast failed");
}
}
}
}
return output_shape;
}
DShape ElemwiseOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) {
return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC || input->format == kOpFormat_NCHW;
})) {
return BroadcastShape(inputs, false);
}
if (std::all_of(inputs.begin(), inputs.end(), [](const NodePtr &input) {
return input->format == kOpFormat_DEFAULT || input->format == kOpFormat_NHWC ||
input->format == kOpFormat_NCHW || input->format == kOpFormat_FRAC_NZ;
})) {
return BroadcastShape(inputs, true);
}
throw GKException("Only support default and fractal_nz");
}
DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
auto it = std::find_if(inputs.begin(), inputs.end(), [](const NodePtr &i) { return i->format != kOpFormat_DEFAULT; });
return it == inputs.end() ? kOpFormat_DEFAULT : (*it)->format;
}
void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
PrimOp::Infer(inputs, attrs);
auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
@@ -178,25 +293,63 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
}
DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
return GetListInt(attrs.find("shape")->second);
TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
CHECK_ATTR(attrs, "dst_type");
auto dst_type = attrs.find("dst_type")->second;
if (dst_type->isa<Type>()) {
return dst_type->cast<TypePtr>()->type_id();
}
return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
}
void SelectOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
if (inputs[0]->type != TypeId::kNumberTypeBool) {
MS_LOG(EXCEPTION) << "Select's input[0] should be bool type";
}
if (inputs[1]->type != inputs[2]->type) {
MS_LOG(EXCEPTION) << "Select's input[1] and input[2]'s type doesn't match";
}
}
DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
CHECK_ATTR(attrs, "shape");
auto new_shape = GetListInt(attrs.find("shape")->second);
auto origin_shape = inputs[0]->shape;
auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>());
auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>());
for (size_t i = 0; i < new_shape.size(); i++) {
if (new_shape[i] == -1) {
auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>());
auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>());
new_shape[i] = origin_product / new_product * (-1);
break;
return new_shape;
}
}
if (origin_product != new_product) {
MS_LOG(EXCEPTION) << "The shape product before and after reshaping should be equal";
}
return new_shape;
}
DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
CHECK_ATTR(attrs, "shape");
return GetListInt(attrs.find("shape")->second);
}
// check rudece axis in range [-size,size)
void ReduceOp::Check(const NodePtrList &inputs, const DAttrs &attrs) {
PrimOp::Check(inputs, attrs);
CHECK_ATTR(attrs, "axis");
auto axis = GetListInt(attrs.find("axis")->second);
int64_t size = static_cast<int64_t>(inputs[0]->shape.size());
auto it = std::find_if(axis.begin(), axis.end(), [&size](const int64_t &i) { return (i >= size || i < (-size)); });
if (it != axis.end()) {
MS_LOG(EXCEPTION) << "reduce_axis should be in range [" << (-size) << "," << size << ")"
<< ",but got " << (*it);
}
}
DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
CHECK_ATTR(attrs, "axis");
CHECK_ATTR(attrs, "keep_dims");
auto axis = GetListInt(attrs.find("axis")->second);
auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);
if (keepdims) {
@@ -218,6 +371,171 @@ DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
}
return new_shape;
}
void CheckNd(const std::vector<int64_t> &shape, size_t n) {
if (shape.size() != n) {
std::ostringstream info;
info << "input dimension should be " << n << ", but got " << shape.size();
throw GKException(info.str());
}
}
DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto shape0 = inputs[0]->shape;
auto shape1 = inputs[1]->shape;
CheckNd(shape0, 4);
CheckNd(shape1, 4);
if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC &&
GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) {
throw GKException("check NHWC format failed");
}
auto n = shape0[0];
auto h = shape0[1];
auto w = shape0[2];
auto out_channel = shape1[0];
CHECK_ATTR(attrs, "pad_list");
CHECK_ATTR(attrs, "pad_mode");
CHECK_ATTR(attrs, "kernel_size");
CHECK_ATTR(attrs, "stride");
CHECK_ATTR(attrs, "dilation");
auto pad_list = GetListInt(attrs.find("pad_list")->second);
auto pad_mode = GetValue<std::string>(attrs.find("pad_mode")->second);
auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
auto stride = GetListInt(attrs.find("stride")->second);
auto dilation = GetListInt(attrs.find("dilation")->second);
CheckNd(pad_list, 4);
CheckNd(kernel_size, 2);
CheckNd(stride, 4);
CheckNd(dilation, 4);
bool has_pad = false;
if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) {
has_pad = true;
} else {
if (pad_mode == "VALID" || pad_mode == "valid") {
if (std::any_of(pad_list.begin(), pad_list.end(), [](int i) { return i == 0; })) {
has_pad = true;
}
}
}
if (!has_pad) {
pad_list = {0, 0, 0, 0};
}
auto k_h = (kernel_size[0] - 1) * dilation[2] + 1;
auto k_w = (kernel_size[1] - 1) * dilation[3] + 1;
auto out_h = (h + pad_list[0] + pad_list[1] - k_h) / stride[2] + 1;
auto out_w = (w + pad_list[2] + pad_list[3] - k_w) / stride[3] + 1;
std::vector<int64_t> output = {n, out_h, out_w, out_channel};
return output;
}
TypeId Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
auto dst_type = attrs.find("dst_type")->second;
if (dst_type->isa<Type>()) {
return dst_type->cast<TypePtr>()->type_id();
}
return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
}
DShape TransposeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
CHECK_ATTR(attrs, "perm");
auto perm = GetListInt(attrs.find("perm")->second);
auto &old_shape = inputs[0]->shape;
DShape new_shape;
if (perm.size() != old_shape.size()) {
MS_LOG(EXCEPTION) << "perm.size() != old_shape.size(). " << perm.size() << " vs " << old_shape.size();
}
std::transform(perm.begin(), perm.end(), std::back_inserter(new_shape),
[&old_shape](int64_t p) { return old_shape[p]; });
return new_shape;
}
DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
if (inputs[0]->shape.size() != 4) return kOpFormat_DEFAULT;
CHECK_ATTR(attrs, "perm");
auto perm = GetListInt(attrs.find("perm")->second);
const auto &ori_format = inputs[0]->format;
if (ori_format == kOpFormat_DEFAULT || ori_format == kOpFormat_NCHW) {
std::vector<int64_t> nchw2nhwc = {0, 2, 3, 1};
if (perm == nchw2nhwc) return kOpFormat_NHWC;
} else if (ori_format == kOpFormat_NHWC) {
std::vector<int64_t> nhwc2nchw = {0, 3, 1, 2};
if (perm == nhwc2nchw) return kOpFormat_DEFAULT;
}
std::ostringstream info;
info << "Unsupported Transpose. ori_format = " << ori_format << ", perm = " << attrs.find("perm")->second->ToString();
throw GKException(info.str());
}
DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
std::vector<int64_t> shape0 = inputs[0]->shape;
std::vector<int64_t> shape1 = inputs[1]->shape;
if (shape0.size() != 2 || shape1.size() != 2) {
std::ostringstream info;
info << "MatMul's input's dimension must be 2, but got " << shape0.size() << " and " << shape1.size();
throw GKException(info.str());
}
auto transpose_a = GetValue<bool>(attrs.find("transpose_a")->second);
auto transpose_b = GetValue<bool>(attrs.find("transpose_b")->second);
int64_t m = transpose_a ? shape0[1] : shape0[0];
int64_t k1 = transpose_a ? shape0[0] : shape0[1];
int64_t k2 = transpose_b ? shape1[1] : shape1[0];
int64_t n = transpose_b ? shape1[0] : shape1[1];
if (k1 != k2) {
MS_LOG(EXCEPTION) << "MatMul's inputs have different k value " << k1 << " vs " << k2;
}
std::vector<int64_t> output = {m, n};
return output;
}
TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
auto dst_type = attrs.find("dst_type")->second;
if (dst_type->isa<Type>()) {
return dst_type->cast<TypePtr>()->type_id();
}
return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
}
DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
std::vector<int64_t> shape0 = inputs[0]->shape;
size_t n = shape0.size();
std::vector<int64_t> pad_before = GetListInt(attrs.find("head")->second);
std::vector<int64_t> pad_after = GetListInt(attrs.find("tail")->second);
if (pad_before.size() != n || pad_after.size() != n) {
MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << pad_before.size() << " vs "
<< pad_after.size();
}
std::vector<int64_t> output;
for (size_t i = 0; i < n; i++) {
output.emplace_back(shape0[i] + pad_before[i] + pad_after[i]);
}
return output;
}
DShape UnPadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
std::vector<int64_t> shape0 = inputs[0]->shape;
size_t n = shape0.size();
std::vector<int64_t> unpad_after = GetListInt(attrs.find("tail")->second);
if (unpad_after.size() != n) {
MS_LOG(EXCEPTION) << "Input dimension and pad mismatch: " << n << " vs " << unpad_after.size();
}
std::vector<int64_t> output;
for (size_t i = 0; i < n; i++) {
output.emplace_back(shape0[i] - unpad_after[i]);
}
return output;
}
void ComplexOp::CheckType(const NodePtrList &inputs, const DAttrs &attrs) {
if (inputs[0]->type != TypeId::kNumberTypeFloat32) {
throw GKException("Complex's input[0] should be float32");
}
if (inputs[0]->type != inputs[1]->type) {
MS_LOG(EXCEPTION) << "Complex's input[0] and inputs[1]'s type mismatch";
}
}
} // namespace graphkernel
} // namespace opt
} // namespace mindspore

+ 155
- 8
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h View File

@@ -20,12 +20,23 @@
#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_map>
#include <functional>
#include "backend/optimizer/graph_kernel/model/node.h"
#include "backend/kernel_compiler/common_utils.h"
#include "ir/dtype/type.h"
namespace mindspore {
namespace opt {
namespace graphkernel {
#define CHECK_ATTR(attrs, attr_name) \
do { \
if (attrs.count(attr_name) == 0) { \
MS_LOG(EXCEPTION) << "The attr [" << attr_name << "] does not exist in [" << #attrs << "]"; \
} \
} while (0)
class PrimOp : public Node {
public:
enum ComputeType {
@@ -39,43 +50,109 @@ class PrimOp : public Node {
PrimOp(const std::string &op, const std::string &node_name, ComputeType compute)
: Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {}
virtual void Check(const NodePtrList &inputs, const DAttrs &attrs);
virtual void CheckShape(const NodePtrList &inputs, const DAttrs &attrs) {}
virtual void CheckType(const NodePtrList &inputs, const DAttrs &attrs);
virtual void CheckFormat(const NodePtrList &inputs, const DAttrs &attrs);
virtual void Infer(const NodePtrList &inputs, const DAttrs &attrs);
virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
void Dump(std::ostringstream &os) const override;
NType NodeType() override { return NType::Primitive; }
const std::string &op() const { return op_; }
ComputeType compute_type() const { return compute_type_; }
virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
protected:
std::string op_;
ComputeType compute_type_;
virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
};
using PrimOpPtr = std::shared_ptr<PrimOp>;
class ElemwiseOp : public PrimOp {
public:
ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {}
void Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
// TODO(dayschan) rewrite InferShape/InferFormat
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class CastOp : public ElemwiseOp {
public:
CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {}
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class InplaceAssignOp : public ElemwiseOp {
public:
InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {}
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->shape; }
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->type; }
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[2]->format; }
};
class SelectOp : public ElemwiseOp {
public:
SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {}
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return inputs[1]->type; }
};
class CompareOp : public ElemwiseOp {
public:
CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {}
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; }
};
class LessOp : public CompareOp {
public:
LessOp(const std::string &op, const std::string &node_name) : CompareOp("Less", node_name) {}
};
class EqualOp : public CompareOp {
public:
EqualOp(const std::string &op, const std::string &node_name) : CompareOp("Equal", node_name) {}
};
class LessEqualOp : public CompareOp {
public:
LessEqualOp(const std::string &op, const std::string &node_name) : CompareOp("LessEqual", node_name) {}
};
class GreaterOp : public CompareOp {
public:
GreaterOp(const std::string &op, const std::string &node_name) : CompareOp("Greater", node_name) {}
};
class GreaterEqualOp : public CompareOp {
public:
GreaterEqualOp(const std::string &op, const std::string &node_name) : CompareOp("GreaterEqual", node_name) {}
};
class ReshapeOp : public PrimOp {
public:
ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override {
return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
: GetValue<std::string>(attrs.find("format")->second);
}
};
class BroadcastToOp : public PrimOp {
public:
BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
@@ -83,8 +160,10 @@ class ReduceOp : public PrimOp {
public:
ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}
protected:
void Check(const NodePtrList &inputs, const DAttrs &attrs) override;
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; };
};
class OpaqueOp : public PrimOp {
@@ -95,6 +174,74 @@ class OpaqueOp : public PrimOp {
class Conv2dOp : public OpaqueOp {
public:
Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {}
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class TransposeOp : public OpaqueOp {
public:
TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {}
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class MatMulOp : public OpaqueOp {
public:
MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {}
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class PadAkgOp : public OpaqueOp {
public:
PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {}
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class UnPadAkgOp : public OpaqueOp {
public:
UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {}
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class CImagOp : public ElemwiseOp {
public:
CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {}
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
throw GKException("CImag's input[0] should be complex64");
}
};
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; }
};
class CRealOp : public ElemwiseOp {
public:
CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {}
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override {
if (inputs[0]->type != TypeId::kNumberTypeComplex64) {
throw GKException("CReal's input[0] should be complex64");
}
};
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeFloat32; }
};
class ComplexOp : public ElemwiseOp {
public:
ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {}
void CheckType(const NodePtrList &inputs, const DAttrs &attrs) override;
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeComplex64; }
};
} // namespace graphkernel
} // namespace opt


+ 90
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_register.h View File

@@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_REGISTER_H_
#include <unordered_map>
#include <functional>
#include <string>
#include <memory>
#include "backend/optimizer/graph_kernel/model/node.h"
namespace mindspore {
namespace opt {
namespace graphkernel {
#define OP_CREATOR(cls) \
[](const std::string &op, const std::string &name) -> PrimOpPtr { return std::make_shared<cls>(op, name); }
class OpRegistry {
public:
static OpRegistry &Instance() {
static OpRegistry instance{};
return instance;
}
void Register(const std::string &op_name,
const std::function<PrimOpPtr(const std::string &, const std::string &)> &func) {
creators.insert({op_name, func});
}
PrimOpPtr NewOp(const std::string &op, const std::string &name) {
return creators.find(op) == creators.end() ? creators["Opaque"](op, name) : creators[op](op, name);
}
private:
OpRegistry() {
Register("Add", OP_CREATOR(ElemwiseOp));
Register("Sub", OP_CREATOR(ElemwiseOp));
Register("RealDiv", OP_CREATOR(ElemwiseOp));
Register("Mul", OP_CREATOR(ElemwiseOp));
Register("Log", OP_CREATOR(ElemwiseOp));
Register("Exp", OP_CREATOR(ElemwiseOp));
Register("Pow", OP_CREATOR(ElemwiseOp));
Register("Sqrt", OP_CREATOR(ElemwiseOp));
Register("Rsqrt", OP_CREATOR(ElemwiseOp));
Register("Neg", OP_CREATOR(ElemwiseOp));
Register("Reciprocal", OP_CREATOR(ElemwiseOp));
Register("Abs", OP_CREATOR(ElemwiseOp));
Register("BroadcastTo", OP_CREATOR(BroadcastToOp));
Register("Reshape", OP_CREATOR(ReshapeOp));
Register("ReduceSum", OP_CREATOR(ReduceOp));
Register("ReduceMax", OP_CREATOR(ReduceOp));
Register("ReduceMin", OP_CREATOR(ReduceOp));
Register("Cast", OP_CREATOR(CastOp));
Register("InplaceAssign", OP_CREATOR(InplaceAssignOp));
Register("Select", OP_CREATOR(SelectOp));
Register("Less", OP_CREATOR(LessOp));
Register("Equal", OP_CREATOR(EqualOp));
Register("LessEqual", OP_CREATOR(LessEqualOp));
Register("GreaterEqual", OP_CREATOR(GreaterEqualOp));
Register("Greater", OP_CREATOR(GreaterOp));
Register("Transpose", OP_CREATOR(TransposeOp));
Register("MatMul", OP_CREATOR(MatMulOp));
Register("PadAkg", OP_CREATOR(PadAkgOp));
Register("UnPadAkg", OP_CREATOR(UnPadAkgOp));
Register("CReal", OP_CREATOR(CRealOp));
Register("CImag", OP_CREATOR(CImagOp));
Register("Complex", OP_CREATOR(ComplexOp));
Register("Opaque", OP_CREATOR(OpaqueOp));
}
~OpRegistry() = default;
std::unordered_map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
};
} // namespace graphkernel
} // namespace opt
} // namespace mindspore
#endif

Loading…
Cancel
Save