refactor: add ssh sniff support and improve protocol sniffing error handling

This commit is contained in:
SunBK201 2025-11-08 17:24:13 +08:00
parent cf7fc96835
commit 44624e09c3
5 changed files with 57 additions and 24 deletions

View File

@ -157,15 +157,17 @@ func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr
} }
}() }()
if strings.HasSuffix(destAddr, "443") && sniff.SniffTLSClientHello(reader) { if strings.HasSuffix(destAddr, "443") {
r.Cache.Add(destAddr, struct{}{}) if isTLS, _ := sniff.SniffTLS(reader); isTLS {
log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, added to cache") r.Cache.Add(destAddr, struct{}{})
statistics.AddConnection(&statistics.ConnectionRecord{ log.LogInfoWithAddr(srcAddr, destAddr, "tls client hello detected, added to cache")
Protocol: sniff.HTTPS, statistics.AddConnection(&statistics.ConnectionRecord{
SrcAddr: srcAddr, Protocol: sniff.HTTPS,
DestAddr: destAddr, SrcAddr: srcAddr,
}) DestAddr: destAddr,
return })
return
}
} }
var isHTTP bool var isHTTP bool
@ -177,7 +179,7 @@ func (r *Rewriter) Process(dst net.Conn, src net.Conn, destAddr string, srcAddr
if !isHTTP { if !isHTTP {
r.Cache.Add(destAddr, struct{}{}) r.Cache.Add(destAddr, struct{}{})
log.LogInfoWithAddr(srcAddr, destAddr, "sniff first request is not http, added to cache, switching to raw proxy") log.LogInfoWithAddr(srcAddr, destAddr, "sniff first request is not http, added to cache, switching to raw proxy")
if sniff.SniffTLSClientHello(reader) { if isTLS, _ := sniff.SniffTLS(reader); isTLS {
statistics.AddConnection(&statistics.ConnectionRecord{ statistics.AddConnection(&statistics.ConnectionRecord{
Protocol: sniff.TLS, Protocol: sniff.TLS,
SrcAddr: srcAddr, SrcAddr: srcAddr,

View File

@ -16,10 +16,24 @@ const (
HTTPS Protocol = "HTTPS" HTTPS Protocol = "HTTPS"
TLS Protocol = "TLS" TLS Protocol = "TLS"
WebSocket Protocol = "WebSocket" WebSocket Protocol = "WebSocket"
SSH Protocol = "SSH"
) )
var ErrPeekTimeout = errors.New("peek timeout") var ErrPeekTimeout = errors.New("peek timeout")
func SniffProtocol(reader *bufio.Reader) (Protocol, error) {
if isTLS, _ := SniffTLS(reader); isTLS {
return TLS, nil
}
if isHTTP, _ := SniffHTTP(reader); isHTTP {
return HTTP, nil
}
if isSSH, _ := SniffSSH(reader); isSSH {
return SSH, nil
}
return TCP, nil
}
// peekLineSlice reads a line from bufio.Reader without consuming it. // peekLineSlice reads a line from bufio.Reader without consuming it.
// returns the line bytes (without CRLF) or error. // returns the line bytes (without CRLF) or error.
func peekLineSlice(br *bufio.Reader, maxSize int) ([]byte, error) { func peekLineSlice(br *bufio.Reader, maxSize int) ([]byte, error) {

14
src/internal/sniff/ssh.go Normal file
View File

@ -0,0 +1,14 @@
package sniff
import "bufio"
func SniffSSH(reader *bufio.Reader) (bool, error) {
header, err := reader.Peek(4)
if err != nil {
return false, err
}
if string(header) == "SSH-" {
return true, nil
}
return false, nil
}

View File

@ -2,24 +2,24 @@ package sniff
import "bufio" import "bufio"
func SniffTLSClientHello(reader *bufio.Reader) bool { func SniffTLS(reader *bufio.Reader) (bool, error) {
header, err := reader.Peek(3) header, err := reader.Peek(3)
if err != nil { if err != nil {
return false return false, err
} }
// TLS record type 0x16 = Handshake // TLS record type 0x16 = Handshake
if header[0] != 0x16 { if header[0] != 0x16 {
return false return false, nil
} }
// TLS version // TLS version
versionMajor := header[1] versionMajor := header[1]
versionMinor := header[2] versionMinor := header[2]
if versionMajor != 0x03 { if versionMajor != 0x03 {
return false return false, nil
} }
if versionMinor < 0x01 || versionMinor > 0x04 { if versionMinor < 0x01 || versionMinor > 0x04 {
return false return false, nil
} }
return true return true, nil
} }

View File

@ -1,8 +1,11 @@
package sniff package sniff
func SniffWebSocket(header []byte) bool { import "bufio"
if len(header) < 2 {
return false func SniffWebSocket(reader *bufio.Reader) (bool, error) {
header, err := reader.Peek(2)
if err != nil {
return false, err
} }
b0 := header[0] b0 := header[0]
@ -14,21 +17,21 @@ func SniffWebSocket(header []byte) bool {
// requested frames from client to server must be masked // requested frames from client to server must be masked
if mask == 0 { if mask == 0 {
return false return false, nil
} }
// Control frames must have FIN set // Control frames must have FIN set
if rsv != 0 { if rsv != 0 {
return false return false, nil
} }
// opcode must be in valid range // opcode must be in valid range
if opcode > 0xA { if opcode > 0xA {
return false return false, nil
} }
// payload length // payload length
payloadLen := b1 & 0x7F payloadLen := b1 & 0x7F
if payloadLen > 0 && payloadLen <= 125 { if payloadLen > 0 && payloadLen <= 125 {
return true return true, nil
} }
return true return true, nil
} }