diff --git a/src/internal/server/http/http.go b/src/internal/server/http/http.go index 76aa8c0..3ce2390 100644 --- a/src/internal/server/http/http.go +++ b/src/internal/server/http/http.go @@ -1,11 +1,9 @@ package http import ( - "bufio" "fmt" "io" "log/slog" - "net" "net/http" "time" @@ -57,9 +55,6 @@ func (s *Server) Close() (err error) { } func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) { - req.RequestURI = "" - req.URL.Scheme = "http" - req.URL.Host = req.Host destPort := req.URL.Port() if destPort == "" { destPort = "80" @@ -73,33 +68,21 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) { }) defer statistics.RemoveConnection(req.RemoteAddr, destAddr) - target, err := base.Connect(destAddr) + slog.Info("HTTP proxy request", slog.String("srcAddr", req.RemoteAddr), slog.String("destAddr", destAddr)) + + req, err := s.rewrite(req, req.RemoteAddr, destAddr) + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + + resp, err := http.DefaultTransport.RoundTrip(req) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } defer func() { - if err := target.Close(); err != nil { - slog.Warn("target.Close", slog.String("destAddr", destAddr), slog.Any("error", err)) - } - }() - - slog.Info("New HTTP proxy request", slog.String("srcAddr", req.RemoteAddr), slog.String("destAddr", destAddr)) - - err = s.rewriteAndForward(target, req, req.Host, req.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - resp, err := http.ReadResponse(bufio.NewReader(target), req) - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - defer func() { - if cerr := resp.Body.Close(); cerr != nil { - slog.Warn("resp.Body.Close", slog.String("destAddr", destAddr), slog.Any("error", cerr)) - } + _ = resp.Body.Close() }() for k, v := range resp.Header { @@ -111,11 +94,11 @@ func (s *Server) handleHTTP(w http.ResponseWriter, req *http.Request) { _, _ = io.Copy(w, resp.Body) } -func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr, srcAddr string) (err error) { +func (s *Server) rewrite(req *http.Request, srcAddr, dstAddr string) (*http.Request, error) { decision := s.Rewriter.EvaluateRewriteDecision(req, srcAddr, dstAddr) if decision.Action == rule.ActionDrop { log.LogInfoWithAddr(srcAddr, dstAddr, "Request dropped by rule") - return fmt.Errorf("request dropped by rule") + return nil, fmt.Errorf("request dropped by rule") } if decision.NeedCache { s.Cache.Add(dstAddr, struct{}{}) @@ -123,10 +106,7 @@ func (s *Server) rewriteAndForward(target net.Conn, req *http.Request, dstAddr, if decision.ShouldRewrite() { req = s.Rewriter.Rewrite(req, srcAddr, dstAddr, decision) } - if err := req.Write(target); err != nil { - return fmt.Errorf("req.Write: %w", err) - } - return nil + return req, nil } func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) { @@ -142,24 +122,21 @@ func (s *Server) handleTunneling(w http.ResponseWriter, req *http.Request) { http.Error(w, "Hijacking not supported", http.StatusInternalServerError) return } - client, _, err := hijacker.Hijack() + src, _, err := hijacker.Hijack() if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } - if _, err := client.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil { + if _, err := src.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil { slog.Warn("failed to write CONNECT response to client", slog.String("client", req.RemoteAddr), slog.Any("error", err)) - _ = client.Close() + _ = src.Close() _ = dest.Close() return } s.ServeConnLink(&base.ConnLink{ - LConn: client, + LConn: src, RConn: dest, - LAddr: client.RemoteAddr().String(), + LAddr: req.RemoteAddr, RAddr: destAddr, }) } - -func (s *Server) HandleClient(client net.Conn) { -}