Skip to content

Commit 7f7255a

Browse files
authored
Flair: Add support for PoS Tagging Models & Version Update (#440)
* flair: add support for PoS tagging models * flair: bump requirement to latest 0.14.0 release * flair: add support for PoS tagging models * flair: add support for PoS tagging models * flair: add support for multiple model tests (incl. fixing unsupported task pipeline) * flair: add parameterized test cases (incl. setup class cache clearing) * flair: apply black to test class
1 parent db248d4 commit 7f7255a

File tree

4 files changed

+43
-21
lines changed

4 files changed

+43
-21
lines changed

docker_images/flair/app/pipelines/token_classification.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, List
22

33
from app.pipelines import Pipeline
4-
from flair.data import Sentence
4+
from flair.data import Sentence, Span, Token
55
from flair.models import SequenceTagger
66

77

@@ -27,21 +27,30 @@ def __call__(self, inputs: str) -> List[Dict[str, Any]]:
2727
"""
2828
sentence: Sentence = Sentence(inputs)
2929

30-
# Also show scores for recognized NEs
31-
self.tagger.predict(sentence, label_name="predicted")
30+
self.tagger.predict(sentence)
3231

3332
entities = []
34-
for span in sentence.get_spans("predicted"):
35-
if len(span.tokens) == 0:
36-
continue
37-
current_entity = {
38-
"entity_group": span.tag,
39-
"word": span.text,
40-
"start": span.tokens[0].start_position,
41-
"end": span.tokens[-1].end_position,
42-
"score": span.score,
43-
}
44-
45-
entities.append(current_entity)
33+
for label in sentence.get_labels():
34+
current_data_point = label.data_point
35+
if isinstance(current_data_point, Token):
36+
current_entity = {
37+
"entity_group": current_data_point.tag,
38+
"word": current_data_point.text,
39+
"start": current_data_point.start_position,
40+
"end": current_data_point.end_position,
41+
"score": current_data_point.score,
42+
}
43+
entities.append(current_entity)
44+
elif isinstance(current_data_point, Span):
45+
if not current_data_point.tokens:
46+
continue
47+
current_entity = {
48+
"entity_group": current_data_point.tag,
49+
"word": current_data_point.text,
50+
"start": current_data_point.tokens[0].start_position,
51+
"end": current_data_point.tokens[-1].end_position,
52+
"score": current_data_point.score,
53+
}
54+
entities.append(current_entity)
4655

4756
return entities
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
starlette==0.27.0
22
pydantic==1.8.2
3-
flair @ git+https://github.com/flairNLP/flair@b18aff236098fc6623de8bdb4c8b50e4bfe7f91f
3+
flair @ git+https://github.com/flairNLP/flair@e17ab1234fcfed2b089d8ef02b99949d520382d2
44
api-inference-community==0.0.25

docker_images/flair/tests/test_api.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Dict
2+
from typing import Dict, List
33
from unittest import TestCase, skipIf
44

55
from app.main import ALLOWED_TASKS, get_pipeline
@@ -8,7 +8,9 @@
88
# Must contain at least one example of each implemented pipeline
99
# Tests do not check the actual values of the model output, so small dummy
1010
# models are recommended for faster tests.
11-
TESTABLE_MODELS: Dict[str, str] = {"token-classification": "flair/chunk-english-fast"}
11+
TESTABLE_MODELS: Dict[str, List[str]] = {
12+
"token-classification": ["flair/chunk-english-fast", "flair/upos-english-fast"]
13+
}
1214

1315

1416
ALL_TASKS = {
@@ -35,5 +37,7 @@ def test_unsupported_tasks(self):
3537
unsupported_tasks = ALL_TASKS - ALLOWED_TASKS.keys()
3638
for unsupported_task in unsupported_tasks:
3739
with self.subTest(msg=unsupported_task, task=unsupported_task):
40+
os.environ["TASK"] = unsupported_task
41+
os.environ["MODEL_ID"] = "XX"
3842
with self.assertRaises(EnvironmentError):
39-
get_pipeline(unsupported_task, model_id="XX")
43+
get_pipeline()

docker_images/flair/tests/test_api_token_classification.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest import TestCase, skipIf
44

55
from app.main import ALLOWED_TASKS
6+
from parameterized import parameterized_class
67
from starlette.testclient import TestClient
78
from tests.test_api import TESTABLE_MODELS
89

@@ -11,17 +12,25 @@
1112
"token-classification" not in ALLOWED_TASKS,
1213
"token-classification not implemented",
1314
)
15+
@parameterized_class(
16+
[{"model_id": model_id} for model_id in TESTABLE_MODELS["token-classification"]]
17+
)
1418
class TokenClassificationTestCase(TestCase):
1519
def setUp(self):
16-
model_id = TESTABLE_MODELS["token-classification"]
1720
self.old_model_id = os.getenv("MODEL_ID")
1821
self.old_task = os.getenv("TASK")
19-
os.environ["MODEL_ID"] = model_id
22+
os.environ["MODEL_ID"] = self.model_id
2023
os.environ["TASK"] = "token-classification"
2124
from app.main import app
2225

2326
self.app = app
2427

28+
@classmethod
29+
def setUpClass(cls):
30+
from app.main import get_pipeline
31+
32+
get_pipeline.cache_clear()
33+
2534
def tearDown(self):
2635
if self.old_model_id is not None:
2736
os.environ["MODEL_ID"] = self.old_model_id

0 commit comments

Comments
 (0)