Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
line-length = 100 # ideally I want this to be less than 100 but don't wanna test and change files with longer lines
line-length = 120 # ideally I want this to be less than 100 but don't wanna test and change files with longer lines
target-version = "py313"
lint.select = [
"E", # pycodestyle errors
Expand Down
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@CLAUDE.md
27 changes: 3 additions & 24 deletions src/kernelbot/api/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]
user_name = user_json.get("username")

if not user_id or not user_name:
raise HTTPException(
status_code=500, detail="Failed to retrieve user ID or username from Discord."
)
raise HTTPException(status_code=500, detail="Failed to retrieve user ID or username from Discord.")

return user_id, user_name

Expand Down Expand Up @@ -135,16 +133,12 @@ async def _handle_github_oauth(code: str, redirect_uri: str) -> tuple[str, str]:
user_name = user_json.get("login") # GitHub uses 'login' for username

if not user_id or not user_name:
raise HTTPException(
status_code=500, detail="Failed to retrieve user ID or username from GitHub."
)
raise HTTPException(status_code=500, detail="Failed to retrieve user ID or username from GitHub.")

return user_id, user_name


async def _run_submission(
submission: SubmissionRequest, mode: SubmissionMode, backend: KernelBackend
):
async def _run_submission(submission: SubmissionRequest, mode: SubmissionMode, backend: KernelBackend):
try:
req = prepare_submission(submission, backend)
except Exception as e:
Expand Down Expand Up @@ -225,21 +219,6 @@ async def to_submit_info(

try:
with db_context as db:
# Per-user rate limit: max 1 submission per hour on Modal B200 for leaderboard 730
if gpu_type == "B200":
lb_id = db.get_leaderboard_id(leaderboard_name)
if lb_id == 730:
last_submission_time = db.check_user_rate_limit(user_id)
if last_submission_time:
raise HTTPException(
status_code=429,
detail=(
f"Rate limit exceeded. You can submit once per hour. "
f"Last submission: {last_submission_time.isoformat()}. "
f"Consider using the NVIDIA runner instead of Modal for faster iteration."
),
)

leaderboard_item = db.get_leaderboard(leaderboard_name)
gpus = leaderboard_item.get("gpu_types", [])
if gpu_type not in gpus:
Expand Down
82 changes: 53 additions & 29 deletions src/kernelbot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

app = FastAPI()


def json_serializer(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
Expand Down Expand Up @@ -185,9 +186,7 @@ def require_admin(
@app.get("/auth/init")
async def auth_init(provider: str, db_context=Depends(get_db)) -> dict:
if provider not in ["discord", "github"]:
raise HTTPException(
status_code=400, detail="Invalid provider, must be 'discord' or 'github'"
)
raise HTTPException(status_code=400, detail="Invalid provider, must be 'discord' or 'github'")

"""
Initialize authentication flow for the specified provider.
Expand Down Expand Up @@ -230,9 +229,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
"""

if auth_provider not in ["discord", "github"]:
raise HTTPException(
status_code=400, detail="Invalid provider, must be 'discord' or 'github'"
)
raise HTTPException(status_code=400, detail="Invalid provider, must be 'discord' or 'github'")

if not code or not state:
raise HTTPException(status_code=400, detail="Missing authorization code or state")
Expand All @@ -252,8 +249,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
if not api_base_url:
raise HTTPException(
status_code=500,
detail="Redirect URI base not configured."
"Set HEROKU_APP_DEFAULT_DOMAIN_NAME or POPCORN_API_URL.",
detail="Redirect URI base not configured.Set HEROKU_APP_DEFAULT_DOMAIN_NAME or POPCORN_API_URL.",
)
redirect_uri_base = api_base_url.rstrip("/")
redirect_uri = f"https://{redirect_uri_base}/auth/cli/{auth_provider}"
Expand All @@ -275,7 +271,10 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
raise HTTPException(status_code=500, detail=f"Error during {auth_provider} OAuth flow: {e}") from e

if not user_id or not user_name:
raise HTTPException(status_code=500,detail="Failed to retrieve user ID or username from provider.",)
raise HTTPException(
status_code=500,
detail="Failed to retrieve user ID or username from provider.",
)

try:
with db_context as db:
Expand All @@ -297,6 +296,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
"is_reset": is_reset,
}


async def _stream_submission_response(
submission_request: SubmissionRequest,
submission_mode_enum: SubmissionMode,
Expand All @@ -315,18 +315,18 @@ async def _stream_submission_response(

while not task.done():
elapsed_time = time.time() - start_time
yield f"event: status\ndata: {json.dumps({'status': 'processing',
'elapsed_time': round(elapsed_time, 2)},
default=json_serializer)}\n\n"
yield f"event: status\ndata: {
json.dumps({'status': 'processing', 'elapsed_time': round(elapsed_time, 2)}, default=json_serializer)
}\n\n"

