remove the "node_name" argument from class Node's constructor, rename the node's "name_" to "debug_name_", and add the SetDebugName interface. refactor the `OpRegistry`: register the nodes by macro `OP_REGISTER`, instead of write them in OpRegistry's constructor. remove the five CompareOp's subclass, which did not override any virtual functions. they can be binded to the `CompareOp` class.tags/v1.6.0
| @@ -619,11 +619,11 @@ void ReorganizeEmptyGraph(const inner::LiteGraphPtr &litegraph) { | |||
| inner::LiteGraph::GraphBuilder gb; | |||
| std::vector<int64_t> new_shape = {1}; | |||
| auto op_ptr = gb.Emit("BroadcastTo", {outputs[i]}, {{"shape", MakeValue(new_shape)}}); | |||
| litegraph->output()->SetInput(i, op_ptr); | |||
| litegraph->SetOutput(i, op_ptr); | |||
| } else if (outputs[i]->NodeType() == inner::NType::Parameter) { | |||
| inner::LiteGraph::GraphBuilder gb; | |||
| auto op_ptr = gb.Emit("Reshape", {outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}}); | |||
| litegraph->output()->SetInput(i, op_ptr); | |||
| litegraph->SetOutput(i, op_ptr); | |||
| } | |||
| } | |||
| return; | |||
| @@ -156,7 +156,7 @@ FuncGraphPtr GkUtils::LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph) | |||
| // Create CNodes. | |||
| for (const auto &op_node : lite_graph->GetOrderedNodes()) { | |||
| if (op_node->NodeType() != inner::NType::Primitive) { | |||
| MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node"; | |||
| MS_LOG(EXCEPTION) << "Node " << op_node->debug_name() << "should be a Primitive node"; | |||
| } | |||
| auto op = std::static_pointer_cast<inner::PrimOp>(op_node); | |||
| AnfNodePtrList inputs = {NewValueNode(std::make_shared<Primitive>(op->op(), op->attrs()))}; | |||
| @@ -168,7 +168,7 @@ FuncGraphPtr GkUtils::LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph) | |||
| return iter->second; | |||
| } else { | |||
| if (inp->NodeType() != inner::NType::Value) { | |||
| MS_LOG(EXCEPTION) << "Node " << inp->name() << "should be a Value node"; | |||
| MS_LOG(EXCEPTION) << "Node " << inp->debug_name() << "should be a Value node"; | |||
| } | |||
| auto inp_value = inp->As<inner::ConstTensorNode>()->data(); | |||
| auto value_node = NewValueNode(inp_value); | |||
| @@ -458,8 +458,8 @@ inner::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph) { | |||
| return inner::NodeBase({shape, type, format}); | |||
| }; | |||
| // set inputs | |||
| for (size_t i = 0; i < params.size(); i++) { | |||
| node_map[params[i]] = gb.Parameter(ExtractBuildInfo(params[i]), std::string("input_") + std::to_string(i)); | |||
| for (auto &p : params) { | |||
| node_map[p] = gb.Parameter(ExtractBuildInfo(p)); | |||
| } | |||
| // set ops | |||
| for (auto node : todos) { | |||
| @@ -18,11 +18,10 @@ | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <map> | |||
| #include <set> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <iostream> | |||
| #include <sstream> | |||
| #include "utils/hash_map.h" | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| @@ -30,22 +29,31 @@ | |||
| #include "backend/optimizer/graph_kernel/model/op_register.h" | |||
| namespace mindspore::graphkernel::inner { | |||
| std::string LiteGraph::Dump() const { | |||
| std::string LiteGraph::ToString(bool reset_node_name) const { | |||
| if (reset_node_name) { | |||
| param_id_ = node_id_ = 0; | |||
| for (auto &inp : inputs_) { | |||
| inp->SetDebugName(ParamName()); | |||
| } | |||
| for (auto &node : ops_) { | |||
| node->SetDebugName(NodeName()); | |||
| } | |||
| } | |||
| std::ostringstream os; | |||
| os << name_ << "("; | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| os << inputs_[i]->name(); | |||
| os << inputs_[i]->debug_name(); | |||
| if (i != inputs_.size() - 1) os << ", "; | |||
| } | |||
| os << ") -> "; | |||
| auto &outputs = GetOutputs(); | |||
| for (size_t i = 0; i < outputs.size(); i++) { | |||
| os << outputs[i]->name(); | |||
| os << outputs[i]->debug_name(); | |||
| if (i != outputs.size() - 1) os << ", "; | |||
| } | |||
| os << " {\n"; | |||
| for (NodePtr op : ops_) { | |||
| os << " " << *op << "\n"; | |||
| for (const NodePtr &op : ops_) { | |||
| os << " " << op->ToString() << "\n"; | |||
| } | |||
| os << "}"; | |||
| return os.str(); | |||
| @@ -55,6 +63,7 @@ const NodePtrList &LiteGraph::GetOrderedNodes() { | |||
| mindspore::HashMap<NodePtr, size_t> outdegrees; | |||
| std::function<void(NodePtr)> dfs; | |||
| std::set<NodePtr> visited; | |||
| // record the out degree of each nodes by Dfs. | |||
| dfs = [&dfs, &outdegrees, &visited](const NodePtr &node) { | |||
| (void)visited.insert(node); | |||
| for (auto &input : node->inputs()) { | |||
| @@ -69,6 +78,8 @@ const NodePtrList &LiteGraph::GetOrderedNodes() { | |||
| dfs(output_); | |||
| NodePtrList res; | |||
| NodePtrList stack; | |||
| // toposort algorithm with out degree | |||
| stack.push_back(output_); | |||
| while (!stack.empty()) { | |||
| auto cur = stack.back(); | |||
| @@ -86,20 +97,19 @@ const NodePtrList &LiteGraph::GetOrderedNodes() { | |||
| if (!outdegrees.empty()) { | |||
| MS_LOG(ERROR) << "Circle was found:"; | |||
| for (auto &node : outdegrees) { | |||
| MS_LOG(ERROR) << " " << *(node.first); | |||
| MS_LOG(ERROR) << " " << node.first->debug_name(); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Circle size: " << outdegrees.size(); | |||
| } | |||
| std::reverse(res.begin(), res.end()); | |||
| res.pop_back(); // erase the output node | |||
| // remove the "OutputNode" | |||
| res.pop_back(); | |||
| ops_ = std::move(res); | |||
| return ops_; | |||
| } | |||
| NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs, | |||
| std::string node_name) { | |||
| if (node_name.empty()) node_name = NewName(); | |||
| PrimOpPtr op_ptr = CreateOp(op, node_name); | |||
| NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs) { | |||
| PrimOpPtr op_ptr = CreateOp(op); | |||
| auto baseinfo = op_ptr->Infer(inputs, attrs); | |||
| op_ptr->SetInputs(inputs); | |||
| op_ptr->SetAttrs(attrs); | |||
| @@ -108,16 +118,17 @@ NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList & | |||
| } | |||
| NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, | |||
| const DAttrs &attrs, std::string node_name) { | |||
| if (node_name.empty()) node_name = NewName(); | |||
| PrimOpPtr op_ptr = CreateOp(op, node_name); | |||
| const DAttrs &attrs) { | |||
| PrimOpPtr op_ptr = CreateOp(op); | |||
| op_ptr->SetInputs(inputs); | |||
| op_ptr->SetAttrs(attrs); | |||
| op_ptr->SetBaseInfo(baseinfo); | |||
| return graph_->Add(op_ptr); | |||
| } | |||
| PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) { | |||
| return OpRegistry::Instance().NewOp(op, node_name); | |||
| PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op) { | |||
| auto node = OpRegistry::Instance().NewOp(op); | |||
| node->SetDebugName(graph_->NodeName()); | |||
| return node; | |||
| } | |||
| } // namespace mindspore::graphkernel::inner | |||
| @@ -17,12 +17,7 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <list> | |||
| #include <stack> | |||
| #include <string> | |||
| #include "utils/hash_map.h" | |||
| #include "utils/hash_set.h" | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| #include "backend/optimizer/graph_kernel/model/op_node.h" | |||
| @@ -39,13 +34,16 @@ class LiteGraph { | |||
| const NodePtrList &GetOrderedNodes(); | |||
| std::string Dump() const; | |||
| std::string ToString(bool reset_node_name = false) const; | |||
| const std::string &name() const { return name_; } | |||
| const NodePtrList &ops() const { return ops_; } | |||
| const NodePtrList &inputs() const { return inputs_; } | |||
| const NodePtr &output() const { return output_; } | |||
| const NodePtr &output(size_t i) const { return output_->input(i); } | |||
| const NodePtrList &GetOutputs() const { return output_->inputs(); } | |||
| void SetOutput(size_t i, const NodePtr &node) { output_->SetInput(i, node); } | |||
| void SetOutputs(const NodePtrList &nodes) { output_->SetInputs(nodes); } | |||
| protected: | |||
| std::string name_; | |||
| NodePtrList ops_; // save all operators in topo order | |||
| @@ -53,7 +51,10 @@ class LiteGraph { | |||
| NodePtr output_; | |||
| private: | |||
| int name_id_{0}; | |||
| std::string ParamName() const { return "input_" + std::to_string(param_id_++); } | |||
| std::string NodeName() const { return "output_" + std::to_string(node_id_++); } | |||
| mutable int param_id_{0}; | |||
| mutable int node_id_{0}; | |||
| }; | |||
| using LiteGraphPtr = std::shared_ptr<LiteGraph>; | |||
| @@ -61,27 +62,28 @@ class LiteGraph::GraphBuilder { | |||
| public: | |||
| explicit GraphBuilder(const std::string &name = "") { graph_ = std::make_shared<LiteGraph>(name); } | |||
| NodePtr Parameter(const NodeBase &baseinfo, std::string name = "") { | |||
| if (name.empty()) name = NewName(); | |||
| auto para = std::make_shared<ParamNode>(name, baseinfo); | |||
| // Create a parameter of graph | |||
| NodePtr Parameter(const NodeBase &baseinfo) { | |||
| auto para = std::make_shared<ParamNode>(baseinfo); | |||
| para->SetDebugName(graph_->ParamName()); | |||
| graph_->inputs_.push_back(para); | |||
| return para; | |||
| } | |||
| NodePtr Value(const tensor::TensorPtr &data, const std::string &name = "") { | |||
| return std::make_shared<ConstTensorNode>(data, name); | |||
| } | |||
| // Create a const value node | |||
| NodePtr Value(const tensor::TensorPtr &data) { return std::make_shared<ConstTensorNode>(data); } | |||
| void SetOutputs(const NodePtrList &nodes) { graph_->output_->SetInputs(nodes); } | |||
| NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}, std::string node_name = ""); | |||
| NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, const DAttrs &attrs = {}, | |||
| std::string node_name = ""); | |||
| // Emit op, auto inferring the baseinfo of Node. | |||
| NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}); | |||
| // Create op node with given baseinfo. | |||
| NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, const DAttrs &attrs = {}); | |||
| LiteGraphPtr Get() { return graph_; } | |||
| private: | |||
| PrimOpPtr CreateOp(const std::string &id, const std::string &name); | |||
| std::string NewName(std::string prefix = "output_") { return prefix + std::to_string(graph_->name_id_++); } | |||
| PrimOpPtr CreateOp(const std::string &op); | |||
| LiteGraphPtr graph_; | |||
| }; | |||
| } // namespace mindspore::graphkernel::inner | |||
| @@ -14,30 +14,25 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <sstream> | |||
| #include <vector> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include "utils/hash_map.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "ir/value.h" | |||
| #include "ir/tensor.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/utils.h" | |||
| #include <utility> | |||
| namespace mindspore::graphkernel::inner { | |||
| void Node::DumpTensor(std::ostringstream &os) const { | |||
| os << name_ << "["; | |||
| void Node::SetBaseInfo(NodeBase baseinfo) { | |||
| this->shape = std::move(baseinfo.shape); | |||
| this->type = std::move(baseinfo.type); | |||
| this->format = std::move(baseinfo.format); | |||
| } | |||
| std::string Node::ToString() const { | |||
| std::ostringstream oss; | |||
| oss << debug_name() << "["; | |||
| for (size_t i = 0; i < shape.size(); i++) { | |||
| os << shape[i]; | |||
| if (i + 1 < shape.size()) os << ","; | |||
| oss << shape[i]; | |||
| if (i + 1 < shape.size()) oss << ","; | |||
| } | |||
| os << "]{" << TypeIdToString(type) << "x" << format << "}"; | |||
| oss << "]{" << TypeIdToString(type) << "x" << format << "}"; | |||
| return oss.str(); | |||
| } | |||
| void Node::AddInput(const NodePtr &new_input) { | |||
| @@ -73,7 +68,7 @@ void Node::SetInputs(const NodePtrList &inputs) { | |||
| void Node::ReplaceWith(const NodePtr &other_node) { | |||
| if (this->users_.empty()) return; | |||
| // copy the users before traversal | |||
| // the users_ will be changed, so we copy the users before traversal | |||
| auto users = this->users_; | |||
| for (auto &user : users) { | |||
| for (auto idx : user.second) { | |||
| @@ -81,4 +76,13 @@ void Node::ReplaceWith(const NodePtr &other_node) { | |||
| } | |||
| } | |||
| } | |||
| void Node::RemoveUser(Node *user, size_t index) { | |||
| if (auto iter = users_.find(user); iter != users_.end()) { | |||
| iter->second.erase(index); | |||
| if (iter->second.empty()) { | |||
| users_.erase(iter); | |||
| } | |||
| } | |||
| } | |||
| } // namespace mindspore::graphkernel::inner | |||
| @@ -17,21 +17,15 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <sstream> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <stdexcept> | |||
| #include "ir/dtype/type_id.h" | |||
| #include "ir/value.h" | |||
| #include "ir/tensor.h" | |||
| #include "utils/hash_map.h" | |||
| #include "mindspore/core/ir/dtype/type_id.h" | |||
| #include "mindspore/core/ir/value.h" | |||
| #include "mindspore/core/ir/tensor.h" | |||
| #include "mindspore/core/utils/shape_utils.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore::graphkernel::inner { | |||
| @@ -58,74 +52,52 @@ using NodePtr = std::shared_ptr<Node>; | |||
| using NodePtrList = std::vector<NodePtr>; | |||
| class Node : public NodeBase, public std::enable_shared_from_this<Node> { | |||
| public: | |||
| Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {} | |||
| virtual ~Node() { | |||
| // remove this node from the previous nodes' user. | |||
| SetInputs({}); | |||
| } | |||
| explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {} | |||
| virtual ~Node() { SetInputs({}); } // remove this node from the previous nodes' user. | |||
| void SetBaseInfo(NodeBase baseinfo) { | |||
| this->shape = std::move(baseinfo.shape); | |||
| this->type = std::move(baseinfo.type); | |||
| this->format = std::move(baseinfo.format); | |||
| } | |||
| virtual NType NodeType() { return NType::Base; } | |||
| friend std::ostream &operator<<(std::ostream &output, const Node &n) { | |||
| std::ostringstream os; | |||
| n.Dump(os); | |||
| output << os.str(); | |||
| return output; | |||
| } | |||
| virtual void Dump(std::ostringstream &os) const = 0; | |||
| virtual void DumpTensor(std::ostringstream &os) const; | |||
| virtual std::string ToString() const; | |||
| void SetBaseInfo(NodeBase baseinfo); | |||
| void AddInput(const NodePtr &new_input); | |||
| void SetInput(size_t i, const NodePtr &new_input); | |||
| void SetInputs(const NodePtrList &inputs); | |||
| void ReplaceWith(const NodePtr &other_node); | |||
| void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; } | |||
| void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } | |||
| void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; } | |||
| template <typename T> | |||
| std::shared_ptr<T> As() { | |||
| return std::static_pointer_cast<T>(shared_from_this()); | |||
| } | |||
| const std::string &name() const { return name_; } | |||
| const std::string &debug_name() const { return debug_name_; } | |||
| const DAttrs &attrs() const { return attrs_; } | |||
| const NodePtr &input(size_t i) const { return inputs_[i]; } | |||
| const NodePtrList &inputs() const { return inputs_; } | |||
| const mindspore::HashMap<Node *, std::set<size_t>> &users() const { return users_; } | |||
| protected: | |||
| std::string name_; | |||
| mutable std::string debug_name_; // only used in Dump function | |||
| DAttrs attrs_; | |||
| NodePtrList inputs_; | |||
| mindspore::HashMap<Node *, std::set<size_t>> users_; | |||
| mindspore::HashMap<Node *, std::set<size_t>> users_; // {user_node: {input edge index set}} | |||
| private: | |||
| // the nodes' users are only maintained by AddInput/SetInput. | |||
| void AddUser(Node *user, size_t index) { users_[user].insert(index); } | |||
| void RemoveUser(Node *user, size_t index) { | |||
| if (auto iter = users_.find(user); iter != users_.end()) { | |||
| iter->second.erase(index); | |||
| if (iter->second.empty()) { | |||
| users_.erase(iter); | |||
| } | |||
| } | |||
| } | |||
| void RemoveUser(Node *user, size_t index); | |||
| }; | |||
| class ConstTensorNode : public Node { | |||
| public: | |||
| explicit ConstTensorNode(const tensor::TensorPtr &data, const std::string &name = "") | |||
| : Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}, name), data_(data) {} | |||
| explicit ConstTensorNode(const tensor::TensorPtr &data) | |||
| : Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}), data_(data) {} | |||
| ~ConstTensorNode() = default; | |||
| NType NodeType() override { return NType::Value; } | |||
| void Dump(std::ostringstream &os) const override { os << ToString(); } | |||
| void DumpTensor(std::ostringstream &os) const override { os << ToString(); } | |||
| std::string ToString() const { return data_->data().ToString(this->type, this->shape, false); } | |||
| std::string ToString() const override { return data_->data().ToString(this->type, this->shape, false); } | |||
| const tensor::TensorPtr data() const { return data_; } | |||
| protected: | |||
| @@ -134,19 +106,18 @@ class ConstTensorNode : public Node { | |||
| class ParamNode : public Node { | |||
| public: | |||
| ParamNode(const std::string &name, const NodeBase &baseinfo) : Node(baseinfo, name) {} | |||
| explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {} | |||
| ~ParamNode() = default; | |||
| void Dump(std::ostringstream &os) const override { DumpTensor(os); } | |||
| NType NodeType() override { return NType::Parameter; } | |||
| }; | |||
| // the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph. | |||
| class OutputNode : public Node { | |||
| public: | |||
| OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "Output") {} | |||
| OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) {} | |||
| ~OutputNode() = default; | |||
| void Dump(std::ostringstream &os) const override { ; } | |||
| NType NodeType() override { return NType::Output; } | |||
| }; | |||
| } // namespace mindspore::graphkernel::inner | |||
| @@ -16,16 +16,14 @@ | |||
| #include "backend/optimizer/graph_kernel/model/op_node.h" | |||
| #include <math.h> | |||
| #include <sstream> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include <string> | |||
| #include <sstream> | |||
| #include <vector> | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include "utils/hash_map.h" | |||
| #include "utils/hash_set.h" | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| namespace mindspore::graphkernel::inner { | |||
| @@ -89,30 +87,36 @@ NodeBase PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| return nodebase; | |||
| } | |||
| void PrimOp::Dump(std::ostringstream &os) const { | |||
| DumpTensor(os); | |||
| os << " = " << this->op_ << "("; | |||
| std::string PrimOp::ToString() const { | |||
| std::ostringstream oss; | |||
| oss << Node::ToString(); | |||
| oss << " = " << this->op_ << "("; | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| inputs_[i]->DumpTensor(os); | |||
| if (i != inputs_.size() - 1) os << ", "; | |||
| if (inputs_[i]->NodeType() == NType::Primitive) { | |||
| oss << inputs_[i]->Node::ToString(); | |||
| } else { | |||
| oss << inputs_[i]->ToString(); | |||
| } | |||
| if (i != inputs_.size() - 1) oss << ", "; | |||
| } | |||
| os << ")"; | |||
| std::ostringstream attr_os; | |||
| oss << ")"; | |||
| std::ostringstream attr_oss; | |||
| bool has_attr = false; | |||
| std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"}; | |||
| for (auto attr : attrs_) { | |||
| if (attr.second != nullptr && black_list.count(attr.first) == 0) { | |||
| if (has_attr) { | |||
| attr_os << ", "; | |||
| attr_oss << ", "; | |||
| } else { | |||
| has_attr = true; | |||
| } | |||
| attr_os << attr.first << ": " << attr.second->ToString(); | |||
| attr_oss << attr.first << ": " << attr.second->ToString(); | |||
| } | |||
| } | |||
| if (has_attr) { | |||
| os << " // attr {" << attr_os.str() << "}"; | |||
| oss << " // attr {" << attr_oss.str() << "}"; | |||
| } | |||
| return oss.str(); | |||
| } | |||
| template <typename TM, typename TD> | |||
| @@ -293,6 +297,7 @@ DFormat ElemwiseOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) | |||
| NodeBase ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| auto nodebase = PrimOp::Infer(inputs, attrs); | |||
| // change the compute_type to BROADCAST if the result shape is greater than the input shapes. | |||
| auto IsBroadcast = [this](const NodePtrList &inputs) -> bool { | |||
| for (auto &ref : inputs) { | |||
| if (ref->shape.size() != this->shape.size()) return true; | |||
| @@ -393,17 +398,21 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| }; | |||
| auto shape0 = inputs[0]->shape; | |||
| auto shape1 = inputs[1]->shape; | |||
| check_nd(shape0, 4); | |||
| check_nd(shape1, 4); | |||
| constexpr auto dim_len = 4; | |||
| check_nd(shape0, dim_len); | |||
| check_nd(shape1, dim_len); | |||
| CHECK_ATTR(attrs, "format"); | |||
| if (inputs[0]->format != kOpFormat_NHWC && inputs[1]->format != kOpFormat_NHWC && | |||
| GetValue<std::string>(attrs.find("format")->second) != kOpFormat_NHWC) { | |||
| MS_LOG(EXCEPTION) << "check NHWC format failed"; | |||
| } | |||
| auto n = shape0[0]; | |||
| auto h = shape0[1]; | |||
| auto w = shape0[2]; | |||
| auto out_channel = shape1[0]; | |||
| constexpr auto axis_n = 0; | |||
| constexpr auto axis_h = 1; | |||
| constexpr auto axis_w = 2; | |||
| auto n = shape0[axis_n]; | |||
| auto h = shape0[axis_h]; | |||
| auto w = shape0[axis_w]; | |||
| auto out_channel = shape1[axis_n]; | |||
| CHECK_ATTR(attrs, "pad_list"); | |||
| CHECK_ATTR(attrs, "pad_mode"); | |||
| CHECK_ATTR(attrs, "kernel_size"); | |||
| @@ -414,7 +423,6 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| auto kernel_size = GetListInt(attrs.find("kernel_size")->second); | |||
| auto stride = GetListInt(attrs.find("stride")->second); | |||
| auto dilation = GetListInt(attrs.find("dilation")->second); | |||
| constexpr auto dim_len = 4; | |||
| check_nd(pad_list, dim_len); | |||
| constexpr auto kernel_len = 2; | |||
| check_nd(kernel_size, kernel_len); | |||
| @@ -464,6 +472,7 @@ DShape TransposeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| } | |||
| DFormat TransposeOp::InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { | |||
| // only support NCHW/NHWC now | |||
| if (inputs[0]->shape.size() != 4) return kOpFormat_DEFAULT; | |||
| CHECK_ATTR(attrs, "perm"); | |||
| auto perm = GetListInt(attrs.find("perm")->second); | |||
| @@ -17,12 +17,8 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_ | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <functional> | |||
| #include "utils/hash_map.h" | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| #include "ir/dtype/type.h" | |||
| @@ -44,14 +40,14 @@ class PrimOp : public Node { | |||
| OPAQUE, | |||
| }; | |||
| PrimOp(const std::string &op, const std::string &node_name, ComputeType compute) | |||
| : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {} | |||
| PrimOp(const std::string &op, ComputeType compute) | |||
| : Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}), op_(op), compute_type_(compute) {} | |||
| ~PrimOp() = default; | |||
| virtual NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs); | |||
| virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op); | |||
| void Dump(std::ostringstream &os) const override; | |||
| std::string ToString() const override; | |||
| NType NodeType() override { return NType::Primitive; } | |||
| const std::string &op() const { return op_; } | |||
| @@ -72,9 +68,22 @@ class PrimOp : public Node { | |||
| }; | |||
| using PrimOpPtr = std::shared_ptr<PrimOp>; | |||
| class ReshapeOp : public PrimOp { | |||
| public: | |||
| explicit ReshapeOp(const std::string &op) : PrimOp(op, RESHAPE) {} | |||
| ~ReshapeOp() = default; | |||
| protected: | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { | |||
| return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT | |||
| : GetValue<std::string>(attrs.find("format")->second); | |||
| } | |||
| }; | |||
| class ElemwiseOp : public PrimOp { | |||
| public: | |||
| ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {} | |||
| explicit ElemwiseOp(const std::string &op) : PrimOp(op, ELEMWISE) {} | |||
| ~ElemwiseOp() = default; | |||
| NodeBase Infer(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| @@ -84,9 +93,35 @@ class ElemwiseOp : public PrimOp { | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class BroadcastToOp : public PrimOp { | |||
| public: | |||
| explicit BroadcastToOp(const std::string &op) : PrimOp(op, BROADCAST) {} | |||
| ~BroadcastToOp() = default; | |||
| protected: | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class ReduceOp : public PrimOp { | |||
| public: | |||
| explicit ReduceOp(const std::string &op) : PrimOp(op, REDUCE) {} | |||
| ~ReduceOp() = default; | |||
| protected: | |||
| void Check(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }; | |||
| }; | |||
| class OpaqueOp : public PrimOp { | |||
| public: | |||
| explicit OpaqueOp(const std::string &op) : PrimOp(op, OPAQUE) {} | |||
| ~OpaqueOp() = default; | |||
| }; | |||
| class CastOp : public ElemwiseOp { | |||
| public: | |||
| CastOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Cast", node_name) {} | |||
| explicit CastOp(const std::string &op) : ElemwiseOp("Cast") {} | |||
| ~CastOp() = default; | |||
| protected: | |||
| @@ -95,7 +130,7 @@ class CastOp : public ElemwiseOp { | |||
| class InplaceAssignOp : public ElemwiseOp { | |||
| public: | |||
| InplaceAssignOp(const std::string &op, const std::string &node_name) : ElemwiseOp("InplaceAssign", node_name) {} | |||
| explicit InplaceAssignOp(const std::string &op) : ElemwiseOp("InplaceAssign") {} | |||
| ~InplaceAssignOp() = default; | |||
| protected: | |||
| @@ -106,7 +141,7 @@ class InplaceAssignOp : public ElemwiseOp { | |||
| class SelectOp : public ElemwiseOp { | |||
| public: | |||
| SelectOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Select", node_name) {} | |||
| explicit SelectOp(const std::string &op) : ElemwiseOp("Select") {} | |||
| ~SelectOp() = default; | |||
| protected: | |||
| @@ -116,85 +151,16 @@ class SelectOp : public ElemwiseOp { | |||
| class CompareOp : public ElemwiseOp { | |||
| public: | |||
| CompareOp(const std::string &op, const std::string &node_name) : ElemwiseOp(op, node_name) {} | |||
| explicit CompareOp(const std::string &op) : ElemwiseOp(op) {} | |||
| ~CompareOp() = default; | |||
| protected: | |||
| TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) override { return TypeId::kNumberTypeBool; } | |||
| }; | |||
| class LessOp : public CompareOp { | |||
| public: | |||
| LessOp(const std::string &op, const std::string &node_name) : CompareOp("Less", node_name) {} | |||
| ~LessOp() = default; | |||
| }; | |||
| class EqualOp : public CompareOp { | |||
| public: | |||
| EqualOp(const std::string &op, const std::string &node_name) : CompareOp("Equal", node_name) {} | |||
| ~EqualOp() = default; | |||
| }; | |||
| class LessEqualOp : public CompareOp { | |||
| public: | |||
| LessEqualOp(const std::string &op, const std::string &node_name) : CompareOp("LessEqual", node_name) {} | |||
| ~LessEqualOp() = default; | |||
| }; | |||
| class GreaterOp : public CompareOp { | |||
| public: | |||
| GreaterOp(const std::string &op, const std::string &node_name) : CompareOp("Greater", node_name) {} | |||
| ~GreaterOp() = default; | |||
| }; | |||
| class GreaterEqualOp : public CompareOp { | |||
| public: | |||
| GreaterEqualOp(const std::string &op, const std::string &node_name) : CompareOp("GreaterEqual", node_name) {} | |||
| ~GreaterEqualOp() = default; | |||
| }; | |||
| class ReshapeOp : public PrimOp { | |||
| public: | |||
| ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {} | |||
| ~ReshapeOp() = default; | |||
| protected: | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { | |||
| return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT | |||
| : GetValue<std::string>(attrs.find("format")->second); | |||
| } | |||
| }; | |||
| class BroadcastToOp : public PrimOp { | |||
| public: | |||
| BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {} | |||
| ~BroadcastToOp() = default; | |||
| protected: | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| }; | |||
| class ReduceOp : public PrimOp { | |||
| public: | |||
| ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {} | |||
| ~ReduceOp() = default; | |||
| protected: | |||
| void Check(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; | |||
| DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override { return kOpFormat_DEFAULT; }; | |||
| }; | |||
| class OpaqueOp : public PrimOp { | |||
| public: | |||
| OpaqueOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, OPAQUE) {} | |||
| ~OpaqueOp() = default; | |||
| }; | |||
| class Conv2dOp : public OpaqueOp { | |||
| public: | |||
| Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {} | |||
| explicit Conv2dOp(const std::string &op) : OpaqueOp("Conv2D") {} | |||
| ~Conv2dOp() = default; | |||
| protected: | |||
| @@ -204,7 +170,7 @@ class Conv2dOp : public OpaqueOp { | |||
| class TransposeOp : public OpaqueOp { | |||
| public: | |||
| TransposeOp(const std::string &op, const std::string &node_name) : OpaqueOp("Transpose", node_name) {} | |||
| explicit TransposeOp(const std::string &op) : OpaqueOp("Transpose") {} | |||
| ~TransposeOp() = default; | |||
| protected: | |||
| @@ -214,7 +180,7 @@ class TransposeOp : public OpaqueOp { | |||
| class MatMulOp : public OpaqueOp { | |||
| public: | |||
| MatMulOp(const std::string &op, const std::string &node_name) : OpaqueOp("MatMul", node_name) {} | |||
| explicit MatMulOp(const std::string &op) : OpaqueOp("MatMul") {} | |||
| ~MatMulOp() = default; | |||
| protected: | |||
| @@ -224,7 +190,7 @@ class MatMulOp : public OpaqueOp { | |||
| class PadAkgOp : public OpaqueOp { | |||
| public: | |||
| PadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("PadAkg", node_name) {} | |||
| explicit PadAkgOp(const std::string &op) : OpaqueOp("PadAkg") {} | |||
| ~PadAkgOp() = default; | |||
| protected: | |||
| @@ -233,7 +199,7 @@ class PadAkgOp : public OpaqueOp { | |||
| class UnPadAkgOp : public OpaqueOp { | |||
| public: | |||
| UnPadAkgOp(const std::string &op, const std::string &node_name) : OpaqueOp("UnPadAkg", node_name) {} | |||
| explicit UnPadAkgOp(const std::string &op) : OpaqueOp("UnPadAkg") {} | |||
| ~UnPadAkgOp() = default; | |||
| protected: | |||
| @@ -242,7 +208,7 @@ class UnPadAkgOp : public OpaqueOp { | |||
| class CImagOp : public ElemwiseOp { | |||
| public: | |||
| CImagOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CImag", node_name) {} | |||
| explicit CImagOp(const std::string &op) : ElemwiseOp("CImag") {} | |||
| ~CImagOp() = default; | |||
| protected: | |||
| @@ -257,7 +223,7 @@ class CImagOp : public ElemwiseOp { | |||
| class CRealOp : public ElemwiseOp { | |||
| public: | |||
| CRealOp(const std::string &op, const std::string &node_name) : ElemwiseOp("CReal", node_name) {} | |||
| explicit CRealOp(const std::string &op) : ElemwiseOp("CReal") {} | |||
| ~CRealOp() = default; | |||
| protected: | |||
| @@ -272,7 +238,7 @@ class CRealOp : public ElemwiseOp { | |||
| class ComplexOp : public ElemwiseOp { | |||
| public: | |||
| ComplexOp(const std::string &op, const std::string &node_name) : ElemwiseOp("Complex", node_name) {} | |||
| explicit ComplexOp(const std::string &op) : ElemwiseOp("Complex") {} | |||
| ~ComplexOp() = default; | |||
| protected: | |||
| @@ -282,7 +248,7 @@ class ComplexOp : public ElemwiseOp { | |||
| class StandardNormalOp : public OpaqueOp { | |||
| public: | |||
| StandardNormalOp(const std::string &op, const std::string &node_name) : OpaqueOp("StandardNormal", node_name) {} | |||
| explicit StandardNormalOp(const std::string &op) : OpaqueOp("StandardNormal") {} | |||
| ~StandardNormalOp() = default; | |||
| protected: | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/graph_kernel/model/op_register.h" | |||
| #include <memory> | |||
| namespace mindspore::graphkernel::inner { | |||
| namespace { | |||
| class OpRegister { | |||
| public: | |||
| OpRegister(const std::string &name, const CreatorFunc &func) { OpRegistry::Instance().Register(name, func); } | |||
| ~OpRegister() = default; | |||
| }; | |||
| #define JOIN(x, y) x##y | |||
| #define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt) | |||
| #define OP_REGISTER(name, cls) \ | |||
| static_assert(std::is_base_of<PrimOp, cls>::value, " should be base of PrimOp"); \ | |||
| static const OpRegister UNIQUE_NAME(g_graphkernel_op, __COUNTER__)( \ | |||
| name, [](const std::string &op) -> PrimOpPtr { return std::make_shared<cls>(op); }) | |||
| } // namespace | |||
| // All nodes supported by GraphKernel are listed below. | |||
| OP_REGISTER("_opaque", OpaqueOp); | |||
| OP_REGISTER("Add", ElemwiseOp); | |||
| OP_REGISTER("Sub", ElemwiseOp); | |||
| OP_REGISTER("RealDiv", ElemwiseOp); | |||
| OP_REGISTER("Mul", ElemwiseOp); | |||
| OP_REGISTER("Log", ElemwiseOp); | |||
| OP_REGISTER("Exp", ElemwiseOp); | |||
| OP_REGISTER("Pow", ElemwiseOp); | |||
| OP_REGISTER("Sqrt", ElemwiseOp); | |||
| OP_REGISTER("Rsqrt", ElemwiseOp); | |||
| OP_REGISTER("Neg", ElemwiseOp); | |||
| OP_REGISTER("Reciprocal", ElemwiseOp); | |||
| OP_REGISTER("Abs", ElemwiseOp); | |||
| OP_REGISTER("BroadcastTo", BroadcastToOp); | |||
| OP_REGISTER("Reshape", ReshapeOp); | |||
| OP_REGISTER("ReduceSum", ReduceOp); | |||
| OP_REGISTER("ReduceMax", ReduceOp); | |||
| OP_REGISTER("ReduceMin", ReduceOp); | |||
| OP_REGISTER("Cast", CastOp); | |||
| OP_REGISTER("InplaceAssign", InplaceAssignOp); | |||
| OP_REGISTER("Select", SelectOp); | |||
| OP_REGISTER("Less", CompareOp); | |||
| OP_REGISTER("Equal", CompareOp); | |||
| OP_REGISTER("LessEqual", CompareOp); | |||
| OP_REGISTER("GreaterEqual", CompareOp); | |||
| OP_REGISTER("Greater", CompareOp); | |||
| OP_REGISTER("Transpose", TransposeOp); | |||
| OP_REGISTER("MatMul", MatMulOp); | |||
| OP_REGISTER("PadAkg", PadAkgOp); | |||
| OP_REGISTER("UnPadAkg", UnPadAkgOp); | |||
| OP_REGISTER("CReal", CRealOp); | |||
| OP_REGISTER("CImag", CImagOp); | |||
| OP_REGISTER("Complex", ComplexOp); | |||
| OP_REGISTER("StandardNormal", StandardNormalOp); | |||
| } // namespace mindspore::graphkernel::inner | |||
| @@ -18,69 +18,34 @@ | |||
| #include <functional> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "utils/hash_map.h" | |||
| #include "backend/optimizer/graph_kernel/model/node.h" | |||
| #include "backend/optimizer/graph_kernel/model/op_node.h" | |||
| namespace mindspore::graphkernel::inner { | |||
| #define OP_CREATOR(cls) \ | |||
| [](const std::string &op, const std::string &name) -> PrimOpPtr { return std::make_shared<cls>(op, name); } | |||
| using CreatorFunc = std::function<PrimOpPtr(const std::string &)>; | |||
| class OpRegistry { | |||
| public: | |||
| static OpRegistry &Instance() { | |||
| static OpRegistry instance{}; | |||
| return instance; | |||
| } | |||
| void Register(const std::string &op_name, | |||
| const std::function<PrimOpPtr(const std::string &, const std::string &)> &func) { | |||
| creators.insert({op_name, func}); | |||
| } | |||
| void Register(const std::string &op_name, const CreatorFunc &func) { creators.insert({op_name, func}); } | |||
| PrimOpPtr NewOp(const std::string &op, const std::string &name) { | |||
| return creators.find(op) == creators.end() ? creators["Opaque"](op, name) : creators[op](op, name); | |||
| PrimOpPtr NewOp(const std::string &op) { | |||
| // "OpaqueOp" is registered by default. | |||
| return creators.find(op) == creators.end() ? creators["_opaque"](op) : creators[op](op); | |||
| } | |||
| private: | |||
| OpRegistry() { | |||
| Register("Add", OP_CREATOR(ElemwiseOp)); | |||
| Register("Sub", OP_CREATOR(ElemwiseOp)); | |||
| Register("RealDiv", OP_CREATOR(ElemwiseOp)); | |||
| Register("Mul", OP_CREATOR(ElemwiseOp)); | |||
| Register("Log", OP_CREATOR(ElemwiseOp)); | |||
| Register("Exp", OP_CREATOR(ElemwiseOp)); | |||
| Register("Pow", OP_CREATOR(ElemwiseOp)); | |||
| Register("Sqrt", OP_CREATOR(ElemwiseOp)); | |||
| Register("Rsqrt", OP_CREATOR(ElemwiseOp)); | |||
| Register("Neg", OP_CREATOR(ElemwiseOp)); | |||
| Register("Reciprocal", OP_CREATOR(ElemwiseOp)); | |||
| Register("Abs", OP_CREATOR(ElemwiseOp)); | |||
| Register("BroadcastTo", OP_CREATOR(BroadcastToOp)); | |||
| Register("Reshape", OP_CREATOR(ReshapeOp)); | |||
| Register("ReduceSum", OP_CREATOR(ReduceOp)); | |||
| Register("ReduceMax", OP_CREATOR(ReduceOp)); | |||
| Register("ReduceMin", OP_CREATOR(ReduceOp)); | |||
| Register("Cast", OP_CREATOR(CastOp)); | |||
| Register("InplaceAssign", OP_CREATOR(InplaceAssignOp)); | |||
| Register("Select", OP_CREATOR(SelectOp)); | |||
| Register("Less", OP_CREATOR(LessOp)); | |||
| Register("Equal", OP_CREATOR(EqualOp)); | |||
| Register("LessEqual", OP_CREATOR(LessEqualOp)); | |||
| Register("GreaterEqual", OP_CREATOR(GreaterEqualOp)); | |||
| Register("Greater", OP_CREATOR(GreaterOp)); | |||
| Register("Transpose", OP_CREATOR(TransposeOp)); | |||
| Register("MatMul", OP_CREATOR(MatMulOp)); | |||
| Register("PadAkg", OP_CREATOR(PadAkgOp)); | |||
| Register("UnPadAkg", OP_CREATOR(UnPadAkgOp)); | |||
| Register("CReal", OP_CREATOR(CRealOp)); | |||
| Register("CImag", OP_CREATOR(CImagOp)); | |||
| Register("Complex", OP_CREATOR(ComplexOp)); | |||
| Register("Opaque", OP_CREATOR(OpaqueOp)); | |||
| Register("StandardNormal", OP_CREATOR(StandardNormalOp)); | |||
| } | |||
| OpRegistry() = default; | |||
| ~OpRegistry() = default; | |||
| mindspore::HashMap<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators; | |||
| OpRegistry(const OpRegistry &) = delete; | |||
| OpRegistry(const OpRegistry &&) = delete; | |||
| OpRegistry &operator=(const OpRegistry &) = delete; | |||
| mindspore::HashMap<std::string, CreatorFunc> creators; | |||
| }; | |||
| } // namespace mindspore::graphkernel::inner | |||
| #endif | |||
| @@ -232,7 +232,7 @@ class TransformOp { | |||
| if (perm.empty()) { | |||
| MS_LOG(EXCEPTION) << "unsupported format: " << format_a_ << " to " << format_b_; | |||
| } | |||
| auto op = inner::OpRegistry::Instance().NewOp("Transpose", "new_trans"); | |||
| auto op = inner::OpRegistry::Instance().NewOp("Transpose"); | |||
| op->SetAttr("perm", MakeValue(perm)); | |||
| return op; | |||
| } | |||