Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions examples/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ def main():
pipeline_name=pipeline_name, model_args=model_args, data_args=data_args, pipeline_args=pipeline_args
)

inferencer.inference(
# `release_gpu=True` does an in-process best-effort cleanup; it is
# sufficient for this standalone example. For colocated training+inference
# (e.g. iterative DPO) or tensor_parallel_size > 1, prefer
# `MemorySafeVLLMInferencer` instead.
res = inferencer.inference(
model,
dataset,
release_gpu=False,
enable_decode_inference_result=pipeline_args.enable_decode_inference_result,
enable_distributed_vllm_inference=pipeline_args.enable_distributed_vllm_inference,
release_gpu=True,
)


Expand Down
2 changes: 1 addition & 1 deletion scripts/run_sglang_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ python examples/sglang_inference.py \
--top_p 0.95 \
--random_seed 42 \
--save_inference_results True \
--inference_results_path output_data/sglang_inference_results/results.json
--inference_results_path output_data/sglang_inference_results/
13 changes: 13 additions & 0 deletions scripts/run_vllm_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
python examples/vllm_inference.py \
--model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \
--dataset_path data/alpaca/prompt_only \
--inference_engine vllm \
--inference_gpu_memory_utilization 0.8 \
--inference_max_model_len 16384 \
--num_output_sequences 2 \
--temperature 1.0 \
--max_new_tokens 2048 \
--top_p 0.95 \
--random_seed 42 \
--save_inference_results True \
--inference_results_path output_data/vllm_inference_results/
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

extra_require = {
"multimodal": ["Pillow"],
"vllm": ["vllm>=0.4.3"],
"vllm": ["vllm>=0.8.0"],
"sglang": ["sglang"],
"ray": ["ray>=2.22.0"],
"gradio": ["gradio"],
Expand Down
32 changes: 27 additions & 5 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,27 @@ class InferencerArguments:
inference_tensor_parallel_size: Optional[int] = field(
default=1, metadata={"help": "The tensor parallel size for inference."}
)
inference_data_parallel_size: Optional[int] = field(
default=1,
metadata={
"help": (
"The data parallel size for inference. Only supported for vLLM (>= 0.8) inference engine. "
"Total GPUs used = tensor_parallel_size * data_parallel_size."
)
},
)
inference_gpu_memory_utilization: Optional[float] = field(
default=0.95, metadata={"help": "The GPU memory utilization for inference."}
)
inference_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": (
"Maximum model context length for inference. If not set, uses the model's default. "
"Reduce this if the model's default exceeds available GPU memory."
)
},
)
enable_deterministic_inference: bool = field(
default=False,
metadata={
Expand All @@ -1065,7 +1083,14 @@ class InferencerArguments:
results_path: Optional[str] = field(default=None, metadata={"help": "The path of results."})

save_inference_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."})
inference_results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."})
inference_results_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Directory to save inference results. Results are saved as 'inference_results.pkl' inside this directory."
)
},
)

def __post_init__(self):
if self.use_accelerator is not None:
Expand All @@ -1087,10 +1112,7 @@ def __post_init__(self):
if self.inference_results_path is None:
raise ValueError("Need to specify inference_results_path when save_inference_results is True.")
else:
if not self.inference_results_path.endswith(".json"):
raise ValueError("The inference_results_path must be a json file.")
else:
Path(self.inference_results_path).parent.mkdir(parents=True, exist_ok=True)
Path(self.inference_results_path).mkdir(parents=True, exist_ok=True)

if self.use_vllm is True:
logger.warning(
Expand Down
93 changes: 35 additions & 58 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@
TEXT_ONLY_DATASET_DESCRIPTION,
)
from lmflow.utils.conversation_template import PRESET_TEMPLATES
from lmflow.utils.data_utils import VLLMInferenceResultWithInput
from lmflow.utils.deprecated import deprecated_args
from lmflow.utils.envs import is_accelerate_env
from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available
from lmflow.utils.versioning import is_flash_attn_available, is_vllm_available
from lmflow.utils.protocol import DataProto

logger = logging.getLogger(__name__)
Expand All @@ -54,10 +53,6 @@
if is_vllm_available():
from vllm import SamplingParams

if is_ray_available():
import ray
import ray.data


class HFDecoderModel(DecoderModel, HFModelMixin, Tunable):
r"""
Expand Down Expand Up @@ -321,20 +316,21 @@ def inference(
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
gpu_memory_utilization: Optional[float] = None,
tensor_parallel_size: Optional[int] = None,
data_parallel_size: int = 1,
max_model_len: Optional[int] = None,
enable_deterministic_inference: bool = False,
attention_backend: Optional[str] = None,
**kwargs,
) -> Union[list[VLLMInferenceResultWithInput] | DataProto]:
) -> Union[list, DataProto]:
"""
Perform generation process of the model.

