|
21 | 21 | from huggingface_hub import HfApi, login, snapshot_download
|
22 | 22 | from transformers import AutoTokenizer, pipeline
|
23 | 23 | from transformers.file_utils import is_tf_available, is_torch_available
|
24 |
| -from transformers.pipelines import Conversation, Pipeline |
| 24 | +from transformers.pipelines import Pipeline |
25 | 25 |
|
26 | 26 | from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available
|
27 | 27 | from sagemaker_huggingface_inference_toolkit.optimum_utils import (
|
@@ -117,25 +117,6 @@ def create_artifact_filter(framework):
|
117 | 117 | return []
|
118 | 118 |
|
119 | 119 |
|
120 |
| -def wrap_conversation_pipeline(pipeline): |
121 |
| - def wrapped_pipeline(inputs, *args, **kwargs): |
122 |
| - converted_input = Conversation( |
123 |
| - inputs["text"], |
124 |
| - past_user_inputs=inputs.get("past_user_inputs", []), |
125 |
| - generated_responses=inputs.get("generated_responses", []), |
126 |
| - ) |
127 |
| - prediction = pipeline(converted_input, *args, **kwargs) |
128 |
| - return { |
129 |
| - "generated_text": prediction.generated_responses[-1], |
130 |
| - "conversation": { |
131 |
| - "past_user_inputs": prediction.past_user_inputs, |
132 |
| - "generated_responses": prediction.generated_responses, |
133 |
| - }, |
134 |
| - } |
135 |
| - |
136 |
| - return wrapped_pipeline |
137 |
| - |
138 |
| - |
139 | 120 | def _is_gpu_available():
|
140 | 121 | """
|
141 | 122 | checks if a gpu is available.
|
@@ -310,8 +291,4 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
|
310 | 291 | task=task, model=model_dir, device=device, trust_remote_code=TRUST_REMOTE_CODE, **kwargs
|
311 | 292 | )
|
312 | 293 |
|
313 |
| - # wrapp specific pipeline to support better ux |
314 |
| - if task == "conversational": |
315 |
| - hf_pipeline = wrap_conversation_pipeline(hf_pipeline) |
316 |
| - |
317 | 294 | return hf_pipeline
|
0 commit comments