Skip to content

[infer] Add progress_callback to inference, judge, and synthesis APIs#2335

Draft
rlehman221 wants to merge 2 commits intomainfrom
rlehman/progress-callback-inference
Draft

[infer] Add progress_callback to inference, judge, and synthesis APIs#2335
rlehman221 wants to merge 2 commits intomainfrom
rlehman/progress-callback-inference

Conversation

@rlehman221
Copy link
Copy Markdown
Contributor

Summary

  • Adds an optional progress_callback: Callable[[int, int], None] parameter to BaseInferenceEngine.infer(), BaseJudge.judge(), and AttributeSynthesizer.synthesize()
  • The callback fires (completed_count, total_count) after each item finishes, enabling callers to report granular progress during bulk operations
  • RemoteInferenceEngine fires the callback per async task completion in the gather loop; other engines accept the parameter but don't use it yet

Motivation

When the Oumi Enterprise worker calls bulk inference/judge/synthesis, it passes all rows at once. Without this callback, the worker can only update heartbeat progress at the start and end of the call — leaving users with no progress updates for potentially long-running operations.

Changes

File Change
base_inference_engine.py progress_callback on infer() and _infer_online()
remote_inference_engine.py Fires callback per task in async gather loop
native_text_inference_engine.py Accepts param in _infer_online()
vllm_inference_engine.py Accepts param in _infer_online()
llama_cpp_inference_engine.py Accepts param in _infer_online()
base_judge.py progress_callback on judge() and _infer(), forwarded to engine
attribute_synthesizer.py progress_callback on synthesize(), forwarded to engine

Design

  • Callback is optional (None default) — fully backwards compatible
  • Thread-safe: asyncio.gather() runs on a single event loop thread, so the nonlocal counter won't race
  • Best-effort: callback exceptions are silently caught to never crash inference
  • Engine-level: judge and synthesis just forward the callback to the inference engine

Add an optional progress_callback parameter that fires (completed, total)
after each item is processed. This enables callers to report granular
progress during bulk operations instead of only at start/end.

Changes:
- BaseInferenceEngine.infer() and _infer_online(): new progress_callback param
- RemoteInferenceEngine._infer(): fires callback per async task completion
- BaseJudge.judge() and _infer(): thread callback to inference engine
- AttributeSynthesizer.synthesize(): thread callback to inference engine
- All other engine subclasses: accept param in _infer_online() signature
- BaseInferenceEngine: test callback forwarding, None default, exception safety
- RemoteInferenceEngine: test callback fires per conversation, None works,
  exception in callback doesn't crash inference
- BaseJudge: test callback forwarded to engine, None default
- AttributeSynthesizer: test callback forwarded to engine, None default
- Fix existing tests for updated _infer_online signature (3rd arg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant