diff --git a/src/main.go b/src/main.go index d7775d6..7c53867 100644 --- a/src/main.go +++ b/src/main.go @@ -18,6 +18,11 @@ import ( "github.com/sunbk201/ua3f/log" ) +var ( + ErrInvalidSocksVersion = errors.New("invalid socks version") + ErrInvalidSocksCmd = errors.New("invalid socks cmd") +) + var version = "0.7.3" var payload string var uaPattern string @@ -90,17 +95,20 @@ func process(client net.Conn) { client.Close() return } - target, destAddrPort, err := Socks5Connect(client) + destAddrPort, err := ParseSocks5Request(client) if err != nil { - // UDP if strings.Contains(err.Error(), "UDP Associate") { Socks5UDP(client) client.Close() return - } else if strings.Contains(err.Error(), "connection timed out") { - logrus.Debug("Connect timeout: ", err) + } else { + logrus.Debug(fmt.Sprintf("[%s][%s] ParseSocks5Request failed: %s", client.RemoteAddr().String(), destAddrPort, err.Error())) + client.Close() return } + } + target, err := Socks5Connect(client, destAddrPort) + if err != nil { logrus.Debug("Connect failed: ", err) client.Close() return @@ -112,7 +120,7 @@ func Socks5Auth(client net.Conn) (err error) { buf := make([]byte, 256) n, err := io.ReadFull(client, buf[:2]) if n != 2 { - if err == io.EOF { + if errors.Is(err, io.EOF) { logrus.Warn(fmt.Sprintf("[%s][Auth] read EOF", client.RemoteAddr().String())) } else { logrus.Error(fmt.Sprintf("[%s][Auth] read header: %s", client.RemoteAddr().String(), err.Error())) @@ -122,7 +130,7 @@ func Socks5Auth(client net.Conn) (err error) { ver, nMethods := int(buf[0]), int(buf[1]) if ver != 5 { logrus.Error(fmt.Sprintf("[%s][Auth] invalid ver", client.RemoteAddr().String())) - return errors.New("invalid version") + return ErrInvalidSocksVersion } n, err = io.ReadFull(client, buf[:nMethods]) if n != nMethods { @@ -264,64 +272,72 @@ func Socks5UDP(client net.Conn) { udpserver.Close() } -func Socks5Connect(client net.Conn) (net.Conn, string, error) { - buf := make([]byte, 256) - n, err := io.ReadFull(client, buf[:4]) - if n != 4 { - return nil, "", errors.New("read header:" + err.Error()) - } - ver, cmd, _, atyp := buf[0], buf[1], buf[2], buf[3] - if ver != 5 { - return nil, "", errors.New("invalid ver") - } - if cmd == 3 { - return nil, "", errors.New("UDP Associate") - } - if cmd != 1 { - return nil, "", errors.New("invalid cmd, only support connect") - } - addr := "" - switch atyp { - case 1: - n, err = io.ReadFull(client, buf[:4]) - if n != 4 { - return nil, "", errors.New("invalid IPv4:" + err.Error()) - } - addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]) - case 3: - n, err = io.ReadFull(client, buf[:1]) - if n != 1 { - return nil, "", errors.New("invalid hostname:" + err.Error()) - } - addrLen := int(buf[0]) - n, err = io.ReadFull(client, buf[:addrLen]) - if n != addrLen { - return nil, "", errors.New("invalid hostname:" + err.Error()) - } - addr = string(buf[:addrLen]) - case 4: - return nil, "", errors.New("IPv6: no supported yet") - default: - return nil, "", errors.New("invalid atyp") - } - n, err = io.ReadFull(client, buf[:2]) - if n != 2 { - return nil, "", errors.New("read port:" + err.Error()) - } - port := binary.BigEndian.Uint16(buf[:2]) - destAddrPort := fmt.Sprintf("%s:%d", addr, port) +func Socks5Connect(client net.Conn, destAddrPort string) (target net.Conn, err error) { logrus.Debug(fmt.Sprintf("Connecting %s", destAddrPort)) dest, err := net.Dial("tcp", destAddrPort) if err != nil { - return nil, destAddrPort, errors.New("dial dst:" + err.Error()) + return nil, err } logrus.Debug(fmt.Sprintf("Connected %s", destAddrPort)) _, err = client.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) if err != nil { dest.Close() - return nil, destAddrPort, errors.New("write rsp:" + err.Error()) + return nil, err } - return dest, destAddrPort, nil + return dest, nil +} + +func ParseSocks5Request(client net.Conn) (destAddrPort string, err error) { + buf := make([]byte, 256) + if _, err = io.ReadFull(client, buf[:4]); err != nil { + err = fmt.Errorf("read header: %w", err) + return + } + ver, cmd, _, atyp := buf[0], buf[1], buf[2], buf[3] + if ver != 5 { + err = ErrInvalidSocksVersion + return + } + if cmd == 3 { + return "", errors.New("UDP Associate") + } else if cmd != 1 { + err = ErrInvalidSocksCmd + return + } + var addr string + switch atyp { + case 1: + if _, err = io.ReadFull(client, buf[:4]); err != nil { + err = fmt.Errorf("invalid IPv4: %w", err) + return + } + addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3]) + case 3: + if _, err = io.ReadFull(client, buf[:1]); err != nil { + err = fmt.Errorf("invalid hostname: %w", err) + return + } + addrLen := int(buf[0]) + if _, err = io.ReadFull(client, buf[:addrLen]); err != nil { + err = fmt.Errorf("invalid hostname: %w", err) + return + } + addr = string(buf[:addrLen]) + case 4: + err = errors.New("IPv6: no supported yet") + return + default: + err = errors.New("invalid atyp") + return + } + + if _, err = io.ReadFull(client, buf[:2]); err != nil { + err = fmt.Errorf("read port: %w", err) + return + } + port := binary.BigEndian.Uint16(buf[:2]) + destAddrPort = fmt.Sprintf("%s:%d", addr, port) + return destAddrPort, nil } func Socks5Forward(client, target net.Conn, destAddrPort string) {