3737 TEXT_ONLY_DATASET_DESCRIPTION ,
3838)
3939from lmflow .utils .conversation_template import PRESET_TEMPLATES
40- from lmflow .utils .data_utils import VLLMInferenceResultWithInput
4140from lmflow .utils .deprecated import deprecated_args
4241from lmflow .utils .envs import is_accelerate_env
43- from lmflow .utils .versioning import is_flash_attn_available , is_ray_available , is_vllm_available
42+ from lmflow .utils .versioning import is_flash_attn_available , is_vllm_available
4443from lmflow .utils .protocol import DataProto
4544
4645logger = logging .getLogger (__name__ )
5453if is_vllm_available ():
5554 from vllm import SamplingParams
5655
57- if is_ray_available ():
58- import ray
59- import ray .data
60-
6156
6257class HFDecoderModel (DecoderModel , HFModelMixin , Tunable ):
6358 r"""
@@ -321,20 +316,21 @@ def inference(
321316 inference_engine : Literal ["huggingface" , "vllm" , "sglang" ] = "huggingface" ,
322317 gpu_memory_utilization : Optional [float ] = None ,
323318 tensor_parallel_size : Optional [int ] = None ,
319+ data_parallel_size : int = 1 ,
320+ max_model_len : Optional [int ] = None ,
324321 enable_deterministic_inference : bool = False ,
325322 attention_backend : Optional [str ] = None ,
326323 ** kwargs ,
327- ) -> Union [list [ VLLMInferenceResultWithInput ] | DataProto ]:
324+ ) -> Union [list , DataProto ]:
328325 """
329326 Perform generation process of the model.
330327
331328 Parameters
332329 ------------
333330 inputs : Union[str, list[str], torch.Tensor, DataProto]
334331 The sequence used as a prompt for the generation or as model inputs to the model.
335- When the inference engine is "vllm", this should be a string or a list of strings .
332+ When the inference engine is "vllm" or "sglang" , this should be a DataProto .
336333 When the inference engine is "huggingface", this should be a tensor.
337- When the inference engine is "sglang", this should be a DataProto.
338334 sampling_params : Optional[Union[dict, "SamplingParams"]], optional
339335 The sampling parameters to use, by default None.
340336 return_logprob : bool, optional
@@ -347,6 +343,10 @@ def inference(
347343 The GPU memory utilization to use, by default None.
348344 tensor_parallel_size : int, optional
349345 The tensor parallel size to use, by default None.
346+ data_parallel_size : int, optional
347+ The data parallel size for vllm inference, by default 1.
348+ max_model_len : int, optional
349+ Maximum model context length for vllm inference, by default None.
350350 enable_deterministic_inference : bool, optional
351351 Whether to enable deterministic inference, by default False.
352352 attention_backend : Optional[str], optional
@@ -365,12 +365,14 @@ def inference(
365365 inference_engine = inference_engine ,
366366 gpu_memory_utilization = gpu_memory_utilization ,
367367 tensor_parallel_size = tensor_parallel_size ,
368+ data_parallel_size = data_parallel_size ,
369+ max_model_len = max_model_len ,
368370 enable_deterministic_inference = enable_deterministic_inference ,
369371 attention_backend = attention_backend ,
370372 )
371373
372374 if inference_engine == "vllm" :
373- res = self .__vllm_inference (inputs = inputs , sampling_params = sampling_params )
375+ res = self .__vllm_inference (inputs = inputs )
374376 elif inference_engine == "sglang" :
375377 res = self .__sglang_inference (
376378 inputs = inputs ,
@@ -424,46 +426,29 @@ def __inference(self, inputs, *args, **kwargs):
424426
425427 def __vllm_inference (
426428 self ,
427- inputs : list [str ],
428- sampling_params : Optional ["SamplingParams" ] = None ,
429- ) -> list [VLLMInferenceResultWithInput ]:
430- """Perform VLLM inference process of the model.
431-
432- Parameters
433- ----------
434- inputs : list[str]
435- Prompt(s), string or a list of strings.
436- sampling_params : Optional[SamplingParams], optional
437- vllm SamplingParams object, by default None.
438-
439- Returns
440- -------
441- list[VLLMInferenceResultWithInput]
442- Return a list of VLLMInferenceResultWithInput, where each
443- element contains the input prompt and the corresponding output.
444-
445- When `sampling_params.detokenize = True`, the output would be a list of strings,
446- contains sampling_params.n samples for the corresponding prompt.
429+ inputs : DataProto ,
430+ ) -> DataProto :
431+ """Perform VLLM inference process of the model."""
432+ prompts = inputs .non_tensor_batch ["inputs" ].tolist ()
433+ sampling_params_dict = inputs .meta_info ["sampling_params" ]
434+
435+ vllm_sampling_params = SamplingParams (
436+ n = sampling_params_dict .get ("n" , 1 ),
437+ temperature = sampling_params_dict .get ("temperature" , 0.0 ),
438+ max_tokens = sampling_params_dict .get ("max_new_tokens" , 100 ),
439+ seed = sampling_params_dict .get ("seed" ),
440+ top_p = sampling_params_dict .get ("top_p" , 1.0 ),
441+ top_k = sampling_params_dict .get ("top_k" , 0 ),
442+ stop_token_ids = sampling_params_dict .get ("stop_token_ids" ),
443+ )
447444
448- When `sampling_params.detokenize = False`, return a list of list of ints
449- (token ids, no decoding after generation).
450- """
451445 vllm_outputs = self .backend_model_for_inference .generate (
452- inputs ,
453- sampling_params = sampling_params ,
446+ prompts ,
447+ sampling_params = vllm_sampling_params ,
454448 use_tqdm = True ,
455449 )
456- # TODO: unified lmflow sample format
457- final_output = []
458- for output in vllm_outputs :
459- if sampling_params .detokenize :
460- output_list = [sentence .text for sentence in output .outputs ]
461- else :
462- output_list = [sentence .token_ids for sentence in output .outputs ]
463-
464- final_output .append ({"input" : output .prompt , "output" : output_list })
465-
466- return final_output
450+ inputs .non_tensor_batch ["outputs" ] = [output .outputs [0 ].text for output in vllm_outputs ]
451+ return inputs
467452
468453 def __sglang_inference (
469454 self ,
@@ -495,9 +480,8 @@ def prepare_inputs_for_inference(
495480 dataset : Dataset ,
496481 apply_chat_template : bool = True ,
497482 inference_engine : Literal ["huggingface" , "vllm" , "sglang" ] = "huggingface" ,
498- enable_distributed_inference : bool = False ,
499483 sampling_params : Optional [dict ] = None ,
500- ) -> Union [list [str ], "ray.data.Dataset" , DataProto ]:
484+ ) -> Union [list [str ], DataProto ]:
501485 if dataset .get_type () == "text_only" :
502486 if apply_chat_template :
503487 dataset = dataset .map (
@@ -572,24 +556,17 @@ def preprocess_conversation(sample):
572556
573557 inference_inputs = [sentence for sentence in inference_inputs if len (sentence ) > 0 ]
574558
575- if inference_engine == "vllm" and enable_distributed_inference :
576- inference_inputs = ray .data .from_items (
577- inference_inputs
578- ) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}
579-
580- if inference_engine == "sglang" :
559+ if inference_engine in ("sglang" , "vllm" ):
581560 if self .tokenizer .bos_token :
582- # in consistent with sglang bench_serving.py demo
583561 inference_inputs = [sentence .replace (self .tokenizer .bos_token , "" ) for sentence in inference_inputs ]
584562
585- # currently only test dataproto on sglang inference
586563 inference_inputs = np .array (inference_inputs )
587564 inference_inputs = DataProto .from_single_dict (
588565 data = {"inputs" : inference_inputs },
589566 meta_info = {"sampling_params" : {** sampling_params , "n" : 1 }, "actual_n_rollouts" : sampling_params ["n" ]}
590567 )
591-
592- # handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines.
568+
569+ # handling n>1 since we don't want one-to-many mapping
593570 inference_inputs = inference_inputs .repeat (sampling_params ["n" ])
594571
595572 return inference_inputs
0 commit comments