package dag import ( "fmt" "reflect" "strings" "gitlink.org.cn/cloudream/common/utils/lo2" "gitlink.org.cn/cloudream/common/utils/reflect2" ) type Graph struct { Nodes []Node isWalking bool } func NewGraph() *Graph { return &Graph{} } func (g *Graph) AddNode(node Node) { g.Nodes = append(g.Nodes, node) node.SetGraph(g) } func (g *Graph) RemoveNode(node Node) { for i, n := range g.Nodes { if n == node { if g.isWalking { g.Nodes[i] = nil } else { g.Nodes = lo2.RemoveAt(g.Nodes, i) } break } } } func (g *Graph) Walk(cb func(node Node) bool) { g.isWalking = true for i := 0; i < len(g.Nodes); i++ { if g.Nodes[i] == nil { continue } if !cb(g.Nodes[i]) { break } } g.isWalking = false g.Nodes = lo2.RemoveAllDefault(g.Nodes) } func (g *Graph) NewStreamVar() *StreamVar { return &StreamVar{} } func (g *Graph) NewValueVar() *ValueVar { return &ValueVar{} } func (g *Graph) Dump() string { nodeIDs := make(map[Node]int) for i, node := range g.Nodes { nodeIDs[node] = i } var sb strings.Builder for _, node := range g.Nodes { id, ok := nodeIDs[node] if !ok { id = len(nodeIDs) nodeIDs[node] = id } sb.WriteString(fmt.Sprintf("[%v]%v\n", id, nodeTypeName(node))) if node.InputStreams().Len() > 0 { sb.WriteString("SIn: ") for i, in := range node.InputStreams().Slots { if i > 0 { sb.WriteString(", ") } if in == nil { sb.WriteString("?") } else { sb.WriteString(fmt.Sprintf("%v", nodeIDs[in.Src])) } } sb.WriteString("\n") } if node.OutputStreams().Len() > 0 { sb.WriteString("SOut: ") for i, out := range node.OutputStreams().Slots { if i > 0 { sb.WriteString(", ") } sb.WriteString("(") for i2, dst := range out.Dst { if i2 > 0 { sb.WriteString(", ") } sb.WriteString(fmt.Sprintf("%v", nodeIDs[dst])) } sb.WriteString(")") } sb.WriteString("\n") } if node.InputValues().Len() > 0 { sb.WriteString("VIn: ") for i, in := range node.InputValues().Slots { if i > 0 { sb.WriteString(", ") } if in == nil { sb.WriteString("?") } else { sb.WriteString(fmt.Sprintf("%v", nodeIDs[in.Src])) } } sb.WriteString("\n") } if node.OutputValues().Len() > 0 { sb.WriteString("VOut: ") for i, out := range node.OutputValues().Slots { if i > 0 { sb.WriteString(", ") } sb.WriteString("(") for i2, dst := range out.Dst { if i2 > 0 { sb.WriteString(", ") } sb.WriteString(fmt.Sprintf("%v", nodeIDs[dst])) } sb.WriteString(")") } sb.WriteString("\n") } } return sb.String() } func nodeTypeName(node Node) string { typ := reflect2.TypeOfValue(node) for typ.Kind() == reflect.Ptr { typ = typ.Elem() } return typ.Name() } func AddNode[N Node](graph *Graph, typ N) N { graph.AddNode(typ) return typ } func WalkOnlyType[N Node](g *Graph, cb func(node N) bool) { g.Walk(func(n Node) bool { node, ok := n.(N) if ok { return cb(node) } return true }) }