Skip to content

Commit 23e39f2

Browse files
Add a T5TokenizerOptions node to set options for the T5 tokenizer. (Comfy-Org#7803)
1 parent 78992c4 commit 23e39f2

File tree

9 files changed

+60
-22
lines changed

9 files changed

+60
-22
lines changed

comfy/sd.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,34 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz
120120
self.layer_idx = None
121121
self.use_clip_schedule = False
122122
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
123+
self.tokenizer_options = {}
123124

124125
def clone(self):
125126
n = CLIP(no_init=True)
126127
n.patcher = self.patcher.clone()
127128
n.cond_stage_model = self.cond_stage_model
128129
n.tokenizer = self.tokenizer
129130
n.layer_idx = self.layer_idx
131+
n.tokenizer_options = self.tokenizer_options.copy()
130132
n.use_clip_schedule = self.use_clip_schedule
131133
n.apply_hooks_to_conds = self.apply_hooks_to_conds
132134
return n
133135

134136
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
135137
return self.patcher.add_patches(patches, strength_patch, strength_model)
136138

139+
def set_tokenizer_option(self, option_name, value):
140+
self.tokenizer_options[option_name] = value
141+
137142
def clip_layer(self, layer_idx):
138143
self.layer_idx = layer_idx
139144

140145
def tokenize(self, text, return_word_ids=False, **kwargs):
146+
tokenizer_options = kwargs.get("tokenizer_options", {})
147+
if len(self.tokenizer_options) > 0:
148+
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
149+
if len(tokenizer_options) > 0:
150+
kwargs["tokenizer_options"] = tokenizer_options
141151
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
142152

143153
def add_hooks_to_dict(self, pooled_dict: dict[str]):

comfy/sd1_clip.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
457457
return embed_out
458458

459459
class SDTokenizer:
460-
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
460+
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
461461
if tokenizer_path is None:
462462
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
463463
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
464464
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
465465
self.min_length = min_length
466466
self.end_token = None
467+
self.min_padding = min_padding
467468

468469
empty = self.tokenizer('')["input_ids"]
469470
self.tokenizer_adds_end_token = has_end_token
@@ -518,13 +519,15 @@ def _try_get_embedding(self, embedding_name:str):
518519
return (embed, leftover)
519520

520521

521-
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
522+
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
522523
'''
523524
Takes a prompt and converts it to a list of (token, weight, word id) elements.
524525
Tokens can both be integer tokens and pre computed CLIP tensors.
525526
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
526527
Returned list has the dimensions NxM where M is the input size of CLIP
527528
'''
529+
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
530+
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
528531

529532
text = escape_important(text)
530533
parsed_weights = token_weights(text, 1.0)
@@ -603,10 +606,12 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
603606
#fill last batch
604607
if self.end_token is not None:
605608
batch.append((self.end_token, 1.0, 0))
606-
if self.pad_to_max_length:
609+
if min_padding is not None:
610+
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
611+
if self.pad_to_max_length and len(batch) < self.max_length:
607612
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
608-
if self.min_length is not None and len(batch) < self.min_length:
609-
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
613+
if min_length is not None and len(batch) < min_length:
614+
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
610615

611616
if not return_word_ids:
612617
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@@ -634,7 +639,7 @@ def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", t
634639

635640
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
636641
out = {}
637-
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
642+
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
638643
return out
639644

640645
def untokenize(self, token_weight_pair):

comfy/sdxl_clip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
2828

2929
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
3030
out = {}
31-
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
32-
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
31+
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
32+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
3333
return out
3434

3535
def untokenize(self, token_weight_pair):

comfy/text_encoders/flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
1919

2020
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
2121
out = {}
22-
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
23-
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
22+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
23+
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
2424
return out
2525

2626
def untokenize(self, token_weight_pair):

comfy/text_encoders/hidream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
1616

1717
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
1818
out = {}
19-
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
20-
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
21-
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
19+
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
20+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
21+
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
2222
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
23-
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
23+
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
2424
return out
2525

2626
def untokenize(self, token_weight_pair):

comfy/text_encoders/hunyuan_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
4949

5050
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
5151
out = {}
52-
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
52+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
5353

5454
if llama_template is None:
5555
llama_text = self.llama_template.format(text)
5656
else:
5757
llama_text = llama_template.format(text)
58-
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
58+
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
5959
embed_count = 0
6060
for r in llama_text_tokens:
6161
for i in range(len(r)):

comfy/text_encoders/hydit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
4141

4242
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
4343
out = {}
44-
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
45-
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
44+
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
45+
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
4646
return out
4747

4848
def untokenize(self, token_weight_pair):

comfy/text_encoders/sd3_clip.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
4545

4646
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
4747
out = {}
48-
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
49-
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
50-
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
48+
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
49+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
50+
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
5151
return out
5252

5353
def untokenize(self, token_weight_pair):

comfy_extras/nodes_cond.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,29 @@ def encode(self, clip, conditioning, text):
2020
c.append(n)
2121
return (c, )
2222

23+
class T5TokenizerOptions:
24+
@classmethod
25+
def INPUT_TYPES(s):
26+
return {
27+
"required": {
28+
"clip": ("CLIP", ),
29+
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
30+
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
31+
}
32+
}
33+
34+
RETURN_TYPES = ("CLIP",)
35+
FUNCTION = "set_options"
36+
37+
def set_options(self, clip, min_padding, min_length):
38+
clip = clip.clone()
39+
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
40+
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
41+
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
42+
43+
return (clip, )
44+
2345
NODE_CLASS_MAPPINGS = {
24-
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
46+
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
47+
"T5TokenizerOptions": T5TokenizerOptions,
2548
}

0 commit comments

Comments
 (0)