|
20 | 20 |
|
21 | 21 | "to_sharegpt",
|
22 | 22 | "standardize_sharegpt",
|
| 23 | + "standardize_data_formats", |
23 | 24 | "apply_chat_template",
|
24 | 25 | "train_on_responses_only",
|
25 | 26 |
|
|
37 | 38 | import re
|
38 | 39 | from unsloth_zoo.dataset_utils import (
|
39 | 40 | train_on_responses_only,
|
| 41 | + standardize_data_formats, |
40 | 42 | )
|
| 43 | +standardize_sharegpt = standardize_data_formats |
41 | 44 | CHAT_TEMPLATES = {}
|
42 | 45 | DEFAULT_SYSTEM_MESSAGE = {}
|
43 | 46 |
|
|
934 | 937 | pass
|
935 | 938 |
|
936 | 939 |
|
| 940 | +# =========================================== Gemma-3 |
| 941 | +# Obtained via |
| 942 | +# print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n")) |
| 943 | +gemma3_template = \ |
| 944 | +"""{{ bos_token }} |
| 945 | +{%- if messages[0]['role'] == 'system' -%} |
| 946 | + {%- if messages[0]['content'] is string -%} |
| 947 | + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} |
| 948 | + {%- else -%} |
| 949 | + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} |
| 950 | + {%- endif -%} |
| 951 | + {%- set loop_messages = messages[1:] -%} |
| 952 | +{%- else -%} |
| 953 | + {%- set first_user_prefix = "" -%} |
| 954 | + {%- set loop_messages = messages -%} |
| 955 | +{%- endif -%} |
| 956 | +{%- for message in loop_messages -%} |
| 957 | + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} |
| 958 | + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} |
| 959 | + {%- endif -%} |
| 960 | + {%- if (message['role'] == 'assistant') -%} |
| 961 | + {%- set role = "model" -%} |
| 962 | + {%- else -%} |
| 963 | + {%- set role = message['role'] -%} |
| 964 | + {%- endif -%} |
| 965 | + {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else "") }} |
| 966 | + {%- if message['content'] is string -%} |
| 967 | + {{ message['content'] | trim }} |
| 968 | + {%- elif message['content'] is iterable -%} |
| 969 | + {%- for item in message['content'] -%} |
| 970 | + {%- if item['type'] == 'image' -%} |
| 971 | + {{ '<start_of_image>' }} |
| 972 | + {%- elif item['type'] == 'text' -%} |
| 973 | + {{ item['text'] | trim }} |
| 974 | + {%- endif -%} |
| 975 | + {%- endfor -%} |
| 976 | + {%- else -%} |
| 977 | + {{ raise_exception("Invalid content type") }} |
| 978 | + {%- endif -%} |
| 979 | + {{ '<end_of_turn>\n' }} |
| 980 | +{%- endfor -%} |
| 981 | +{%- if add_generation_prompt -%} |
| 982 | + {{ '<start_of_turn>model\n' }} |
| 983 | +{%- endif -%} |
| 984 | +""" |
| 985 | + |
| 986 | +# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802 |
| 987 | +gemma3_ollama = \ |
| 988 | +''' |
| 989 | +FROM {__FILE_LOCATION__} |
| 990 | +TEMPLATE """{{- range $i, $_ := .Messages }} |
| 991 | +{{- $last := eq (len (slice $.Messages $i)) 1 }} |
| 992 | +{{- if or (eq .Role "user") (eq .Role "system") }}<start_of_turn>user |
| 993 | +{{ .Content }}<end_of_turn> |
| 994 | +{{ if $last }}<start_of_turn>model |
| 995 | +{{ end }} |
| 996 | +{{- else if eq .Role "assistant" }}<start_of_turn>model |
| 997 | +{{ .Content }}{{ if not $last }}<end_of_turn> |
| 998 | +{{ end }} |
| 999 | +{{- end }} |
| 1000 | +{{- end }}""" |
| 1001 | +PARAMETER stop "<end_of_turn>" |
| 1002 | +PARAMETER stop "<eos>" |
| 1003 | +PARAMETER temperature 0.1 |
| 1004 | +PARAMETER min_p 0.0 |
| 1005 | +PARAMETER top_k 64 |
| 1006 | +PARAMETER top_p 0.95 |
| 1007 | +PARAMETER num_predict 32768 |
| 1008 | +''' |
| 1009 | + |
| 1010 | +gemma3_template_eos_token = "<end_of_turn>" |
| 1011 | +CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,) |
| 1012 | +DEFAULT_SYSTEM_MESSAGE["gemma-3"] = None # No system message in Gemma-3 |
| 1013 | + |
| 1014 | +CHAT_TEMPLATES["gemma3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,) |
| 1015 | +DEFAULT_SYSTEM_MESSAGE["gemma3"] = None # No system message in Gemma-3 |
| 1016 | +pass |
| 1017 | + |
937 | 1018 | def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
|
938 | 1019 | system_message_pattern = r"\{system_message\}"
|
939 | 1020 |
|
@@ -1033,11 +1114,12 @@ def get_chat_template(
|
1033 | 1114 |
|
1034 | 1115 | # Check fast tokenizer
|
1035 | 1116 | if not is_fast_tokenizer:
|
1036 |
| - print( |
1037 |
| - "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ |
1038 |
| - "Please log a Github issue if you want this as a new feature!\n"\ |
1039 |
| - "Your chat template will still work, but it won't add or edit tokens." |
1040 |
| - ) |
| 1117 | + pass |
| 1118 | + # print( |
| 1119 | + # "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ |
| 1120 | + # "Please log a Github issue if you want this as a new feature!\n"\ |
| 1121 | + # "Your chat template will still work, but it won't add or edit tokens." |
| 1122 | + # ) |
1041 | 1123 |
|
1042 | 1124 | elif token_mapping is not None:
|
1043 | 1125 | # token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
|
@@ -1396,82 +1478,6 @@ def __convert_to_sharegpt__(examples):
|
1396 | 1478 | pass
|
1397 | 1479 |
|
1398 | 1480 |
|
1399 |
| -def standardize_sharegpt( |
1400 |
| - dataset, |
1401 |
| - aliases_for_system = ["system",], |
1402 |
| - aliases_for_user = ["user", "human", "input",], |
1403 |
| - aliases_for_assistant = ["gpt", "assistant", "output",], |
1404 |
| -): |
1405 |
| - """ |
1406 |
| - Standardizes ShareGPT and other formats to user/assistant Hugging Face format. |
1407 |
| - |
1408 |
| - Get aliases for the system, user and assistant roles. |
1409 |
| - These shall map to "system", "user" and "assistant" respectively. |
1410 |
| - |
1411 |
| - aliases_for_system = ["system",], |
1412 |
| - aliases_for_user = ["user", "human", "input",], |
1413 |
| - aliases_for_assistant = ["gpt", "assistant", "output",], |
1414 |
| - """ |
1415 |
| - import collections |
1416 |
| - import itertools |
1417 |
| - |
1418 |
| - convos = dataset[:10]["conversations"] |
1419 |
| - uniques = collections.defaultdict(list) |
1420 |
| - for convo in convos: |
1421 |
| - for message in convo: |
1422 |
| - for key, value in message.items(): |
1423 |
| - uniques[key].append(value) |
1424 |
| - pass |
1425 |
| - |
1426 |
| - # Must be only 2 entries |
1427 |
| - assert(len(uniques.keys()) == 2) |
1428 |
| - |
1429 |
| - keys = list(uniques.keys()) |
1430 |
| - length_first = len(set(uniques[keys[0]])) |
1431 |
| - length_second = len(set(uniques[keys[1]])) |
1432 |
| - |
1433 |
| - if length_first < length_second: |
1434 |
| - # Role is assigned to the first element |
1435 |
| - role_key = keys[0] |
1436 |
| - content_key = keys[1] |
1437 |
| - else: |
1438 |
| - role_key = keys[1] |
1439 |
| - content_key = keys[0] |
1440 |
| - pass |
1441 |
| - |
1442 |
| - # Check roles are in aliases |
1443 |
| - all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant) |
1444 |
| - roles = set(uniques[role_key]) |
1445 |
| - leftover_aliases = (all_aliases | roles) - all_aliases |
1446 |
| - if len(leftover_aliases) != 0: |
1447 |
| - raise TypeError( |
1448 |
| - f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases." |
1449 |
| - ) |
1450 |
| - pass |
1451 |
| - |
1452 |
| - # Mapping for aliases |
1453 |
| - aliases_mapping = {} |
1454 |
| - for x in aliases_for_system: aliases_mapping[x] = "system" |
1455 |
| - for x in aliases_for_user: aliases_mapping[x] = "user" |
1456 |
| - for x in aliases_for_assistant: aliases_mapping[x] = "assistant" |
1457 |
| - |
1458 |
| - def _standardize_dataset(examples): |
1459 |
| - convos = examples["conversations"] |
1460 |
| - all_convos = [] |
1461 |
| - for convo in convos: |
1462 |
| - new_convo = [ |
1463 |
| - { "role" : aliases_mapping[message[role_key]], "content" : message[content_key], } |
1464 |
| - for message in convo |
1465 |
| - ] |
1466 |
| - all_convos.append(new_convo) |
1467 |
| - pass |
1468 |
| - return { "conversations" : all_convos, } |
1469 |
| - pass |
1470 |
| - |
1471 |
| - return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format") |
1472 |
| -pass |
1473 |
| - |
1474 |
| - |
1475 | 1481 | def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
|
1476 | 1482 | added_tokens_decoder = tokenizer.added_tokens_decoder.values()
|
1477 | 1483 | added_tokens_decoder = [str(x) for x in added_tokens_decoder]
|
@@ -1934,6 +1940,11 @@ def formatting_prompts_func(examples):
|
1934 | 1940 | tokenizer._ollama_modelfile = modelfile
|
1935 | 1941 | tokenizer._unsloth_input_part = input_part
|
1936 | 1942 | tokenizer._unsloth_output_part = output_part
|
| 1943 | + if hasattr(tokenizer, "tokenizer"): |
| 1944 | + tokenizer.tokenizer.chat_template = jinja_template |
| 1945 | + tokenizer.tokenizer._ollama_modelfile = modelfile |
| 1946 | + tokenizer.tokenizer._unsloth_input_part = input_part |
| 1947 | + tokenizer.tokenizer._unsloth_output_part = output_part |
1937 | 1948 |
|
1938 | 1949 | return dataset.map(formatting_prompts_func, batched = True,)
|
1939 | 1950 | pass
|
|
0 commit comments