From 44624e09c345fd41b32924d8ef4d51c7f3b27f4e Mon Sep 17 00:00:00 2001 From: SunBK201 Date: Sat, 8 Nov 2025 17:24:13 +0800 Subject: [PATCH] refactor: add ssh sniff support and improve protocol sniffing error handling --- src/internal/rewrite/rewriter.go | 22 ++++++++++++---------- src/internal/sniff/sniff.go | 14 ++++++++++++++ src/internal/sniff/ssh.go | 14 ++++++++++++++ src/internal/sniff/tls.go | 12 ++++++------ src/internal/sniff/websocket.go | 19 +++++++++++-------- 5 files changed, 57 insertions(+), 24 deletions(-) create mode 100644 src/internal/sniff/ssh.go diff --git a/src/internal/rewrite/rewriter.go b/src/internal/rewrite/rewriter.go index acf8197..6cbd97c 100644 --- a/src/internal/rewrite/rewriter.go +++ b/src/internal/rewrite/rewriter.go @@ -157,15 +157,17 @@ func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr } }() - if strings.HasSuffix(destAddr, "443") && sniff.SniffTLSClientHello(reader) { - r.Cache.Add(destAddr, struct{}{}) - log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, added to cache") - statistics.AddConnection(&statistics.ConnectionRecord{ - Protocol: sniff.HTTPS, - SrcAddr: srcAddr, - DestAddr: destAddr, - }) - return + if strings.HasSuffix(destAddr, "443") { + if isTLS, _ := sniff.SniffTLS(reader); isTLS { + r.Cache.Add(destAddr, struct{}{}) + log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, added to cache") + statistics.AddConnection(&statistics.ConnectionRecord{ + Protocol: sniff.HTTPS, + SrcAddr: srcAddr, + DestAddr: destAddr, + }) + return + } } var isHTTP bool @@ -177,7 +179,7 @@ func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr if !isHTTP { r.Cache.Add(destAddr, struct{}{}) log.LogInfoWithAddr(srcAddr, destAddr, "sniff first request is not http, added to cache, switching to raw proxy") - if sniff.SniffTLSClientHello(reader) { + if isTLS, _ := sniff.SniffTLS(reader); isTLS { statistics.AddConnection(&statistics.ConnectionRecord{ Protocol: sniff.TLS, SrcAddr: srcAddr, diff --git a/src/internal/sniff/sniff.go b/src/internal/sniff/sniff.go index eff96cd..4432d9f 100644 --- a/src/internal/sniff/sniff.go +++ b/src/internal/sniff/sniff.go @@ -16,10 +16,24 @@ const ( HTTPS Protocol = "HTTPS" TLS Protocol = "TLS" WebSocket Protocol = "WebSocket" + SSH Protocol = "SSH" ) var ErrPeekTimeout = errors.New("peek timeout") +func SniffProtocol(reader *bufio.Reader) (Protocol, error) { + if isTLS, _ := SniffTLS(reader); isTLS { + return TLS, nil + } + if isHTTP, _ := SniffHTTP(reader); isHTTP { + return HTTP, nil + } + if isSSH, _ := SniffSSH(reader); isSSH { + return SSH, nil + } + return TCP, nil +} + // peekLineSlice reads a line from bufio.Reader without consuming it. // returns the line bytes (without CRLF) or error. func peekLineSlice(br *bufio.Reader, maxSize int) ([]byte, error) { diff --git a/src/internal/sniff/ssh.go b/src/internal/sniff/ssh.go new file mode 100644 index 0000000..6ce78b8 --- /dev/null +++ b/src/internal/sniff/ssh.go @@ -0,0 +1,14 @@ +package sniff + +import "bufio" + +func SniffSSH(reader *bufio.Reader) (bool, error) { + header, err := reader.Peek(4) + if err != nil { + return false, err + } + if string(header) == "SSH-" { + return true, nil + } + return false, nil +} diff --git a/src/internal/sniff/tls.go b/src/internal/sniff/tls.go index 83042fe..e9ca2a5 100644 --- a/src/internal/sniff/tls.go +++ b/src/internal/sniff/tls.go @@ -2,24 +2,24 @@ package sniff import "bufio" -func SniffTLSClientHello(reader *bufio.Reader) bool { +func SniffTLS(reader *bufio.Reader) (bool, error) { header, err := reader.Peek(3) if err != nil { - return false + return false, err } // TLS record type 0x16 = Handshake if header[0] != 0x16 { - return false + return false, nil } // TLS version versionMajor := header[1] versionMinor := header[2] if versionMajor != 0x03 { - return false + return false, nil } if versionMinor < 0x01 || versionMinor > 0x04 { - return false + return false, nil } - return true + return true, nil } diff --git a/src/internal/sniff/websocket.go b/src/internal/sniff/websocket.go index 572005d..26269e1 100644 --- a/src/internal/sniff/websocket.go +++ b/src/internal/sniff/websocket.go @@ -1,8 +1,11 @@ package sniff -func SniffWebSocket(header []byte) bool { - if len(header) < 2 { - return false +import "bufio" + +func SniffWebSocket(reader *bufio.Reader) (bool, error) { + header, err := reader.Peek(2) + if err != nil { + return false, err } b0 := header[0] @@ -14,21 +17,21 @@ func SniffWebSocket(header []byte) bool { // requested frames from client to server must be masked if mask == 0 { - return false + return false, nil } // Control frames must have FIN set if rsv != 0 { - return false + return false, nil } // opcode must be in valid range if opcode > 0xA { - return false + return false, nil } // payload length payloadLen := b1 & 0x7F if payloadLen > 0 && payloadLen <= 125 { - return true + return true, nil } - return true + return true, nil }