@@ -818,22 +818,21 @@ def __init__(
818
818
"""
819
819
super ().__init__ ()
820
820
821
- self .tokens : list [Token ] = []
821
+ self ._tokens : Optional [list [Token ]] = None
822
+ self ._text : str = "" # Change from Optional[str] to str with empty string default
822
823
823
- # private field for all known spans
824
- self ._known_spans : dict [str , _PartOfSentence ] = {}
824
+ # private field for all known spans with explicit typing
825
+ self ._known_spans : dict [str , Union [ Span , Relation ] ] = {}
825
826
826
827
self .language_code : Optional [str ] = language_code
827
828
828
829
self ._start_position = start_position
829
830
830
831
# the tokenizer used for this sentence
831
832
if isinstance (use_tokenizer , Tokenizer ):
832
- tokenizer = use_tokenizer
833
-
833
+ self ._tokenizer = use_tokenizer
834
834
elif isinstance (use_tokenizer , bool ):
835
- tokenizer = SegtokTokenizer () if use_tokenizer else SpaceTokenizer ()
836
-
835
+ self ._tokenizer = SegtokTokenizer () if use_tokenizer else SpaceTokenizer ()
837
836
else :
838
837
raise AssertionError ("Unexpected type of parameter 'use_tokenizer'. Parameter should be bool or Tokenizer" )
839
838
@@ -848,24 +847,79 @@ def __init__(
848
847
self ._next_sentence : Optional [Sentence ] = None
849
848
self ._position_in_dataset : Optional [tuple [Dataset , int ]] = None
850
849
851
- # if text is passed, instantiate sentence with tokens (words)
852
- if isinstance (text , str ):
853
- text = Sentence ._handle_problem_characters (text )
854
- words = tokenizer .tokenize (text )
855
- elif text and isinstance (text [0 ], Token ):
856
- for t in text :
857
- self ._add_token (t )
858
- self .tokens [- 1 ].whitespace_after = 0
859
- return
850
+ # if list of strings or tokens is passed, create tokens directly
851
+ if not isinstance (text , str ):
852
+ self ._tokens = []
853
+
854
+ # First construct the text from tokens to ensure proper text reconstruction
855
+ if len (text ) > 0 :
856
+ # Type check the input list and cast
857
+ if all (isinstance (t , Token ) for t in text ):
858
+ tokens = cast (list [Token ], text )
859
+ reconstructed_text = ""
860
+ for i , token in enumerate (tokens ):
861
+ reconstructed_text += token .text
862
+ if i < len (tokens ) - 1 : # Add whitespace between tokens
863
+ reconstructed_text += " " * token .whitespace_after
864
+ self ._text = reconstructed_text
865
+ elif all (isinstance (t , str ) for t in text ):
866
+ strings = cast (list [str ], text )
867
+ self ._text = " " .join (strings )
868
+ else :
869
+ raise TypeError ("All elements must be either Token or str" )
870
+ else :
871
+ self ._text = ""
872
+
873
+ # Now add the tokens
874
+ current_position = 0
875
+ for i , item in enumerate (text ):
876
+ # create Token if string, otherwise use existing Token
877
+ if isinstance (item , str ):
878
+ # For strings, create new Token with default whitespace
879
+ token = Token (text = item )
880
+ token .whitespace_after = 0 if i == len (text ) - 1 else 1
881
+ elif isinstance (item , Token ):
882
+ # For existing Tokens, preserve their whitespace_after
883
+ token = item
884
+
885
+ # Set start position for the token
886
+ token .start_position = current_position
887
+ current_position += len (token .text ) + token .whitespace_after
888
+
889
+ self ._add_token (token )
890
+
891
+ if len (text ) > 0 :
892
+ # convention: the last token has no whitespace after
893
+ self .tokens [- 1 ].whitespace_after = 0
860
894
else :
861
- words = cast (list [str ], text )
862
- text = " " .join (words )
895
+ self ._text = Sentence ._handle_problem_characters (text )
896
+
897
+ # log a warning if the dataset is empty
898
+ if self ._text == "" :
899
+ log .warning ("Warning: An empty Sentence was created! Are there empty strings in your dataset?" )
900
+
901
+ @property
902
+ def tokens (self ) -> list [Token ]:
903
+ """Gets the tokens of this sentence. Automatically triggers tokenization if not yet tokenized."""
904
+ if self ._tokens is None :
905
+ self ._tokenize ()
906
+ if self ._tokens is None :
907
+ raise ValueError ("Tokens are None after tokenization - this indicates a bug in the tokenization process" )
908
+ return self ._tokens
909
+
910
+ def _tokenize (self ) -> None :
911
+ """Internal method that performs tokenization."""
912
+
913
+ # tokenize the text
914
+ words = self ._tokenizer .tokenize (self ._text )
863
915
864
916
# determine token positions and whitespace_after flag
865
917
current_offset : int = 0
866
918
previous_token : Optional [Token ] = None
919
+ self ._tokens = []
920
+
867
921
for word in words :
868
- word_start_position : int = text .index (word , current_offset )
922
+ word_start_position : int = self . _text .index (word , current_offset )
869
923
delta_offset : int = word_start_position - current_offset
870
924
871
925
token : Token = Token (text = word , start_position = word_start_position )
@@ -878,17 +932,56 @@ def __init__(
878
932
previous_token = token
879
933
880
934
# the last token has no whitespace after
881
- if len (self ) > 0 :
882
- self .tokens [- 1 ].whitespace_after = 0
935
+ if len (self . _tokens ) > 0 :
936
+ self ._tokens [- 1 ].whitespace_after = 0
883
937
884
- # log a warning if the dataset is empty
885
- if text == "" :
886
- log .warning ("Warning: An empty Sentence was created! Are there empty strings in your dataset?" )
938
+ def __iter__ (self ):
939
+ """Allows iteration over tokens. Triggers tokenization if not yet tokenized."""
940
+ return iter (self .tokens )
941
+
942
+ def __len__ (self ) -> int :
943
+ """Returns the number of tokens in this sentence. Triggers tokenization if not yet tokenized."""
944
+ return len (self .tokens )
887
945
888
946
@property
889
947
def unlabeled_identifier (self ):
890
948
return f'Sentence[{ len (self )} ]: "{ self .text } "'
891
949
950
+ @property
951
+ def text (self ) -> str :
952
+ """Returns the original text of this sentence. Does not trigger tokenization."""
953
+ return self ._text
954
+
955
+ def to_original_text (self ) -> str :
956
+ """Returns the original text of this sentence."""
957
+ return self ._text
958
+
959
+ def to_tagged_string (self , main_label : Optional [str ] = None ) -> str :
960
+ # For sentence-level labels, we don't need tokenization
961
+ if not self ._tokens :
962
+ output = f'Sentence: "{ self .text } "'
963
+ if self .labels :
964
+ output += self ._printout_labels (main_label )
965
+ return output
966
+
967
+ # Only tokenize if we have token-level labels or spans to print
968
+ already_printed = [self ]
969
+ output = super ().__str__ ()
970
+
971
+ label_append = []
972
+ for label in self .get_labels (main_label ):
973
+ if label .data_point in already_printed :
974
+ continue
975
+ label_append .append (
976
+ f'"{ label .data_point .text } "{ label .data_point ._printout_labels (main_label = main_label , add_score = False )} '
977
+ )
978
+ already_printed .append (label .data_point )
979
+
980
+ if len (label_append ) > 0 :
981
+ output += f"{ flair ._arrow } [" + ", " .join (label_append ) + "]"
982
+
983
+ return output
984
+
892
985
def get_relations (self , label_type : Optional [str ] = None ) -> list [Relation ]:
893
986
relations : list [Relation ] = []
894
987
for label in self .get_labels (label_type ):
@@ -951,11 +1044,13 @@ def to(self, device: str, pin_memory: bool = False):
951
1044
token .to (device , pin_memory )
952
1045
953
1046
def clear_embeddings (self , embedding_names : Optional [list [str ]] = None ):
1047
+ # clear sentence embeddings
954
1048
super ().clear_embeddings (embedding_names )
955
1049
956
- # clear token embeddings
957
- for token in self :
958
- token .clear_embeddings (embedding_names )
1050
+ # clear token embeddings if sentence is tokenized
1051
+ if self ._is_tokenized ():
1052
+ for token in self .tokens :
1053
+ token .clear_embeddings (embedding_names )
959
1054
960
1055
def left_context (self , context_length : int , respect_document_boundaries : bool = True ) -> list [Token ]:
961
1056
sentence = self
@@ -987,29 +1082,6 @@ def right_context(self, context_length: int, respect_document_boundaries: bool =
987
1082
def __str__ (self ) -> str :
988
1083
return self .to_tagged_string ()
989
1084
990
- def to_tagged_string (self , main_label : Optional [str ] = None ) -> str :
991
- already_printed = [self ]
992
-
993
- output = super ().__str__ ()
994
-
995
- label_append = []
996
- for label in self .get_labels (main_label ):
997
- if label .data_point in already_printed :
998
- continue
999
- label_append .append (
1000
- f'"{ label .data_point .text } "{ label .data_point ._printout_labels (main_label = main_label , add_score = False )} '
1001
- )
1002
- already_printed .append (label .data_point )
1003
-
1004
- if len (label_append ) > 0 :
1005
- output += f"{ flair ._arrow } [" + ", " .join (label_append ) + "]"
1006
-
1007
- return output
1008
-
1009
- @property
1010
- def text (self ) -> str :
1011
- return self .to_original_text ()
1012
-
1013
1085
def to_tokenized_string (self ) -> str :
1014
1086
if self .tokenized is None :
1015
1087
self .tokenized = " " .join ([t .text for t in self .tokens ])
@@ -1056,15 +1128,6 @@ def infer_space_after(self):
1056
1128
last_token = token
1057
1129
return self
1058
1130
1059
- def to_original_text (self ) -> str :
1060
- # if sentence has no tokens, return empty string
1061
- if len (self ) == 0 :
1062
- return ""
1063
- # otherwise, return concatenation of tokens with the correct offsets
1064
- return (self [0 ].start_position - self .start_position ) * " " + "" .join (
1065
- [t .text + t .whitespace_after * " " for t in self .tokens ]
1066
- ).strip ()
1067
-
1068
1131
def to_dict (self , tag_type : Optional [str ] = None ) -> dict [str , Any ]:
1069
1132
return {
1070
1133
"text" : self .to_original_text (),
@@ -1090,12 +1153,6 @@ def __getitem__(self, subscript):
1090
1153
else :
1091
1154
return self .tokens [subscript ]
1092
1155
1093
- def __iter__ (self ):
1094
- return iter (self .tokens )
1095
-
1096
- def __len__ (self ) -> int :
1097
- return len (self .tokens )
1098
-
1099
1156
def __repr__ (self ) -> str :
1100
1157
return self .__str__ ()
1101
1158
@@ -1233,20 +1290,59 @@ def get_labels(self, label_type: Optional[str] = None):
1233
1290
return []
1234
1291
1235
1292
def remove_labels (self , typename : str ):
1236
- # labels also need to be deleted at all tokens
1237
- for token in self :
1238
- token .remove_labels (typename )
1239
-
1240
- # labels also need to be deleted at all known spans
1241
- for span in self ._known_spans .values ():
1242
- span .remove_labels (typename )
1293
+ # only access tokens if already tokenized
1294
+ if self ._is_tokenized ():
1295
+ # labels also need to be deleted at all tokens
1296
+ for token in self .tokens :
1297
+ token .remove_labels (typename )
1243
1298
1244
- # remove spans without labels
1245
- self ._known_spans = {k : v for k , v in self ._known_spans .items () if len (v .labels ) > 0 }
1299
+ # labels also need to be deleted at all known spans
1300
+ for span in self ._known_spans .values ():
1301
+ span .remove_labels (typename )
1246
1302
1247
- # delete labels at object itself
1303
+ # delete labels at object itself first
1248
1304
super ().remove_labels (typename )
1249
1305
1306
+ def _is_tokenized (self ) -> bool :
1307
+ return self ._tokens is not None
1308
+
1309
+ def truncate (self , max_tokens : int ) -> None :
1310
+ """Truncates the sentence to a maximum number of tokens and updates all annotations accordingly."""
1311
+ if len (self .tokens ) <= max_tokens :
1312
+ return
1313
+
1314
+ # Truncate tokens
1315
+ self ._tokens = self .tokens [:max_tokens ]
1316
+
1317
+ # Remove spans that reference removed tokens
1318
+ self ._known_spans = {
1319
+ identifier : span
1320
+ for identifier , span in self ._known_spans .items ()
1321
+ if isinstance (span , Span ) and all (token .idx <= max_tokens for token in span .tokens )
1322
+ }
1323
+
1324
+ # Remove relations that reference removed spans
1325
+ self ._known_spans = {
1326
+ identifier : relation
1327
+ for identifier , relation in self ._known_spans .items ()
1328
+ if not isinstance (relation , Relation )
1329
+ or (
1330
+ all (token .idx <= max_tokens for token in relation .first .tokens )
1331
+ and all (token .idx <= max_tokens for token in relation .second .tokens )
1332
+ )
1333
+ }
1334
+
1335
+ # Clean up any labels that reference removed spans/relations
1336
+ for typename in list (self .annotation_layers .keys ()):
1337
+ self .annotation_layers [typename ] = [
1338
+ label
1339
+ for label in self .annotation_layers [typename ]
1340
+ if (
1341
+ not isinstance (label .data_point , (Span , Relation ))
1342
+ or label .data_point .unlabeled_identifier in self ._known_spans
1343
+ )
1344
+ ]
1345
+
1250
1346
1251
1347
class DataPair (DataPoint , typing .Generic [DT , DT2 ]):
1252
1348
def __init__ (self , first : DT , second : DT2 ) -> None :
@@ -1375,7 +1471,7 @@ class Corpus(typing.Generic[T_co]):
1375
1471
"""The main object in Flair for holding a dataset used for training and testing.
1376
1472
1377
1473
A corpus consists of three splits: A `train` split used for training, a `dev` split used for model selection
1378
- and/ or early stopping and a `test` split used for testing. All three splits are optional, so it is possible
1474
+ or early stopping and a `test` split used for testing. All three splits are optional, so it is possible
1379
1475
to create a corpus only using one or two splits. If the option `sample_missing_splits` is set to True,
1380
1476
missing splits will be randomly sampled from the training split.
1381
1477
"""
0 commit comments