Parameters
------------
inputs : Union[str, list[str], torch.Tensor, DataProto]
The sequence used as a prompt for the generation or as model inputs to the model.
When the inference engine is "vllm", this should be a string or a list of strings.
When the inference engine is "vllm" or "sglang", this should be a DataProto.
When the inference engine is "huggingface", this should be a tensor.
When the inference engine is "sglang", this should be a DataProto.
sampling_params : Optional[Union[dict, "SamplingParams"]], optional
The sampling parameters to use, by default None.
return_logprob : bool, optional
Expand All @@ -347,6 +343,10 @@ def inference(
The GPU memory utilization to use, by default None.
tensor_parallel_size : int, optional
The tensor parallel size to use, by default None.
data_parallel_size : int, optional
The data parallel size for vllm inference, by default 1.
max_model_len : int, optional
Maximum model context length for vllm inference, by default None.
enable_deterministic_inference : bool, optional
Whether to enable deterministic inference, by default False.
attention_backend : Optional[str], optional
Expand All @@ -365,12 +365,14 @@ def inference(
inference_engine=inference_engine,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
max_model_len=max_model_len,
enable_deterministic_inference=enable_deterministic_inference,
attention_backend=attention_backend,
)

if inference_engine == "vllm":
res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params)
res = self.__vllm_inference(inputs=inputs)
elif inference_engine == "sglang":
res = self.__sglang_inference(
inputs=inputs,
Expand Down Expand Up @@ -424,46 +426,29 @@ def __inference(self, inputs, *args, **kwargs):

def __vllm_inference(
self,
inputs: list[str],
sampling_params: Optional["SamplingParams"] = None,
) -> list[VLLMInferenceResultWithInput]:
"""Perform VLLM inference process of the model.

Parameters
----------
inputs : list[str]
Prompt(s), string or a list of strings.
sampling_params : Optional[SamplingParams], optional
vllm SamplingParams object, by default None.

Returns
-------
list[VLLMInferenceResultWithInput]
Return a list of VLLMInferenceResultWithInput, where each
element contains the input prompt and the corresponding output.

When `sampling_params.detokenize = True`, the output would be a list of strings,
contains sampling_params.n samples for the corresponding prompt.
inputs: DataProto,
) -> DataProto:
"""Perform VLLM inference process of the model."""
prompts = inputs.non_tensor_batch["inputs"].tolist()
sampling_params_dict = inputs.meta_info["sampling_params"]

vllm_sampling_params = SamplingParams(
n=sampling_params_dict.get("n", 1),
temperature=sampling_params_dict.get("temperature", 0.0),
max_tokens=sampling_params_dict.get("max_new_tokens", 100),
seed=sampling_params_dict.get("seed"),
top_p=sampling_params_dict.get("top_p", 1.0),
top_k=sampling_params_dict.get("top_k", 0),
stop_token_ids=sampling_params_dict.get("stop_token_ids"),
)

When `sampling_params.detokenize = False`, return a list of list of ints
(token ids, no decoding after generation).
"""
vllm_outputs = self.backend_model_for_inference.generate(
inputs,
sampling_params=sampling_params,
prompts,
sampling_params=vllm_sampling_params,
use_tqdm=True,
)
# TODO: unified lmflow sample format
final_output = []
for output in vllm_outputs:
if sampling_params.detokenize:
output_list = [sentence.text for sentence in output.outputs]
else:
output_list = [sentence.token_ids for sentence in output.outputs]

final_output.append({"input": output.prompt, "output": output_list})

return final_output
inputs.non_tensor_batch["outputs"] = [output.outputs[0].text for output in vllm_outputs]
return inputs

def __sglang_inference(
self,
Expand Down Expand Up @@ -495,9 +480,8 @@ def prepare_inputs_for_inference(
dataset: Dataset,
apply_chat_template: bool = True,
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
enable_distributed_inference: bool = False,
sampling_params: Optional[dict] = None,
) -> Union[list[str], "ray.data.Dataset", DataProto]:
) -> Union[list[str], DataProto]:
if dataset.get_type() == "text_only":
if apply_chat_template:
dataset = dataset.map(
Expand Down Expand Up @@ -572,24 +556,17 @@ def preprocess_conversation(sample):

inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0]

if inference_engine == "vllm" and enable_distributed_inference:
inference_inputs = ray.data.from_items(
inference_inputs
) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}

