diff --git a/internal/server/server.go b/internal/server/server.go index f0ccad606515..6bb2ae434ffa 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -314,6 +314,14 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip host validation for health check probes. Container + // orchestrators (Kubernetes, Docker, Cloud Run) typically hit + // /healthz via the pod IP or localhost, which would otherwise + // trip a strict AllowedHosts setting and break liveness probes. + if r.URL.Path == "/healthz" { + next.ServeHTTP(w, r) + return + } _, hasWildcard := allowedHosts["*"] hostname := r.Host if host, _, err := net.SplitHostPort(r.Host); err == nil { @@ -488,6 +496,16 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { _, _ = w.Write([]byte("🧰 Hello, World! 🧰")) }) + // healthz endpoint for container orchestration health checks + // (Kubernetes liveness/readiness probes, Docker HEALTHCHECK, etc.). + // Returns 200 OK with a small JSON body so probes can rely on both + // status code and payload. + r.Get("/healthz", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok"}`)) + }) + return s, nil } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ae0510b77675..2bc3297d8fea 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -116,6 +116,169 @@ func TestServe(t *testing.T) { } } +func TestHealthz(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + addr, port := "127.0.0.1", 5004 + cfg := server.ServerConfig{ + Version: "0.0.0", + Address: addr, + Port: port, + AllowedHosts: []string{"*"}, + } + + otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer func() { + err := otelShutdown(ctx) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + }() + + testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithLogger(ctx, testLogger) + + instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithInstrumentation(ctx, instrumentation) + + s, err := server.NewServer(ctx, cfg) + if err != nil { + t.Fatalf("unable to initialize server: %v", err) + } + + err = s.Listen(ctx) + if err != nil { + t.Fatalf("unable to start server: %v", err) + } + + errCh := make(chan error) + go func() { + defer close(errCh) + if serveErr := s.Serve(ctx); serveErr != nil { + errCh <- serveErr + } + }() + + url := fmt.Sprintf("http://%s:%d/healthz", addr, port) + resp, err := http.Get(url) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.StatusCode) + } + + if ct := resp.Header.Get("Content-Type"); ct != "application/json" { + t.Fatalf("expected Content-Type application/json, got %q", ct) + } + + raw, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading from request body: %s", err) + } + + var body map[string]string + if err := json.Unmarshal(raw, &body); err != nil { + t.Fatalf("expected JSON body, got %q: %s", string(raw), err) + } + if body["status"] != "ok" { + t.Fatalf(`expected {"status":"ok"}, got %q`, string(raw)) + } +} + +// TestHealthzBypassesHostCheck verifies that /healthz is reachable even when +// AllowedHosts does not include the request host. Container probes (Kubernetes, +// Docker, Cloud Run) commonly hit the endpoint via the pod IP or localhost, +// so the strict host validation must not block them. +func TestHealthzBypassesHostCheck(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + addr, port := "127.0.0.1", 5005 + cfg := server.ServerConfig{ + Version: "0.0.0", + Address: addr, + Port: port, + AllowedHosts: []string{"toolbox.example.com"}, + } + + otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer func() { + err := otelShutdown(ctx) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + }() + + testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithLogger(ctx, testLogger) + + instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithInstrumentation(ctx, instrumentation) + + s, err := server.NewServer(ctx, cfg) + if err != nil { + t.Fatalf("unable to initialize server: %v", err) + } + + err = s.Listen(ctx) + if err != nil { + t.Fatalf("unable to start server: %v", err) + } + + errCh := make(chan error) + go func() { + defer close(errCh) + if serveErr := s.Serve(ctx); serveErr != nil { + errCh <- serveErr + } + }() + + // Hit /healthz via the pod IP (127.0.0.1), which is not in AllowedHosts. + url := fmt.Sprintf("http://%s:%d/healthz", addr, port) + resp, err := http.Get(url) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected /healthz to bypass host check and return 200, got %d", resp.StatusCode) + } + + // Sanity check: confirm the host check is still active for other paths. + rootURL := fmt.Sprintf("http://%s:%d/", addr, port) + rootResp, err := http.Get(rootURL) + if err != nil { + t.Fatalf("error when sending root request: %s", err) + } + defer rootResp.Body.Close() + if rootResp.StatusCode != http.StatusForbidden { + t.Fatalf("expected / to be blocked by host check (403), got %d", rootResp.StatusCode) + } +} + func TestUpdateServer(t *testing.T) { ctx, err := testutils.ContextWithNewLogger() if err != nil {