|
- package dag
-
- import (
- "fmt"
- "reflect"
- "strings"
-
- "gitlink.org.cn/cloudream/common/utils/lo2"
- "gitlink.org.cn/cloudream/common/utils/reflect2"
- )
-
- type Graph struct {
- Nodes []Node
- isWalking bool
- }
-
- func NewGraph() *Graph {
- return &Graph{}
- }
-
- func (g *Graph) AddNode(node Node) {
- g.Nodes = append(g.Nodes, node)
- node.SetGraph(g)
- }
-
- func (g *Graph) RemoveNode(node Node) {
- for i, n := range g.Nodes {
- if n == node {
- if g.isWalking {
- g.Nodes[i] = nil
- } else {
- g.Nodes = lo2.RemoveAt(g.Nodes, i)
- }
- break
- }
- }
- }
-
- func (g *Graph) Walk(cb func(node Node) bool) {
- g.isWalking = true
- for i := 0; i < len(g.Nodes); i++ {
- if g.Nodes[i] == nil {
- continue
- }
-
- if !cb(g.Nodes[i]) {
- break
- }
- }
- g.isWalking = false
-
- g.Nodes = lo2.RemoveAllDefault(g.Nodes)
- }
-
- func (g *Graph) NewStreamVar() *StreamVar {
- return &StreamVar{}
- }
-
- func (g *Graph) NewValueVar() *ValueVar {
- return &ValueVar{}
- }
-
- func (g *Graph) Dump() string {
- nodeIDs := make(map[Node]int)
- for i, node := range g.Nodes {
- nodeIDs[node] = i
- }
-
- var sb strings.Builder
- for _, node := range g.Nodes {
- id, ok := nodeIDs[node]
- if !ok {
- id = len(nodeIDs)
- nodeIDs[node] = id
- }
- sb.WriteString(fmt.Sprintf("[%v]%v\n", id, nodeTypeName(node)))
- if node.InputStreams().Len() > 0 {
- sb.WriteString("SIn: ")
- for i, in := range node.InputStreams().Slots {
- if i > 0 {
- sb.WriteString(", ")
- }
-
- if in == nil {
- sb.WriteString("?")
- } else {
- sb.WriteString(fmt.Sprintf("%v", nodeIDs[in.Src]))
- }
- }
- sb.WriteString("\n")
- }
- if node.OutputStreams().Len() > 0 {
- sb.WriteString("SOut: ")
- for i, out := range node.OutputStreams().Slots {
- if i > 0 {
- sb.WriteString(", ")
- }
-
- sb.WriteString("(")
- for i2, dst := range out.Dst {
- if i2 > 0 {
- sb.WriteString(", ")
- }
- sb.WriteString(fmt.Sprintf("%v", nodeIDs[dst]))
- }
- sb.WriteString(")")
- }
- sb.WriteString("\n")
- }
-
- if node.InputValues().Len() > 0 {
- sb.WriteString("VIn: ")
- for i, in := range node.InputValues().Slots {
- if i > 0 {
- sb.WriteString(", ")
- }
-
- if in == nil {
- sb.WriteString("?")
- } else {
- sb.WriteString(fmt.Sprintf("%v", nodeIDs[in.Src]))
- }
- }
- sb.WriteString("\n")
- }
- if node.OutputValues().Len() > 0 {
- sb.WriteString("VOut: ")
- for i, out := range node.OutputValues().Slots {
- if i > 0 {
- sb.WriteString(", ")
- }
-
- sb.WriteString("(")
- for i2, dst := range out.Dst {
- if i2 > 0 {
- sb.WriteString(", ")
- }
- sb.WriteString(fmt.Sprintf("%v", nodeIDs[dst]))
- }
- sb.WriteString(")")
- }
- sb.WriteString("\n")
- }
-
- }
- return sb.String()
- }
-
- func nodeTypeName(node Node) string {
- typ := reflect2.TypeOfValue(node)
- for typ.Kind() == reflect.Ptr {
- typ = typ.Elem()
- }
- return typ.Name()
- }
-
- func AddNode[N Node](graph *Graph, typ N) N {
- graph.AddNode(typ)
- return typ
- }
-
- func WalkOnlyType[N Node](g *Graph, cb func(node N) bool) {
- g.Walk(func(n Node) bool {
- node, ok := n.(N)
- if ok {
- return cb(node)
- }
- return true
- })
- }
|