Skip to content

Commit 94aa56f

Browse files
minor : improve C++ and Python style (ggml-org#768)
* use some STL functions * use self.field than setattr, use pathlib.Path * recover some format * const some iter * Keep the original * 2 space
1 parent 4d89ee2 commit 94aa56f

File tree

4 files changed

+103
-108
lines changed

4 files changed

+103
-108
lines changed

models/convert-h5-to-ggml.py

+17-21
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import code
2424
import torch
2525
import numpy as np
26+
from pathlib import Path
2627

2728
from transformers import WhisperForConditionalGeneration
2829

@@ -75,16 +76,13 @@ def bytes_to_unicode():
7576
print("Usage: convert-h5-to-ggml.py dir_model path-to-whisper-repo dir-output [use-f32]\n")
7677
sys.exit(1)
7778

78-
dir_model = sys.argv[1]
79-
dir_whisper = sys.argv[2]
80-
dir_out = sys.argv[3]
79+
dir_model = Path(sys.argv[1])
80+
dir_whisper = Path(sys.argv[2])
81+
dir_out = Path(sys.argv[3])
8182

82-
with open(dir_model + "/vocab.json", "r", encoding="utf8") as f:
83-
encoder = json.load(f)
84-
with open(dir_model + "/added_tokens.json", "r", encoding="utf8") as f:
85-
encoder_added = json.load(f)
86-
with open(dir_model + "/config.json", "r", encoding="utf8") as f:
87-
hparams = json.load(f)
83+
encoder = json.load((dir_model / "vocab.json").open("r", encoding="utf8"))
84+
encoder_added = json.load((dir_model / "added_tokens.json").open( "r", encoding="utf8"))
85+
hparams = json.load((dir_model / "config.json").open("r", encoding="utf8") )
8886

8987
model = WhisperForConditionalGeneration.from_pretrained(dir_model)
9088

@@ -96,16 +94,15 @@ def bytes_to_unicode():
9694

9795
dir_tokenizer = dir_model
9896

99-
fname_out = dir_out + "/ggml-model.bin"
97+
fname_out = dir_out / "ggml-model.bin"
10098

101-
with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f:
102-
tokens = json.load(f)
99+
tokens = json.load(open(dir_tokenizer / "vocab.json", "r", encoding="utf8"))
103100

104101
# use 16-bit or 32-bit floats
105102
use_f16 = True
106103
if len(sys.argv) > 4:
107104
use_f16 = False
108-
fname_out = dir_out + "/ggml-model-f32.bin"
105+
fname_out = dir_out / "ggml-model-f32.bin"
109106

110107
fout = open(fname_out, "wb")
111108

@@ -171,18 +168,17 @@ def bytes_to_unicode():
171168
data = data.astype(np.float16)
172169

173170
# reshape conv bias from [n] to [n, 1]
174-
if name == "encoder.conv1.bias" or \
175-
name == "encoder.conv2.bias":
171+
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
176172
data = data.reshape(data.shape[0], 1)
177-
print(" Reshaped variable: " + name + " to shape: ", data.shape)
173+
print(" Reshaped variable: " , name , " to shape: ", data.shape)
178174

179175
n_dims = len(data.shape)
180176
print(name, n_dims, data.shape)
181177

182178
# looks like the whisper models are in f16 by default
183179
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
184180
# ftype == 0 -> float32, ftype == 1 -> float16
185-
ftype = 1;
181+
ftype = 1
186182
if use_f16:
187183
if n_dims < 2 or \
188184
name == "encoder.conv1.bias" or \
@@ -197,16 +193,16 @@ def bytes_to_unicode():
197193
ftype = 0
198194

199195
# header
200-
str = name.encode('utf-8')
201-
fout.write(struct.pack("iii", n_dims, len(str), ftype))
196+
str_ = name.encode('utf-8')
197+
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
202198
for i in range(n_dims):
203199
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
204-
fout.write(str);
200+
fout.write(str_)
205201

206202
# data
207203
data.tofile(fout)
208204

209205
fout.close()
210206

211-
print("Done. Output file: " + fname_out)
207+
print("Done. Output file: " , fname_out)
212208
print("")

models/convert-pt-to-ggml.py

+20-21
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import torch
4141
import numpy as np
4242
import base64
43-
43+
from pathlib import Path
4444
#from transformers import GPTJForCausalLM
4545
#from transformers import GPT2TokenizerFast
4646

@@ -194,17 +194,17 @@ def bytes_to_unicode():
194194
print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
195195
sys.exit(1)
196196

197-
fname_inp = sys.argv[1]
198-
dir_whisper = sys.argv[2]
199-
dir_out = sys.argv[3]
197+
fname_inp = Path(sys.argv[1])
198+
dir_whisper = Path(sys.argv[2])
199+
dir_out = Path(sys.argv[3])
200200

201201
# try to load PyTorch binary data
202202
try:
203203
model_bytes = open(fname_inp, "rb").read()
204204
with io.BytesIO(model_bytes) as fp:
205205
checkpoint = torch.load(fp, map_location="cpu")
206-
except:
207-
print("Error: failed to load PyTorch model file: %s" % fname_inp)
206+
except Exception:
207+
print("Error: failed to load PyTorch model file:" , fname_inp)
208208
sys.exit(1)
209209

210210
hparams = checkpoint["dims"]
@@ -218,17 +218,17 @@ def bytes_to_unicode():
218218

219219
# load mel filters
220220
n_mels = hparams["n_mels"]
221-
with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
221+
with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
222222
filters = torch.from_numpy(f[f"mel_{n_mels}"])
223223
#print (filters)
224224

225225
#code.interact(local=locals())
226226

227227
multilingual = hparams["n_vocab"] == 51865
228-
tokenizer = os.path.join(dir_whisper, "whisper/assets", multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
228+
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
229229

230230
# output in the same directory as the model
231-
fname_out = dir_out + "/ggml-model.bin"
231+
fname_out = dir_out / "ggml-model.bin"
232232

233233
with open(tokenizer, "rb") as f:
234234
contents = f.read()
@@ -238,9 +238,9 @@ def bytes_to_unicode():
238238
use_f16 = True
239239
if len(sys.argv) > 4:
240240
use_f16 = False
241-
fname_out = dir_out + "/ggml-model-f32.bin"
241+
fname_out = dir_out / "ggml-model-f32.bin"
242242

243-
fout = open(fname_out, "wb")
243+
fout = fname_out.open("wb")
244244

245245
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
246246
fout.write(struct.pack("i", hparams["n_vocab"]))
@@ -273,20 +273,19 @@ def bytes_to_unicode():
273273

274274
for name in list_vars.keys():
275275
data = list_vars[name].squeeze().numpy()
276-
print("Processing variable: " + name + " with shape: ", data.shape)
276+
print("Processing variable: " , name , " with shape: ", data.shape)
277277

278278
# reshape conv bias from [n] to [n, 1]
279-
if name == "encoder.conv1.bias" or \
280-
name == "encoder.conv2.bias":
279+
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
281280
data = data.reshape(data.shape[0], 1)
282-
print(" Reshaped variable: " + name + " to shape: ", data.shape)
281+
print(f" Reshaped variable: {name} to shape: ", data.shape)
283282

284-
n_dims = len(data.shape);
283+
n_dims = len(data.shape)
285284

286285
# looks like the whisper models are in f16 by default
287286
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
288287
# ftype == 0 -> float32, ftype == 1 -> float16
289-
ftype = 1;
288+
ftype = 1
290289
if use_f16:
291290
if n_dims < 2 or \
292291
name == "encoder.conv1.bias" or \
@@ -307,16 +306,16 @@ def bytes_to_unicode():
307306
# data = data.transpose()
308307

309308
# header
310-
str = name.encode('utf-8')
311-
fout.write(struct.pack("iii", n_dims, len(str), ftype))
309+
str_ = name.encode('utf-8')
310+
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
312311
for i in range(n_dims):
313312
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
314-
fout.write(str);
313+
fout.write(str_)
315314

316315
# data
317316
data.tofile(fout)
318317

319318
fout.close()
320319

321-
print("Done. Output file: " + fname_out)
320+
print("Done. Output file: " , fname_out)
322321
print("")

models/convert-whisper-to-coreml.py

+22-25
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
2020
"""
2121
for k in state_dict:
2222
is_attention = all(substr in k for substr in ['attn', '.weight'])
23-
is_mlp = any([k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']])
23+
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
2424

2525
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
2626
state_dict[k] = state_dict[k][:, :, None, None]
@@ -42,11 +42,10 @@ def __init__(self, *args, **kwargs):
4242
class MultiHeadAttentionANE(MultiHeadAttention):
4343
def __init__(self, n_state: int, n_head: int):
4444
super().__init__(n_state, n_head)
45-
46-
setattr(self, 'query', nn.Conv2d(n_state, n_state, kernel_size=1))
47-
setattr(self, 'key', nn.Conv2d(n_state, n_state, kernel_size=1, bias=False))
48-
setattr(self, 'value', nn.Conv2d(n_state, n_state, kernel_size=1))
49-
setattr(self, 'out', nn.Conv2d(n_state, n_state, kernel_size=1))
45+
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
46+
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
47+
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
48+
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
5049

5150
def forward(self,
5251
x: Tensor,
@@ -104,30 +103,28 @@ def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tens
104103
class ResidualAttentionBlockANE(ResidualAttentionBlock):
105104
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
106105
super().__init__(n_state, n_head, cross_attention)
107-
108-
setattr(self, 'attn', MultiHeadAttentionANE(n_state, n_head))
109-
setattr(self, 'attn_ln', LayerNormANE(n_state))
110-
111-
setattr(self, 'cross_attn', MultiHeadAttentionANE(n_state, n_head) if cross_attention else None)
112-
setattr(self, 'cross_attn_ln', LayerNormANE(n_state) if cross_attention else None)
106+
self.attn = MultiHeadAttentionANE(n_state, n_head)
107+
self.attn_ln = LayerNormANE(n_state)
108+
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
109+
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
113110

114111
n_mlp = n_state * 4
115-
setattr(self, 'mlp', nn.Sequential(
112+
self.mlp = nn.Sequential(
116113
nn.Conv2d(n_state, n_mlp, kernel_size=1),
117114
nn.GELU(),
118115
nn.Conv2d(n_mlp, n_state, kernel_size=1)
119-
))
120-
setattr(self, 'mlp_ln', LayerNormANE(n_state))
116+
)
117+
self.mlp_ln = LayerNormANE(n_state)
121118

122119

123120
class AudioEncoderANE(AudioEncoder):
124121
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
125122
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
126123

127-
setattr(self, 'blocks', nn.ModuleList(
124+
self.blocks = nn.ModuleList(
128125
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
129-
))
130-
setattr(self, 'ln_post', LayerNormANE(n_state))
126+
)
127+
self.ln_post = LayerNormANE(n_state)
131128

132129
def forward(self, x: Tensor):
133130
"""
@@ -168,10 +165,10 @@ class TextDecoderANE(TextDecoder):
168165
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
169166
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
170167

171-
setattr(self, 'blocks', nn.ModuleList(
168+
self.blocks= nn.ModuleList(
172169
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
173-
))
174-
setattr(self, 'ln', LayerNormANE(n_state))
170+
)
171+
self.ln= LayerNormANE(n_state)
175172

176173
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
177174
"""
@@ -213,20 +210,20 @@ class WhisperANE(Whisper):
213210
def __init__(self, dims: ModelDimensions):
214211
super().__init__(dims)
215212

216-
setattr(self, 'encoder', AudioEncoderANE(
213+
self.encoder = AudioEncoderANE(
217214
self.dims.n_mels,
218215
self.dims.n_audio_ctx,
219216
self.dims.n_audio_state,
220217
self.dims.n_audio_head,
221218
self.dims.n_audio_layer,
222-
))
223-
setattr(self, 'decoder', TextDecoderANE(
219+
)
220+
self.decoder = TextDecoderANE(
224221
self.dims.n_vocab,
225222
self.dims.n_text_ctx,
226223
self.dims.n_text_state,
227224
self.dims.n_text_head,
228225
self.dims.n_text_layer,
229-
))
226+
)
230227

231228
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
232229

0 commit comments

Comments
 (0)