diff --git a/controller/subscription.go b/controller/subscription.go index c6095312b77..6c386b0ab94 100644 --- a/controller/subscription.go +++ b/controller/subscription.go @@ -55,10 +55,17 @@ func GetSubscriptionSelf(c *gin.Context) { activeSubscriptions = []model.SubscriptionSummary{} } + var primarySubscription any + if primary := model.SelectPrimarySubscriptionSummary(activeSubscriptions); primary != nil { + primarySubscription = primary + } + common.ApiSuccess(c, gin.H{ - "billing_preference": pref, - "subscriptions": activeSubscriptions, // all active subscriptions - "all_subscriptions": allSubscriptions, // all subscriptions including expired + "billing_preference": pref, + "subscriptions": activeSubscriptions, // all active subscriptions + "all_subscriptions": allSubscriptions, // all subscriptions including expired + "primary_subscription": primarySubscription, + "active_subscription_count": len(activeSubscriptions), }) } diff --git a/controller/subscription_self_test.go b/controller/subscription_self_test.go new file mode 100644 index 00000000000..7824dcac34d --- /dev/null +++ b/controller/subscription_self_test.go @@ -0,0 +1,227 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +type subscriptionSelfAPIResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` +} + +type subscriptionSelfPlan struct { + ID int `json:"id"` + Title string `json:"title"` + QuotaResetPeriod string `json:"quota_reset_period"` + QuotaResetCustomSeconds int64 `json:"quota_reset_custom_seconds"` +} + +type subscriptionSelfSummary struct { + Subscription *model.UserSubscription `json:"subscription"` + Plan *subscriptionSelfPlan `json:"plan"` +} + +type subscriptionSelfResponseData struct { + BillingPreference string `json:"billing_preference"` + Subscriptions []subscriptionSelfSummary `json:"subscriptions"` + AllSubscriptions []subscriptionSelfSummary `json:"all_subscriptions"` + PrimarySubscription *subscriptionSelfSummary `json:"primary_subscription"` + ActiveSubscriptionCount int `json:"active_subscription_count"` +} + +func setupSubscriptionSelfControllerTestDB(t *testing.T) *gorm.DB { + t.Helper() + + gin.SetMode(gin.TestMode) + common.UsingSQLite = true + common.UsingMySQL = false + common.UsingPostgreSQL = false + common.RedisEnabled = false + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + require.NoError(t, err) + + model.DB = db + model.LOG_DB = db + + require.NoError(t, db.AutoMigrate(&model.User{}, &model.SubscriptionPlan{}, &model.UserSubscription{})) + + t.Cleanup(func() { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + }) + + return db +} + +func seedSubscriptionSelfUser(t *testing.T, db *gorm.DB, id int) *model.User { + t.Helper() + + user := &model.User{ + Id: id, + Username: fmt.Sprintf("user-%d", id), + Password: "password123", + Status: 1, + Role: common.RoleCommonUser, + Group: "default", + } + require.NoError(t, db.Create(user).Error) + return user +} + +func seedSubscriptionSelfPlan(t *testing.T, db *gorm.DB, plan model.SubscriptionPlan) *model.SubscriptionPlan { + t.Helper() + + require.NoError(t, db.Create(&plan).Error) + return &plan +} + +func seedSubscriptionSelfSubscription(t *testing.T, db *gorm.DB, sub model.UserSubscription) *model.UserSubscription { + t.Helper() + + require.NoError(t, db.Create(&sub).Error) + return &sub +} + +func newSubscriptionSelfContext(t *testing.T, userID int) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodGet, "/api/subscription/self", bytes.NewReader(nil)) + ctx.Set("id", userID) + return ctx, recorder +} + +func decodeSubscriptionSelfResponse(t *testing.T, recorder *httptest.ResponseRecorder) subscriptionSelfAPIResponse { + t.Helper() + + var response subscriptionSelfAPIResponse + require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &response)) + return response +} + +func decodeSubscriptionSelfData(t *testing.T, response subscriptionSelfAPIResponse) subscriptionSelfResponseData { + t.Helper() + + var data subscriptionSelfResponseData + require.NoError(t, common.Unmarshal(response.Data, &data)) + return data +} + +func TestGetSubscriptionSelfReturnsPrimarySubscriptionCountAndPlanTitle(t *testing.T) { + db := setupSubscriptionSelfControllerTestDB(t) + seedSubscriptionSelfUser(t, db, 1) + now := common.GetTimestamp() + + planMonthly := seedSubscriptionSelfPlan(t, db, model.SubscriptionPlan{ + Id: 101, + Title: "Max 月订阅", + TotalAmount: 80000000, + QuotaResetPeriod: model.SubscriptionResetMonthly, + QuotaResetCustomSeconds: 0, + Enabled: true, + }) + + seedSubscriptionSelfSubscription(t, db, model.UserSubscription{ + Id: 201, + UserId: 1, + PlanId: planMonthly.Id, + AmountTotal: 80000000, + AmountUsed: 2300000, + StartTime: now - 3600, + EndTime: now + 86400, + Status: "active", + Source: "order", + NextResetTime: now + 3600, + }) + + ctx, recorder := newSubscriptionSelfContext(t, 1) + GetSubscriptionSelf(ctx) + + response := decodeSubscriptionSelfResponse(t, recorder) + require.True(t, response.Success) + + data := decodeSubscriptionSelfData(t, response) + require.Equal(t, 1, data.ActiveSubscriptionCount) + require.NotNil(t, data.PrimarySubscription) + require.NotNil(t, data.PrimarySubscription.Plan) + require.Equal(t, "Max 月订阅", data.PrimarySubscription.Plan.Title) + require.Equal(t, "Max 月订阅", data.Subscriptions[0].Plan.Title) + require.Equal(t, "Max 月订阅", data.AllSubscriptions[0].Plan.Title) +} + +func TestGetSubscriptionSelfSkipsExhaustedLimitedSubscription(t *testing.T) { + db := setupSubscriptionSelfControllerTestDB(t) + seedSubscriptionSelfUser(t, db, 1) + now := common.GetTimestamp() + + exhaustedPlan := seedSubscriptionSelfPlan(t, db, model.SubscriptionPlan{ + Id: 201, + Title: "Daily Exhausted", + TotalAmount: 1000, + QuotaResetPeriod: model.SubscriptionResetDaily, + Enabled: true, + }) + activePlan := seedSubscriptionSelfPlan(t, db, model.SubscriptionPlan{ + Id: 202, + Title: "Monthly Available", + TotalAmount: 8000, + QuotaResetPeriod: model.SubscriptionResetMonthly, + Enabled: true, + }) + + seedSubscriptionSelfSubscription(t, db, model.UserSubscription{ + Id: 301, + UserId: 1, + PlanId: exhaustedPlan.Id, + AmountTotal: 1000, + AmountUsed: 1000, + StartTime: now - 7200, + EndTime: now + 3600, + Status: "active", + Source: "order", + }) + seedSubscriptionSelfSubscription(t, db, model.UserSubscription{ + Id: 302, + UserId: 1, + PlanId: activePlan.Id, + AmountTotal: 8000, + AmountUsed: 2500, + StartTime: now - 7200, + EndTime: now + 7200, + Status: "active", + Source: "order", + }) + + ctx, recorder := newSubscriptionSelfContext(t, 1) + GetSubscriptionSelf(ctx) + + response := decodeSubscriptionSelfResponse(t, recorder) + require.True(t, response.Success) + + data := decodeSubscriptionSelfData(t, response) + require.Equal(t, 2, data.ActiveSubscriptionCount) + require.NotNil(t, data.PrimarySubscription) + require.NotNil(t, data.PrimarySubscription.Plan) + require.Equal(t, "Monthly Available", data.PrimarySubscription.Plan.Title) + require.Equal(t, int64(8000), data.PrimarySubscription.Subscription.AmountTotal) + require.Equal(t, int64(2500), data.PrimarySubscription.Subscription.AmountUsed) +} diff --git a/controller/topup.go b/controller/topup.go index e7a392a4d31..25407187444 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -23,6 +23,9 @@ import ( ) func GetTopUpInfo(c *gin.Context) { + if !ensureRechargeAllowed(c) { + return + } // 获取支付方式 payMethods := operation_setting.PayMethods @@ -164,6 +167,9 @@ func getMinTopup() int64 { } func RequestEpay(c *gin.Context) { + if !ensureRechargeAllowed(c) { + return + } var req EpayRequest err := c.ShouldBindJSON(&req) if err != nil { @@ -366,6 +372,9 @@ func EpayNotify(c *gin.Context) { } func RequestAmount(c *gin.Context) { + if !ensureRechargeAllowed(c) { + return + } var req AmountRequest err := c.ShouldBindJSON(&req) if err != nil { @@ -463,4 +472,3 @@ func AdminCompleteTopUp(c *gin.Context) { } common.ApiSuccess(c, nil) } - diff --git a/controller/topup_access.go b/controller/topup_access.go new file mode 100644 index 00000000000..4f43e7d3528 --- /dev/null +++ b/controller/topup_access.go @@ -0,0 +1,21 @@ +package controller + +import ( + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" +) + +func ensureRechargeAllowed(c *gin.Context) bool { + userId := c.GetInt("id") + user, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return false + } + if !user.AllowRecharge { + common.ApiErrorMsg(c, "当前账户不支持充值") + return false + } + return true +} diff --git a/controller/topup_access_test.go b/controller/topup_access_test.go new file mode 100644 index 00000000000..3c6408ea386 --- /dev/null +++ b/controller/topup_access_test.go @@ -0,0 +1,71 @@ +package controller + +import ( + "net/http" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/stretchr/testify/require" +) + +func TestGetTopUpInfoRejectsRechargeRestrictedUser(t *testing.T) { + db := setupUserRechargeControllerTestDB(t) + user := seedRechargeUser(t, db, model.User{ + Username: "topup-disabled-user", + Password: "password123", + DisplayName: "Topup Disabled", + Role: common.RoleCommonUser, + Status: common.UserStatusEnabled, + Group: "default", + AllowRecharge: false, + }) + + ctx, recorder := newUserRechargeContext(t, http.MethodGet, "/api/user/topup/info", nil, user.Id, common.RoleCommonUser) + GetTopUpInfo(ctx) + + response := decodeUserRechargeResponse(t, recorder) + require.False(t, response.Success) + require.Equal(t, "当前账户不支持充值", response.Message) +} + +func TestRequestAmountRejectsRechargeRestrictedUser(t *testing.T) { + db := setupUserRechargeControllerTestDB(t) + user := seedRechargeUser(t, db, model.User{ + Username: "amount-disabled-user", + Password: "password123", + DisplayName: "Amount Disabled", + Role: common.RoleCommonUser, + Status: common.UserStatusEnabled, + Group: "default", + AllowRecharge: false, + }) + + ctx, recorder := newUserRechargeContext(t, http.MethodPost, "/api/user/amount", map[string]any{ + "amount": 1, + }, user.Id, common.RoleCommonUser) + RequestAmount(ctx) + + response := decodeUserRechargeResponse(t, recorder) + require.False(t, response.Success) + require.Equal(t, "当前账户不支持充值", response.Message) +} + +func TestGetTopUpInfoAllowsRechargeEnabledUser(t *testing.T) { + db := setupUserRechargeControllerTestDB(t) + user := seedRechargeUser(t, db, model.User{ + Username: "topup-enabled-user", + Password: "password123", + DisplayName: "Topup Enabled", + Role: common.RoleCommonUser, + Status: common.UserStatusEnabled, + Group: "default", + AllowRecharge: true, + }) + + ctx, recorder := newUserRechargeContext(t, http.MethodGet, "/api/user/topup/info", nil, user.Id, common.RoleCommonUser) + GetTopUpInfo(ctx) + + response := decodeUserRechargeResponse(t, recorder) + require.True(t, response.Success, response.Message) +} diff --git a/controller/topup_creem.go b/controller/topup_creem.go index 54b67b854f4..40add794fc6 100644 --- a/controller/topup_creem.go +++ b/controller/topup_creem.go @@ -143,6 +143,9 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) { } func RequestCreemPay(c *gin.Context) { + if !ensureRechargeAllowed(c) { + return + } var req CreemPayRequest // 读取body内容用于打印,同时保留原始数据供后续使用 diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index e1718cc5ec8..392048cc9de 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -126,6 +126,9 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { } func RequestStripeAmount(c *gin.Context) { + if !ensureRechargeAllowed(c) { + return + } var req StripePayRequest err := c.ShouldBindJSON(&req) if err != nil { @@ -136,6 +139,9 @@ func RequestStripeAmount(c *gin.Context) { } func RequestStripePay(c *gin.Context) { + if !ensureRechargeAllowed(c) { + return + } var req StripePayRequest err := c.ShouldBindJSON(&req) if err != nil { diff --git a/controller/topup_waffo.go b/controller/topup_waffo.go index fce37642338..690146e7b33 100644 --- a/controller/topup_waffo.go +++ b/controller/topup_waffo.go @@ -101,6 +101,9 @@ type WaffoPayRequest struct { // RequestWaffoPay 创建 Waffo 支付订单 func RequestWaffoPay(c *gin.Context) { + if !ensureRechargeAllowed(c) { + return + } if !setting.WaffoEnabled { c.JSON(200, gin.H{"message": "error", "data": "Waffo 支付未启用"}) return diff --git a/controller/user.go b/controller/user.go index 7921e337c6e..0dc7440faa0 100644 --- a/controller/user.go +++ b/controller/user.go @@ -175,11 +175,12 @@ func Register(c *gin.Context) { affCode := user.AffCode // this code is the inviter's code, not the user's own code inviterId, _ := model.GetUserIdByAffCode(affCode) cleanUser := model.User{ - Username: user.Username, - Password: user.Password, - DisplayName: user.Username, - InviterId: inviterId, - Role: common.RoleCommonUser, // 明确设置角色为普通用户 + Username: user.Username, + Password: user.Password, + DisplayName: user.Username, + InviterId: inviterId, + Role: common.RoleCommonUser, // 明确设置角色为普通用户 + AllowRecharge: true, } if common.EmailVerificationEnabled { cleanUser.Email = user.Email @@ -412,6 +413,7 @@ func GetSelf(c *gin.Context) { "linux_do_id": user.LinuxDOId, "setting": user.Setting, "stripe_customer": user.StripeCustomer, + "allow_recharge": user.AllowRecharge, "sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段 "permissions": permissions, // 新增权限字段 } @@ -542,8 +544,19 @@ func GetUserModels(c *gin.Context) { } func UpdateUser(c *gin.Context) { + var requestData map[string]any + err := json.NewDecoder(c.Request.Body).Decode(&requestData) + if err != nil { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } + requestDataBytes, err := json.Marshal(requestData) + if err != nil { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } var updatedUser model.User - err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) + err = json.Unmarshal(requestDataBytes, &updatedUser) if err != nil || updatedUser.Id == 0 { common.ApiErrorI18n(c, i18n.MsgInvalidParams) return @@ -569,6 +582,9 @@ func UpdateUser(c *gin.Context) { common.ApiErrorI18n(c, i18n.MsgUserCannotCreateHigherLevel) return } + if _, ok := requestData["allow_recharge"]; !ok { + updatedUser.AllowRecharge = originUser.AllowRecharge + } if updatedUser.Password == "$I_LOVE_U" { updatedUser.Password = "" // rollback to what it should be } @@ -802,8 +818,21 @@ func DeleteSelf(c *gin.Context) { } func CreateUser(c *gin.Context) { + var requestData map[string]any + err := json.NewDecoder(c.Request.Body).Decode(&requestData) + if err != nil { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } + + requestDataBytes, err := json.Marshal(requestData) + if err != nil { + common.ApiErrorI18n(c, i18n.MsgInvalidParams) + return + } + var user model.User - err := json.NewDecoder(c.Request.Body).Decode(&user) + err = json.Unmarshal(requestDataBytes, &user) user.Username = strings.TrimSpace(user.Username) if err != nil || user.Username == "" || user.Password == "" { common.ApiErrorI18n(c, i18n.MsgInvalidParams) @@ -822,16 +851,29 @@ func CreateUser(c *gin.Context) { return } // Even for admin users, we cannot fully trust them! + allowRecharge := true + if rawAllowRecharge, ok := requestData["allow_recharge"]; ok { + if allowRechargeValue, ok := rawAllowRecharge.(bool); ok { + allowRecharge = allowRechargeValue + } + } cleanUser := model.User{ - Username: user.Username, - Password: user.Password, - DisplayName: user.DisplayName, - Role: user.Role, // 保持管理员设置的角色 + Username: user.Username, + Password: user.Password, + DisplayName: user.DisplayName, + Role: user.Role, // 保持管理员设置的角色 + AllowRecharge: allowRecharge, } if err := cleanUser.Insert(0); err != nil { common.ApiError(c, err) return } + if !allowRecharge { + if err := model.DB.Model(&model.User{}).Where("id = ?", cleanUser.Id).Update("allow_recharge", false).Error; err != nil { + common.ApiError(c, err) + return + } + } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/controller/user_recharge_test.go b/controller/user_recharge_test.go new file mode 100644 index 00000000000..b4e8ad1025b --- /dev/null +++ b/controller/user_recharge_test.go @@ -0,0 +1,197 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +type userRechargeAPIResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` +} + +type selfUserRechargeResponse struct { + ID int `json:"id"` + Username string `json:"username"` + AllowRecharge bool `json:"allow_recharge"` +} + +func setupUserRechargeControllerTestDB(t *testing.T) *gorm.DB { + t.Helper() + + gin.SetMode(gin.TestMode) + common.UsingSQLite = true + common.UsingMySQL = false + common.UsingPostgreSQL = false + common.RedisEnabled = false + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + require.NoError(t, err) + + model.DB = db + model.LOG_DB = db + + require.NoError(t, db.AutoMigrate(&model.User{})) + + t.Cleanup(func() { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + }) + + return db +} + +func newUserRechargeContext(t *testing.T, method string, target string, body any, userID int, role int) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + + var requestBody *bytes.Reader + if body != nil { + payload, err := common.Marshal(body) + require.NoError(t, err) + requestBody = bytes.NewReader(payload) + } else { + requestBody = bytes.NewReader(nil) + } + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(method, target, requestBody) + if body != nil { + ctx.Request.Header.Set("Content-Type", "application/json") + } + ctx.Set("id", userID) + ctx.Set("role", role) + return ctx, recorder +} + +func decodeUserRechargeResponse(t *testing.T, recorder *httptest.ResponseRecorder) userRechargeAPIResponse { + t.Helper() + + var response userRechargeAPIResponse + require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &response)) + return response +} + +func seedRechargeUser(t *testing.T, db *gorm.DB, user model.User) *model.User { + t.Helper() + + createMap := map[string]interface{}{ + "username": user.Username, + "password": user.Password, + "display_name": user.DisplayName, + "role": user.Role, + "status": user.Status, + "group": user.Group, + "allow_recharge": user.AllowRecharge, + } + if user.Id > 0 { + createMap["id"] = user.Id + } + + require.NoError(t, db.Model(&model.User{}).Create(createMap).Error) + var created model.User + require.NoError(t, db.Where("username = ?", user.Username).First(&created).Error) + require.NoError(t, db.First(&created, created.Id).Error) + return &created +} + +func TestCreateUserPersistsAllowRecharge(t *testing.T) { + db := setupUserRechargeControllerTestDB(t) + + ctx, recorder := newUserRechargeContext(t, http.MethodPost, "/api/user/", map[string]any{ + "username": "create-recharge-user", + "password": "password123", + "display_name": "Create Recharge User", + "role": common.RoleCommonUser, + "allow_recharge": false, + }, 900, common.RoleRootUser) + + CreateUser(ctx) + + response := decodeUserRechargeResponse(t, recorder) + require.True(t, response.Success, response.Message) + + var created model.User + require.NoError(t, db.Where("username = ?", "create-recharge-user").First(&created).Error) + require.False(t, created.AllowRecharge) +} + +func TestUpdateUserPersistsAllowRecharge(t *testing.T) { + db := setupUserRechargeControllerTestDB(t) + user := seedRechargeUser(t, db, model.User{ + Id: 1001, + Username: "update-recharge-user", + Password: "password123", + DisplayName: "Update Recharge User", + Role: common.RoleCommonUser, + Status: common.UserStatusEnabled, + Group: "default", + AllowRecharge: true, + }) + + ctx, recorder := newUserRechargeContext(t, http.MethodPut, "/api/user/", map[string]any{ + "id": user.Id, + "username": user.Username, + "display_name": user.DisplayName, + "role": user.Role, + "status": user.Status, + "group": user.Group, + "allow_recharge": false, + }, 901, common.RoleRootUser) + + UpdateUser(ctx) + + response := decodeUserRechargeResponse(t, recorder) + require.True(t, response.Success, response.Message) + + updated, err := model.GetUserById(user.Id, true) + require.NoError(t, err) + require.False(t, updated.AllowRecharge) +} + +func TestGetSelfIncludesAllowRecharge(t *testing.T) { + db := setupUserRechargeControllerTestDB(t) + user := seedRechargeUser(t, db, model.User{ + Id: 1101, + Username: "self-recharge-user", + Password: "password123", + DisplayName: "Self Recharge User", + Role: common.RoleCommonUser, + Status: common.UserStatusEnabled, + Group: "default", + AllowRecharge: false, + }) + var allowRechargeRaw int + require.NoError(t, db.Raw("SELECT allow_recharge FROM users WHERE id = ?", user.Id).Scan(&allowRechargeRaw).Error) + require.Equal(t, 0, allowRechargeRaw) + user, err := model.GetUserById(user.Id, false) + require.NoError(t, err) + require.False(t, user.AllowRecharge) + + ctx, recorder := newUserRechargeContext(t, http.MethodGet, "/api/user/self", nil, user.Id, common.RoleCommonUser) + GetSelf(ctx) + + response := decodeUserRechargeResponse(t, recorder) + require.True(t, response.Success, response.Message) + + var self selfUserRechargeResponse + require.NoError(t, common.Unmarshal(response.Data, &self)) + require.Equal(t, user.Id, self.ID) + require.False(t, self.AllowRecharge) +} diff --git a/go.mod b/go.mod index a078b2091ad..f66d6267795 100644 --- a/go.mod +++ b/go.mod @@ -61,6 +61,7 @@ require ( ) require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/DmitriyVTitov/size v1.5.0 // indirect github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect diff --git a/go.sum b/go.sum index 8b687906336..1d568a3c875 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A= github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DmitriyVTitov/size v1.5.0 h1:/PzqxYrOyOUX1BXj6J9OuVRVGe+66VL4D9FlUaW515g= github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0= github.com/abema/go-mp4 v1.4.1 h1:YoS4VRqd+pAmddRPLFf8vMk74kuGl6ULSjzhsIqwr6M= @@ -172,6 +174,7 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= diff --git a/model/subscription.go b/model/subscription.go index 2d23a8b5bf2..1913962927b 100644 --- a/model/subscription.go +++ b/model/subscription.go @@ -3,6 +3,7 @@ package model import ( "errors" "fmt" + "sort" "strconv" "strings" "sync" @@ -266,8 +267,16 @@ func (s *UserSubscription) BeforeUpdate(tx *gorm.DB) error { return nil } +type SubscriptionPlanSummary struct { + Id int `json:"id"` + Title string `json:"title"` + QuotaResetPeriod string `json:"quota_reset_period"` + QuotaResetCustomSeconds int64 `json:"quota_reset_custom_seconds"` +} + type SubscriptionSummary struct { - Subscription *UserSubscription `json:"subscription"` + Subscription *UserSubscription `json:"subscription"` + Plan *SubscriptionPlanSummary `json:"plan,omitempty"` } func calcPlanEndTime(start time.Time, plan *SubscriptionPlan) (int64, error) { @@ -701,16 +710,80 @@ func buildSubscriptionSummaries(subs []UserSubscription) []SubscriptionSummary { if len(subs) == 0 { return []SubscriptionSummary{} } + planIDs := make([]int, 0, len(subs)) + seenPlanIDs := make(map[int]struct{}, len(subs)) + for _, sub := range subs { + if sub.PlanId <= 0 { + continue + } + if _, ok := seenPlanIDs[sub.PlanId]; ok { + continue + } + seenPlanIDs[sub.PlanId] = struct{}{} + planIDs = append(planIDs, sub.PlanId) + } + planSummaryMap := make(map[int]*SubscriptionPlanSummary, len(planIDs)) + if len(planIDs) > 0 { + var plans []SubscriptionPlan + if err := DB.Where("id IN ?", planIDs).Find(&plans).Error; err == nil { + for i := range plans { + plan := plans[i] + planSummaryMap[plan.Id] = &SubscriptionPlanSummary{ + Id: plan.Id, + Title: plan.Title, + QuotaResetPeriod: plan.QuotaResetPeriod, + QuotaResetCustomSeconds: plan.QuotaResetCustomSeconds, + } + } + } + } result := make([]SubscriptionSummary, 0, len(subs)) for _, sub := range subs { subCopy := sub - result = append(result, SubscriptionSummary{ + summary := SubscriptionSummary{ Subscription: &subCopy, - }) + } + if planSummary, ok := planSummaryMap[subCopy.PlanId]; ok { + planCopy := *planSummary + summary.Plan = &planCopy + } + result = append(result, summary) } return result } +func isSubscriptionStillConsumable(sub *UserSubscription) bool { + if sub == nil { + return false + } + return sub.AmountTotal == 0 || sub.AmountUsed < sub.AmountTotal +} + +func SelectPrimarySubscriptionSummary(subs []SubscriptionSummary) *SubscriptionSummary { + if len(subs) == 0 { + return nil + } + ordered := make([]SubscriptionSummary, len(subs)) + copy(ordered, subs) + sort.SliceStable(ordered, func(i, j int) bool { + left := ordered[i].Subscription + right := ordered[j].Subscription + if left == nil || right == nil { + return left != nil + } + if left.EndTime != right.EndTime { + return left.EndTime < right.EndTime + } + return left.Id < right.Id + }) + for i := range ordered { + if isSubscriptionStillConsumable(ordered[i].Subscription) { + return &ordered[i] + } + } + return &ordered[0] +} + // AdminInvalidateUserSubscription marks a user subscription as cancelled and ends it immediately. func AdminInvalidateUserSubscription(userSubscriptionId int) (string, error) { if userSubscriptionId <= 0 { diff --git a/model/user.go b/model/user.go index 79e63e8fd59..ddfefe50593 100644 --- a/model/user.go +++ b/model/user.go @@ -50,6 +50,7 @@ type User struct { Setting string `json:"setting" gorm:"type:text;column:setting"` Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"` + AllowRecharge bool `json:"allow_recharge" gorm:"default:true"` } func (user *User) ToBaseUser() *UserBase { @@ -293,12 +294,13 @@ func GetUserById(id int, selectAll bool) (*User, error) { if id == 0 { return nil, errors.New("id 为空!") } - user := User{Id: id} - var err error = nil - if selectAll { - err = DB.First(&user, "id = ?", id).Error - } else { - err = DB.Omit("password").First(&user, "id = ?", id).Error + user := User{} + err := DB.First(&user, "id = ?", id).Error + if err != nil { + return &user, err + } + if !selectAll { + user.Password = "" } return &user, err } @@ -520,10 +522,11 @@ func (user *User) Edit(updatePassword bool) error { newUser := *user updates := map[string]interface{}{ - "username": newUser.Username, - "display_name": newUser.DisplayName, - "group": newUser.Group, - "remark": newUser.Remark, + "username": newUser.Username, + "display_name": newUser.DisplayName, + "group": newUser.Group, + "remark": newUser.Remark, + "allow_recharge": newUser.AllowRecharge, } if updatePassword { updates["password"] = newUser.Password diff --git a/model/user_allow_recharge_test.go b/model/user_allow_recharge_test.go new file mode 100644 index 00000000000..e69b928a8f0 --- /dev/null +++ b/model/user_allow_recharge_test.go @@ -0,0 +1,60 @@ +package model + +import ( + "bytes" + "log" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" + "gorm.io/driver/postgres" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +func TestGetUserByIdDoesNotIssueExtraAllowRechargeQueryOnPostgres(t *testing.T) { + sqlDB, mock, err := sqlmock.New() + require.NoError(t, err) + defer sqlDB.Close() + + var logBuffer bytes.Buffer + gormDB, err := gorm.Open(postgres.New(postgres.Config{Conn: sqlDB}), &gorm.Config{ + Logger: gormlogger.New( + log.New(&logBuffer, "", 0), + gormlogger.Config{ + LogLevel: gormlogger.Error, + Colorful: false, + }, + ), + }) + require.NoError(t, err) + + originalDB := DB + DB = gormDB + defer func() { + DB = originalDB + }() + + rows := sqlmock.NewRows([]string{ + "id", + "username", + "password", + "display_name", + "role", + "status", + "group", + "allow_recharge", + }).AddRow(7, "pg-user", "secret", "PG User", 1, 1, "default", false) + + mock.ExpectQuery(`SELECT \* FROM "users" WHERE id = \$1 AND "users"\."deleted_at" IS NULL ORDER BY "users"\."id" LIMIT 1`). + WithArgs(7). + WillReturnRows(rows) + + user, err := GetUserById(7, false) + require.NoError(t, err) + require.NotNil(t, user) + require.False(t, user.AllowRecharge) + require.Empty(t, user.Password) + require.NoError(t, mock.ExpectationsWereMet()) + require.NotContains(t, logBuffer.String(), "SELECT allow_recharge FROM users") +} diff --git a/web/src/components/dashboard/DashboardHeader.jsx b/web/src/components/dashboard/DashboardHeader.jsx index c2867e90c2a..b49b16afe83 100644 --- a/web/src/components/dashboard/DashboardHeader.jsx +++ b/web/src/components/dashboard/DashboardHeader.jsx @@ -20,6 +20,7 @@ For commercial licensing, please contact support@quantumnous.com import React from 'react'; import { Button } from '@douyinfe/semi-ui'; import { RefreshCw, Search } from 'lucide-react'; +import DashboardSubscriptionSummary from './DashboardSubscriptionSummary'; const DashboardHeader = ({ getGreeting, @@ -27,6 +28,7 @@ const DashboardHeader = ({ showSearchModal, refresh, loading, + dashboardSubscriptionSummary, t, }) => { const ICON_BUTTON_CLASS = 'text-white hover:bg-opacity-80 !rounded-full'; @@ -39,20 +41,26 @@ const DashboardHeader = ({ > {getGreeting} -
-
); diff --git a/web/src/components/dashboard/DashboardSubscriptionSummary.jsx b/web/src/components/dashboard/DashboardSubscriptionSummary.jsx new file mode 100644 index 00000000000..e346b768949 --- /dev/null +++ b/web/src/components/dashboard/DashboardSubscriptionSummary.jsx @@ -0,0 +1,177 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React, { useMemo } from 'react'; +import { Popover } from '@douyinfe/semi-ui'; +import { + buildDashboardSubscriptionTriggerText, + getDashboardSubscriptionTranslator, + shouldShowDashboardSubscriptionTriggerTitle, +} from '../../helpers/dashboardSubscriptionSummary'; + +const DashboardSubscriptionSummary = ({ dashboardSubscriptionSummary, t }) => { + const translate = getDashboardSubscriptionTranslator(t); + const summary = dashboardSubscriptionSummary?.summary || dashboardSubscriptionSummary || {}; + const summaryText = summary?.summaryText?.trim() || ''; + const resetText = summary?.resetText?.trim() || ''; + const extraText = summary?.extraText?.trim() || ''; + const badgeText = summary?.badgeText?.trim() || ''; + const usedAmountText = summary?.usedAmountText?.trim() || ''; + const totalAmountText = summary?.totalAmountText?.trim() || ''; + const showProgress = Boolean(summary?.showProgress); + const progressPercent = Number.isFinite(summary?.progressPercent) + ? Math.max(0, Math.min(100, Number(summary.progressPercent))) + : 0; + const displayProgressPercent = Number.isFinite(summary?.displayProgressPercent) + ? Math.max(0, Math.min(100, Number(summary.displayProgressPercent))) + : progressPercent; + const rows = Array.isArray(dashboardSubscriptionSummary?.rows) + ? dashboardSubscriptionSummary.rows + : []; + const shouldShowTitle = shouldShowDashboardSubscriptionTriggerTitle(summary); + const triggerText = buildDashboardSubscriptionTriggerText(summary); + + const popoverContent = useMemo(() => { + if (!rows.length) { + return ( +
+
+ {translate('暂无活跃订阅')} +
+
+ ); + } + + return ( +
+
+ {translate('活跃订阅')} +
+
+ {rows.map((row, index) => ( +
0 ? 'border-t border-black/8 dark:border-white/[0.08]' : ''}`} + > +
+
+ + + + {row.badgeText || 'PLAN'} + + + {row.titleText && row.titleText !== row.badgeText ? ( + + {row.titleText} + + ) : null} +
+ {row.isPrimary ? ( + + {translate('主订阅')} + + ) : null} +
+
+ + {row.usedAmountText || '$0.00'} + + / + + {row.totalAmountText || '-'} + +
+
+ {row.resetText || '-'} +
+ {row.showProgress ? ( +
+
+
+ ) : null} +
+ ))} +
+
+ ); + }, [rows, translate]); + + if (!summaryText) { + return null; + } + + return ( + + + + ); +}; + +export default DashboardSubscriptionSummary; diff --git a/web/src/components/dashboard/index.jsx b/web/src/components/dashboard/index.jsx index 811e23ca760..5c77d56558b 100644 --- a/web/src/components/dashboard/index.jsx +++ b/web/src/components/dashboard/index.jsx @@ -96,11 +96,13 @@ const Dashboard = () => { }; const initChart = async () => { - await dashboardData.loadQuotaData().then((data) => { - if (data && data.length > 0) { - dashboardCharts.updateChartData(data); - } - }); + const [data] = await Promise.all([ + dashboardData.loadQuotaData(), + dashboardData.loadDashboardSubscriptionSummary(), + ]); + if (data && data.length > 0) { + dashboardCharts.updateChartData(data); + } await loadUserData(); await dashboardData.loadUptimeData(); }; @@ -158,6 +160,7 @@ const Dashboard = () => { showSearchModal={dashboardData.showSearchModal} refresh={handleRefresh} loading={dashboardData.loading} + dashboardSubscriptionSummary={dashboardData.dashboardSubscriptionSummary} t={dashboardData.t} /> diff --git a/web/src/components/layout/SiderBar.jsx b/web/src/components/layout/SiderBar.jsx index bcbe41237a6..94ecaaa50f2 100644 --- a/web/src/components/layout/SiderBar.jsx +++ b/web/src/components/layout/SiderBar.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React, { useEffect, useMemo, useState } from 'react'; +import React, { useContext, useEffect, useMemo, useState } from 'react'; import { Link, useLocation } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; import { getLucideIcon } from '../../helpers/render'; @@ -26,6 +26,8 @@ import { useSidebarCollapsed } from '../../hooks/common/useSidebarCollapsed'; import { useSidebar } from '../../hooks/common/useSidebar'; import { useMinimumLoadingTime } from '../../hooks/common/useMinimumLoadingTime'; import { isAdmin, isRoot, showError } from '../../helpers'; +import { UserContext } from '../../context/User'; +import { canAccessWalletManagement } from '../../helpers/rechargeAccess'; import SkeletonWrapper from './components/SkeletonWrapper'; import { Nav, Divider, Button } from '@douyinfe/semi-ui'; @@ -53,6 +55,7 @@ const routerMap = { const SiderBar = ({ onNavigate = () => {} }) => { const { t } = useTranslation(); + const [userState] = useContext(UserContext); const [collapsed, toggleCollapsed] = useSidebarCollapsed(); const { isModuleVisible, @@ -138,12 +141,18 @@ const SiderBar = ({ onNavigate = () => {} }) => { // 根据配置过滤项目 const filteredItems = items.filter((item) => { + if ( + item.itemKey === 'topup' && + !canAccessWalletManagement(userState?.user) + ) { + return false; + } const configVisible = isModuleVisible('personal', item.itemKey); return configVisible; }); return filteredItems; - }, [t, isModuleVisible]); + }, [t, isModuleVisible, userState?.user]); const adminItems = useMemo(() => { const items = [ diff --git a/web/src/components/layout/headerbar/UserArea.jsx b/web/src/components/layout/headerbar/UserArea.jsx index 9fc011da18a..c15d7993f0b 100644 --- a/web/src/components/layout/headerbar/UserArea.jsx +++ b/web/src/components/layout/headerbar/UserArea.jsx @@ -28,6 +28,7 @@ import { IconKey, } from '@douyinfe/semi-icons'; import { stringToColor } from '../../../helpers'; +import { canAccessWalletManagement } from '../../../helpers/rechargeAccess'; import SkeletonWrapper from '../components/SkeletonWrapper'; const UserArea = ({ @@ -87,20 +88,22 @@ const UserArea = ({ {t('令牌管理')}
- { - navigate('/console/topup'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('钱包管理')} -
-
+ {canAccessWalletManagement(userState.user) && ( + { + navigate('/console/topup'); + }} + className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' + > +
+ + {t('钱包管理')} +
+
+ )} { display_name: '', password: '', remark: '', + allow_recharge: true, }); const submit = async (values) => { @@ -173,6 +174,17 @@ const AddUserModal = (props) => { showClear /> + + + diff --git a/web/src/components/table/users/modals/EditUserModal.jsx b/web/src/components/table/users/modals/EditUserModal.jsx index 94ccc0a0ce5..3418593b40d 100644 --- a/web/src/components/table/users/modals/EditUserModal.jsx +++ b/web/src/components/table/users/modals/EditUserModal.jsx @@ -94,6 +94,7 @@ const EditUserModal = (props) => { quota_amount: 0, group: 'default', remark: '', + allow_recharge: true, }); const fetchGroups = async () => { @@ -335,85 +336,104 @@ const EditUserModal = (props) => { {/* 权限设置 */} - {userId && ( - -
- - - -
- - {t('权限设置')} - -
- {t('用户分组和额度管理')} -
+ +
+ + + +
+ + {t('权限设置')} + +
+ {userId + ? t('用户分组、额度和充值权限管理') + : t('配置用户的基础权限与充值能力')}
+
- - - - - - - - - - - - - - - - -
setShowQuotaInput((v) => !v)} - > - {showQuotaInput - ? `▾ ${t('收起原生额度输入')}` - : `▸ ${t('使用原生额度输入')}`} -
-
+ + + + + + {userId && ( + <> + + + + + -
- -
-
- )} + + + + + + + + + +
setShowQuotaInput((v) => !v)} + > + {showQuotaInput + ? `▾ ${t('收起原生额度输入')}` + : `▸ ${t('使用原生额度输入')}`} +
+
+ +
+ + + )} + + {/* 绑定信息入口 */} {userId && ( diff --git a/web/src/components/topup/RechargeCard.jsx b/web/src/components/topup/RechargeCard.jsx index f37d129b33d..3fe99a1eb20 100644 --- a/web/src/components/topup/RechargeCard.jsx +++ b/web/src/components/topup/RechargeCard.jsx @@ -54,6 +54,7 @@ const { Text } = Typography; const RechargeCard = ({ t, + walletAccessAllowed = true, enableOnlineTopUp, enableStripeTopUp, enableCreemTopUp, @@ -223,7 +224,16 @@ const RechargeCard = ({ } > {/* 在线充值表单 */} - {statusLoading ? ( + {!walletAccessAllowed ? ( + + ) : statusLoading ? (
diff --git a/web/src/components/topup/index.jsx b/web/src/components/topup/index.jsx index 0348e3c8dd9..afbcb66949d 100644 --- a/web/src/components/topup/index.jsx +++ b/web/src/components/topup/index.jsx @@ -29,6 +29,7 @@ import { copy, getQuotaPerUnit, } from '../../helpers'; +import { canAccessWalletManagement } from '../../helpers/rechargeAccess'; import { Modal, Toast } from '@douyinfe/semi-ui'; import { useTranslation } from 'react-i18next'; import { UserContext } from '../../context/User'; @@ -45,6 +46,7 @@ const TopUp = () => { const [searchParams, setSearchParams] = useSearchParams(); const [userState, userDispatch] = useContext(UserContext); const [statusState] = useContext(StatusContext); + const walletAccessAllowed = canAccessWalletManagement(userState?.user); const [redemptionCode, setRedemptionCode] = useState(''); const [amount, setAmount] = useState(0.0); @@ -591,12 +593,29 @@ const TopUp = () => { getAffLink().then(); }, []); - // 在 statusState 可用时获取充值信息 useEffect(() => { - getTopupInfo().then(); + if (!userState?.user) return; + + if (walletAccessAllowed) { + getTopupInfo().then(); + } else { + setTopupInfo({ + amount_options: [], + discount: {}, + }); + setPayMethods([]); + setPresetAmounts([]); + setEnableOnlineTopUp(false); + setEnableStripeTopUp(false); + setEnableCreemTopUp(false); + setEnableWaffoTopUp(false); + setWaffoPayMethods([]); + setCreemProducts([]); + } + getSubscriptionPlans().then(); getSubscriptionSelf().then(); - }, []); + }, [userState?.user?.id, userState?.user?.allow_recharge, walletAccessAllowed]); useEffect(() => { if (statusState?.status) { @@ -783,6 +802,7 @@ const TopUp = () => {
. + +For commercial licensing, please contact support@quantumnous.com +*/ + +function toNumber(value) { + const number = Number(value); + return Number.isFinite(number) ? number : 0; +} + +function toTimestampMs(value) { + const number = toNumber(value); + if (number <= 0) return 0; + return number < 1e12 ? number * 1000 : number; +} + +function formatCompactDecimal(value) { + const number = Math.abs(toNumber(value)); + return number.toFixed(2); +} + +export function formatDashboardSubscriptionAmount(value) { + const number = toNumber(value); + const sign = number < 0 ? '-' : ''; + return `$${sign}${formatCompactDecimal(number)}`; +} + +export function getDashboardSubscriptionTranslator(translate) { + return typeof translate === 'function' ? translate : (value) => value; +} + +function convertQuotaToUsd(value, options = {}) { + const quotaPerUnit = toNumber(options.quotaPerUnit) > 0 ? toNumber(options.quotaPerUnit) : 1; + return toNumber(value) / quotaPerUnit; +} + +export function formatDashboardSubscriptionResetTime(value, options = {}) { + const timestampMs = toTimestampMs(value); + if (!timestampMs) return ''; + + const date = new Date(timestampMs); + const formatter = new Intl.DateTimeFormat('en-US', { + timeZone: options.timeZone, + month: '2-digit', + day: '2-digit', + hour: '2-digit', + minute: '2-digit', + hour12: false, + }); + const parts = formatter.formatToParts(date); + const getPart = (type) => parts.find((part) => part.type === type)?.value || ''; + return `${getPart('month')}-${getPart('day')} ${getPart('hour')}:${getPart('minute')}`; +} + +function buildQuotaText(subscription = {}, options = {}) { + const usedText = formatDashboardSubscriptionAmount( + convertQuotaToUsd(subscription.amount_used, options), + ); + const totalAmount = toNumber(subscription.amount_total); + const totalText = totalAmount > 0 + ? formatDashboardSubscriptionAmount(convertQuotaToUsd(totalAmount, options)) + : '∞'; + return { + usedAmountText: usedText, + totalAmountText: totalText, + quotaText: `${usedText} / ${totalText}`, + }; +} + +function buildResetText(subscription = {}, options = {}) { + const nextResetTime = toTimestampMs(subscription.next_reset_time); + if (nextResetTime > 0) { + const formatted = formatDashboardSubscriptionResetTime(nextResetTime, options); + return formatted + ? { resetText: `${formatted} 刷新`, summaryText: `${formatted} 刷新` } + : { resetText: '', summaryText: '' }; + } + return { resetText: '有效期总额度', summaryText: '有效期总额度' }; +} + +function buildBadgeText(titleText) { + const normalizedTitle = String(titleText || '').trim(); + if (!normalizedTitle) return ''; + + const englishWord = normalizedTitle.match(/[A-Za-z0-9+-]+/); + if (englishWord?.[0]) { + return englishWord[0].slice(0, 8).toUpperCase(); + } + + return normalizedTitle.slice(0, 4).toUpperCase(); +} + +function normalizeTriggerTitleToken(value) { + return String(value || '') + .trim() + .toUpperCase() + .replace(/[^A-Z0-9\u4E00-\u9FFF]+/g, ''); +} + +export function shouldShowDashboardSubscriptionTriggerTitle(summary = {}) { + const titleText = String(summary.titleText || '').trim(); + if (!titleText) return false; + + const badgeText = String(summary.badgeText || '').trim(); + if (!badgeText) return true; + + const normalizedTitle = normalizeTriggerTitleToken(titleText); + const normalizedBadge = normalizeTriggerTitleToken(badgeText); + if (!normalizedTitle || !normalizedBadge) return true; + + return !normalizedTitle.startsWith(normalizedBadge); +} + +export function buildDashboardSubscriptionTriggerText(summary = {}) { + const quotaText = String(summary.quotaText || '').trim(); + const extraText = String(summary.extraText || '').trim(); + + if (!quotaText) { + return extraText; + } + + return extraText ? `${quotaText} · ${extraText}` : quotaText; +} + +function buildProgressData(subscription = {}) { + const totalAmount = toNumber(subscription.amount_total); + const usedAmount = Math.max(0, toNumber(subscription.amount_used)); + + if (totalAmount <= 0) { + return { + showProgress: false, + progressPercent: 0, + displayProgressPercent: 0, + }; + } + + const progressPercent = Math.max(0, Math.min(100, (usedAmount / totalAmount) * 100)); + const displayProgressPercent = + progressPercent > 0 && progressPercent < 2 ? 2 : progressPercent; + + return { + showProgress: true, + progressPercent, + displayProgressPercent, + }; +} + +function getSubscriptionId(summary = {}) { + const subscription = summary.subscription || summary.Subscription || {}; + const rawId = + subscription.id ?? subscription.Id ?? summary.id ?? summary.Id ?? summary.subscription_id; + return Math.trunc(toNumber(rawId)); +} + +function orderSubscriptionSummariesForPopover(payload = {}) { + const primarySubscription = + payload.primary_subscription || payload.primarySubscription || payload.subscription || null; + const primaryId = getSubscriptionId(primarySubscription); + const activeSubscriptions = Array.isArray(payload.subscriptions) + ? payload.subscriptions + : Array.isArray(payload.active_subscriptions) + ? payload.active_subscriptions + : Array.isArray(payload.activeSubscriptions) + ? payload.activeSubscriptions + : []; + + const ordered = activeSubscriptions.filter(Boolean).slice(); + if (!primarySubscription) { + return ordered; + } + + if (primaryId > 0) { + const primaryIndex = ordered.findIndex((item) => getSubscriptionId(item) === primaryId); + if (primaryIndex >= 0) { + const [primaryItem] = ordered.splice(primaryIndex, 1); + ordered.unshift(primaryItem); + return ordered; + } + } + + return [primarySubscription, ...ordered]; +} + +export function buildDashboardSubscriptionSummaryViewModel(source = {}, options = {}) { + const plan = source.plan || {}; + const subscription = source.subscription || {}; + const titleText = String(plan.title || '').trim(); + const badgeText = buildBadgeText(titleText); + const { + usedAmountText, + totalAmountText, + quotaText, + } = buildQuotaText(subscription, options); + const { resetText, summaryText: resetSummaryText } = buildResetText(subscription, options); + const { showProgress, progressPercent, displayProgressPercent } = buildProgressData(subscription); + const extraCount = Math.max(0, Math.trunc(toNumber(options.extraCount))); + const extraText = extraCount > 0 ? `+${extraCount}` : ''; + + let summaryText = titleText ? `${titleText} ${quotaText}` : quotaText; + if (resetSummaryText) summaryText += ` · ${resetSummaryText}`; + if (extraText) summaryText += ` · ${extraText}`; + + return { + titleText, + badgeText, + usedAmountText, + totalAmountText, + quotaText, + resetText, + extraText, + showProgress, + progressPercent, + displayProgressPercent, + summaryText, + }; +} + +export function buildDashboardSubscriptionSummaryFromPayload(payload = {}, options = {}) { + const primarySubscription = + payload.primary_subscription || payload.primarySubscription || payload.subscription || null; + + if (!primarySubscription) { + return { + titleText: '', + quotaText: '', + resetText: '', + extraText: '', + summaryText: '', + }; + } + + const activeSubscriptionCount = Math.max( + 0, + Math.trunc(toNumber(payload.active_subscription_count ?? payload.activeSubscriptionCount)), + ); + + return buildDashboardSubscriptionSummaryViewModel(primarySubscription, { + ...options, + extraCount: Math.max(0, activeSubscriptionCount - 1), + }); +} + +export function buildDashboardSubscriptionPopoverRows(items = [], options = {}) { + return items.map((item) => { + const vm = buildDashboardSubscriptionSummaryViewModel(item, options); + return { + badgeText: vm.badgeText, + titleText: vm.titleText, + usedAmountText: vm.usedAmountText, + totalAmountText: vm.totalAmountText, + quotaText: vm.quotaText, + resetText: vm.resetText, + extraText: '', + showProgress: vm.showProgress, + progressPercent: vm.progressPercent, + displayProgressPercent: vm.displayProgressPercent, + }; + }); +} + +export function buildDashboardSubscriptionPopoverRowsFromPayload(payload = {}, options = {}) { + const orderedItems = orderSubscriptionSummariesForPopover(payload); + return orderedItems.map((item, index) => { + const vm = buildDashboardSubscriptionSummaryViewModel(item, { + ...options, + extraCount: 0, + }); + return { + ...vm, + isPrimary: index === 0, + }; + }); +} + +export function buildDashboardSubscriptionDisplayFromPayload(payload = {}, options = {}) { + return { + summary: buildDashboardSubscriptionSummaryFromPayload(payload, options), + rows: buildDashboardSubscriptionPopoverRowsFromPayload(payload, options), + }; +} diff --git a/web/src/helpers/dashboardSubscriptionSummary.test.js b/web/src/helpers/dashboardSubscriptionSummary.test.js new file mode 100644 index 00000000000..ef6c269f7aa --- /dev/null +++ b/web/src/helpers/dashboardSubscriptionSummary.test.js @@ -0,0 +1,317 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { + buildDashboardSubscriptionPopoverRows, + buildDashboardSubscriptionPopoverRowsFromPayload, + buildDashboardSubscriptionDisplayFromPayload, + buildDashboardSubscriptionSummaryFromPayload, + buildDashboardSubscriptionSummaryViewModel, + buildDashboardSubscriptionTriggerText, + formatDashboardSubscriptionAmount, + formatDashboardSubscriptionResetTime, + getDashboardSubscriptionTranslator, + shouldShowDashboardSubscriptionTriggerTitle, +} from './dashboardSubscriptionSummary.js'; + +test('formats usd quota amounts compactly', () => { + assert.equal(formatDashboardSubscriptionAmount(80), '$80.00'); + assert.equal(formatDashboardSubscriptionAmount(2.3), '$2.30'); + assert.equal(formatDashboardSubscriptionAmount(2.35), '$2.35'); + assert.equal(formatDashboardSubscriptionAmount(0), '$0.00'); +}); + +test('falls back to identity translator when translation function is missing', () => { + const fallbackTranslate = getDashboardSubscriptionTranslator(); + assert.equal(fallbackTranslate('主订阅'), '主订阅'); + + const customTranslate = getDashboardSubscriptionTranslator((value) => `x:${value}`); + assert.equal(customTranslate('活跃订阅'), 'x:活跃订阅'); +}); + +test('builds trigger text without directly exposing reset time', () => { + assert.equal( + buildDashboardSubscriptionTriggerText({ + titleText: 'Max 月订阅', + badgeText: 'MAX', + quotaText: '$2.30 / $80.00', + resetText: '04-15 00:00 刷新', + extraText: '+2', + }), + '$2.30 / $80.00 · +2', + ); + + assert.equal( + buildDashboardSubscriptionTriggerText({ + titleText: 'Daily', + badgeText: 'DAILY', + quotaText: '$1.00 / $5.00', + extraText: '', + }), + '$1.00 / $5.00', + ); +}); + +test('builds a styled summary view model with badge, exact amounts and progress', () => { + const vm = buildDashboardSubscriptionSummaryViewModel( + { + plan: { + title: 'Max 月订阅', + }, + subscription: { + amount_used: 23, + amount_total: 800, + next_reset_time: Date.UTC(2026, 3, 15, 0, 0, 0) / 1000, + }, + }, + { + quotaPerUnit: 10, + extraCount: 2, + timeZone: 'UTC', + }, + ); + + assert.equal( + vm.summaryText, + 'Max 月订阅 $2.30 / $80.00 · 04-15 00:00 刷新 · +2', + ); + assert.equal(vm.titleText, 'Max 月订阅'); + assert.equal(vm.badgeText, 'MAX'); + assert.equal(vm.usedAmountText, '$2.30'); + assert.equal(vm.totalAmountText, '$80.00'); + assert.equal(vm.quotaText, '$2.30 / $80.00'); + assert.equal(vm.resetText, '04-15 00:00 刷新'); + assert.equal(vm.extraText, '+2'); + assert.equal(vm.showProgress, true); + assert.equal(vm.progressPercent, 2.875); + assert.equal(vm.displayProgressPercent, 2.875); +}); + +test('ensures tiny positive usage still has a visible progress width', () => { + const vm = buildDashboardSubscriptionSummaryViewModel( + { + plan: { + title: 'Max', + }, + subscription: { + amount_used: 1, + amount_total: 8000, + next_reset_time: Date.UTC(2026, 3, 15, 0, 0, 0) / 1000, + }, + }, + { + quotaPerUnit: 100, + timeZone: 'UTC', + }, + ); + + assert.equal(vm.showProgress, true); + assert.equal(vm.progressPercent, 0.0125); + assert.equal(vm.displayProgressPercent, 2); +}); + +test('hides duplicated trigger title when badge already represents the plan', () => { + assert.equal( + shouldShowDashboardSubscriptionTriggerTitle({ + titleText: 'Max', + badgeText: 'MAX', + }), + false, + ); + + assert.equal( + shouldShowDashboardSubscriptionTriggerTitle({ + titleText: 'Max 月订阅', + badgeText: 'MAX', + }), + false, + ); +}); + +test('builds a summary from subscription self payload with derived extra count', () => { + const vm = buildDashboardSubscriptionSummaryFromPayload( + { + primary_subscription: { + plan: { + title: 'Max', + }, + subscription: { + amount_used: 23, + amount_total: 800, + next_reset_time: Date.UTC(2026, 3, 15, 0, 0, 0) / 1000, + }, + }, + active_subscription_count: 3, + }, + { + quotaPerUnit: 10, + timeZone: 'UTC', + }, + ); + + assert.equal(vm.summaryText, 'Max $2.30 / $80.00 · 04-15 00:00 刷新 · +2'); + assert.equal(vm.extraText, '+2'); + assert.equal(vm.quotaText, '$2.30 / $80.00'); +}); + +test('orders popover rows with the primary subscription first', () => { + const rows = buildDashboardSubscriptionPopoverRowsFromPayload( + { + primary_subscription: { + plan: { + title: 'Max', + }, + subscription: { + id: 200, + amount_used: 23, + amount_total: 800, + next_reset_time: Date.UTC(2026, 3, 15, 0, 0, 0) / 1000, + }, + }, + subscriptions: [ + { + plan: { + title: 'Daily', + }, + subscription: { + id: 201, + amount_used: 12, + amount_total: 100, + next_reset_time: Date.UTC(2026, 3, 16, 0, 0, 0) / 1000, + }, + }, + { + plan: { + title: 'Max', + }, + subscription: { + id: 200, + amount_used: 23, + amount_total: 800, + next_reset_time: Date.UTC(2026, 3, 15, 0, 0, 0) / 1000, + }, + }, + ], + active_subscription_count: 2, + }, + { + quotaPerUnit: 10, + timeZone: 'UTC', + }, + ); + + assert.equal(rows[0].titleText, 'Max'); + assert.equal(rows[0].summaryText, 'Max $2.30 / $80.00 · 04-15 00:00 刷新'); + assert.equal(rows[0].progressPercent, 2.875); + assert.equal(rows[0].displayProgressPercent, 2.875); + assert.equal(rows[1].titleText, 'Daily'); +}); + +test('builds dashboard subscription display payload with summary and ordered rows', () => { + const display = buildDashboardSubscriptionDisplayFromPayload( + { + primary_subscription: { + plan: { + title: 'Max', + }, + subscription: { + id: 200, + amount_used: 23, + amount_total: 800, + next_reset_time: Date.UTC(2026, 3, 15, 0, 0, 0) / 1000, + }, + }, + subscriptions: [ + { + plan: { + title: 'Daily', + }, + subscription: { + id: 201, + amount_used: 12, + amount_total: 100, + next_reset_time: Date.UTC(2026, 3, 16, 0, 0, 0) / 1000, + }, + }, + ], + active_subscription_count: 2, + }, + { + quotaPerUnit: 10, + timeZone: 'UTC', + }, + ); + + assert.equal( + display.summary.summaryText, + 'Max $2.30 / $80.00 · 04-15 00:00 刷新 · +1', + ); + assert.equal(display.rows[0].titleText, 'Max'); + assert.equal(display.rows[0].isPrimary, true); + assert.equal(display.rows[1].titleText, 'Daily'); +}); + +test('omits extra count when it is zero', () => { + const vm = buildDashboardSubscriptionSummaryViewModel( + { + plan: { + title: 'Max', + }, + subscription: { + amount_used: 23, + amount_total: 800, + next_reset_time: Date.UTC(2026, 3, 15, 0, 0, 0) / 1000, + }, + }, + { + quotaPerUnit: 10, + extraCount: 0, + timeZone: 'UTC', + }, + ); + + assert.equal(vm.summaryText, 'Max $2.30 / $80.00 · 04-15 00:00 刷新'); + assert.equal(vm.extraText, ''); +}); + +test('builds popover rows for unlimited plans without reset time', () => { + const rows = buildDashboardSubscriptionPopoverRows([ + { + plan: { + title: 'Daily', + }, + subscription: { + amount_used: 125, + amount_total: 0, + next_reset_time: 0, + }, + }, + ], { + quotaPerUnit: 10, + }); + + assert.equal(rows.length, 1); + assert.deepEqual(rows[0], { + titleText: 'Daily', + badgeText: 'DAILY', + usedAmountText: '$12.50', + totalAmountText: '∞', + quotaText: '$12.50 / ∞', + resetText: '有效期总额度', + extraText: '', + showProgress: false, + progressPercent: 0, + displayProgressPercent: 0, + }); +}); + +test('formats reset time using mm-dd hh:mm', () => { + const resetText = formatDashboardSubscriptionResetTime( + Date.UTC(2026, 11, 5, 9, 7, 0) / 1000, + { + timeZone: 'UTC', + }, + ); + + assert.equal(resetText, '12-05 09:07'); +}); diff --git a/web/src/helpers/rechargeAccess.js b/web/src/helpers/rechargeAccess.js new file mode 100644 index 00000000000..b8a33e9f467 --- /dev/null +++ b/web/src/helpers/rechargeAccess.js @@ -0,0 +1,7 @@ +export function canAccessWalletManagement(user) { + return !!user && user.allow_recharge !== false; +} + +export function isRechargeRestricted(user) { + return !!user && user.allow_recharge === false; +} diff --git a/web/src/helpers/rechargeAccess.test.js b/web/src/helpers/rechargeAccess.test.js new file mode 100644 index 00000000000..ccbe0fbf655 --- /dev/null +++ b/web/src/helpers/rechargeAccess.test.js @@ -0,0 +1,56 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; + +import { + canAccessWalletManagement, + isRechargeRestricted, +} from './rechargeAccess.js'; + +test('wallet management is available for recharge-enabled users', () => { + assert.equal( + canAccessWalletManagement({ + id: 1, + allow_recharge: true, + }), + true, + ); +}); + +test('wallet management is hidden for recharge-restricted users', () => { + assert.equal( + canAccessWalletManagement({ + id: 2, + allow_recharge: false, + }), + false, + ); + assert.equal( + isRechargeRestricted({ + id: 2, + allow_recharge: false, + }), + true, + ); +}); + +test('wallet management is hidden safely when user data is missing', () => { + assert.equal(canAccessWalletManagement(null), false); + assert.equal(isRechargeRestricted(null), false); +}); + +test('wallet management stays available when allow_recharge is missing', () => { + assert.equal( + canAccessWalletManagement({ + id: 3, + username: 'legacy-user', + }), + true, + ); + assert.equal( + isRechargeRestricted({ + id: 3, + username: 'legacy-user', + }), + false, + ); +}); diff --git a/web/src/hooks/dashboard/useDashboardData.js b/web/src/hooks/dashboard/useDashboardData.js index e9b2cad83e7..db115034cd8 100644 --- a/web/src/hooks/dashboard/useDashboardData.js +++ b/web/src/hooks/dashboard/useDashboardData.js @@ -22,6 +22,8 @@ import { useNavigate } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; import { API, isAdmin, showError, timestamp2string } from '../../helpers'; import { getDefaultTime, getInitialTimestamp } from '../../helpers/dashboard'; +import { getQuotaPerUnit } from '../../helpers/quota'; +import { buildDashboardSubscriptionDisplayFromPayload } from '../../helpers/dashboardSubscriptionSummary'; import { TIME_OPTIONS } from '../../constants/dashboard.constants'; import { useIsMobile } from '../common/useIsMobile'; import { useMinimumLoadingTime } from '../common/useMinimumLoadingTime'; @@ -60,6 +62,8 @@ export const useDashboardData = (userState, userDispatch, statusState) => { const [pieData, setPieData] = useState([{ type: 'null', value: '0' }]); const [lineData, setLineData] = useState([]); const [modelColors, setModelColors] = useState({}); + const [dashboardSubscriptionSummary, setDashboardSubscriptionSummary] = + useState(null); // ========== 图表状态 ========== const [activeChartTab, setActiveChartTab] = useState('1'); @@ -234,6 +238,30 @@ export const useDashboardData = (userState, userDispatch, statusState) => { } }, [inputs, isAdminUser]); + const loadDashboardSubscriptionSummary = useCallback(async () => { + try { + const res = await API.get('/api/subscription/self'); + const { success, message, data } = res.data; + if (success) { + const display = buildDashboardSubscriptionDisplayFromPayload(data || {}, { + quotaPerUnit: getQuotaPerUnit(), + timeZone: Intl.DateTimeFormat().resolvedOptions().timeZone, + }); + setDashboardSubscriptionSummary( + display.summary?.summaryText ? display : null, + ); + return display; + } + showError(message); + setDashboardSubscriptionSummary(null); + return null; + } catch (err) { + console.error(err); + setDashboardSubscriptionSummary(null); + return null; + } + }, []); + const getUserData = useCallback(async () => { let res = await API.get(`/api/user/self`); const { success, message, data } = res.data; @@ -245,10 +273,13 @@ export const useDashboardData = (userState, userDispatch, statusState) => { }, [userDispatch]); const refresh = useCallback(async () => { - const data = await loadQuotaData(); + const [data] = await Promise.all([ + loadQuotaData(), + loadDashboardSubscriptionSummary(), + ]); await loadUptimeData(); return data; - }, [loadQuotaData, loadUptimeData]); + }, [loadQuotaData, loadDashboardSubscriptionSummary, loadUptimeData]); const handleSearchConfirm = useCallback( async (updateChartDataCallback) => { @@ -300,6 +331,8 @@ export const useDashboardData = (userState, userDispatch, statusState) => { setLineData, modelColors, setModelColors, + dashboardSubscriptionSummary, + loadDashboardSubscriptionSummary, // 图表状态 activeChartTab,