diff --git a/main.go b/main.go index dbbf44a1826..68bb424e82f 100644 --- a/main.go +++ b/main.go @@ -2,13 +2,17 @@ package main import ( "bytes" + "context" "embed" + "errors" "fmt" "log" "net/http" "os" + "os/signal" "strconv" "strings" + "syscall" "time" "github.com/QuantumNous/new-api/common" @@ -192,10 +196,30 @@ func main() { // Log startup success message common.LogStartupSuccess(startTime, port) - err = server.Run(":" + port) - if err != nil { - common.FatalLog("failed to start HTTP server: " + err.Error()) + srv := &http.Server{ + Addr: ":" + port, + Handler: server, + } + + go func() { + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + common.FatalLog("failed to start HTTP server: " + err.Error()) + } + }() + + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + sig := <-quit + common.SysLog(fmt.Sprintf("received signal: %v, shutting down...", sig)) + + // SSE streams may run for minutes; give them time to finish before forced exit + shutdownTimeout := time.Duration(common.GetEnvOrDefault("SHUTDOWN_TIMEOUT_SECONDS", 120)) * time.Second + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + common.SysError(fmt.Sprintf("server forced to shutdown: %v", err)) } + common.SysLog("server exited") } func InjectUmamiAnalytics() {