refactor: improve http request handling

This commit is contained in:
SunBK201 2025-11-28 23:10:47 +08:00
parent f7bca9cf0a
commit a005803862

View File

@ -1,11 +1,9 @@
package http package http
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"time" "time"
@ -57,9 +55,6 @@ func (s *Server) Close() (err error) {
} }
func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) { func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) {
req.RequestURI = ""
req.URL.Scheme = "http"
req.URL.Host = req.Host
destPort := req.URL.Port() destPort := req.URL.Port()
if destPort == "" { if destPort == "" {
destPort = "80" destPort = "80"
@ -73,33 +68,21 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) {
}) })
defer statistics.RemoveConnection(req.RemoteAddr, destAddr) defer statistics.RemoveConnection(req.RemoteAddr, destAddr)
target, err := base.Connect(destAddr) slog.Info("HTTP proxy request", slog.String("srcAddr", req.RemoteAddr), slog.String("destAddr", destAddr))
req, err := s.rewrite(req, req.RemoteAddr, destAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
resp, err := http.DefaultTransport.RoundTrip(req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable) http.Error(w, err.Error(), http.StatusServiceUnavailable)
return return
} }
defer func() { defer func() {
if err := target.Close(); err != nil { _ = resp.Body.Close()
slog.Warn("target.Close", slog.String("destAddr", destAddr), slog.Any("error", err))
}
}()
slog.Info("New HTTP proxy request", slog.String("srcAddr", req.RemoteAddr), slog.String("destAddr", destAddr))
err = s.rewriteAndForward(target, req, req.Host, req.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
resp, err := http.ReadResponse(bufio.NewReader(target), req)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
slog.Warn("resp.Body.Close", slog.String("destAddr", destAddr), slog.Any("error", cerr))
}
}() }()
for k, v := range resp.Header { for k, v := range resp.Header {
@ -111,11 +94,11 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) {
_, _ = io.Copy(w, resp.Body) _, _ = io.Copy(w, resp.Body)
} }
func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr, srcAddr string) (err error) { func (s *Server) rewrite(req *http.Request, srcAddr, dstAddr string) (*http.Request, error) {
decision := s.Rewriter.EvaluateRewriteDecision(req, srcAddr, dstAddr) decision := s.Rewriter.EvaluateRewriteDecision(req, srcAddr, dstAddr)
if decision.Action == rule.ActionDrop { if decision.Action == rule.ActionDrop {
log.LogInfoWithAddr(srcAddr, dstAddr, "Request dropped by rule") log.LogInfoWithAddr(srcAddr, dstAddr, "Request dropped by rule")
return fmt.Errorf("request dropped by rule") return nil, fmt.Errorf("request dropped by rule")
} }
if decision.NeedCache { if decision.NeedCache {
s.Cache.Add(dstAddr, struct{}{}) s.Cache.Add(dstAddr, struct{}{})
@ -123,10 +106,7 @@ func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr,
if decision.ShouldRewrite() { if decision.ShouldRewrite() {
req = s.Rewriter.Rewrite(req, srcAddr, dstAddr, decision) req = s.Rewriter.Rewrite(req, srcAddr, dstAddr, decision)
} }
if err := req.Write(target); err != nil { return req, nil
return fmt.Errorf("req.Write: %w", err)
}
return nil
} }
func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) { func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
@ -142,24 +122,21 @@ func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError) http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return return
} }
client, _, err := hijacker.Hijack() src, _, err := hijacker.Hijack()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable) http.Error(w, err.Error(), http.StatusServiceUnavailable)
return return
} }
if _, err := client.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil { if _, err := src.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
slog.Warn("failed to write CONNECT response to client", slog.String("client", req.RemoteAddr), slog.Any("error", err)) slog.Warn("failed to write CONNECT response to client", slog.String("client", req.RemoteAddr), slog.Any("error", err))
_ = client.Close() _ = src.Close()
_ = dest.Close() _ = dest.Close()
return return
} }
s.ServeConnLink(&base.ConnLink{ s.ServeConnLink(&base.ConnLink{
LConn: client, LConn: src,
RConn: dest, RConn: dest,
LAddr: client.RemoteAddr().String(), LAddr: req.RemoteAddr,
RAddr: destAddr, RAddr: destAddr,
}) })
} }
func (s *Server) HandleClient(client net.Conn) {
}