Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions controller/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,17 @@ func GetSubscriptionSelf(c *gin.Context) {
activeSubscriptions = []model.SubscriptionSummary{}
}

var primarySubscription any
if primary := model.SelectPrimarySubscriptionSummary(activeSubscriptions); primary != nil {
primarySubscription = primary
}

common.ApiSuccess(c, gin.H{
"billing_preference": pref,
"subscriptions": activeSubscriptions, // all active subscriptions
"all_subscriptions": allSubscriptions, // all subscriptions including expired
"billing_preference": pref,
"subscriptions": activeSubscriptions, // all active subscriptions
"all_subscriptions": allSubscriptions, // all subscriptions including expired
"primary_subscription": primarySubscription,
"active_subscription_count": len(activeSubscriptions),
})
}

Expand Down
227 changes: 227 additions & 0 deletions controller/subscription_self_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package controller

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)

type subscriptionSelfAPIResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}

type subscriptionSelfPlan struct {
ID int `json:"id"`
Title string `json:"title"`
QuotaResetPeriod string `json:"quota_reset_period"`
QuotaResetCustomSeconds int64 `json:"quota_reset_custom_seconds"`
}

type subscriptionSelfSummary struct {
Subscription *model.UserSubscription `json:"subscription"`
Plan *subscriptionSelfPlan `json:"plan"`
}

type subscriptionSelfResponseData struct {
BillingPreference string `json:"billing_preference"`
Subscriptions []subscriptionSelfSummary `json:"subscriptions"`
AllSubscriptions []subscriptionSelfSummary `json:"all_subscriptions"`
PrimarySubscription *subscriptionSelfSummary `json:"primary_subscription"`
ActiveSubscriptionCount int `json:"active_subscription_count"`
}

func setupSubscriptionSelfControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()

gin.SetMode(gin.TestMode)
common.UsingSQLite = true
common.UsingMySQL = false
common.UsingPostgreSQL = false
common.RedisEnabled = false

dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_"))
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
require.NoError(t, err)

model.DB = db
model.LOG_DB = db

require.NoError(t, db.AutoMigrate(&model.User{}, &model.SubscriptionPlan{}, &model.UserSubscription{}))

t.Cleanup(func() {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})

return db
}

func seedSubscriptionSelfUser(t *testing.T, db *gorm.DB, id int) *model.User {
t.Helper()

user := &model.User{
Id: id,
Username: fmt.Sprintf("user-%d", id),
Password: "password123",
Status: 1,
Role: common.RoleCommonUser,
Group: "default",
}
require.NoError(t, db.Create(user).Error)
return user
}

func seedSubscriptionSelfPlan(t *testing.T, db *gorm.DB, plan model.SubscriptionPlan) *model.SubscriptionPlan {
t.Helper()

require.NoError(t, db.Create(&plan).Error)
return &plan
}

func seedSubscriptionSelfSubscription(t *testing.T, db *gorm.DB, sub model.UserSubscription) *model.UserSubscription {
t.Helper()

require.NoError(t, db.Create(&sub).Error)
return &sub
}

func newSubscriptionSelfContext(t *testing.T, userID int) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()

recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/subscription/self", bytes.NewReader(nil))
ctx.Set("id", userID)
return ctx, recorder
}

func decodeSubscriptionSelfResponse(t *testing.T, recorder *httptest.ResponseRecorder) subscriptionSelfAPIResponse {
t.Helper()

var response subscriptionSelfAPIResponse
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &response))
return response
}

func decodeSubscriptionSelfData(t *testing.T, response subscriptionSelfAPIResponse) subscriptionSelfResponseData {
t.Helper()

var data subscriptionSelfResponseData
require.NoError(t, common.Unmarshal(response.Data, &data))
return data
}

func TestGetSubscriptionSelfReturnsPrimarySubscriptionCountAndPlanTitle(t *testing.T) {
db := setupSubscriptionSelfControllerTestDB(t)
seedSubscriptionSelfUser(t, db, 1)
now := common.GetTimestamp()

planMonthly := seedSubscriptionSelfPlan(t, db, model.SubscriptionPlan{
Id: 101,
Title: "Max 月订阅",
TotalAmount: 80000000,
QuotaResetPeriod: model.SubscriptionResetMonthly,
QuotaResetCustomSeconds: 0,
Enabled: true,
})

seedSubscriptionSelfSubscription(t, db, model.UserSubscription{
Id: 201,
UserId: 1,
PlanId: planMonthly.Id,
AmountTotal: 80000000,
AmountUsed: 2300000,
StartTime: now - 3600,
EndTime: now + 86400,
Status: "active",
Source: "order",
NextResetTime: now + 3600,
})

ctx, recorder := newSubscriptionSelfContext(t, 1)
GetSubscriptionSelf(ctx)

response := decodeSubscriptionSelfResponse(t, recorder)
require.True(t, response.Success)

data := decodeSubscriptionSelfData(t, response)
require.Equal(t, 1, data.ActiveSubscriptionCount)
require.NotNil(t, data.PrimarySubscription)
require.NotNil(t, data.PrimarySubscription.Plan)
require.Equal(t, "Max 月订阅", data.PrimarySubscription.Plan.Title)
require.Equal(t, "Max 月订阅", data.Subscriptions[0].Plan.Title)
require.Equal(t, "Max 月订阅", data.AllSubscriptions[0].Plan.Title)
}

