|
- package zk
-
- import (
- "encoding/binary"
- "fmt"
- "io"
- "net"
- "sync"
- )
-
- var (
- requests = make(map[int32]int32) // Map of Xid -> Opcode
- requestsLock = &sync.Mutex{}
- )
-
- func trace(conn1, conn2 net.Conn, client bool) {
- defer conn1.Close()
- defer conn2.Close()
- buf := make([]byte, 10*1024)
- init := true
- for {
- _, err := io.ReadFull(conn1, buf[:4])
- if err != nil {
- fmt.Println("1>", client, err)
- return
- }
-
- blen := int(binary.BigEndian.Uint32(buf[:4]))
-
- _, err = io.ReadFull(conn1, buf[4:4+blen])
- if err != nil {
- fmt.Println("2>", client, err)
- return
- }
-
- var cr interface{}
- opcode := int32(-1)
- readHeader := true
- if client {
- if init {
- cr = &connectRequest{}
- readHeader = false
- } else {
- xid := int32(binary.BigEndian.Uint32(buf[4:8]))
- opcode = int32(binary.BigEndian.Uint32(buf[8:12]))
- requestsLock.Lock()
- requests[xid] = opcode
- requestsLock.Unlock()
- cr = requestStructForOp(opcode)
- if cr == nil {
- fmt.Printf("Unknown opcode %d\n", opcode)
- }
- }
- } else {
- if init {
- cr = &connectResponse{}
- readHeader = false
- } else {
- xid := int32(binary.BigEndian.Uint32(buf[4:8]))
- zxid := int64(binary.BigEndian.Uint64(buf[8:16]))
- errnum := int32(binary.BigEndian.Uint32(buf[16:20]))
- if xid != -1 || zxid != -1 {
- requestsLock.Lock()
- found := false
- opcode, found = requests[xid]
- if !found {
- println("WEFWEFEW")
- opcode = 0
- }
- delete(requests, xid)
- requestsLock.Unlock()
- } else {
- opcode = opWatcherEvent
- }
- cr = responseStructForOp(opcode)
- if cr == nil {
- fmt.Printf("Unknown opcode %d\n", opcode)
- }
- if errnum != 0 {
- cr = &struct{}{}
- }
- }
- }
- opname := "."
- if opcode != -1 {
- opname = opNames[opcode]
- }
- if cr == nil {
- fmt.Printf("%+v %s %+v\n", client, opname, buf[4:4+blen])
- } else {
- n := 4
- hdrStr := ""
- if readHeader {
- var hdr interface{}
- if client {
- hdr = &requestHeader{}
- } else {
- hdr = &responseHeader{}
- }
- if n2, err := decodePacket(buf[n:n+blen], hdr); err != nil {
- fmt.Println(err)
- } else {
- n += n2
- }
- hdrStr = fmt.Sprintf(" %+v", hdr)
- }
- if _, err := decodePacket(buf[n:n+blen], cr); err != nil {
- fmt.Println(err)
- }
- fmt.Printf("%+v %s%s %+v\n", client, opname, hdrStr, cr)
- }
-
- init = false
-
- written, err := conn2.Write(buf[:4+blen])
- if err != nil {
- fmt.Println("3>", client, err)
- return
- } else if written != 4+blen {
- fmt.Printf("Written != read: %d != %d\n", written, blen)
- return
- }
- }
- }
-
- func handleConnection(addr string, conn net.Conn) {
- zkConn, err := net.Dial("tcp", addr)
- if err != nil {
- fmt.Println(err)
- return
- }
- go trace(conn, zkConn, true)
- trace(zkConn, conn, false)
- }
-
- func StartTracer(listenAddr, serverAddr string) {
- ln, err := net.Listen("tcp", listenAddr)
- if err != nil {
- panic(err)
- }
- for {
- conn, err := ln.Accept()
- if err != nil {
- fmt.Println(err)
- continue
- }
- go handleConnection(serverAddr, conn)
- }
- }
|