Browse Source

cleancode for litegraph

remove the "node_name" argument from class Node's constructor,
rename the node's "name_" to "debug_name_", and add the SetDebugName interface.

refactor the `OpRegistry`:
register the nodes by macro `OP_REGISTER`, instead of write them
in OpRegistry's constructor.

remove the five CompareOp's subclass, which did not override any virtual functions.
they can be binded to the `CompareOp` class.
tags/v1.6.0
dayschan 4 years ago
parent
commit
9b2c402344
12 changed files with 270 additions and 272 deletions
  1. +2
    -2
      mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc
  2. +2
    -2
      mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.cc
  3. +2
    -2
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc
  4. +29
    -18
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc
  5. +22
    -20
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h
  6. +24
    -20
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.cc
  7. +19
    -48
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h
  8. +30
    -21
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc
  9. +56
    -90
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h
  10. +70
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_register.cc
  11. +13
    -48
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_register.h
  12. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc

+ 2
- 2
mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc View File

@@ -619,11 +619,11 @@ void ReorganizeEmptyGraph(const inner::LiteGraphPtr &litegraph) {
inner::LiteGraph::GraphBuilder gb;
std::vector<int64_t> new_shape = {1};
auto op_ptr = gb.Emit("BroadcastTo", {outputs[i]}, {{"shape", MakeValue(new_shape)}});
litegraph->output()->SetInput(i, op_ptr);
litegraph->SetOutput(i, op_ptr);
} else if (outputs[i]->NodeType() == inner::NType::Parameter) {
inner::LiteGraph::GraphBuilder gb;
auto op_ptr = gb.Emit("Reshape", {outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
litegraph->output()->SetInput(i, op_ptr);
litegraph->SetOutput(i, op_ptr);
}
}
return;


+ 2
- 2
mindspore/ccsrc/backend/optimizer/graph_kernel/core/graph_kernel_utils.cc View File

@@ -156,7 +156,7 @@ FuncGraphPtr GkUtils::LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph)
// Create CNodes.
for (const auto &op_node : lite_graph->GetOrderedNodes()) {
if (op_node->NodeType() != inner::NType::Primitive) {
MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node";
MS_LOG(EXCEPTION) << "Node " << op_node->debug_name() << "should be a Primitive node";
}
auto op = std::static_pointer_cast<inner::PrimOp>(op_node);
AnfNodePtrList inputs = {NewValueNode(std::make_shared<Primitive>(op->op(), op->attrs()))};
@@ -168,7 +168,7 @@ FuncGraphPtr GkUtils::LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph)
return iter->second;
} else {
if (inp->NodeType() != inner::NType::Value) {
MS_LOG(EXCEPTION) << "Node " << inp->name() << "should be a Value node";
MS_LOG(EXCEPTION) << "Node " << inp->debug_name() << "should be a Value node";
}
auto inp_value = inp->As<inner::ConstTensorNode>()->data();
auto value_node = NewValueNode(inp_value);


+ 2
- 2
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -458,8 +458,8 @@ inner::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph) {
return inner::NodeBase({shape, type, format});
};
// set inputs
for (size_t i = 0; i < params.size(); i++) {
node_map[params[i]] = gb.Parameter(ExtractBuildInfo(params[i]), std::string("input_") + std::to_string(i));
for (auto &p : params) {
node_map[p] = gb.Parameter(ExtractBuildInfo(p));
}
// set ops
for (auto node : todos) {


+ 29
- 18
mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc View File

@@ -18,11 +18,10 @@
#include <memory>
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <utility>
#include <string>
#include <iostream>
#include <sstream>
#include "utils/hash_map.h"
#include "backend/optimizer/graph_kernel/model/node.h"
@@ -30,22 +29,31 @@
#include "backend/optimizer/graph_kernel/model/op_register.h"
namespace mindspore::graphkernel::inner {
std::string LiteGraph::Dump() const {
std::string LiteGraph::ToString(bool reset_node_name) const {
if (reset_node_name) {
param_id_ = node_id_ = 0;
for (auto &inp : inputs_) {
inp->SetDebugName(ParamName());
}
for (auto &node : ops_) {
node->SetDebugName(NodeName());
}
}
std::ostringstream os;
os << name_ << "(";
for (size_t i = 0; i < inputs_.size(); i++) {
os << inputs_[i]->name();
os << inputs_[i]->debug_name();
if (i != inputs_.size() - 1) os << ", ";
}
os << ") -> ";
auto &outputs = GetOutputs();
for (size_t i = 0; i < outputs.size(); i++) {
os << outputs[i]->name();
os << outputs[i]->debug_name();
if (i != outputs.size() - 1) os << ", ";
}
os << " {\n";
for (NodePtr op : ops_) {
os << " " << *op << "\n";
for (const NodePtr &op : ops_) {
os << " " << op->ToString() << "\n";
}
os << "}";
return os.str();
@@ -55,6 +63,7 @@ const NodePtrList &LiteGraph::GetOrderedNodes() {
mindspore::HashMap<NodePtr, size_t> outdegrees;
std::function<void(NodePtr)> dfs;
std::set<NodePtr> visited;
// record the out degree of each nodes by Dfs.
dfs = [&dfs, &outdegrees, &visited](const NodePtr &node) {
(void)visited.insert(node);
for (auto &input : node->inputs()) {
@@ -69,6 +78,8 @@ const NodePtrList &LiteGraph::GetOrderedNodes() {
dfs(output_);
NodePtrList res;
NodePtrList stack;
// toposort algorithm with out degree
stack.push_back(output_);
while (!stack.empty()) {
auto cur = stack.back();
@@ -86,20 +97,19 @@ const NodePtrList &LiteGraph::GetOrderedNodes() {
if (!outdegrees.empty()) {
MS_LOG(ERROR) << "Circle was found:";
for (auto &node : outdegrees) {
MS_LOG(ERROR) << " " << *(node.first);
MS_LOG(ERROR) << " " << node.first->debug_name();
}
MS_LOG(EXCEPTION) << "Circle size: " << outdegrees.size();
}
std::reverse(res.begin(), res.end());
res.pop_back(); // erase the output node
// remove the "OutputNode"
res.pop_back();
ops_ = std::move(res);
return ops_;
}
NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs,
std::string node_name) {
if (node_name.empty()) node_name = NewName();
PrimOpPtr op_ptr = CreateOp(op, node_name);
NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs) {
PrimOpPtr op_ptr = CreateOp(op);
auto baseinfo = op_ptr->Infer(inputs, attrs);
op_ptr->SetInputs(inputs);
op_ptr->SetAttrs(attrs);
@@ -108,16 +118,17 @@ NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &
}
NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
const DAttrs &attrs, std::string node_name) {
if (node_name.empty()) node_name = NewName();
PrimOpPtr op_ptr = CreateOp(op, node_name);
const DAttrs &attrs) {
PrimOpPtr op_ptr = CreateOp(op);
op_ptr->SetInputs(inputs);
op_ptr->SetAttrs(attrs);
op_ptr->SetBaseInfo(baseinfo);
return graph_->Add(op_ptr);
}
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
return OpRegistry::Instance().NewOp(op, node_name);
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op) {
auto node = OpRegistry::Instance().NewOp(op);
node->SetDebugName(graph_->NodeName());
return node;
}
} // namespace mindspore::graphkernel::inner

+ 22
- 20
mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h View File

@@ -17,12 +17,7 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_
#include <memory>
#include <vector>
#include <list>
#include <stack>
#include <string>
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "backend/optimizer/graph_kernel/model/node.h"
#include "backend/optimizer/graph_kernel/model/op_node.h"
@@ -39,13 +34,16 @@ class LiteGraph {
const NodePtrList &GetOrderedNodes();
std::string Dump() const;
std::string ToString(bool reset_node_name = false) const;
const std::string &name() const { return name_; }
const NodePtrList &ops() const { return ops_; }
const NodePtrList &inputs() const { return inputs_; }
const NodePtr &output() const { return output_; }
const NodePtr &output(size_t i) const { return output_->input(i); }
const NodePtrList &GetOutputs() const { return output_->inputs(); }
void SetOutput(size_t i, const NodePtr &node) { output_->SetInput(i, node); }
void SetOutputs(const NodePtrList &nodes) { output_->SetInputs(nodes); }
protected:
std::string name_;
NodePtrList ops_; // save all operators in topo order
@@ -53,7 +51,10 @@ class LiteGraph {
NodePtr output_;
private:
int name_id_{0};
std::string ParamName() const { return "input_" + std::to_string(param_id_++); }
std::string NodeName() const { return "output_" + std::to_string(node_id_++); }
mutable int param_id_{0};
mutable int node_id_{0};
};
using LiteGraphPtr = std::shared_ptr<LiteGraph>;
@@ -61,27 +62,28 @@ class LiteGraph::GraphBuilder {
public:
explicit GraphBuilder(const std::string &name = "") { graph_ = std::make_shared<LiteGraph>(name); }
NodePtr Parameter(const NodeBase &baseinfo, std::string name = "") {
if (name.empty()) name = NewName();
auto para = std::make_shared<ParamNode>(name, baseinfo);
// Create a parameter of graph
NodePtr Parameter(const NodeBase &baseinfo) {
auto para = std::make_shared<ParamNode>(baseinfo);
para->SetDebugName(graph_->ParamName());
graph_->inputs_.push_back(para);
return para;
}
NodePtr Value(const tensor::TensorPtr &data, const std::string &name = "") {
return std::make_shared<ConstTensorNode>(data, name);
}
// Create a const value node
NodePtr Value(const tensor::TensorPtr &data) { return std::make_shared<ConstTensorNode>(data); }
void SetOutputs(const NodePtrList &nodes) { graph_->output_->SetInputs(nodes); }
NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}, std::string node_name = "");
NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, const DAttrs &attrs = {},
std::string node_name = "");
// Emit op, auto inferring the baseinfo of Node.
NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {});
// Create op node with given baseinfo.
NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, const DAttrs &attrs = {});
LiteGraphPtr Get() { return graph_; }
private:
PrimOpPtr CreateOp(const std::string &id, const std::string &name);
std::string NewName(std::string prefix = "output_") { return prefix + std::to_string(graph_->name_id_++); }
PrimOpPtr CreateOp(const std::string &op);
LiteGraphPtr graph_;
};
} // namespace mindspore::graphkernel::inner


+ 24
- 20
mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.cc View File

@@ -14,30 +14,25 @@
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/model/node.h"
#include <memory>
#include <algorithm>
#include <functional>
#include <sstream>
#include <vector>
#include <iostream>
#include <string>
#include "utils/hash_map.h"
#include "ir/dtype/type_id.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "utils/shape_utils.h"
#include "utils/utils.h"
#include <utility>
namespace mindspore::graphkernel::inner {
void Node::DumpTensor(std::ostringstream &os) const {
os << name_ << "[";
void Node::SetBaseInfo(NodeBase baseinfo) {
this->shape = std::move(baseinfo.shape);
this->type = std::move(baseinfo.type);
this->format = std::move(baseinfo.format);
}
std::string Node::ToString() const {
std::ostringstream oss;
oss << debug_name() << "[";
for (size_t i = 0; i < shape.size(); i++) {
os << shape[i];
if (i + 1 < shape.size()) os << ",";
oss << shape[i];
if (i + 1 < shape.size()) oss << ",";
}
os << "]{" << TypeIdToString(type) << "x" << format << "}";
oss << "]{" << TypeIdToString(type) << "x" << format << "}";
return oss.str();
}
void Node::AddInput(const NodePtr &new_input) {
@@ -73,7 +68,7 @@ void Node::SetInputs(const NodePtrList &inputs) {
void Node::ReplaceWith(const NodePtr &other_node) {
if (this->users_.empty()) return;
// copy the users before traversal
// the users_ will be changed, so we copy the users before traversal
auto users = this->users_;
for (auto &user : users) {
for (auto idx : user.second) {
@@ -81,4 +76,13 @@ void Node::ReplaceWith(const NodePtr &other_node) {
}
}
}
void Node::RemoveUser(Node *user, size_t index) {
if (auto iter = users_.find(user); iter != users_.end()) {
iter->second.erase(index);
if (iter->second.empty()) {
users_.erase(iter);
}
}
}
} // namespace mindspore::graphkernel::inner

+ 19
- 48
mindspore/ccsrc/backend/optimizer/graph_kernel/model/node.h View File

@@ -17,21 +17,15 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
#include <memory>
#include <algorithm>
#include <functional>
#include <sstream>
#include <vector>
#include <set>
#include <iostream>
#include <utility>
#include <string>
#include <stdexcept>
#include "ir/dtype/type_id.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "utils/hash_map.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "mindspore/core/ir/value.h"
#include "mindspore/core/ir/tensor.h"
#include "mindspore/core/utils/shape_utils.h"
#include "utils/shape_utils.h"
#include "utils/utils.h"
namespace mindspore::graphkernel::inner {
@@ -58,74 +52,52 @@ using NodePtr = std::shared_ptr<Node>;
using NodePtrList = std::vector<NodePtr>;
class Node : public NodeBase, public std::enable_shared_from_this<Node> {
public:
Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {}
virtual ~Node() {
// remove this node from the previous nodes' user.
SetInputs({});
}
explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {}
virtual ~Node() { SetInputs({}); } // remove this node from the previous nodes' user.
void SetBaseInfo(NodeBase baseinfo) {
this->shape = std::move(baseinfo.shape);
this->type = std::move(baseinfo.type);
this->format = std::move(baseinfo.format);
}
virtual NType NodeType() { return NType::Base; }
friend std::ostream &operator<<(std::ostream &output, const Node &n) {
std::ostringstream os;
n.Dump(os);
output << os.str();
return output;
}
virtual void Dump(std::ostringstream &os) const = 0;
virtual void DumpTensor(std::ostringstream &os) const;
virtual std::string ToString() const;
void SetBaseInfo(NodeBase baseinfo);
void AddInput(const NodePtr &new_input);
void SetInput(size_t i, const NodePtr &new_input);
void SetInputs(const NodePtrList &inputs);
void ReplaceWith(const NodePtr &other_node);
void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; }
void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; }
template <typename T>
std::shared_ptr<T> As() {
return std::static_pointer_cast<T>(shared_from_this());
}
const std::string &name() const { return name_; }
const std::string &debug_name() const { return debug_name_; }
const DAttrs &attrs() const { return attrs_; }
const NodePtr &input(size_t i) const { return inputs_[i]; }
const NodePtrList &inputs() const { return inputs_; }
const mindspore::HashMap<Node *, std::set<size_t>> &users() const { return users_; }
protected:
std::string name_;
mutable std::string debug_name_; // only used in Dump function
DAttrs attrs_;
NodePtrList inputs_;
mindspore::HashMap<Node *, std::set<size_t>> users_;
mindspore::HashMap<Node *, std::set<size_t>> users_; // {user_node: {input edge index set}}
private:
// the nodes' users are only maintained by AddInput/SetInput.
void AddUser(Node *user, size_t index) { users_[user].insert(index); }
void RemoveUser(Node *user, size_t index) {
if (auto iter = users_.find(user); iter != users_.end()) {
iter->second.erase(index);
if (iter->second.empty()) {
users_.erase(iter);
}
}
}
void RemoveUser(Node *user, size_t index);
};
class ConstTensorNode : public Node {
public:
explicit ConstTensorNode(const tensor::TensorPtr &data, const std::string &name = "")
: Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}, name), data_(data) {}
explicit ConstTensorNode(const tensor::TensorPtr &data)
: Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}), data_(data) {}
~ConstTensorNode() = default;
NType NodeType() override { return NType::Value; }
void Dump(std::ostringstream &os) const override { os << ToString(); }
void DumpTensor(std::ostringstream &os) const override { os << ToString(); }
std::string ToString() const { return data_->data().ToString(this->type, this->shape, false); }
std::string ToString() const override { return data_->data().ToString(this->type, this->shape, false); }
const tensor::TensorPtr data() const { return data_; }
protected:
@@ -134,19 +106,18 @@ class ConstTensorNode : public Node {
class ParamNode : public Node {
public:
ParamNode(const std::string &name, const NodeBase &baseinfo) : Node(baseinfo, name) {}
explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {}
~ParamNode() = default;
void Dump(std::ostringstream &os) const override { DumpTensor(os); }
NType NodeType() override { return NType::Parameter; }
};
// the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph.
class OutputNode : public Node {
public:
OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "Output") {}
OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) {}
~OutputNode() = default;
void Dump(std::ostringstream &os) const override { ; }
NType NodeType() override { return NType::Output; }
};
} // namespace mindspore::graphkernel::inner


+ 30
- 21
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc View File

@@ -16,16 +16,14 @@
#include "backend/optimizer/graph_kernel/model/op_node.h"
#include <math.h>
#include <sstream>
#include <memory>
#include <algorithm>
#include <set>
#include <string>
#include <sstream>
#include <vector>
#include <functional>
#include <numeric>
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "backend/optimizer/graph_kernel/model/node.h"
namespace mindspore::graphkernel::inner {
@@ -89,30 +87,36 @@ NodeBase PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
return nodebase;
}
void PrimOp::Dump(std::ostringstream &os) const {
DumpTensor(os);
os << " = " << this->op_ << "(";
std::string PrimOp::ToString() const {
std::ostringstream oss;
oss << Node::ToString();
oss << " = " << this->op_ << "(";
for (size_t i = 0; i < inputs_.size(); i++) {
inputs_[i]->DumpTensor(os);
if (i != inputs_.size() - 1) os << ", ";
if (inputs_[i]->NodeType() == NType::Primitive) {
oss << inputs_[i]->Node::ToString();
} else {
oss << inputs_[i]->ToString();
}
if (i != inputs_.size() - 1) oss << ", ";
}
os << ")";
std::ostringstream attr_os;
oss << ")";
std::ostringstream attr_oss;
bool has_attr = false;
std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"};
for (auto attr : attrs_) {
if (attr.second != nullptr && black_list.count(attr.first) == 0) {
if (has_attr) {
attr_os << ", ";
attr_oss << ", ";
} else {
has_attr = true;
}
attr_os << attr.first << ": " << attr.second->ToString();
attr_oss << attr.first << ": " << attr.second->ToString();
}
}
if (has_attr) {
os << " // attr {" << attr_os.str() << "}";
oss << " // attr {" << attr_oss.str() << "}";
}
return oss.str();
}
template <typename TM, typename TD>
@@ -293,6 +297,7 @@ DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs)
NodeBase ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
auto nodebase = PrimOp::Infer(inputs, attrs);
// change the compute_type to BROADCAST if the result shape is greater than the input shapes.
auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
for (auto &ref : inputs) {
if (ref->shape.size() != this->shape.size()) return true;
@@ -393,17 +398,21 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
};
auto shape0 = inputs[0]->shape;
auto shape1 = inputs[1]->shape;
check_nd(shape0, 4);
check_nd(shape1, 4);
constexpr auto dim_len = 4;
check_nd(shape0, dim_len);
check_nd(shape1, dim_len);
CHECK_ATTR(attrs, "format");
if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC &&
GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) {
MS_LOG(EXCEPTION) << "check NHWC format failed";
}
auto n = shape0[0];
auto h = shape0[1];
auto w = shape0[2];
auto out_channel = shape1[0];
constexpr auto axis_n = 0;
constexpr auto axis_h = 1;
constexpr auto axis_w = 2;
auto n = shape0[axis_n];
auto h = shape0[axis_h];
auto w = shape0[axis_w];
auto out_channel = shape1[axis_n];
CHECK_ATTR(attrs, "pad_list");
CHECK_ATTR(attrs, "pad_mode");
CHECK_ATTR(attrs, "kernel_size");
@@ -414,7 +423,6 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
auto stride = GetListInt(attrs.find("stride")->second);
auto dilation = GetListInt(attrs.find("dilation")->second);
constexpr auto dim_len = 4;
check_nd(pad_list, dim_len);
constexpr auto kernel_len = 2;
check_nd(kernel_size, kernel_len);
@@ -464,6 +472,7 @@ DShape TransposeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
}
DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) {
// only support NCHW/NHWC now
if (inputs[0]->shape.size() != 4) return kOpFormat_DEFAULT;
CHECK_ATTR(attrs, "perm");
auto perm = GetListInt(attrs.find("perm")->second);


+ 56
- 90
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h View File

@@ -17,12 +17,8 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_
#include <memory>
#include <algorithm>
#include <sstream>
#include <string>
#include <functional>
#include "utils/hash_map.h"
#include "backend/optimizer/graph_kernel/model/node.h"
#include "ir/dtype/type.h"
@@ -44,14 +40,14 @@ class PrimOp : public Node {
OPAQUE,
};
PrimOp(const std::string &op, const std::string &node_name, ComputeType compute)
: Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {}
PrimOp(const std::string &op, ComputeType compute)
: Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}), op_(op), compute_type_(compute) {}
~PrimOp() = default;
virtual NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs);
virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
void Dump(std::ostringstream &os) const override;
std::string ToString() const override;
NType NodeType() override { return NType::Primitive; }
const std::string &op() const { return op_; }
@@ -72,9 +68,22 @@ class PrimOp : public Node {
};
using PrimOpPtr = std::shared_ptr<PrimOp>;
class ReshapeOp : public PrimOp {
public:
explicit ReshapeOp(const std::string &op) : PrimOp(op, RESHAPE) {}
~ReshapeOp() = default;
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override {
return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
: GetValue<std::string>(attrs.find("format")->second);
}
};
class ElemwiseOp : public PrimOp {
public:
ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {}
explicit ElemwiseOp(const std::string &op) : PrimOp(op, ELEMWISE) {}
~ElemwiseOp() = default;
NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
@@ -84,9 +93,35 @@ class ElemwiseOp : public PrimOp {
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class BroadcastToOp : public PrimOp {
public:
explicit BroadcastToOp(const std::string &op) : PrimOp(op, BROADCAST) {}
~BroadcastToOp() = default;
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class ReduceOp : public PrimOp {
public:
explicit ReduceOp(const std::string &op) : PrimOp(op, REDUCE) {}
~ReduceOp() = default;
protected:
void Check(const NodePtrList &inputs, const DAttrs &attrs) override;
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; };
};
class OpaqueOp : public PrimOp {
public:
explicit OpaqueOp(const std::string &op) : PrimOp(op, OPAQUE) {}
~OpaqueOp() = default;
};
class CastOp : public ElemwiseOp {
public:
CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {}
explicit CastOp(const std::string &op) : ElemwiseOp("Cast") {}
~CastOp() = default;
protected:
@@ -95,7 +130,7 @@ class CastOp : public ElemwiseOp {
class InplaceAssignOp : public ElemwiseOp {
public:
InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {}
explicit InplaceAssignOp(const std::string &op) : ElemwiseOp("InplaceAssign") {}
~InplaceAssignOp() = default;
protected:
@@ -106,7 +141,7 @@ class InplaceAssignOp : public ElemwiseOp {
class SelectOp : public ElemwiseOp {
public:
SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {}
explicit SelectOp(const std::string &op) : ElemwiseOp("Select") {}
~SelectOp() = default;
protected:
@@ -116,85 +151,16 @@ class SelectOp : public ElemwiseOp {
class CompareOp : public ElemwiseOp {
public:
CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {}
explicit CompareOp(const std::string &op) : ElemwiseOp(op) {}
~CompareOp() = default;
protected:
TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; }
};
class LessOp : public CompareOp {
public:
LessOp(const std::string &op, const std::string &node_name) : CompareOp("Less", node_name) {}
~LessOp() = default;
};
class EqualOp : public CompareOp {
public:
EqualOp(const std::string &op, const std::string &node_name) : CompareOp("Equal", node_name) {}
~EqualOp() = default;
};
class LessEqualOp : public CompareOp {
public:
LessEqualOp(const std::string &op, const std::string &node_name) : CompareOp("LessEqual", node_name) {}
~LessEqualOp() = default;
};
class GreaterOp : public CompareOp {
public:
GreaterOp(const std::string &op, const std::string &node_name) : CompareOp("Greater", node_name) {}
~GreaterOp() = default;
};
class GreaterEqualOp : public CompareOp {
public:
GreaterEqualOp(const std::string &op, const std::string &node_name) : CompareOp("GreaterEqual", node_name) {}
~GreaterEqualOp() = default;
};
class ReshapeOp : public PrimOp {
public:
ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {}
~ReshapeOp() = default;
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override {
return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
: GetValue<std::string>(attrs.find("format")->second);
}
};
class BroadcastToOp : public PrimOp {
public:
BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {}
~BroadcastToOp() = default;
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class ReduceOp : public PrimOp {
public:
ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}
~ReduceOp() = default;
protected:
void Check(const NodePtrList &inputs, const DAttrs &attrs) override;
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; };
};
class OpaqueOp : public PrimOp {
public:
OpaqueOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, OPAQUE) {}
~OpaqueOp() = default;
};
class Conv2dOp : public OpaqueOp {
public:
Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {}
explicit Conv2dOp(const std::string &op) : OpaqueOp("Conv2D") {}
~Conv2dOp() = default;
protected:
@@ -204,7 +170,7 @@ class Conv2dOp : public OpaqueOp {
class TransposeOp : public OpaqueOp {
public:
TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {}
explicit TransposeOp(const std::string &op) : OpaqueOp("Transpose") {}
~TransposeOp() = default;
protected:
@@ -214,7 +180,7 @@ class TransposeOp : public OpaqueOp {
class MatMulOp : public OpaqueOp {
public:
MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {}
explicit MatMulOp(const std::string &op) : OpaqueOp("MatMul") {}
~MatMulOp() = default;
protected:
@@ -224,7 +190,7 @@ class MatMulOp : public OpaqueOp {
class PadAkgOp : public OpaqueOp {
public:
PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {}
explicit PadAkgOp(const std::string &op) : OpaqueOp("PadAkg") {}
~PadAkgOp() = default;
protected:
@@ -233,7 +199,7 @@ class PadAkgOp : public OpaqueOp {
class UnPadAkgOp : public OpaqueOp {
public:
UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {}
explicit UnPadAkgOp(const std::string &op) : OpaqueOp("UnPadAkg") {}
~UnPadAkgOp() = default;
protected:
@@ -242,7 +208,7 @@ class UnPadAkgOp : public OpaqueOp {
class CImagOp : public ElemwiseOp {
public:
CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {}
explicit CImagOp(const std::string &op) : ElemwiseOp("CImag") {}
~CImagOp() = default;
protected:
@@ -257,7 +223,7 @@ class CImagOp : public ElemwiseOp {
class CRealOp : public ElemwiseOp {
public:
CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {}
explicit CRealOp(const std::string &op) : ElemwiseOp("CReal") {}
~CRealOp() = default;
protected:
@@ -272,7 +238,7 @@ class CRealOp : public ElemwiseOp {
class ComplexOp : public ElemwiseOp {
public:
ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {}
explicit ComplexOp(const std::string &op) : ElemwiseOp("Complex") {}
~ComplexOp() = default;
protected:
@@ -282,7 +248,7 @@ class ComplexOp : public ElemwiseOp {
class StandardNormalOp : public OpaqueOp {
public:
StandardNormalOp(const std::string &op, const std::string &node_name) : OpaqueOp("StandardNormal", node_name) {}
explicit StandardNormalOp(const std::string &op) : OpaqueOp("StandardNormal") {}
~StandardNormalOp() = default;
protected:


+ 70
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_register.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/model/op_register.h"
#include <memory>
namespace mindspore::graphkernel::inner {
namespace {
class OpRegister {
public:
OpRegister(const std::string &name, const CreatorFunc &func) { OpRegistry::Instance().Register(name, func); }
~OpRegister() = default;
};
#define JOIN(x, y) x##y
#define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt)
#define OP_REGISTER(name, cls) \
static_assert(std::is_base_of<PrimOp, cls>::value, " should be base of PrimOp"); \
static const OpRegister UNIQUE_NAME(g_graphkernel_op, __COUNTER__)( \
name, [](const std::string &op) -> PrimOpPtr { return std::make_shared<cls>(op); })
} // namespace
// All nodes supported by GraphKernel are listed below.
OP_REGISTER("_opaque", OpaqueOp);
OP_REGISTER("Add", ElemwiseOp);
OP_REGISTER("Sub", ElemwiseOp);
OP_REGISTER("RealDiv", ElemwiseOp);
OP_REGISTER("Mul", ElemwiseOp);
OP_REGISTER("Log", ElemwiseOp);
OP_REGISTER("Exp", ElemwiseOp);
OP_REGISTER("Pow", ElemwiseOp);
OP_REGISTER("Sqrt", ElemwiseOp);
OP_REGISTER("Rsqrt", ElemwiseOp);
OP_REGISTER("Neg", ElemwiseOp);
OP_REGISTER("Reciprocal", ElemwiseOp);
OP_REGISTER("Abs", ElemwiseOp);
OP_REGISTER("BroadcastTo", BroadcastToOp);
OP_REGISTER("Reshape", ReshapeOp);
OP_REGISTER("ReduceSum", ReduceOp);
OP_REGISTER("ReduceMax", ReduceOp);
OP_REGISTER("ReduceMin", ReduceOp);
OP_REGISTER("Cast", CastOp);
OP_REGISTER("InplaceAssign", InplaceAssignOp);
OP_REGISTER("Select", SelectOp);
OP_REGISTER("Less", CompareOp);
OP_REGISTER("Equal", CompareOp);
OP_REGISTER("LessEqual", CompareOp);
OP_REGISTER("GreaterEqual", CompareOp);
OP_REGISTER("Greater", CompareOp);
OP_REGISTER("Transpose", TransposeOp);
OP_REGISTER("MatMul", MatMulOp);
OP_REGISTER("PadAkg", PadAkgOp);
OP_REGISTER("UnPadAkg", UnPadAkgOp);
OP_REGISTER("CReal", CRealOp);
OP_REGISTER("CImag", CImagOp);
OP_REGISTER("Complex", ComplexOp);
OP_REGISTER("StandardNormal", StandardNormalOp);
} // namespace mindspore::graphkernel::inner

+ 13
- 48
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_register.h View File

@@ -18,69 +18,34 @@
#include <functional>
#include <string>
#include <memory>
#include "utils/hash_map.h"
#include "backend/optimizer/graph_kernel/model/node.h"
#include "backend/optimizer/graph_kernel/model/op_node.h"
namespace mindspore::graphkernel::inner {
#define OP_CREATOR(cls) \
[](const std::string &op, const std::string &name) -> PrimOpPtr { return std::make_shared<cls>(op, name); }
using CreatorFunc = std::function<PrimOpPtr(const std::string &)>;
class OpRegistry {
public:
static OpRegistry &Instance() {
static OpRegistry instance{};
return instance;
}
void Register(const std::string &op_name,
const std::function<PrimOpPtr(const std::string &, const std::string &)> &func) {
creators.insert({op_name, func});
}
void Register(const std::string &op_name, const CreatorFunc &func) { creators.insert({op_name, func}); }
PrimOpPtr NewOp(const std::string &op, const std::string &name) {
return creators.find(op) == creators.end() ? creators["Opaque"](op, name) : creators[op](op, name);
PrimOpPtr NewOp(const std::string &op) {
// "OpaqueOp" is registered by default.
return creators.find(op) == creators.end() ? creators["_opaque"](op) : creators[op](op);
}
private:
OpRegistry() {
Register("Add", OP_CREATOR(ElemwiseOp));
Register("Sub", OP_CREATOR(ElemwiseOp));
Register("RealDiv", OP_CREATOR(ElemwiseOp));
Register("Mul", OP_CREATOR(ElemwiseOp));
Register("Log", OP_CREATOR(ElemwiseOp));
Register("Exp", OP_CREATOR(ElemwiseOp));
Register("Pow", OP_CREATOR(ElemwiseOp));
Register("Sqrt", OP_CREATOR(ElemwiseOp));
Register("Rsqrt", OP_CREATOR(ElemwiseOp));
Register("Neg", OP_CREATOR(ElemwiseOp));
Register("Reciprocal", OP_CREATOR(ElemwiseOp));
Register("Abs", OP_CREATOR(ElemwiseOp));
Register("BroadcastTo", OP_CREATOR(BroadcastToOp));
Register("Reshape", OP_CREATOR(ReshapeOp));
Register("ReduceSum", OP_CREATOR(ReduceOp));
Register("ReduceMax", OP_CREATOR(ReduceOp));
Register("ReduceMin", OP_CREATOR(ReduceOp));
Register("Cast", OP_CREATOR(CastOp));
Register("InplaceAssign", OP_CREATOR(InplaceAssignOp));
Register("Select", OP_CREATOR(SelectOp));
Register("Less", OP_CREATOR(LessOp));
Register("Equal", OP_CREATOR(EqualOp));
Register("LessEqual", OP_CREATOR(LessEqualOp));
Register("GreaterEqual", OP_CREATOR(GreaterEqualOp));
Register("Greater", OP_CREATOR(GreaterOp));
Register("Transpose", OP_CREATOR(TransposeOp));
Register("MatMul", OP_CREATOR(MatMulOp));
Register("PadAkg", OP_CREATOR(PadAkgOp));
Register("UnPadAkg", OP_CREATOR(UnPadAkgOp));
Register("CReal", OP_CREATOR(CRealOp));
Register("CImag", OP_CREATOR(CImagOp));
Register("Complex", OP_CREATOR(ComplexOp));
Register("Opaque", OP_CREATOR(OpaqueOp));
Register("StandardNormal", OP_CREATOR(StandardNormalOp));
}
OpRegistry() = default;
~OpRegistry() = default;
mindspore::HashMap<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
OpRegistry(const OpRegistry &) = delete;
OpRegistry(const OpRegistry &&) = delete;
OpRegistry &operator=(const OpRegistry &) = delete;
mindspore::HashMap<std::string, CreatorFunc> creators;
};
} // namespace mindspore::graphkernel::inner
#endif

+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/transform_op_optimizer.cc View File

@@ -232,7 +232,7 @@ class TransformOp {
if (perm.empty()) {
MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_;
}
auto op = inner::OpRegistry::Instance().NewOp("Transpose", "new_trans");
auto op = inner::OpRegistry::Instance().NewOp("Transpose");
op->SetAttr("perm", MakeValue(perm));
return op;
}


Loading…
Cancel
Save