caddyhttp: Reject invalid Host header (fix #7449)

This commit is contained in:
Matthew Holt
2026-01-30 12:24:16 -07:00
parent 565c1c3054
commit 7d24124430

View File

@@ -18,6 +18,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -297,6 +298,8 @@ var (
// ServeHTTP is the entry point for all HTTP requests. // ServeHTTP is the entry point for all HTTP requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// If there are listener wrappers that process tls connections but don't return a *tls.Conn, this field will be nil. // If there are listener wrappers that process tls connections but don't return a *tls.Conn, this field will be nil.
if r.TLS == nil { if r.TLS == nil {
if tlsConnStateFunc, ok := r.Context().Value(tlsConnectionStateFuncCtxKey).(func() *tls.ConnectionState); ok { if tlsConnStateFunc, ok := r.Context().Value(tlsConnectionStateFuncCtxKey).(func() *tls.ConnectionState); ok {
@@ -304,6 +307,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
// enable full-duplex for HTTP/1, ensuring the entire
// request body gets consumed before writing the response
if s.EnableFullDuplex && r.ProtoMajor == 1 {
if err := http.NewResponseController(w).EnableFullDuplex(); err != nil { //nolint:bodyclose
if c := s.logger.Check(zapcore.WarnLevel, "failed to enable full duplex"); c != nil {
c.Write(zap.Error(err))
}
}
}
// set the Server header
h := w.Header() h := w.Header()
h["Server"] = serverHeader h["Server"] = serverHeader
@@ -316,39 +330,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
// reject very long methods; probably a mistake or an attack // prepare internals of the request for the handler pipeline
if len(r.Method) > 32 {
if s.shouldLogRequest(r) {
if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil {
c.Write(
zap.String("method_trunc", r.Method[:32]),
zap.String("remote_addr", r.RemoteAddr),
)
}
}
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
repl := caddy.NewReplacer() repl := caddy.NewReplacer()
r = PrepareRequest(r, repl, w, s) r = PrepareRequest(r, repl, w, s)
// enable full-duplex for HTTP/1, ensuring the entire // clone the request for logging purposes before it enters any handler chain;
// request body gets consumed before writing the response // this is necessary to capture the original request in case it gets modified
if s.EnableFullDuplex && r.ProtoMajor == 1 { // during handling (cloning the request and using .WithLazy is considerably
if err := http.NewResponseController(w).EnableFullDuplex(); err != nil { //nolint:bodyclose // faster than using .With, which will JSON-encode the request immediately)
if c := s.logger.Check(zapcore.WarnLevel, "failed to enable full duplex"); c != nil {
c.Write(zap.Error(err))
}
}
}
// clone the request for logging purposes before
// it enters any handler chain; this is necessary
// to capture the original request in case it gets
// modified during handling
// cloning the request and using .WithLazy is considerably faster
// than using .With, which will JSON encode the request immediately
shouldLogCredentials := s.Logs != nil && s.Logs.ShouldLogCredentials shouldLogCredentials := s.Logs != nil && s.Logs.ShouldLogCredentials
loggableReq := zap.Object("request", LoggableHTTPRequest{ loggableReq := zap.Object("request", LoggableHTTPRequest{
Request: r.Clone(r.Context()), Request: r.Clone(r.Context()),
@@ -376,36 +365,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// capture the original version of the request // capture the original version of the request
accLog := s.accessLogger.With(loggableReq) accLog := s.accessLogger.WithLazy(loggableReq)
defer s.logRequest(accLog, r, wrec, &duration, repl, bodyReader, shouldLogCredentials) defer s.logRequest(accLog, r, wrec, &duration, repl, bodyReader, shouldLogCredentials)
} }
start := time.Now() // guarantee ACME HTTP challenges; handle them separately from any user-defined handlers
// guarantee ACME HTTP challenges; handle them
// separately from any user-defined handlers
if s.tlsApp.HandleHTTPChallenge(w, r) { if s.tlsApp.HandleHTTPChallenge(w, r) {
duration = time.Since(start) duration = time.Since(start)
return return
} }
// execute the primary handler chain err := s.serveHTTP(w, r)
err := s.primaryHandlerChain.ServeHTTP(w, r)
duration = time.Since(start) duration = time.Since(start)
// if no errors, we're done!
if err == nil { if err == nil {
return return
} }
// restore original request before invoking error handler chain (issue #3717) // restore original request before invoking error handler chain (issue #3717)
// TODO: this does not restore original headers, if modified (for efficiency) // NOTE: this does not restore original headers if modified (for efficiency)
origReq := r.Context().Value(OriginalRequestCtxKey).(http.Request) origReq, ok := r.Context().Value(OriginalRequestCtxKey).(http.Request)
r.Method = origReq.Method if ok {
r.RemoteAddr = origReq.RemoteAddr r.Method = origReq.Method
r.RequestURI = origReq.RequestURI r.RemoteAddr = origReq.RemoteAddr
cloneURL(origReq.URL, r.URL) r.RequestURI = origReq.RequestURI
cloneURL(origReq.URL, r.URL)
}
// prepare the error log // prepare the error log
errLog = errLog.With(zap.Duration("duration", duration)) errLog = errLog.With(zap.Duration("duration", duration))
@@ -424,8 +410,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.Errors != nil && len(s.Errors.Routes) > 0 { if s.Errors != nil && len(s.Errors.Routes) > 0 {
// execute user-defined error handling route // execute user-defined error handling route
if err2 := s.errorHandlerChain.ServeHTTP(w, r); err2 == nil { if err2 := s.errorHandlerChain.ServeHTTP(w, r); err2 == nil {
// user's error route handled the error response // user's error route handled the error response successfully, so now just log the error
// successfully, so now just log the error
for _, logger := range errLoggers { for _, logger := range errLoggers {
if c := logger.Check(zapcore.DebugLevel, errMsg); c != nil { if c := logger.Check(zapcore.DebugLevel, errMsg); c != nil {
if fields == nil { if fields == nil {
@@ -473,6 +458,35 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
// reject very long methods; probably a mistake or an attack
if len(r.Method) > 32 {
if s.shouldLogRequest(r) {
if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil {
c.Write(
zap.String("method_trunc", r.Method[:32]),
zap.String("remote_addr", r.RemoteAddr),
)
}
}
return HandlerError{StatusCode: http.StatusMethodNotAllowed}
}
// RFC 9112 section 3.2: "A server MUST respond with a 400 (Bad Request) status
// code to any HTTP/1.1 request message that lacks a Host header field and to any
// request message that contains more than one Host header field line or a Host
// header field with an invalid field value."
if r.Host == "" {
return HandlerError{
Err: errors.New("rfc9112 forbids empty Host"),
StatusCode: http.StatusBadRequest,
}
}
// execute the primary handler chain
return s.primaryHandlerChain.ServeHTTP(w, r)
}
// wrapPrimaryRoute wraps stack (a compiled middleware handler chain) // wrapPrimaryRoute wraps stack (a compiled middleware handler chain)
// in s.enforcementHandler which performs crucial security checks, etc. // in s.enforcementHandler which performs crucial security checks, etc.
func (s *Server) wrapPrimaryRoute(stack Handler) Handler { func (s *Server) wrapPrimaryRoute(stack Handler) Handler {