Skip to content

Commit 3cd8601

Browse files
authored
Merge pull request #967 from OptimalScale/lmflow-vllm-dataproto
[Data] Apply DataProto to vLLM Inference & Align API with SGLang
2 parents f3597db + 7408430 commit 3cd8601

11 files changed

Lines changed: 418 additions & 312 deletions

File tree

examples/vllm_inference.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ def main():
3737
pipeline_name=pipeline_name, model_args=model_args, data_args=data_args, pipeline_args=pipeline_args
3838
)
3939

40-
inferencer.inference(
40+
# `release_gpu=True` does an in-process best-effort cleanup; it is
41+
# sufficient for this standalone example. For colocated training+inference
42+
# (e.g. iterative DPO) or tensor_parallel_size > 1, prefer
43+
# `MemorySafeVLLMInferencer` instead.
44+
res = inferencer.inference(
4145
model,
4246
dataset,
43-
release_gpu=False,
44-
enable_decode_inference_result=pipeline_args.enable_decode_inference_result,
45-
enable_distributed_vllm_inference=pipeline_args.enable_distributed_vllm_inference,
47+
release_gpu=True,
4648
)
4749

4850

scripts/run_sglang_inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ python examples/sglang_inference.py \
99
--top_p 0.95 \
1010
--random_seed 42 \
1111
--save_inference_results True \
12-
--inference_results_path output_data/sglang_inference_results/results.json
12+
--inference_results_path output_data/sglang_inference_results/

scripts/run_vllm_inference.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
python examples/vllm_inference.py \
2+
--model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \
3+
--dataset_path data/alpaca/prompt_only \
4+
--inference_engine vllm \
5+
--inference_gpu_memory_utilization 0.8 \
6+
--inference_max_model_len 16384 \
7+
--num_output_sequences 2 \
8+
--temperature 1.0 \
9+
--max_new_tokens 2048 \
10+
--top_p 0.95 \
11+
--random_seed 42 \
12+
--save_inference_results True \
13+
--inference_results_path output_data/vllm_inference_results/

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
extra_require = {
1919
"multimodal": ["Pillow"],
20-
"vllm": ["vllm>=0.4.3"],
20+
"vllm": ["vllm>=0.8.0"],
2121
"sglang": ["sglang"],
2222
"ray": ["ray>=2.22.0"],
2323
"gradio": ["gradio"],

src/lmflow/args.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,9 +1041,27 @@ class InferencerArguments:
10411041
inference_tensor_parallel_size: Optional[int] = field(
10421042
default=1, metadata={"help": "The tensor parallel size for inference."}
10431043
)
1044+
inference_data_parallel_size: Optional[int] = field(
1045+
default=1,
1046+
metadata={
1047+
"help": (
1048+
"The data parallel size for inference. Only supported for vLLM (>= 0.8) inference engine. "
1049+
"Total GPUs used = tensor_parallel_size * data_parallel_size."
1050+
)
1051+
},
1052+
)
10441053
inference_gpu_memory_utilization: Optional[float] = field(
10451054
default=0.95, metadata={"help": "The GPU memory utilization for inference."}
10461055
)
1056+
inference_max_model_len: Optional[int] = field(
1057+
default=None,
1058+
metadata={
1059+
"help": (
1060+
"Maximum model context length for inference. If not set, uses the model's default. "
1061+
"Reduce this if the model's default exceeds available GPU memory."
1062+
)
1063+
},
1064+
)
10471065
enable_deterministic_inference: bool = field(
10481066
default=False,
10491067
metadata={
@@ -1065,7 +1083,14 @@ class InferencerArguments:
10651083
results_path: Optional[str] = field(default=None, metadata={"help": "The path of results."})
10661084

10671085
save_inference_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."})
1068-
inference_results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."})
1086+
inference_results_path: Optional[str] = field(
1087+
default=None,
1088+
metadata={
1089+
"help": (
1090+
"Directory to save inference results. Results are saved as 'inference_results.pkl' inside this directory."
1091+
)
1092+
},
1093+
)
10691094

10701095
def __post_init__(self):
10711096
if self.use_accelerator is not None:
@@ -1087,10 +1112,7 @@ def __post_init__(self):
10871112
if self.inference_results_path is None:
10881113
raise ValueError("Need to specify inference_results_path when save_inference_results is True.")
10891114
else:
1090-
if not self.inference_results_path.endswith(".json"):
1091-
raise ValueError("The inference_results_path must be a json file.")
1092-
else:
1093-
Path(self.inference_results_path).parent.mkdir(parents=True, exist_ok=True)
1115+
Path(self.inference_results_path).mkdir(parents=True, exist_ok=True)
10941116

10951117
if self.use_vllm is True:
10961118
logger.warning(

src/lmflow/models/hf_decoder_model.py

Lines changed: 35 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@
3737
TEXT_ONLY_DATASET_DESCRIPTION,
3838
)
3939
from lmflow.utils.conversation_template import PRESET_TEMPLATES
40-
from lmflow.utils.data_utils import VLLMInferenceResultWithInput
4140
from lmflow.utils.deprecated import deprecated_args
4241
from lmflow.utils.envs import is_accelerate_env
43-
from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available
42+
from lmflow.utils.versioning import is_flash_attn_available, is_vllm_available
4443
from lmflow.utils.protocol import DataProto
4544

4645
logger = logging.getLogger(__name__)
@@ -54,10 +53,6 @@
5453
if is_vllm_available():
5554
from vllm import SamplingParams
5655

57-
if is_ray_available():
58-
import ray
59-
import ray.data
60-
6156

6257
class HFDecoderModel(DecoderModel, HFModelMixin, Tunable):
6358
r"""
@@ -321,20 +316,21 @@ def inference(
321316
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
322317
gpu_memory_utilization: Optional[float] = None,
323318
tensor_parallel_size: Optional[int] = None,
319+
data_parallel_size: int = 1,
320+
max_model_len: Optional[int] = None,
324321
enable_deterministic_inference: bool = False,
325322
attention_backend: Optional[str] = None,
326323
**kwargs,
327-
) -> Union[list[VLLMInferenceResultWithInput] | DataProto]:
324+
) -> Union[list, DataProto]:
328325
"""
329326
Perform generation process of the model.
330327
331328
Parameters
332329
------------
333330
inputs : Union[str, list[str], torch.Tensor, DataProto]
334331
The sequence used as a prompt for the generation or as model inputs to the model.
335-
When the inference engine is "vllm", this should be a string or a list of strings.
332+
When the inference engine is "vllm" or "sglang", this should be a DataProto.
336333
When the inference engine is "huggingface", this should be a tensor.
337-
When the inference engine is "sglang", this should be a DataProto.
338334
sampling_params : Optional[Union[dict, "SamplingParams"]], optional
339335
The sampling parameters to use, by default None.
340336
return_logprob : bool, optional
@@ -347,6 +343,10 @@ def inference(
347343
The GPU memory utilization to use, by default None.
348344
tensor_parallel_size : int, optional
349345
The tensor parallel size to use, by default None.
346+
data_parallel_size : int, optional
347+
The data parallel size for vllm inference, by default 1.
348+
max_model_len : int, optional
349+
Maximum model context length for vllm inference, by default None.
350350
enable_deterministic_inference : bool, optional
351351
Whether to enable deterministic inference, by default False.
352352
attention_backend : Optional[str], optional
@@ -365,12 +365,14 @@ def inference(
365365
inference_engine=inference_engine,
366366
gpu_memory_utilization=gpu_memory_utilization,
367367
tensor_parallel_size=tensor_parallel_size,
368+
data_parallel_size=data_parallel_size,
369+
max_model_len=max_model_len,
368370
enable_deterministic_inference=enable_deterministic_inference,
369371
attention_backend=attention_backend,
370372
)
371373

372374
if inference_engine == "vllm":
373-
res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params)
375+
res = self.__vllm_inference(inputs=inputs)
374376
elif inference_engine == "sglang":
375377
res = self.__sglang_inference(
376378
inputs=inputs,
@@ -424,46 +426,29 @@ def __inference(self, inputs, *args, **kwargs):
424426

425427
def __vllm_inference(
426428
self,
427-
inputs: list[str],
428-
sampling_params: Optional["SamplingParams"] = None,
429-
) -> list[VLLMInferenceResultWithInput]:
430-
"""Perform VLLM inference process of the model.
431-
432-
Parameters
433-
----------
434-
inputs : list[str]
435-
Prompt(s), string or a list of strings.
436-
sampling_params : Optional[SamplingParams], optional
437-
vllm SamplingParams object, by default None.
438-
439-
Returns
440-
-------
441-
list[VLLMInferenceResultWithInput]
442-
Return a list of VLLMInferenceResultWithInput, where each
443-
element contains the input prompt and the corresponding output.
444-
445-
When `sampling_params.detokenize = True`, the output would be a list of strings,
446-
contains sampling_params.n samples for the corresponding prompt.
429+
inputs: DataProto,
430+
) -> DataProto:
431+
"""Perform VLLM inference process of the model."""
432+
prompts = inputs.non_tensor_batch["inputs"].tolist()
433+
sampling_params_dict = inputs.meta_info["sampling_params"]
434+
435+
vllm_sampling_params = SamplingParams(
436+
n=sampling_params_dict.get("n", 1),
437+
temperature=sampling_params_dict.get("temperature", 0.0),
438+
max_tokens=sampling_params_dict.get("max_new_tokens", 100),
439+
seed=sampling_params_dict.get("seed"),
440+
top_p=sampling_params_dict.get("top_p", 1.0),
441+
top_k=sampling_params_dict.get("top_k", 0),
442+
stop_token_ids=sampling_params_dict.get("stop_token_ids"),
443+
)
447444

448-
When `sampling_params.detokenize = False`, return a list of list of ints
449-
(token ids, no decoding after generation).
450-
"""
451445
vllm_outputs = self.backend_model_for_inference.generate(
452-
inputs,
453-
sampling_params=sampling_params,
446+
prompts,
447+
sampling_params=vllm_sampling_params,
454448
use_tqdm=True,
455449
)
456-
# TODO: unified lmflow sample format
457-
final_output = []
458-
for output in vllm_outputs:
459-
if sampling_params.detokenize:
460-
output_list = [sentence.text for sentence in output.outputs]
461-
else:
462-
output_list = [sentence.token_ids for sentence in output.outputs]
463-
464-
final_output.append({"input": output.prompt, "output": output_list})
465-
466-
return final_output
450+
inputs.non_tensor_batch["outputs"] = [output.outputs[0].text for output in vllm_outputs]
451+
return inputs
467452

468453
def __sglang_inference(
469454
self,
@@ -495,9 +480,8 @@ def prepare_inputs_for_inference(
495480
dataset: Dataset,
496481
apply_chat_template: bool = True,
497482
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
498-
enable_distributed_inference: bool = False,
499483
sampling_params: Optional[dict] = None,
500-
) -> Union[list[str], "ray.data.Dataset", DataProto]:
484+
) -> Union[list[str], DataProto]:
501485
if dataset.get_type() == "text_only":
502486
if apply_chat_template:
503487
dataset = dataset.map(
@@ -572,24 +556,17 @@ def preprocess_conversation(sample):
572556

573557
inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0]
574558

575-
if inference_engine == "vllm" and enable_distributed_inference:
576-
inference_inputs = ray.data.from_items(
577-
inference_inputs
578-
) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}
579-
580-
if inference_engine == "sglang":
559+
if inference_engine in ("sglang", "vllm"):
581560
if self.tokenizer.bos_token:
582-
# in consistent with sglang bench_serving.py demo
583561
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]
584562

585-
# currently only test dataproto on sglang inference
586563
inference_inputs = np.array(inference_inputs)
587564
inference_inputs = DataProto.from_single_dict(
588565
data={"inputs": inference_inputs},
589566
meta_info={"sampling_params": {**sampling_params, "n": 1}, "actual_n_rollouts": sampling_params["n"]}
590567
)
591-
592-
# handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines.
568+
569+
# handling n>1 since we don't want one-to-many mapping
593570
inference_inputs = inference_inputs.repeat(sampling_params["n"])
594571

595572
return inference_inputs

src/lmflow/models/hf_model_mixin.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -448,20 +448,27 @@ def __prepare_model_for_vllm_inference(
448448
model_args: ModelArguments,
449449
gpu_memory_utilization: float,
450450
tensor_parallel_size: int,
451+
data_parallel_size: int = 1,
452+
max_model_len: Optional[int] = None,
451453
):
452454
if not is_vllm_available():
453455
raise ImportError('VLLM is not available. Please install via `pip install -e ".[vllm]"`.')
454456

455457
from vllm import LLM
456458

457-
self.backend_model_for_inference = LLM(
459+
kwargs = dict(
458460
model=model_args.model_name_or_path,
459461
tokenizer=model_args.model_name_or_path,
460462
dtype=model_args.torch_dtype if model_args.torch_dtype else "auto",
461463
load_format="auto",
462464
gpu_memory_utilization=gpu_memory_utilization,
463465
tensor_parallel_size=tensor_parallel_size,
466+
data_parallel_size=data_parallel_size,
464467
)
468+
if max_model_len is not None:
469+
kwargs["max_model_len"] = max_model_len
470+
471+
self.backend_model_for_inference = LLM(**kwargs)
465472

466473
def __prepare_model_for_sglang_inference(
467474
self,
@@ -513,6 +520,8 @@ def activate_model_for_inference(
513520
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
514521
gpu_memory_utilization: Optional[float] = None,
515522
tensor_parallel_size: Optional[int] = None,
523+
data_parallel_size: int = 1,
524+
max_model_len: Optional[int] = None,
516525
enable_deterministic_inference: bool = False,
517526
attention_backend: Optional[str] = None,
518527
):
@@ -525,6 +534,8 @@ def activate_model_for_inference(
525534
model_args=self.model_args,
526535
gpu_memory_utilization=gpu_memory_utilization,
527536
tensor_parallel_size=tensor_parallel_size,
537+
data_parallel_size=data_parallel_size,
538+
max_model_len=max_model_len,
528539
)
529540
elif inference_engine == "sglang":
530541
self.__prepare_model_for_sglang_inference(
@@ -548,20 +559,26 @@ def deactivate_model_for_inference(
548559
):
549560
"""Deactivate the model and release the resources.
550561
551-
NOTE: Currently, VLLM doesn't have an official way to do this, and the
552-
implementation below cannot release all gpu resources by our observation.
553-
Thus this method is just a placeholder for future implementation. See:
554-
[Github issue](https://github.com/vllm-project/vllm/issues/1908)
562+
NOTE: For vllm (>=0.8), the best-effort release below works for most
563+
single-GPU, inference-only use cases. It remains unreliable when
564+
``tensor_parallel_size > 1``, CUDA graphs are enabled, or the same
565+
process also holds an HF training model — in those cases use
566+
:class:`MemorySafeVLLMInferencer`, which isolates inference in a
567+
subprocess. vllm still has no official in-process shutdown API
568+
(RFC vllm-project/vllm#24885); ``MemorySafeVLLMInferencer`` is kept
569+
for backward compatibility and will be migrated to vllm sleep mode
570+
in a follow-up.
555571
"""
556572
if not self._activated:
557573
logger.warning("You are trying to deactivate the model for inference, but it is already deactivated.")
558574
return
559575

560576
if inference_engine == "vllm":
561-
from vllm.distributed.parallel_state import destroy_model_parallel
562-
563-
destroy_model_parallel()
564-
del self.backend_model_for_inference.llm_engine.model_executor.driver_worker
577+
try:
578+
from vllm.distributed.parallel_state import destroy_model_parallel
579+
destroy_model_parallel()
580+
except Exception:
581+
pass
565582
del self.backend_model_for_inference
566583
gc.collect()
567584
torch.cuda.empty_cache()

src/lmflow/pipeline/sglang_inferencer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
33
import json
44
import logging
5+
import os
56
from typing import Optional, Union
67

78
from transformers import AutoTokenizer
@@ -101,14 +102,13 @@ def save_inference_results(
101102
outputs: DataProto,
102103
inference_results_path: str,
103104
):
104-
if not inference_results_path.endswith(".pkl"):
105-
logger.warning(f"The inference results path must be a pickle file. Change the path to {inference_results_path}.pkl")
106-
inference_results_path = inference_results_path + ".pkl"
107-
outputs.save_to_disk(inference_results_path)
108-
logger.info(f"Inference results are saved to {inference_results_path}.")
105+
save_path = os.path.join(inference_results_path, "inference_results.pkl")
106+
outputs.save_to_disk(save_path)
107+
logger.info(f"Inference results are saved to {save_path}.")
109108

110109
def load_inference_results(
111110
self,
112111
inference_results_path: str,
113112
) -> DataProto:
114-
return DataProto.load_from_disk(inference_results_path)
113+
load_path = os.path.join(inference_results_path, "inference_results.pkl")
114+
return DataProto.load_from_disk(load_path)

0 commit comments

Comments
 (0)