From 6f30aad221246ab5cb05d49ae04591ef9f0c23d4 Mon Sep 17 00:00:00 2001 From: diodiogod Date: Wed, 4 Feb 2026 15:36:36 -0300 Subject: [PATCH] feat: Add Continue Training feature for completed jobs Adds a 'Continue Training' feature that allows users to continue training from completed jobs in two ways: 1. Resume Training - Continue from last checkpoint with same job name and step counter 2. Start Fresh from Weights - Clone job with new name using final checkpoint as pretrained weights Changes: - Added /api/jobs/[jobID]/continue endpoint supporting both resume and clone modes - Added ContinueTrainingModal component with intuitive mode selection UI - Updated JobActionBar to show Continue Training option for completed jobs - Added continueJob() utility function for API calls - Improved Modal component to use React Portal for consistent rendering - Fixed Modal to prevent accidental close when dragging text selection - Enhanced checkpoint detection in BaseSDTrainProcess to prioritize by step number The checkpoint detection now intelligently sorts by: 1. Final files without step numbers (highest priority) 2. Checkpoints with highest step number 3. Most recently modified files (fallback) This ensures correct checkpoint loading even when files are copied or moved. Fixes issues where: - Modal transparency varied depending on render location - Modal closed when dragging text selection outside bounds - Checkpoint detection failed with copied/moved files due to unreliable creation times --- jobs/process/BaseSDTrainProcess.py | 16 +- ui/src/app/api/jobs/[jobID]/continue/route.ts | 139 +++++++++++++ ui/src/components/ContinueTrainingModal.tsx | 184 ++++++++++++++++++ ui/src/components/JobActionBar.tsx | 37 +++- ui/src/components/Modal.tsx | 23 ++- ui/src/utils/jobs.ts | 23 ++- 6 files changed, 415 insertions(+), 7 deletions(-) create mode 100644 ui/src/app/api/jobs/[jobID]/continue/route.ts create mode 100644 ui/src/components/ContinueTrainingModal.tsx diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 91f62a5ac..ccbf7fefa 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -815,7 +815,21 @@ def get_latest_save_path(self, name=None, post=''): paths = [p for p in paths if '_cn' not in p] if len(paths) > 0: - latest_path = max(paths, key=os.path.getctime) + # Smart sorting: prioritize by step number in filename, fallback to ctime + import re + def get_sort_key(p): + # Extract step number from filename (e.g., "model_000500.safetensors" -> 500) + step_match = re.search(r'_(\d+)\.(safetensors|pt)$', p) + if step_match: + # Return tuple: (has_step_number, step_number, -ctime) + # Higher step numbers sort first, newer files break ties + return (True, int(step_match.group(1)), -os.path.getctime(p)) + else: + # Final files without step numbers (e.g., "model.safetensors") + # Sort these FIRST with priority 2, then by newest ctime + return (True, float('inf'), -os.path.getctime(p)) + + latest_path = max(paths, key=get_sort_key) if latest_path is None and self.network_config is not None and self.network_config.pretrained_lora_path is not None: # set pretrained lora path as load path if we do not have a checkpoint to resume from diff --git a/ui/src/app/api/jobs/[jobID]/continue/route.ts b/ui/src/app/api/jobs/[jobID]/continue/route.ts new file mode 100644 index 000000000..7767da889 --- /dev/null +++ b/ui/src/app/api/jobs/[jobID]/continue/route.ts @@ -0,0 +1,139 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; + +const prisma = new PrismaClient(); + +export async function POST(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + const body = await request.json(); + const { mode, newSteps, newName } = body; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + if (mode === 'resume') { + // Mode 1: Resume training - same job, increase steps, change status to stopped + // DO NOT set pretrained_lora_path - let Python code auto-detect checkpoint + // This ensures metadata (step count) is loaded correctly + const jobConfig = JSON.parse(job.job_config); + + // Update steps if provided + if (newSteps && newSteps > job.step) { + jobConfig.config.process[0].train.steps = newSteps; + } + + // Remove any pretrained_lora_path that might exist from previous clone operations + if (jobConfig.config.process[0].network?.pretrained_lora_path) { + delete jobConfig.config.process[0].network.pretrained_lora_path; + } + + // Update job to allow resumption + const updatedJob = await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'stopped', + stop: false, + info: 'Ready to resume - will auto-detect latest checkpoint', + job_config: JSON.stringify(jobConfig), + }, + }); + + console.log(`Job ${jobID} ready to resume with ${newSteps} steps`); + return NextResponse.json(updatedJob); + + } else if (mode === 'clone') { + // Mode 2: Clone with new name, using final checkpoint as pretrained_lora_path + const jobConfig = JSON.parse(job.job_config); + const oldName = jobConfig.config.name; + const finalName = newName || `${oldName}_continued`; + + // Update job name + jobConfig.config.name = finalName; + + // Update steps if provided + if (newSteps) { + jobConfig.config.process[0].train.steps = newSteps; + } + + // Find the latest checkpoint from the old job + const trainingFolder = jobConfig.config.process[0].training_folder; + const oldJobFolder = path.join(trainingFolder, oldName); + + let latestCheckpoint = null; + if (fs.existsSync(oldJobFolder)) { + const files = fs.readdirSync(oldJobFolder); + const checkpoints = files.filter(f => + f.startsWith(oldName) && + (f.endsWith('.safetensors') || f.endsWith('.pt')) + ); + + if (checkpoints.length > 0) { + // Smart sorting: Find the best checkpoint + // Priority: 1) Final file without step, 2) Highest step number, 3) Most recent + checkpoints.sort((a, b) => { + // Extract step number from filename (e.g., "lora_1_4000.safetensors" -> 4000) + const stepRegex = /_(\d+)\.(safetensors|pt)$/; + const aMatch = a.match(stepRegex); + const bMatch = b.match(stepRegex); + + const aHasStep = !!aMatch; + const bHasStep = !!bMatch; + + // If neither has step (both are final files like "lora_1.safetensors"), use modification time + if (!aHasStep && !bHasStep) { + const aPath = path.join(oldJobFolder, a); + const bPath = path.join(oldJobFolder, b); + return fs.statSync(bPath).mtime.getTime() - fs.statSync(aPath).mtime.getTime(); + } + + // Prefer files WITHOUT step numbers (final files) over checkpoints + if (!aHasStep && bHasStep) return -1; // a is final, prefer it + if (aHasStep && !bHasStep) return 1; // b is final, prefer it + + // Both have step numbers, use highest step + const aStep = parseInt(aMatch![1]); + const bStep = parseInt(bMatch![1]); + return bStep - aStep; + }); + latestCheckpoint = path.join(oldJobFolder, checkpoints[0]); + } + } + + // Set pretrained_lora_path to the latest checkpoint + if (latestCheckpoint) { + if (!jobConfig.config.process[0].network) { + jobConfig.config.process[0].network = {}; + } + jobConfig.config.process[0].network.pretrained_lora_path = latestCheckpoint; + } + + // Create new job + const newJob = await prisma.job.create({ + data: { + name: finalName, + gpu_ids: job.gpu_ids, + job_config: JSON.stringify(jobConfig), + status: 'stopped', + stop: false, + step: 0, + info: latestCheckpoint + ? `Starting from checkpoint: ${path.basename(latestCheckpoint)}` + : 'Starting fresh', + queue_position: 0, + }, + }); + + console.log(`Cloned job ${jobID} as ${newJob.id} with name ${finalName}`); + return NextResponse.json(newJob); + + } else { + return NextResponse.json({ error: 'Invalid mode' }, { status: 400 }); + } +} diff --git a/ui/src/components/ContinueTrainingModal.tsx b/ui/src/components/ContinueTrainingModal.tsx new file mode 100644 index 000000000..e0d956e1e --- /dev/null +++ b/ui/src/components/ContinueTrainingModal.tsx @@ -0,0 +1,184 @@ +import React, { useState } from 'react'; +import { Modal } from './Modal'; +import { Job } from '@prisma/client'; +import { getTotalSteps } from '@/utils/jobs'; + +interface ContinueTrainingModalProps { + isOpen: boolean; + onClose: () => void; + job: Job; + onContinue: (mode: 'resume' | 'clone', newSteps: number, newName?: string) => void; +} + +export const ContinueTrainingModal: React.FC = ({ + isOpen, + onClose, + job, + onContinue, +}) => { + const [mode, setMode] = useState<'resume' | 'clone'>('resume'); + const currentSteps = getTotalSteps(job); + const [newSteps, setNewSteps] = useState(currentSteps + 2000); + const [newName, setNewName] = useState(`${job.name}_continued`); + + const handleContinue = () => { + onContinue(mode, newSteps, mode === 'clone' ? newName : undefined); + onClose(); + }; + + return ( + +
+ {/* Mode Selection */} +
+ + + {/* Resume Option */} +
setMode('resume')} + > +
+ setMode('resume')} + className="mt-1 h-4 w-4 text-blue-500" + /> +
+

