Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,21 @@ def get_latest_save_path(self, name=None, post=''):
paths = [p for p in paths if '_cn' not in p]

if len(paths) > 0:
latest_path = max(paths, key=os.path.getctime)
# Smart sorting: prioritize by step number in filename, fallback to ctime
import re
def get_sort_key(p):
# Extract step number from filename (e.g., "model_000500.safetensors" -> 500)
step_match = re.search(r'_(\d+)\.(safetensors|pt)$', p)
if step_match:
# Return tuple: (has_step_number, step_number, -ctime)
# Higher step numbers sort first, newer files break ties
return (True, int(step_match.group(1)), -os.path.getctime(p))
else:
# Final files without step numbers (e.g., "model.safetensors")
# Sort these FIRST with priority 2, then by newest ctime
return (True, float('inf'), -os.path.getctime(p))

latest_path = max(paths, key=get_sort_key)

if latest_path is None and self.network_config is not None and self.network_config.pretrained_lora_path is not None:
# set pretrained lora path as load path if we do not have a checkpoint to resume from
Expand Down
139 changes: 139 additions & 0 deletions ui/src/app/api/jobs/[jobID]/continue/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import path from 'path';
import fs from 'fs';

const prisma = new PrismaClient();

export async function POST(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const body = await request.json();
const { mode, newSteps, newName } = body;

const job = await prisma.job.findUnique({
where: { id: jobID },
});

if (!job) {
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
}

if (mode === 'resume') {
// Mode 1: Resume training - same job, increase steps, change status to stopped
// DO NOT set pretrained_lora_path - let Python code auto-detect checkpoint
// This ensures metadata (step count) is loaded correctly
const jobConfig = JSON.parse(job.job_config);

// Update steps if provided
if (newSteps && newSteps > job.step) {
jobConfig.config.process[0].train.steps = newSteps;
}

// Remove any pretrained_lora_path that might exist from previous clone operations
if (jobConfig.config.process[0].network?.pretrained_lora_path) {
delete jobConfig.config.process[0].network.pretrained_lora_path;
}

// Update job to allow resumption
const updatedJob = await prisma.job.update({
where: { id: jobID },
data: {
status: 'stopped',
stop: false,
info: 'Ready to resume - will auto-detect latest checkpoint',
job_config: JSON.stringify(jobConfig),
},
});

console.log(`Job ${jobID} ready to resume with ${newSteps} steps`);
return NextResponse.json(updatedJob);

} else if (mode === 'clone') {
// Mode 2: Clone with new name, using final checkpoint as pretrained_lora_path
const jobConfig = JSON.parse(job.job_config);
const oldName = jobConfig.config.name;
const finalName = newName || `${oldName}_continued`;

// Update job name
jobConfig.config.name = finalName;

// Update steps if provided
if (newSteps) {
jobConfig.config.process[0].train.steps = newSteps;
}

// Find the latest checkpoint from the old job
const trainingFolder = jobConfig.config.process[0].training_folder;
const oldJobFolder = path.join(trainingFolder, oldName);

let latestCheckpoint = null;
if (fs.existsSync(oldJobFolder)) {
const files = fs.readdirSync(oldJobFolder);
const checkpoints = files.filter(f =>
f.startsWith(oldName) &&
(f.endsWith('.safetensors') || f.endsWith('.pt'))
);

if (checkpoints.length > 0) {
// Smart sorting: Find the best checkpoint
// Priority: 1) Final file without step, 2) Highest step number, 3) Most recent
checkpoints.sort((a, b) => {
// Extract step number from filename (e.g., "lora_1_4000.safetensors" -> 4000)
const stepRegex = /_(\d+)\.(safetensors|pt)$/;
const aMatch = a.match(stepRegex);
const bMatch = b.match(stepRegex);

const aHasStep = !!aMatch;
const bHasStep = !!bMatch;

// If neither has step (both are final files like "lora_1.safetensors"), use modification time
if (!aHasStep && !bHasStep) {
const aPath = path.join(oldJobFolder, a);
const bPath = path.join(oldJobFolder, b);
return fs.statSync(bPath).mtime.getTime() - fs.statSync(aPath).mtime.getTime();
}

// Prefer files WITHOUT step numbers (final files) over checkpoints
if (!aHasStep && bHasStep) return -1; // a is final, prefer it
if (aHasStep && !bHasStep) return 1; // b is final, prefer it

// Both have step numbers, use highest step
const aStep = parseInt(aMatch![1]);
const bStep = parseInt(bMatch![1]);
return bStep - aStep;
});
latestCheckpoint = path.join(oldJobFolder, checkpoints[0]);
}
}

// Set pretrained_lora_path to the latest checkpoint
if (latestCheckpoint) {
if (!jobConfig.config.process[0].network) {
jobConfig.config.process[0].network = {};
}
jobConfig.config.process[0].network.pretrained_lora_path = latestCheckpoint;
}

// Create new job
const newJob = await prisma.job.create({
data: {
name: finalName,
gpu_ids: job.gpu_ids,
job_config: JSON.stringify(jobConfig),
status: 'stopped',
stop: false,
step: 0,
info: latestCheckpoint
? `Starting from checkpoint: ${path.basename(latestCheckpoint)}`
: 'Starting fresh',
queue_position: 0,
},
});

console.log(`Cloned job ${jobID} as ${newJob.id} with name ${finalName}`);
return NextResponse.json(newJob);

} else {
return NextResponse.json({ error: 'Invalid mode' }, { status: 400 });
}
}
184 changes: 184 additions & 0 deletions ui/src/components/ContinueTrainingModal.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import React, { useState } from 'react';
import { Modal } from './Modal';
import { Job } from '@prisma/client';
import { getTotalSteps } from '@/utils/jobs';