try:
await asyncio.wait_for(asyncio.shield(task), timeout=15.0)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
yield f"event: error\ndata: {json.dumps(
{'status': 'error', 'detail': 'Submission cancelled'},
default=json_serializer)}\n\n"
yield f"event: error\ndata: {
json.dumps({'status': 'error', 'detail': 'Submission cancelled'}, default=json_serializer)
}\n\n"
return

result, reports = await task
Expand Down Expand Up @@ -360,6 +360,7 @@ async def _stream_submission_response(
except asyncio.CancelledError:
pass


@app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}")
async def run_submission( # noqa: C901
leaderboard_name: str,
Expand Down Expand Up @@ -398,27 +399,28 @@ async def run_submission( # noqa: C901
)
return StreamingResponse(generator, media_type="text/event-stream")


async def enqueue_background_job(
req: ProcessedSubmissionRequest,
mode: SubmissionMode,
backend: KernelBackend,
manager: BackgroundSubmissionManager,
):

# pre-create the submission for api returns
with backend.db as db:
sub_id = db.create_submission(
leaderboard=req.leaderboard,
file_name=req.file_name,
code=req.code,
user_id=req.user_id,
time=datetime.datetime.now(),
time=datetime.datetime.now(datetime.timezone.utc),
user_name=req.user_name,
)
job_id = db.upsert_submission_job_status(sub_id, "initial", None)
# put submission request in queue
await manager.enqueue(req, mode, sub_id)
return sub_id,job_id
return sub_id, job_id


@app.post("/submission/{leaderboard_name}/{gpu_type}/{submission_mode}")
async def run_submission_async(
Expand All @@ -445,15 +447,13 @@ async def run_submission_async(
JSONResponse: A JSON response containing job_id and and submission_id for the client to poll for status.
"""
try:

await simple_rate_limit()
logger.info(f"Received submission request for {leaderboard_name} {gpu_type} {submission_mode}")


# throw error if submission request is invalid
try:
submission_request, submission_mode_enum = await to_submit_info(
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
)

req = prepare_submission(submission_request, backend_instance)
Expand All @@ -466,13 +466,13 @@ async def run_submission_async(
raise HTTPException(status_code=400, detail="Invalid GPU type")

# put submission request to background manager to run in background
sub_id,job_status_id = await enqueue_background_job(
sub_id, job_status_id = await enqueue_background_job(
req, submission_mode_enum, backend_instance, background_submission_manager
)

return JSONResponse(
status_code=202,
content={"details":{"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"},
content={"details": {"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"},
)
# Preserve FastAPI HTTPException as-is
except HTTPException:
Expand Down Expand Up @@ -536,8 +536,7 @@ async def create_dev_leaderboard(
# GPUs must be specified in task.yml
if not definition.gpus:
raise HTTPException(
status_code=400,
detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types."
status_code=400, detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types."
)

with db_context as db:
Expand Down Expand Up @@ -629,7 +628,7 @@ async def admin_update_problems(
branch=branch,
force=force,
creator_id=0, # API-created
forum_id=-1, # No Discord forum
forum_id=-1, # No Discord forum
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
Expand All @@ -643,6 +642,33 @@ async def admin_update_problems(
}


@app.get("/leaderboard/rate-limits/{leaderboard_name}")
async def get_leaderboard_rate_limits(leaderboard_name: str, db_context=Depends(get_db)) -> dict:
with db_context as db:
rate_limits = db.get_leaderboard_rate_limits(leaderboard_name)
return {"status": "ok", "rate_limits": rate_limits}


@app.post("/leaderboard/rate-limits/{leaderboard_name}/{gpu_type}")
async def set_leaderboard_gpu_rate_limit(
leaderboard_name: str,
gpu_type: str,
rate_limit_seconds: int,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
if rate_limit_seconds <= 0:
rate_limit_seconds = None
with db_context as db:
db.set_leaderboard_gpu_rate_limit(leaderboard_name, gpu_type, rate_limit_seconds)
return {
"status": "ok",
"leaderboard_name": leaderboard_name,
"gpu_type": gpu_type,
"rate_limit_seconds": rate_limit_seconds,
}


@app.get("/leaderboards")
async def get_leaderboards(db_context=Depends(get_db)):
"""An endpoint that returns all leaderboards.
Expand Down Expand Up @@ -692,9 +718,7 @@ async def get_submissions(
try:
with db_context as db:
# Add validation for leaderboard and GPU? Might be redundant if DB handles it.
return db.get_leaderboard_submissions(
leaderboard_name, gpu_name, limit=limit, offset=offset
)
return db.get_leaderboard_submissions(leaderboard_name, gpu_name, limit=limit, offset=offset)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching submissions: {e}") from e

Expand Down
Loading
Loading