Resume Training

+

+ Continue from the last checkpoint with the same job name. Training will resume from + step {job.step} and continue to the new step count. +

+
+ + + + Keeps same name and continues from checkpoint +
+
+
+
+ + {/* Clone Option */} +
setMode('clone')} + > +
+ setMode('clone')} + className="mt-1 h-4 w-4 text-blue-500" + /> +
+

Start Fresh from Weights

+

+ Create a new job with a different name, using the final checkpoint as starting weights. + Training will start from step 0 with the loaded weights. +

+
+ + + + Creates new job with pretrained weights +
+
+
+
+
+ + {/* New Name (only for clone mode) */} + {mode === 'clone' && ( +
+ + setNewName(e.target.value)} + className="w-full rounded-lg border border-gray-700 bg-gray-800 px-4 py-2 text-gray-100 focus:border-blue-500 focus:outline-none focus:ring-2 focus:ring-blue-500" + placeholder="Enter new job name" + /> +
+ )} + + {/* New Steps */} +
+ +
+ setNewSteps(parseInt(e.target.value) || 0)} + className="flex-1 rounded-lg border border-gray-700 bg-gray-800 px-4 py-2 text-gray-100 focus:border-blue-500 focus:outline-none focus:ring-2 focus:ring-blue-500" + min={mode === 'resume' ? job.step : 0} + /> + {mode === 'resume' && ( +
+ Current: {job.step} / {currentSteps} +
+ )} +
+ {mode === 'resume' && newSteps <= job.step && ( +

+ Steps must be greater than current step ({job.step}) +

+ )} +
+ + {/* Action Buttons */} +
+ + +
+
+
+ ); +}; diff --git a/ui/src/components/JobActionBar.tsx b/ui/src/components/JobActionBar.tsx index 7917459cb..c61eb5f12 100644 --- a/ui/src/components/JobActionBar.tsx +++ b/ui/src/components/JobActionBar.tsx @@ -3,10 +3,12 @@ import { Eye, Trash2, Pen, Play, Pause, Cog, X } from 'lucide-react'; import { Button } from '@headlessui/react'; import { openConfirm } from '@/components/ConfirmModal'; import { Job } from '@prisma/client'; -import { startJob, stopJob, deleteJob, getAvaliableJobActions, markJobAsStopped } from '@/utils/jobs'; +import { startJob, stopJob, deleteJob, getAvaliableJobActions, markJobAsStopped, continueJob } from '@/utils/jobs'; import { startQueue } from '@/utils/queue'; import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react'; import { redirect } from 'next/navigation'; +import { useState } from 'react'; +import { ContinueTrainingModal } from './ContinueTrainingModal'; interface JobActionBarProps { job: Job; @@ -25,10 +27,24 @@ export default function JobActionBar({ hideView, autoStartQueue = false, }: JobActionBarProps) { - const { canStart, canStop, canDelete, canEdit, canRemoveFromQueue } = getAvaliableJobActions(job); + const { canStart, canStop, canDelete, canEdit, canRemoveFromQueue, canContinue } = getAvaliableJobActions(job); + const [showContinueModal, setShowContinueModal] = useState(false); if (!afterDelete) afterDelete = onRefresh; + const handleContinue = async (mode: 'resume' | 'clone', newSteps: number, newName?: string) => { + try { + const result = await continueJob(job.id, mode, newSteps, newName); + if (onRefresh) onRefresh(); + // If cloned, optionally redirect to the new job + if (mode === 'clone' && result) { + window.location.href = `/jobs/${result.id}`; + } + } catch (error) { + console.error('Error continuing job:', error); + } + }; + return (
{canStart && ( @@ -128,6 +144,16 @@ export default function JobActionBar({ Clone Job + {canContinue && ( + +
setShowContinueModal(true)} + > + Continue Training +
+
+ )}
+ + setShowContinueModal(false)} + job={job} + onContinue={handleContinue} + />
); } diff --git a/ui/src/components/Modal.tsx b/ui/src/components/Modal.tsx index 68dbf9d5f..d7d1176d9 100644 --- a/ui/src/components/Modal.tsx +++ b/ui/src/components/Modal.tsx @@ -1,4 +1,5 @@ import React, { Fragment, useEffect } from 'react'; +import { createPortal } from 'react-dom'; interface ModalProps { isOpen: boolean; @@ -19,6 +20,8 @@ export const Modal: React.FC = ({ size = 'md', closeOnOverlayClick = true, }) => { + const [mouseDownOnOverlay, setMouseDownOnOverlay] = React.useState(false); + // Close on ESC key press useEffect(() => { const handleEscKey = (e: KeyboardEvent) => { @@ -39,11 +42,21 @@ export const Modal: React.FC = ({ }; }, [isOpen, onClose]); - // Handle overlay click + // Track mouse down on overlay + const handleOverlayMouseDown = (e: React.MouseEvent) => { + if (e.target === e.currentTarget) { + setMouseDownOnOverlay(true); + } else { + setMouseDownOnOverlay(false); + } + }; + + // Handle overlay click - only close if both mousedown AND mouseup were on overlay const handleOverlayClick = (e: React.MouseEvent) => { - if (e.target === e.currentTarget && closeOnOverlayClick) { + if (e.target === e.currentTarget && closeOnOverlayClick && mouseDownOnOverlay) { onClose(); } + setMouseDownOnOverlay(false); }; if (!isOpen) return null; @@ -56,11 +69,12 @@ export const Modal: React.FC = ({ xl: 'max-w-4xl', }; - return ( + const modalContent = ( {/* Modal backdrop */}
= ({
); + + // Render modal at document root level using portal + return typeof document !== 'undefined' ? createPortal(modalContent, document.body) : null; }; diff --git a/ui/src/utils/jobs.ts b/ui/src/utils/jobs.ts index 8e4854585..69f3c4573 100644 --- a/ui/src/utils/jobs.ts +++ b/ui/src/utils/jobs.ts @@ -66,6 +66,26 @@ export const markJobAsStopped = (jobID: string) => { }); }; +export const continueJob = (jobID: string, mode: 'resume' | 'clone', newSteps?: number, newName?: string) => { + return new Promise((resolve, reject) => { + apiClient + .post(`/api/jobs/${jobID}/continue`, { + mode, + newSteps, + newName, + }) + .then(res => res.data) + .then(data => { + console.log('Job continued:', data); + resolve(data); + }) + .catch(error => { + console.error('Error continuing job:', error); + reject(error); + }); + }); +}; + export const getJobConfig = (job: Job) => { return JSON.parse(job.job_config) as JobConfig; }; @@ -82,7 +102,8 @@ export const getAvaliableJobActions = (job: Job) => { if (job.status === 'completed' && jobConfig.config.process[0].train.steps > job.step && !isStopping) { canStart = true; } - return { canDelete, canEdit, canStop, canStart, canRemoveFromQueue }; + const canContinue = job.status === 'completed' && !isStopping; + return { canDelete, canEdit, canStop, canStart, canRemoveFromQueue, canContinue }; }; export const getNumberOfSamples = (job: Job) => {