diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index 0a676e431..82f0ed814 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -283,3 +283,6 @@ const proxyProtocolInfoVarKey = "reverse_proxy.proxy_protocol_info" type ProxyProtocolInfo struct { AddrPort netip.AddrPort } + +// proxyVarKey is the key used that indicates the proxy server used for a request. +const proxyVarKey = "reverse_proxy.proxy" diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 910033ca1..81bc293c3 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -236,15 +236,15 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e } dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { - // For unix socket upstreams, we need to recover the dial info from - // the request's context, because the Host on the request's URL - // will have been modified by directing the request, overwriting - // the unix socket filename. - // Also, we need to avoid overwriting the address at this point - // when not necessary, because http.ProxyFromEnvironment may have - // modified the address according to the user's env proxy config. + // The network is usually tcp, and the address is the host in http.Request.URL.Host + // and that's been overwritten in directRequest + // However, if proxy is used according to http.ProxyFromEnvironment or proxy providers, + // address will be the address of the proxy server. + + // This means we can safely use the address in dialInfo if proxy is not used (the address and network will be same any way) + // or if the upstream is unix (because there is no way socks or http proxy can be used for unix address). if dialInfo, ok := GetDialInfo(ctx); ok { - if strings.HasPrefix(dialInfo.Network, "unix") { + if caddyhttp.GetVar(ctx, proxyVarKey) == nil || strings.HasPrefix(dialInfo.Network, "unix") { network = dialInfo.Network address = dialInfo.Address } @@ -339,9 +339,19 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e } else { proxy = http.ProxyFromEnvironment } + // we need to keep track if a proxy is used for a request + proxyWrapper := func(req *http.Request) (*url.URL, error) { + u, err := proxy(req) + if u == nil || err != nil { + return u, err + } + // there must be a proxy for this request + caddyhttp.SetVar(req.Context(), proxyVarKey, u) + return u, nil + } rt := &http.Transport{ - Proxy: proxy, + Proxy: proxyWrapper, DialContext: dialContext, MaxConnsPerHost: h.MaxConnsPerHost, ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), @@ -370,14 +380,6 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout) } - // The proxy protocol header can only be sent once right after opening the connection. - // So single connection must not be used for multiple requests, which can potentially - // come from different clients. - if !rt.DisableKeepAlives && h.ProxyProtocol != "" { - caddyCtx.Logger().Warn("disabling keepalives, they are incompatible with using PROXY protocol") - rt.DisableKeepAlives = true - } - if h.Compression != nil { rt.DisableCompression = !*h.Compression } diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 230bec951..136216609 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -1175,7 +1175,7 @@ func (lb LoadBalancing) tryAgain(ctx caddy.Context, start time.Time, retries int // directRequest modifies only req.URL so that it points to the upstream // in the given DialInfo. It must modify ONLY the request URL. -func (Handler) directRequest(req *http.Request, di DialInfo) { +func (h *Handler) directRequest(req *http.Request, di DialInfo) { // we need a host, so set the upstream's host address reqHost := di.Address @@ -1186,6 +1186,13 @@ func (Handler) directRequest(req *http.Request, di DialInfo) { reqHost = di.Host } + // add client address to the host to let transport differentiate requests from different clients + if ht, ok := h.Transport.(*HTTPTransport); ok && ht.ProxyProtocol != "" { + if proxyProtocolInfo, ok := caddyhttp.GetVar(req.Context(), proxyProtocolInfoVarKey).(ProxyProtocolInfo); ok { + reqHost = proxyProtocolInfo.AddrPort.String() + "->" + reqHost + } + } + req.URL.Host = reqHost }