func TestGetSubscriptionSelfSkipsExhaustedLimitedSubscription(t *testing.T) {
db := setupSubscriptionSelfControllerTestDB(t)
seedSubscriptionSelfUser(t, db, 1)
now := common.GetTimestamp()

exhaustedPlan := seedSubscriptionSelfPlan(t, db, model.SubscriptionPlan{
Id: 201,
Title: "Daily Exhausted",
TotalAmount: 1000,
QuotaResetPeriod: model.SubscriptionResetDaily,
Enabled: true,
})
activePlan := seedSubscriptionSelfPlan(t, db, model.SubscriptionPlan{
Id: 202,
Title: "Monthly Available",
TotalAmount: 8000,
QuotaResetPeriod: model.SubscriptionResetMonthly,
Enabled: true,
})

seedSubscriptionSelfSubscription(t, db, model.UserSubscription{
Id: 301,
UserId: 1,
PlanId: exhaustedPlan.Id,
AmountTotal: 1000,
AmountUsed: 1000,
StartTime: now - 7200,
EndTime: now + 3600,
Status: "active",
Source: "order",
})
seedSubscriptionSelfSubscription(t, db, model.UserSubscription{
Id: 302,
UserId: 1,
PlanId: activePlan.Id,
AmountTotal: 8000,
AmountUsed: 2500,
StartTime: now - 7200,
EndTime: now + 7200,
Status: "active",
Source: "order",
})

ctx, recorder := newSubscriptionSelfContext(t, 1)
GetSubscriptionSelf(ctx)

response := decodeSubscriptionSelfResponse(t, recorder)
require.True(t, response.Success)

data := decodeSubscriptionSelfData(t, response)
require.Equal(t, 2, data.ActiveSubscriptionCount)
require.NotNil(t, data.PrimarySubscription)
require.NotNil(t, data.PrimarySubscription.Plan)
require.Equal(t, "Monthly Available", data.PrimarySubscription.Plan.Title)
require.Equal(t, int64(8000), data.PrimarySubscription.Subscription.AmountTotal)
require.Equal(t, int64(2500), data.PrimarySubscription.Subscription.AmountUsed)
}
10 changes: 9 additions & 1 deletion controller/topup.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import (
)

func GetTopUpInfo(c *gin.Context) {
if !ensureRechargeAllowed(c) {
return
}
// 获取支付方式
payMethods := operation_setting.PayMethods

Expand Down Expand Up @@ -164,6 +167,9 @@ func getMinTopup() int64 {
}

func RequestEpay(c *gin.Context) {
if !ensureRechargeAllowed(c) {
return
}
var req EpayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
Expand Down Expand Up @@ -366,6 +372,9 @@ func EpayNotify(c *gin.Context) {
}

func RequestAmount(c *gin.Context) {
if !ensureRechargeAllowed(c) {
return
}
var req AmountRequest
err := c.ShouldBindJSON(&req)
if err != nil {
Expand Down Expand Up @@ -463,4 +472,3 @@ func AdminCompleteTopUp(c *gin.Context) {
}
common.ApiSuccess(c, nil)
}

21 changes: 21 additions & 0 deletions controller/topup_access.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package controller

import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)

func ensureRechargeAllowed(c *gin.Context) bool {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return false
}
if !user.AllowRecharge {
common.ApiErrorMsg(c, "当前账户不支持充值")
return false
}
return true
}
71 changes: 71 additions & 0 deletions controller/topup_access_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package controller

import (
"net/http"
"testing"

"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/stretchr/testify/require"
)

func TestGetTopUpInfoRejectsRechargeRestrictedUser(t *testing.T) {
db := setupUserRechargeControllerTestDB(t)
user := seedRechargeUser(t, db, model.User{
Username: "topup-disabled-user",
Password: "password123",
DisplayName: "Topup Disabled",
Role: common.RoleCommonUser,
Status: common.UserStatusEnabled,
Group: "default",
AllowRecharge: false,
})

ctx, recorder := newUserRechargeContext(t, http.MethodGet, "/api/user/topup/info", nil, user.Id, common.RoleCommonUser)
GetTopUpInfo(ctx)

response := decodeUserRechargeResponse(t, recorder)
require.False(t, response.Success)
require.Equal(t, "当前账户不支持充值", response.Message)
}

func TestRequestAmountRejectsRechargeRestrictedUser(t *testing.T) {
db := setupUserRechargeControllerTestDB(t)
user := seedRechargeUser(t, db, model.User{
Username: "amount-disabled-user",
Password: "password123",
DisplayName: "Amount Disabled",
Role: common.RoleCommonUser,
Status: common.UserStatusEnabled,
Group: "default",
AllowRecharge: false,
})

ctx, recorder := newUserRechargeContext(t, http.MethodPost, "/api/user/amount", map[string]any{
"amount": 1,
}, user.Id, common.RoleCommonUser)
RequestAmount(ctx)

response := decodeUserRechargeResponse(t, recorder)
require.False(t, response.Success)
require.Equal(t, "当前账户不支持充值", response.Message)
}

func TestGetTopUpInfoAllowsRechargeEnabledUser(t *testing.T) {
db := setupUserRechargeControllerTestDB(t)
user := seedRechargeUser(t, db, model.User{
Username: "topup-enabled-user",
Password: "password123",
DisplayName: "Topup Enabled",
Role: common.RoleCommonUser,
Status: common.UserStatusEnabled,
Group: "default",
AllowRecharge: true,
})

ctx, recorder := newUserRechargeContext(t, http.MethodGet, "/api/user/topup/info", nil, user.Id, common.RoleCommonUser)
GetTopUpInfo(ctx)

response := decodeUserRechargeResponse(t, recorder)
require.True(t, response.Success, response.Message)
}
3 changes: 3 additions & 0 deletions controller/topup_creem.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) {
}

func RequestCreemPay(c *gin.Context) {
if !ensureRechargeAllowed(c) {
return
}
var req CreemPayRequest

// 读取body内容用于打印,同时保留原始数据供后续使用
Expand Down
Loading