Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,12 @@ def reset(self):
if self.batch is not None:
self.batch.n_tokens = 0

def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_all: bool):
def set_batch(self,
batch: Sequence[int],
n_past: llama_cpp.llama_pos,
logits_all: bool,
logits_last: bool = True
):
if len(batch) > self.n_tokens_capacity:
raise IndexError(f"Input batch size {len(batch)} exceeds capacity {self.n_tokens_capacity}")

Expand All @@ -684,7 +689,7 @@ def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_al
self.batch.seq_id[i][0] = 0
self.batch.n_seq_id[i] = 1
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True
self.batch.logits[n_tokens - 1] = logits_last

def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
n_tokens = len(batch)
Expand Down
30 changes: 30 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from ._logger import set_verbose
from ._utils import suppress_stdout_stderr

from .mtmd_cpp import mtmd_context_params_default, mtmd_init_from_file
from .mtmd import MultiModalContext


class Llama:
"""High-level Python wrapper for a llama.cpp model."""
Expand Down Expand Up @@ -130,6 +133,10 @@ def __init__(
# Misc
spm_infill: bool = False,
verbose: bool = True,
mmproj_path: str = None,
mmproj_use_gpu: Optional[bool] = None,
image_min_tokens: int = -1,
image_max_tokens: int = -1,
# Extra Params
**kwargs, # type: ignore
):
Expand Down Expand Up @@ -426,6 +433,29 @@ def __init__(
)
)

if mmproj_path != None:
mparams = mtmd_context_params_default();
mparams.use_gpu = mmproj_use_gpu if mmproj_use_gpu != None else n_gpu_layers == -1
mparams.print_timings = verbose
mparams.n_threads = self.n_threads
mparams.flash_attn_type = self.context_params.flash_attn_type
mparams.warmup = True
if image_min_tokens > 0:
mparams.image_min_tokens = image_min_tokens
if image_max_tokens > 0:
mparams.image_max_tokens = image_max_tokens

with suppress_stdout_stderr(disable=verbose):
mctx = mtmd_init_from_file(mmproj_path.encode("utf-8"), self._model.model, mparams)
if mctx is None:
raise RuntimeError(f"failed to load multimodal projection '{mmproj_path}'")

self.mtmd_context = self._stack.enter_context(
contextlib.closing(
MultiModalContext(mctx)
)
)

# Check for Encoder-Decoder architecture
self._has_encoder = self._model.has_encoder()
self._has_decoder = self._model.has_decoder()
Expand Down
3 changes: 3 additions & 0 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ class ChatFormatterResponse:
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
added_special: bool = False

medias: List[Union[str, bytes, bytearray]] = None
media_types: List[str] = None


class ChatFormatter(Protocol):
"""Base Protocol for a chat formatter. A chat formatter is a function that
Expand Down
71 changes: 69 additions & 2 deletions llama_cpp/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
LLAMA_POOLING_TYPE_LAST,
LLAMA_POOLING_TYPE_RANK, # Specifically for Reranking models
)
from .mtmd import MediaChunk, mtmd_tokenize, mtmd_prefill
from ._utils import suppress_stdout_stderr

# Normalization modes for embedding vectors
# See: https://github.com/ggml-org/llama.cpp/tree/master/examples/embedding#--embd-normalize-integer
Expand Down Expand Up @@ -168,7 +170,7 @@ def embed(
if self.verbose:
llama_cpp.llama_perf_context_reset(ctx)
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)

# Initialize State Variables
results: List[Any] = []
Expand Down Expand Up @@ -219,7 +221,7 @@ def _decode_batch():
results.append(data)

self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
batch_seq_lens = []

# Main Streaming Loop
Expand Down Expand Up @@ -427,3 +429,68 @@ def create_embedding(
print(f"Warning: Failed to calculate similarity matrix: {e}")

return response


def embed_multimodal(
self,
prompt: str,
files: List[bytes | str] = [],

normalize: int = NORM_MODE_EUCLIDEAN,
return_count: bool = False,
) -> Union[List[float], List[List[float]], Tuple[Any, int]]:

ctx = self._ctx.ctx
mctx = self.mtmd_context.ctx

# Determine if it is in Rerank mode
try:
pooling_type = self.pooling_type()
except AttributeError:
pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED
is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK)
is_none = (pooling_type == LLAMA_POOLING_TYPE_NONE) # Token-level embedding

out_dim = self.n_embd()

if self.verbose:
type_str = "TOKEN (None)" if is_none else ("RANK (Score)" if is_rank else "SEQ (Vector)")
print(f"LlamaEmbedding Debug: Mode={type_str} | Pooling={pooling_type} | Dim={out_dim}")

# Reset Context and Batch
if self.verbose:
llama_cpp.llama_perf_context_reset(ctx)
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)

# Initialize State Variables
result: Any = None


with suppress_stdout_stderr(disable=self.verbose):
tokens: MultimodalTokenList = mtmd_tokenize(mctx, prompt, files)

n_tokens = len(tokens)

if n_tokens == 0:
result = []
else:
n_past = mtmd_prefill(self._ctx, mctx, self._batch, tokens)

# Extract Embeddings
ptr = llama_cpp.llama_get_embeddings_ith(ctx, self._batch.n_tokens() - 1)
data = ptr[:out_dim]
data = self._normalize_vector(data, normalize)

result = data

self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)

if self.verbose:
llama_cpp.llama_perf_context_print(ctx)

if return_count:
return result, n_tokens

return result
Loading
Loading