diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index 7d2f41a12..de318a953 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "net" @@ -297,6 +298,8 @@ var ( // ServeHTTP is the entry point for all HTTP requests. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + start := time.Now() + // If there are listener wrappers that process tls connections but don't return a *tls.Conn, this field will be nil. if r.TLS == nil { if tlsConnStateFunc, ok := r.Context().Value(tlsConnectionStateFuncCtxKey).(func() *tls.ConnectionState); ok { @@ -304,6 +307,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + // enable full-duplex for HTTP/1, ensuring the entire + // request body gets consumed before writing the response + if s.EnableFullDuplex && r.ProtoMajor == 1 { + if err := http.NewResponseController(w).EnableFullDuplex(); err != nil { //nolint:bodyclose + if c := s.logger.Check(zapcore.WarnLevel, "failed to enable full duplex"); c != nil { + c.Write(zap.Error(err)) + } + } + } + + // set the Server header h := w.Header() h["Server"] = serverHeader @@ -316,39 +330,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - // reject very long methods; probably a mistake or an attack - if len(r.Method) > 32 { - if s.shouldLogRequest(r) { - if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil { - c.Write( - zap.String("method_trunc", r.Method[:32]), - zap.String("remote_addr", r.RemoteAddr), - ) - } - } - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - + // prepare internals of the request for the handler pipeline repl := caddy.NewReplacer() r = PrepareRequest(r, repl, w, s) - // enable full-duplex for HTTP/1, ensuring the entire - // request body gets consumed before writing the response - if s.EnableFullDuplex && r.ProtoMajor == 1 { - if err := http.NewResponseController(w).EnableFullDuplex(); err != nil { //nolint:bodyclose - if c := s.logger.Check(zapcore.WarnLevel, "failed to enable full duplex"); c != nil { - c.Write(zap.Error(err)) - } - } - } - - // clone the request for logging purposes before - // it enters any handler chain; this is necessary - // to capture the original request in case it gets - // modified during handling - // cloning the request and using .WithLazy is considerably faster - // than using .With, which will JSON encode the request immediately + // clone the request for logging purposes before it enters any handler chain; + // this is necessary to capture the original request in case it gets modified + // during handling (cloning the request and using .WithLazy is considerably + // faster than using .With, which will JSON-encode the request immediately) shouldLogCredentials := s.Logs != nil && s.Logs.ShouldLogCredentials loggableReq := zap.Object("request", LoggableHTTPRequest{ Request: r.Clone(r.Context()), @@ -376,36 +365,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // capture the original version of the request - accLog := s.accessLogger.With(loggableReq) + accLog := s.accessLogger.WithLazy(loggableReq) defer s.logRequest(accLog, r, wrec, &duration, repl, bodyReader, shouldLogCredentials) } - start := time.Now() - - // guarantee ACME HTTP challenges; handle them - // separately from any user-defined handlers + // guarantee ACME HTTP challenges; handle them separately from any user-defined handlers if s.tlsApp.HandleHTTPChallenge(w, r) { duration = time.Since(start) return } - // execute the primary handler chain - err := s.primaryHandlerChain.ServeHTTP(w, r) + err := s.serveHTTP(w, r) duration = time.Since(start) - // if no errors, we're done! if err == nil { return } // restore original request before invoking error handler chain (issue #3717) - // TODO: this does not restore original headers, if modified (for efficiency) - origReq := r.Context().Value(OriginalRequestCtxKey).(http.Request) - r.Method = origReq.Method - r.RemoteAddr = origReq.RemoteAddr - r.RequestURI = origReq.RequestURI - cloneURL(origReq.URL, r.URL) + // NOTE: this does not restore original headers if modified (for efficiency) + origReq, ok := r.Context().Value(OriginalRequestCtxKey).(http.Request) + if ok { + r.Method = origReq.Method + r.RemoteAddr = origReq.RemoteAddr + r.RequestURI = origReq.RequestURI + cloneURL(origReq.URL, r.URL) + } // prepare the error log errLog = errLog.With(zap.Duration("duration", duration)) @@ -424,8 +410,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if s.Errors != nil && len(s.Errors.Routes) > 0 { // execute user-defined error handling route if err2 := s.errorHandlerChain.ServeHTTP(w, r); err2 == nil { - // user's error route handled the error response - // successfully, so now just log the error + // user's error route handled the error response successfully, so now just log the error for _, logger := range errLoggers { if c := logger.Check(zapcore.DebugLevel, errMsg); c != nil { if fields == nil { @@ -473,6 +458,35 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { + // reject very long methods; probably a mistake or an attack + if len(r.Method) > 32 { + if s.shouldLogRequest(r) { + if c := s.accessLogger.Check(zapcore.DebugLevel, "rejecting request with long method"); c != nil { + c.Write( + zap.String("method_trunc", r.Method[:32]), + zap.String("remote_addr", r.RemoteAddr), + ) + } + } + return HandlerError{StatusCode: http.StatusMethodNotAllowed} + } + + // RFC 9112 section 3.2: "A server MUST respond with a 400 (Bad Request) status + // code to any HTTP/1.1 request message that lacks a Host header field and to any + // request message that contains more than one Host header field line or a Host + // header field with an invalid field value." + if r.Host == "" { + return HandlerError{ + Err: errors.New("rfc9112 forbids empty Host"), + StatusCode: http.StatusBadRequest, + } + } + + // execute the primary handler chain + return s.primaryHandlerChain.ServeHTTP(w, r) +} + // wrapPrimaryRoute wraps stack (a compiled middleware handler chain) // in s.enforcementHandler which performs crucial security checks, etc. func (s *Server) wrapPrimaryRoute(stack Handler) Handler {