Skip to content

Commit a266dff

Browse files
committed
Remove Conversation pipeline utils
Remove support for the deprecated Conversation pipeline that has been removed from transformers as of v4.42. Fixes issue #129
1 parent 45d94b7 commit a266dff

File tree

2 files changed

+1
-60
lines changed

2 files changed

+1
-60
lines changed

src/sagemaker_huggingface_inference_toolkit/transformers_utils.py

+1-24
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from huggingface_hub import HfApi, login, snapshot_download
2222
from transformers import AutoTokenizer, pipeline
2323
from transformers.file_utils import is_tf_available, is_torch_available
24-
from transformers.pipelines import Conversation, Pipeline
24+
from transformers.pipelines import Pipeline
2525

2626
from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available
2727
from sagemaker_huggingface_inference_toolkit.optimum_utils import (
@@ -117,25 +117,6 @@ def create_artifact_filter(framework):
117117
return []
118118

119119

120-
def wrap_conversation_pipeline(pipeline):
121-
def wrapped_pipeline(inputs, *args, **kwargs):
122-
converted_input = Conversation(
123-
inputs["text"],
124-
past_user_inputs=inputs.get("past_user_inputs", []),
125-
generated_responses=inputs.get("generated_responses", []),
126-
)
127-
prediction = pipeline(converted_input, *args, **kwargs)
128-
return {
129-
"generated_text": prediction.generated_responses[-1],
130-
"conversation": {
131-
"past_user_inputs": prediction.past_user_inputs,
132-
"generated_responses": prediction.generated_responses,
133-
},
134-
}
135-
136-
return wrapped_pipeline
137-
138-
139120
def _is_gpu_available():
140121
"""
141122
checks if a gpu is available.
@@ -310,8 +291,4 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
310291
task=task, model=model_dir, device=device, trust_remote_code=TRUST_REMOTE_CODE, **kwargs
311292
)
312293

313-
# wrapp specific pipeline to support better ux
314-
if task == "conversational":
315-
hf_pipeline = wrap_conversation_pipeline(hf_pipeline)
316-
317294
return hf_pipeline

tests/unit/test_transformers_utils.py

-36
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import os
1515
import tempfile
1616

17-
from transformers import pipeline
1817
from transformers.file_utils import is_torch_available
1918
from transformers.testing_utils import require_tf, require_torch, slow
2019

@@ -26,7 +25,6 @@
2625
get_pipeline,
2726
infer_task_from_hub,
2827
infer_task_from_model_architecture,
29-
wrap_conversation_pipeline,
3028
)
3129

3230

@@ -129,37 +127,3 @@ def test_infer_task_from_model_architecture():
129127
storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname)
130128
task = infer_task_from_model_architecture(f"{storage_dir}/config.json")
131129
assert task == "token-classification"
132-
133-
134-
@require_torch
135-
def test_wrap_conversation_pipeline():
136-
init_pipeline = pipeline(
137-
"conversational",
138-
model="microsoft/DialoGPT-small",
139-
tokenizer="microsoft/DialoGPT-small",
140-
framework="pt",
141-
)
142-
conv_pipe = wrap_conversation_pipeline(init_pipeline)
143-
data = {
144-
"past_user_inputs": ["Which movie is the best ?"],
145-
"generated_responses": ["It's Die Hard for sure."],
146-
"text": "Can you explain why?",
147-
}
148-
res = conv_pipe(data)
149-
assert "conversation" in res
150-
assert "generated_text" in res
151-
152-
153-
@require_torch
154-
def test_wrapped_pipeline():
155-
with tempfile.TemporaryDirectory() as tmpdirname:
156-
storage_dir = _load_model_from_hub("microsoft/DialoGPT-small", tmpdirname)
157-
conv_pipe = get_pipeline("conversational", -1, storage_dir)
158-
data = {
159-
"past_user_inputs": ["Which movie is the best ?"],
160-
"generated_responses": ["It's Die Hard for sure."],
161-
"text": "Can you explain why?",
162-
}
163-
res = conv_pipe(data)
164-
assert "conversation" in res
165-
assert "generated_text" in res

0 commit comments

Comments
 (0)