interface ContinueTrainingModalProps {
isOpen: boolean;
onClose: () => void;
job: Job;
onContinue: (mode: 'resume' | 'clone', newSteps: number, newName?: string) => void;
}

export const ContinueTrainingModal: React.FC<ContinueTrainingModalProps> = ({
isOpen,
onClose,
job,
onContinue,
}) => {
const [mode, setMode] = useState<'resume' | 'clone'>('resume');
const currentSteps = getTotalSteps(job);
const [newSteps, setNewSteps] = useState(currentSteps + 2000);
const [newName, setNewName] = useState(`${job.name}_continued`);

const handleContinue = () => {
onContinue(mode, newSteps, mode === 'clone' ? newName : undefined);
onClose();
};

return (
<Modal isOpen={isOpen} onClose={onClose} title="Continue Training" size="lg">
<div className="space-y-6">
{/* Mode Selection */}
<div className="space-y-3">
<label className="block text-sm font-medium text-gray-200">Continue Mode</label>

{/* Resume Option */}
<div
className={`cursor-pointer rounded-lg border-2 p-4 transition-colors ${
mode === 'resume'
? 'border-blue-500 bg-blue-500/10'
: 'border-gray-700 bg-gray-800 hover:border-gray-600'
}`}
onClick={() => setMode('resume')}
>
<div className="flex items-start">
<input
type="radio"
name="mode"
checked={mode === 'resume'}
onChange={() => setMode('resume')}
className="mt-1 h-4 w-4 text-blue-500"
/>
<div className="ml-3">
<h4 className="text-base font-semibold text-gray-100">Resume Training</h4>
<p className="mt-1 text-sm text-gray-400">
Continue from the last checkpoint with the same job name. Training will resume from
step {job.step} and continue to the new step count.
</p>
<div className="mt-2 flex items-center space-x-2 text-xs text-gray-500">
<svg
className="h-4 w-4"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<span>Keeps same name and continues from checkpoint</span>
</div>
</div>
</div>
</div>

{/* Clone Option */}
<div
className={`cursor-pointer rounded-lg border-2 p-4 transition-colors ${
mode === 'clone'
? 'border-blue-500 bg-blue-500/10'
: 'border-gray-700 bg-gray-800 hover:border-gray-600'
}`}
onClick={() => setMode('clone')}
>
<div className="flex items-start">
<input
type="radio"
name="mode"
checked={mode === 'clone'}
onChange={() => setMode('clone')}
className="mt-1 h-4 w-4 text-blue-500"
/>
<div className="ml-3">
<h4 className="text-base font-semibold text-gray-100">Start Fresh from Weights</h4>
<p className="mt-1 text-sm text-gray-400">
Create a new job with a different name, using the final checkpoint as starting weights.
Training will start from step 0 with the loaded weights.
</p>
<div className="mt-2 flex items-center space-x-2 text-xs text-gray-500">
<svg
className="h-4 w-4"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
/>
</svg>
<span>Creates new job with pretrained weights</span>
</div>
</div>
</div>
</div>
</div>

{/* New Name (only for clone mode) */}
{mode === 'clone' && (
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">New Job Name</label>
<input
type="text"
value={newName}
onChange={e => setNewName(e.target.value)}
className="w-full rounded-lg border border-gray-700 bg-gray-800 px-4 py-2 text-gray-100 focus:border-blue-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
placeholder="Enter new job name"
/>
</div>
)}

{/* New Steps */}
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">
{mode === 'resume' ? 'Total Steps (increase to continue)' : 'Total Steps for New Training'}
</label>
<div className="flex items-center space-x-3">
<input
type="number"
value={newSteps}
onChange={e => setNewSteps(parseInt(e.target.value) || 0)}
className="flex-1 rounded-lg border border-gray-700 bg-gray-800 px-4 py-2 text-gray-100 focus:border-blue-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
min={mode === 'resume' ? job.step : 0}
/>
{mode === 'resume' && (
<div className="text-sm text-gray-400">
Current: {job.step} / {currentSteps}
</div>
)}
</div>
{mode === 'resume' && newSteps <= job.step && (
<p className="mt-1 text-xs text-red-400">
Steps must be greater than current step ({job.step})
</p>
)}
</div>

{/* Action Buttons */}
<div className="flex justify-end space-x-3 pt-4 border-t border-gray-700">
<button
onClick={onClose}
className="rounded-lg border border-gray-600 px-4 py-2 text-sm font-medium text-gray-300 hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-gray-500"
>
Cancel
</button>
<button
onClick={handleContinue}
disabled={mode === 'resume' && newSteps <= job.step}
className="rounded-lg bg-blue-600 px-4 py-2 text-sm font-medium text-white hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 disabled:opacity-50 disabled:cursor-not-allowed"
>
{mode === 'resume' ? 'Resume Training' : 'Create & Start'}
</button>
</div>
</div>
</Modal>
);
};
Loading