forked from abetlen/llama-cpp-python
-
Notifications
You must be signed in to change notification settings - Fork 50
Expand file tree
/
Copy pathllama_embedding.py
More file actions
496 lines (403 loc) · 18.9 KB
/
llama_embedding.py
File metadata and controls
496 lines (403 loc) · 18.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
import numpy as np
from typing import Union, List, Optional, Dict, Any, Tuple
import llama_cpp.llama_cpp as llama_cpp
from .llama_types import Embedding
from .llama import Llama
# Pooling types from .llama_cpp
from .llama_cpp import (
LLAMA_POOLING_TYPE_UNSPECIFIED,
LLAMA_POOLING_TYPE_NONE,
LLAMA_POOLING_TYPE_MEAN,
LLAMA_POOLING_TYPE_CLS,
LLAMA_POOLING_TYPE_LAST,
LLAMA_POOLING_TYPE_RANK, # Specifically for Reranking models
)
from .mtmd import MediaChunk, mtmd_tokenize, mtmd_prefill
from ._utils import suppress_stdout_stderr
# Normalization modes for embedding vectors
# See: https://github.com/ggml-org/llama.cpp/tree/master/examples/embedding#--embd-normalize-integer
NORM_MODE_NONE = -1
NORM_MODE_MAX_INT16 = 0
NORM_MODE_TAXICAB = 1
NORM_MODE_EUCLIDEAN = 2
NORM_MODE_PNORM = 6
# TODO(JamePeng): Needs more extensive testing with various embedding and reranking models.
class LlamaEmbedding(Llama):
"""
A specialized class for high-performance Text Embedding and Reranking.
Inherits from the base Llama class but is optimized for vector operations.
Key Features:
1. Auto-configuration: Automatically sets embeddings=True.
2. Streaming Batch: Handles massive datasets without OOM (Out Of Memory).
3. Native Reranking Support: Specifically handles `LLAMA_POOLING_TYPE_RANK` models (like BGE-Reranker, Qwen3-Reranker). /
It correctly identifies classification heads to output scalar relevance scores instead of high-dimensional vectors.
4. Advanced Normalization: Implements MaxInt16, Taxicab (L1), and Euclidean (L2) normalization strategies /
using NumPy for optimal performance and compatibility with various vector databases.
"""
def __init__(
self,
model_path: str,
n_ctx: int = 0,
n_batch: int = 512,
n_ubatch: int = 512,
pooling_type: int = LLAMA_POOLING_TYPE_UNSPECIFIED,
n_gpu_layers: int = 0,
verbose: bool = True,
**kwargs):
"""
Initialize the embedding model with enforced configuration.
Args:
model_path: Path to the GGUF model file.
n_ctx: Text context, 0 = from model
n_batch: Prompt processing maximum batch size
n_ubatch: Physical batch size
pooling_type: The pooling strategy used by the model.
- Use `LLAMA_POOLING_TYPE_RANK` (4) for Reranker models.
- Use `LLAMA_POOLING_TYPE_UNSPECIFIED` (-1) to let the model metadata decide (for standard embeddings).
n_gpu_layers: Number of model layers to offload to GPU.
- Set to 0 for CPU only.
- Set to -1 for all layers (recommended for best performance).
**kwargs: Additional arguments passed to the Llama base class (e.g., n_batch, n_ctx, verbose).
"""
kwargs["embeddings"] = True
kwargs["n_gpu_layers"] = n_gpu_layers
kwargs["n_ctx"] = n_ctx
kwargs["n_batch"] = n_batch
kwargs["n_ubatch"] = n_ubatch
kwargs["verbose"] = verbose
# Enable Unified KV Cache (Crucial for Batching)
# This allows us to assign arbitrary seq_ids in a batch, enabling the parallel /
# encoding of multiple unrelated documents without "invalid seq_id" errors.
kwargs["kv_unified"] = True
# Set pooling type
kwargs["pooling_type"] = pooling_type
super().__init__(model_path=model_path, **kwargs)
if self.verbose:
print(f"LlamaEmbedding initialized with pooling_type: {self.pooling_type()}")
def _normalize_vector(self, vector: List[float], mode: int) -> List[float]:
"""
Apply mathematical normalization to a vector.
Uses numpy for performance.
"""
if mode == NORM_MODE_NONE: return vector
arr = np.array(vector, dtype=np.float32)
# Mode 0: Max Absolute Int16 -> 32760 * x_i / max|x_i|
if mode == NORM_MODE_MAX_INT16:
max_abs = np.max(np.abs(arr))
if max_abs == 0: return vector
return ((arr / max_abs) * 32760.0).tolist()
# Mode 1: Taxicab (L1 Norm) -> x_i / sum|x_i|
elif mode == NORM_MODE_TAXICAB:
norm = np.sum(np.abs(arr))
if norm == 0: return vector
return (arr / norm).tolist()
# Mode 2: Euclidean (L2 Norm) -> x_i / sqrt(sum x_i^2)
elif mode == NORM_MODE_EUCLIDEAN:
norm = np.linalg.norm(arr)
if norm == 0: return vector
return (arr / norm).tolist()
# Mode > 2: p-norm
elif mode > 2:
norm = np.sum(np.abs(arr) ** mode) ** (1.0 / mode)
if norm == 0: return vector
return (arr / norm).tolist()
return vector
def embed(
self,
input: Union[str, List[str], List[List[int]]],
normalize: int = NORM_MODE_EUCLIDEAN,
truncate: bool = True,
separator: Optional[str] = None,
return_count: bool = False,
) -> Union[List[float], List[List[float]], Tuple[Any, int]]:
ctx = self._ctx.ctx
n_batch = self.n_batch
n_ctx = self._n_ctx
n_ubatch = self.context_params.n_ubatch
# Determine if it is in Rerank mode
try:
pooling_type = self.pooling_type()
except AttributeError:
pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED
is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK)
is_none = (pooling_type == LLAMA_POOLING_TYPE_NONE) # Token-level embedding
logits_all = True if is_none else False
# Determine the output dimension
if is_rank:
out_dim = llama_cpp.llama_model_n_cls_out(self._model.model)
else:
out_dim = self.n_embd()
if self.verbose:
type_str = "TOKEN (None)" if is_none else ("RANK (Score)" if is_rank else "SEQ (Vector)")
print(f"LlamaEmbedding Debug: Mode={type_str} | Pooling={pooling_type} | Dim={out_dim}")
# Preprocess Input
inputs: List[Union[str, List[int]]] = []
is_single = False
if isinstance(input, str):
if separator:
inputs = input.split(separator)
is_single = False
else:
inputs = [input]
is_single = True
else:
inputs = input
is_single = False
# Reset Context and Batch
if self.verbose:
llama_cpp.llama_perf_context_reset(ctx)
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
# Initialize State Variables
results: List[Any] = []
batch_seq_lens: List[int] = []
total_tokens_processed = 0
# --- Decode Current Batch ---
def _decode_batch():
nonlocal batch_seq_lens
if not batch_seq_lens: return
self._ctx.decode(self._batch)
# Extract Embeddings
# Branch A: LLAMA_POOLING_TYPE_NONE (Token Level)
if is_none:
curr_token_idx = 0
for seq_len in batch_seq_lens:
doc_tokens_embd = []
for _ in range(seq_len):
# Get the vector of the i-th token
ptr = llama_cpp.llama_get_embeddings_ith(ctx, curr_token_idx)
if ptr is None:
# Fallback: append zero vector or skip (here we zero-pad to keep shape)
doc_tokens_embd.append([0.0] * out_dim)
else:
data = ptr[:out_dim]
# Normalization
data = self._normalize_vector(data, normalize)
doc_tokens_embd.append(data)
curr_token_idx += 1
results.append(doc_tokens_embd)
# Branch B: Sequence Level (Mean, Cls, Rank, Unspecified)
else:
for i in range(len(batch_seq_lens)):
# Obtain the vector of the i-th sequence.
ptr = llama_cpp.llama_get_embeddings_seq(ctx, i)
data = ptr[:out_dim]
if not is_rank:
data = self._normalize_vector(data, normalize)
if is_rank and len(data) == 1:
results.append(data[0])
else:
results.append(data)
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
batch_seq_lens = []
# Main Streaming Loop
idx_in_batch = 0
for item in inputs:
# Tokenize
tokens: List[int] = []
if isinstance(item, list) and (not item or isinstance(item[0], int)):
tokens = item
elif isinstance(item, str):
tokens = self.tokenize(item.encode("utf-8"))
else:
raise ValueError("Input item must be str or List[int]")
# Truncate
if truncate and len(tokens) > n_ctx:
tokens = tokens[:n_ctx]
n_tokens = len(tokens)
total_tokens_processed += n_tokens
if n_tokens == 0:
results.append(0.0 if is_rank else [])
continue
# Check Batch Capacity
if (self._batch.n_tokens() + n_tokens > n_batch) or (idx_in_batch >= n_ubatch):
_decode_batch()
idx_in_batch = 0
# Add to Batch
self._batch.add_sequence(tokens, idx_in_batch, logits_all=logits_all)
batch_seq_lens.append(n_tokens)
idx_in_batch += 1
# Process Remaining Items
_decode_batch()
if self.verbose:
llama_cpp.llama_perf_context_print(ctx)
final_result = results[0] if is_single else results
if return_count:
return final_result, total_tokens_processed
return final_result
def rank(self, query: str, documents: List[str]) -> List[float]:
"""
Calculate relevance scores for a list of documents against a query using a Reranking model.
This method follows the implementation logic of the latest llama.cpp embedding example,
supporting both specialized chat templates and manual sequence construction.
Link: https://github.com/ggml-org/llama.cpp/blob/master/examples/embedding/embedding.cpp
Args:
query: The search query string.
documents: A list of candidate document strings to be scored.
Returns:
A list of float scores, where higher values indicate greater relevance.
"""
# Ensure the model is configured for Reranking (Cross-Encoding)
if self.pooling_type() != LLAMA_POOLING_TYPE_RANK:
raise ValueError(f"Model pooling_type is {self.pooling_type()}, but LLAMA_POOLING_TYPE_RANK is required.")
# 1. Attempt to retrieve the built-in 'rerank' chat template from model metadata.
# Modern GGUF models often include a template for formatting query/document pairs.
rerank_template = llama_cpp.llama_model_chat_template(self._model.model, b"rerank")
if rerank_template:
rerank_template = rerank_template.decode("utf-8")
batch_inputs: List[List[int]] = []
# 2. Case A: Using Model-Specific Template
# If a template exists, we perform dynamic string replacement for {query} and {document}.
if rerank_template:
for doc in documents:
final_prompt = rerank_template.replace("{query}", query).replace("{document}", doc)
# Tokenize the full formatted prompt. Template usually dictates BOS/EOS placement.
tokens = self.tokenize(final_prompt.encode("utf-8"), add_bos=False, special=True)
batch_inputs.append(tokens)
# 3. Case B: Manual Sequence Construction (Fallback)
# If no template is found, construct the standard [BOS] Query [SEP] Doc [EOS] sequence.
else:
# Determine separator and end-of-sequence tokens
sep_id = self.token_sep() if self.token_sep() != -1 else self.token_eos()
eos_id = self.token_eos()
# Pre-tokenize the query with BOS (Beginning of Sequence)
q_tokens = self.tokenize(query.encode("utf-8"), add_bos=True, special=True)
# Remove the automatically added EOS token from the query to allow concatenation.
if q_tokens and q_tokens[-1] == eos_id:
q_tokens.pop()
for doc in documents:
# Tokenize document without an additional BOS token
d_tokens = self.tokenize(doc.encode("utf-8"), add_bos=False, special=True)
# Combine: [BOS] Query [SEP] Document
full_seq = q_tokens + [sep_id] + d_tokens
# Ensure the sequence is properly terminated with an EOS token for inference.
if not full_seq or full_seq[-1] != eos_id:
full_seq.append(eos_id)
batch_inputs.append(full_seq)
# Execute embedding inference. Rerankers output raw logits/scores, so we skip normalization.
raw_results = self.embed(batch_inputs, normalize=NORM_MODE_NONE)
results_list = [raw_results] if (len(batch_inputs) == 1 and isinstance(raw_results[0], float)) else raw_results
# 5. Output Post-Processing
# For generative rerankers like Qwen3-Reranker, output dim is 2 ([yes_logit, no_logit]).
final_scores = []
# Ensure we iterate through results (embed returns List[Any] for batch inputs)
for res in results_list:
if isinstance(res, (list, np.ndarray)) and len(res) == 2:
final_scores.append(float(res[0])) # Standard scalar score in list form yes_logit
else:
final_scores.append(float(res)) # Raw scalar score
return final_scores
def create_embedding(
self,
input: Union[str, List[str]],
model: Optional[str] = None,
normalize: int = NORM_MODE_EUCLIDEAN,
output_format: str = "json"
) -> Union[Dict[str, Any], List[float], List[List[float]]]:
"""
High-level API compatible with OpenAI format.
Args:
output_format:
- 'json': OpenAI style dict (Default)
- 'json+': OpenAI style dict + cosineSimilarity matrix
- 'array': Raw python list (List[float] or List[List[float]])
"""
model_name = model if model is not None else self.model_path
# Normalize input to list
inputs_list = [input] if isinstance(input, str) else input
# Generate Embeddings(and get token count)
embeddings, token_count = self.embed(
inputs_list,
normalize=normalize,
return_count=True
)
if output_format == "array":
return embeddings
# Structure the OpenAI-style response ('json' or 'json+')
# Ensure embeddings is a list for iteration
# (If input was single string, embeddings is List[float], wrap it for the loop)
iter_embeddings = [embeddings] if isinstance(embeddings[0], float) else embeddings
data: List[Embedding] = [
{
"object": "embedding",
"embedding": emb,
"index": idx,
}
for idx, emb in enumerate(iter_embeddings)
]
response = {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": token_count, # Input consumption
"completion_tokens": 0, # The Embedding task does not generate text, so the value is 0.
"total_tokens": token_count, # Total consumption = Input consumption + Output
}
}
# Calculate Cosine Similarity Matrix (Optimized via Numpy)
# Only if output_format is 'json+' and we have vectors
if output_format == "json+" and len(embeddings) > 1 and isinstance(embeddings[0], list):
try:
# Assuming embeddings are already L2 normalized if normalize=2
mat = np.array(embeddings)
# Safety check: Force normalize if not already done, to ensure Cosine (not Dot Product)
if normalize != NORM_MODE_EUCLIDEAN:
norm = np.linalg.norm(mat, axis=1, keepdims=True)
# Avoid division by zero
norm[norm == 0] = 1e-10
mat = mat / norm
# Matrix multiplication: A @ A.T
sim_matrix = np.dot(mat, mat.T)
response["cosineSimilarity"] = sim_matrix.tolist()
except Exception as e:
if self.verbose:
print(f"Warning: Failed to calculate similarity matrix: {e}")
return response
def embed_multimodal(
self,
prompt: str,
files: List[bytes | str] = [],
normalize: int = NORM_MODE_EUCLIDEAN,
return_count: bool = False,
) -> Union[List[float], List[List[float]], Tuple[Any, int]]:
ctx = self._ctx.ctx
mctx = self.mtmd_context.ctx
# Determine if it is in Rerank mode
try:
pooling_type = self.pooling_type()
except AttributeError:
pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED
is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK)
is_none = (pooling_type == LLAMA_POOLING_TYPE_NONE) # Token-level embedding
out_dim = self.n_embd()
if self.verbose:
type_str = "TOKEN (None)" if is_none else ("RANK (Score)" if is_rank else "SEQ (Vector)")
print(f"LlamaEmbedding Debug: Mode={type_str} | Pooling={pooling_type} | Dim={out_dim}")
# Reset Context and Batch
if self.verbose:
llama_cpp.llama_perf_context_reset(ctx)
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
# Initialize State Variables
result: Any = None
with suppress_stdout_stderr(disable=self.verbose):
tokens: MultimodalTokenList = mtmd_tokenize(mctx, prompt, files)
n_tokens = len(tokens)
if n_tokens == 0:
result = []
else:
n_past = mtmd_prefill(self._ctx, mctx, self._batch, tokens)
# Extract Embeddings
ptr = llama_cpp.llama_get_embeddings_ith(ctx, self._batch.n_tokens() - 1)
data = ptr[:out_dim]
data = self._normalize_vector(data, normalize)
result = data
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
if self.verbose:
llama_cpp.llama_perf_context_print(ctx)
if return_count:
return result, n_tokens
return result