Skip to content

Commit fcf601f

Browse files
author
katsu560
committed
revised script from ggml
1 parent 8404d20 commit fcf601f

File tree

1 file changed

+58
-44
lines changed

1 file changed

+58
-44
lines changed

gguf-py/scripts/gguf-addfile.py

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
import sys
1010
from pathlib import Path
1111
from typing import Any
12+
#from typing import Any, Literal, NamedTuple, TypeVar, Union
1213

1314
import numpy as np
15+
import numpy.typing as npt
1416

1517
# Necessary to load the local gguf package
1618
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
1719
sys.path.insert(0, str(Path(__file__).parent.parent))
1820

19-
from gguf import GGUFReader, GGUFWriter, ReaderField, GGUFEndian, GGUFValueType, Keys # noqa: E402
21+
from gguf import GGUFReader, GGUFWriter, ReaderField, GGMLQuantizationType, GGUFEndian, GGUFValueType, Keys # noqa: E402
2022

2123
logger = logging.getLogger("gguf-addfile")
2224

@@ -54,17 +56,11 @@ def decode_field(field: ReaderField) -> Any:
5456
sub_type = field.types[-1]
5557

5658
if sub_type == GGUFValueType.STRING:
57-
if not field.name[0] == Keys.General.FILE_MARK:
58-
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
59-
else:
60-
return [bytes(field.parts[idx]) for idx in field.data]
59+
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
6160
else:
6261
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
6362
if main_type == GGUFValueType.STRING:
64-
if not field.name[0] == Keys.General.FILE_MARK:
65-
return str(bytes(field.parts[-1]), encoding='utf8')
66-
else:
67-
return bytes(field.parts[-1])
63+
return str(bytes(field.parts[-1]), encoding='utf8')
6864
else:
6965
return field.parts[-1][0]
7066

@@ -77,54 +73,64 @@ def get_field_data(reader: GGUFReader, key: str) -> Any:
7773
return decode_field(field)
7874

7975

80-
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str]) -> None:
76+
def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, filename: str[Any]) -> None:
77+
logger.debug(f'copy_with_filename: {filename}') #debug
78+
val = filename
8179
for field in reader.fields.values():
8280
# Suppress virtual fields and fields written by GGUFWriter
8381
if field.name == Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
8482
logger.debug(f'Suppressing {field.name}')
8583
continue
8684

87-
# Skip old chat templates if we have new ones
88-
if field.name.startswith(Keys.Tokenizer.CHAT_TEMPLATE) and Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
89-
logger.debug(f'Skipping {field.name}')
85+
# Copy existed fields except 'embedded_files'
86+
if not field.name == Keys.EMBEDDED_FILES:
87+
cur_val = decode_field(field)
88+
writer.add_key(field.name)
89+
writer.add_val(cur_val, field.types[0])
90+
logger.debug(f'Copying {field.name}')
9091
continue
9192

92-
old_val = decode_field(field)
93-
val = new_metadata.get(field.name, old_val)
94-
95-
if field.name in new_metadata:
96-
logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
97-
del new_metadata[field.name]
98-
elif val is not None:
99-
logger.debug(f'Copying {field.name}')
93+
# Update embedded_files
94+
val = decode_field(field)
95+
for path in filename:
96+
logger.debug(f'Adding {field.name}: {path}')
97+
val.append(path)
10098

101-
if val is not None:
102-
writer.add_key(field.name)
103-
writer.add_val(val, field.types[0])
104-
105-
if Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
106-
logger.debug('Adding chat template(s)')
107-
writer.add_chat_template(new_metadata[Keys.Tokenizer.CHAT_TEMPLATE])
108-
del new_metadata[Keys.Tokenizer.CHAT_TEMPLATE]
109-
110-
for key, name in new_metadata.items():
111-
logger.debug(f'Adding {key}: {name}')
112-
with open(name, "rb") as f:
113-
val = f.read()
114-
writer.add_object(key, val)
99+
# Add filenames to kv
100+
logger.info(f'* Modifying {Keys.EMBEDDED_FILES} to {val}')
101+
writer.add_array(Keys.EMBEDDED_FILES, val)
115102

116103
for tensor in reader.tensors:
117104
# Dimensions are written in reverse order, so flip them first
118105
shape = np.flipud(tensor.shape)
119106
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
120107

108+
# Add file info as tensor_info
109+
for path in filename:
110+
logger.debug(f'Adding tensor_info {path}')
111+
with open(path, "rb") as f:
112+
data = f.read()
113+
data_len = len(data)
114+
dims = [data_len]
115+
raw_dtype = GGMLQuantizationType.I8
116+
writer.add_tensor_info(path, dims, np.float16, data_len, raw_dtype)
117+
121118
writer.write_header_to_file()
122119
writer.write_kv_data_to_file()
123120
writer.write_ti_data_to_file()
124121

125122
for tensor in reader.tensors:
126123
writer.write_tensor_data(tensor.data)
127124

125+
# Write file body as tensor data
126+
for path in filename:
127+
logger.debug(f'Adding tensor data {path}')
128+
with open(path, "rb") as f:
129+
data = f.read()
130+
data_len = len(data)
131+
# write data with padding
132+
writer.write_data(data)
133+
128134
writer.close()
129135

130136

@@ -133,6 +139,7 @@ def main() -> None:
133139
parser.add_argument("input", type=str, help="GGUF format model input filename")
134140
parser.add_argument("output", type=str, help="GGUF format model output filename")
135141
parser.add_argument("addfiles", type=str, nargs='+', help="add filenames ...")
142+
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
136143
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
137144
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
138145
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
@@ -142,6 +149,15 @@ def main() -> None:
142149
arch = get_field_data(reader, Keys.General.ARCHITECTURE)
143150
endianess = get_byteorder(reader)
144151

152+
if os.path.isfile(args.output) and not args.force:
153+
logger.warning('*** Warning *** Warning *** Warning **')
154+
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
155+
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
156+
response = input('YES, I am sure> ')
157+
if response != 'YES':
158+
logger.info("You didn't enter YES. Okay then, see ya!")
159+
sys.exit(0)
160+
145161
logger.info(f'* Writing: {args.output}')
146162
writer = GGUFWriter(args.output, arch=arch, endianess=endianess)
147163

@@ -150,15 +166,13 @@ def main() -> None:
150166
logger.debug(f'Setting custom alignment: {alignment}')
151167
writer.data_alignment = alignment
152168

153-
logger.info(f'* Adding: {args.addfiles}')
154-
new_metadata = {}
155-
for path in args.addfiles:
156-
# add FILE_MARK to key
157-
key = Keys.General.FILE_MARK + path
158-
new_metadata[key] = path
159-
logger.info(f'* Adding: {key} = {path}')
160-
copy_with_new_metadata(reader, writer, new_metadata)
161-
169+
if args.addfiles is not None:
170+
filename = []
171+
for path in args.addfiles:
172+
filename.append(path)
173+
logger.info(f'* Adding: {path}')
174+
copy_with_filename(reader, writer, filename)
175+
162176

163177
if __name__ == '__main__':
164178
main()

0 commit comments

Comments
 (0)