diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 8edc585e7..dd01b6ef5 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -40,6 +40,7 @@ import ( "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/caddyserver/caddy/v2/modules/caddyhttp/headers" "github.com/caddyserver/caddy/v2/modules/caddytls" "github.com/caddyserver/caddy/v2/modules/internal/network" ) @@ -514,6 +515,28 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e return rt, nil } +// RequestHeaderOps implements TransportHeaderOpsProvider. It returns header +// operations for requests when the transport's configuration indicates they +// should be applied. In particular, when TLS is enabled for this transport, +// return an operation to set the Host header to the upstream host:port +// placeholder so HTTPS upstreams get the proper Host by default. +// +// Note: this is a provision-time hook; the Handler will call this during +// its Provision and cache the resulting HeaderOps. The HeaderOps are +// applied per-request (so placeholders are expanded at request time). +func (h *HTTPTransport) RequestHeaderOps() *headers.HeaderOps { + // If TLS is not configured for this transport, don't inject Host + // defaults. TLS being non-nil indicates HTTPS to the upstream. + if h.TLS == nil { + return nil + } + return &headers.HeaderOps{ + Set: http.Header{ + "Host": []string{"{http.reverse_proxy.upstream.hostport}"}, + }, + } +} + // RoundTrip implements http.RoundTripper. func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { h.SetScheme(req) diff --git a/modules/caddyhttp/reverseproxy/httptransport_test.go b/modules/caddyhttp/reverseproxy/httptransport_test.go index 46931c8b1..1fa4965f2 100644 --- a/modules/caddyhttp/reverseproxy/httptransport_test.go +++ b/modules/caddyhttp/reverseproxy/httptransport_test.go @@ -94,3 +94,24 @@ func TestHTTPTransportUnmarshalCaddyFileWithCaPools(t *testing.T) { }) } } + +func TestHTTPTransport_RequestHeaderOps_TLS(t *testing.T) { + var ht HTTPTransport + // When TLS is nil, expect no header ops + if ops := ht.RequestHeaderOps(); ops != nil { + t.Fatalf("expected nil HeaderOps when TLS is nil, got: %#v", ops) + } + + // When TLS is configured, expect a HeaderOps that sets Host + ht.TLS = &TLSConfig{} + ops := ht.RequestHeaderOps() + if ops == nil { + t.Fatal("expected non-nil HeaderOps when TLS is set") + } + if ops.Set == nil { + t.Fatalf("expected ops.Set to be non-nil, got nil") + } + if got := ops.Set.Get("Host"); got != "{http.reverse_proxy.upstream.hostport}" { + t.Fatalf("unexpected Host value; want placeholder, got: %s", got) + } +} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 2ea17046a..6f6a0f9f2 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -192,6 +192,13 @@ type Handler struct { CB CircuitBreaker `json:"-"` DynamicUpstreams UpstreamSource `json:"-"` + // transportHeaderOps is a set of header operations provided + // by the transport at provision time, if the transport + // implements TransportHeaderOpsProvider. These ops are + // applied before any user-configured header ops so the + // user can override transport defaults. + transportHeaderOps *headers.HeaderOps + // Holds the parsed CIDR ranges from TrustedProxies trustedProxies []netip.Prefix @@ -322,6 +329,18 @@ func (h *Handler) Provision(ctx caddy.Context) error { h.Transport = t } + // If the transport can provide header ops, cache them now so we don't + // have to compute them per-request. Provision the HeaderOps if present + // so any runtime artifacts (like precompiled regex) are prepared. + if tph, ok := h.Transport.(RequestHeaderOpsTransport); ok { + h.transportHeaderOps = tph.RequestHeaderOps() + if h.transportHeaderOps != nil { + if err := h.transportHeaderOps.Provision(ctx); err != nil { + return fmt.Errorf("provisioning transport header ops: %v", err) + } + } + } + // set up load balancing if h.LoadBalancing == nil { h.LoadBalancing = new(LoadBalancing) @@ -575,14 +594,26 @@ func (h *Handler) proxyLoopIteration(r *http.Request, origReq *http.Request, w h repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails()) // mutate request headers according to this upstream; - // because we're in a retry loop, we have to copy - // headers (and the r.Host value) from the original - // so that each retry is identical to the first - if h.Headers != nil && h.Headers.Request != nil { + // because we're in a retry loop, we have to copy headers + // (and the r.Host value) from the original so that each + // retry is identical to the first. If either transport or + // user ops exist, apply them in order (transport first, + // then user, so user's config wins). + var userOps *headers.HeaderOps + if h.Headers != nil { + userOps = h.Headers.Request + } + transportOps := h.transportHeaderOps + if transportOps != nil || userOps != nil { r.Header = make(http.Header) copyHeader(r.Header, reqHeader) r.Host = reqHost - h.Headers.Request.ApplyToRequest(r) + if transportOps != nil { + transportOps.ApplyToRequest(r) + } + if userOps != nil { + userOps.ApplyToRequest(r) + } } // proxy the request to that upstream @@ -1542,6 +1573,17 @@ type BufferedTransport interface { DefaultBufferSizes() (int64, int64) } +// RequestHeaderOpsTransport may be implemented by a transport to provide +// header operations to apply to requests immediately before the RoundTrip. +// For example, overriding the default Host when TLS is enabled. +type RequestHeaderOpsTransport interface { + // RequestHeaderOps allows a transport to provide header operations + // to apply to the request. The transport is asked at provision time + // to return a HeaderOps (or nil) that will be applied before + // user-configured header ops. + RequestHeaderOps() *headers.HeaderOps +} + // roundtripSucceededError is an error type that is returned if the // roundtrip succeeded, but an error occurred after-the-fact. type roundtripSucceededError struct{ error }