if inference_engine == "sglang":
if inference_engine in ("sglang", "vllm"):
if self.tokenizer.bos_token:
# in consistent with sglang bench_serving.py demo
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]

# currently only test dataproto on sglang inference
inference_inputs = np.array(inference_inputs)
inference_inputs = DataProto.from_single_dict(
data={"inputs": inference_inputs},
meta_info={"sampling_params": {**sampling_params, "n": 1}, "actual_n_rollouts": sampling_params["n"]}
)
# handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines.

# handling n>1 since we don't want one-to-many mapping
inference_inputs = inference_inputs.repeat(sampling_params["n"])

return inference_inputs
Expand Down
35 changes: 26 additions & 9 deletions src/lmflow/models/hf_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,20 +448,27 @@ def __prepare_model_for_vllm_inference(
model_args: ModelArguments,
gpu_memory_utilization: float,
tensor_parallel_size: int,
data_parallel_size: int = 1,
max_model_len: Optional[int] = None,
):
if not is_vllm_available():
raise ImportError('VLLM is not available. Please install via `pip install -e ".[vllm]"`.')

from vllm import LLM

self.backend_model_for_inference = LLM(
kwargs = dict(
model=model_args.model_name_or_path,
tokenizer=model_args.model_name_or_path,
dtype=model_args.torch_dtype if model_args.torch_dtype else "auto",
load_format="auto",
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
)
if max_model_len is not None:
kwargs["max_model_len"] = max_model_len

self.backend_model_for_inference = LLM(**kwargs)

def __prepare_model_for_sglang_inference(
self,
Expand Down Expand Up @@ -513,6 +520,8 @@ def activate_model_for_inference(
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
gpu_memory_utilization: Optional[float] = None,
tensor_parallel_size: Optional[int] = None,
data_parallel_size: int = 1,
max_model_len: Optional[int] = None,
enable_deterministic_inference: bool = False,
attention_backend: Optional[str] = None,
):
Expand All @@ -525,6 +534,8 @@ def activate_model_for_inference(
model_args=self.model_args,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
max_model_len=max_model_len,
)
elif inference_engine == "sglang":
self.__prepare_model_for_sglang_inference(
Expand All @@ -548,20 +559,26 @@ def deactivate_model_for_inference(
):
"""Deactivate the model and release the resources.

NOTE: Currently, VLLM doesn't have an official way to do this, and the
implementation below cannot release all gpu resources by our observation.
Thus this method is just a placeholder for future implementation. See:
[Github issue](https://github.com/vllm-project/vllm/issues/1908)
NOTE: For vllm (>=0.8), the best-effort release below works for most
single-GPU, inference-only use cases. It remains unreliable when
``tensor_parallel_size > 1``, CUDA graphs are enabled, or the same
process also holds an HF training model — in those cases use
:class:`MemorySafeVLLMInferencer`, which isolates inference in a
subprocess. vllm still has no official in-process shutdown API
(RFC vllm-project/vllm#24885); ``MemorySafeVLLMInferencer`` is kept
for backward compatibility and will be migrated to vllm sleep mode
in a follow-up.
"""
if not self._activated:
logger.warning("You are trying to deactivate the model for inference, but it is already deactivated.")
return

if inference_engine == "vllm":
from vllm.distributed.parallel_state import destroy_model_parallel

destroy_model_parallel()
del self.backend_model_for_inference.llm_engine.model_executor.driver_worker
try:
from vllm.distributed.parallel_state import destroy_model_parallel
destroy_model_parallel()
except Exception:
pass
del self.backend_model_for_inference
gc.collect()
torch.cuda.empty_cache()
Expand Down
12 changes: 6 additions & 6 deletions src/lmflow/pipeline/sglang_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import json
import logging
import os
from typing import Optional, Union

from transformers import AutoTokenizer
Expand Down Expand Up @@ -101,14 +102,13 @@ def save_inference_results(
outputs: DataProto,
inference_results_path: str,
):
if not inference_results_path.endswith(".pkl"):
logger.warning(f"The inference results path must be a pickle file. Change the path to {inference_results_path}.pkl")
inference_results_path = inference_results_path + ".pkl"
outputs.save_to_disk(inference_results_path)
logger.info(f"Inference results are saved to {inference_results_path}.")
save_path = os.path.join(inference_results_path, "inference_results.pkl")
outputs.save_to_disk(save_path)
logger.info(f"Inference results are saved to {save_path}.")

def load_inference_results(
self,
inference_results_path: str,
) -> DataProto:
return DataProto.load_from_disk(inference_results_path)
load_path = os.path.join(inference_results_path, "inference_results.pkl")
return DataProto.load_from_disk(load_path)
Loading
Loading