Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llama-cpp-2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repository = "https://github.com/utilityai/llama-cpp-rs"

[dependencies]
enumflags2 = "0.7.12"
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.113" }
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.122" }
thiserror = { workspace = true }
tracing = { workspace = true }
tracing-core = { workspace = true }
Expand All @@ -35,7 +35,7 @@ mtmd = ["llama-cpp-sys-2/mtmd"]


[target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies]
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.113", features = [
llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.122", features = [
"metal",
] }

Expand Down
48 changes: 17 additions & 31 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ impl LlamaContext<'_> {
/// * `dest` - The sequence id to copy the cache to.
/// * `size` - The size of the cache to copy.
pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
unsafe { llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, 0, size) }
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_sys_2::llama_memory_seq_cp(mem, src, dest, 0, size) }
}

/// Copy the cache from one sequence to another.
Expand Down Expand Up @@ -57,9 +58,8 @@ impl LlamaContext<'_> {
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
unsafe {
llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
}
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_sys_2::llama_memory_seq_cp(mem, src, dest, p0, p1) };
Ok(())
}

Expand Down Expand Up @@ -92,18 +92,15 @@ impl LlamaContext<'_> {
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
Ok(unsafe { llama_cpp_sys_2::llama_kv_self_seq_rm(self.context.as_ptr(), src, p0, p1) })
}

/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
#[must_use]
pub fn get_kv_cache_used_cells(&self) -> i32 {
unsafe { llama_cpp_sys_2::llama_kv_self_used_cells(self.context.as_ptr()) }
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
Ok(unsafe { llama_cpp_sys_2::llama_memory_seq_rm(mem, src, p0, p1) })
}

/// Clear the KV cache
pub fn clear_kv_cache(&mut self) {
unsafe { llama_cpp_sys_2::llama_kv_self_clear(self.context.as_ptr()) }
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
// clear both metadata and data buffers to match previous semantics
unsafe { llama_cpp_sys_2::llama_memory_clear(mem, true) }
}

/// Removes all tokens that do not belong to the specified sequence
Expand All @@ -112,7 +109,8 @@ impl LlamaContext<'_> {
///
/// * `seq_id` - The sequence id to keep
pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
unsafe { llama_cpp_sys_2::llama_kv_self_seq_keep(self.context.as_ptr(), seq_id) }
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_sys_2::llama_memory_seq_keep(mem, seq_id) }
}

#[allow(clippy::doc_markdown)]
Expand Down Expand Up @@ -146,9 +144,8 @@ impl LlamaContext<'_> {
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
unsafe {
llama_cpp_sys_2::llama_kv_self_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
}
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_sys_2::llama_memory_seq_add(mem, seq_id, p0, p1, delta) };
Ok(())
}

Expand Down Expand Up @@ -183,7 +180,8 @@ impl LlamaContext<'_> {
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
let d = c_int::from(d.get());
unsafe { llama_cpp_sys_2::llama_kv_self_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_sys_2::llama_memory_seq_div(mem, seq_id, p0, p1, d) }
Ok(())
}

Expand All @@ -194,19 +192,7 @@ impl LlamaContext<'_> {
/// * `seq_id` - The sequence id to get the max position for
#[must_use]
pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
unsafe { llama_cpp_sys_2::llama_kv_self_seq_pos_max(self.context.as_ptr(), seq_id) }
}

/// Defragment the KV cache
/// This will be applied:
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
pub fn kv_cache_defrag(&mut self) {
unsafe { llama_cpp_sys_2::llama_kv_self_defrag(self.context.as_ptr()) }
}

/// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
pub fn kv_cache_update(&mut self) {
unsafe { llama_cpp_sys_2::llama_kv_self_update(self.context.as_ptr()) }
let mem = unsafe { llama_cpp_sys_2::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_sys_2::llama_memory_seq_pos_max(mem, seq_id) }
}
}
32 changes: 9 additions & 23 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,34 +335,20 @@ impl LlamaContextParams {
self.context_params.n_ubatch
}

/// Set the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_flash_attention(true);
/// assert_eq!(params.flash_attention(), true);
/// ```
/// Set the flash attention policy using llama.cpp enum
#[must_use]
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
self.context_params.flash_attn = enabled;
pub fn with_flash_attention_policy(
mut self,
policy: llama_cpp_sys_2::llama_flash_attn_type,
) -> Self {
self.context_params.flash_attn_type = policy;
self
}

/// Get the `flash_attention` parameter
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default();
/// assert_eq!(params.flash_attention(), false);
/// ```
/// Get the flash attention policy
#[must_use]
pub fn flash_attention(&self) -> bool {
self.context_params.flash_attn
pub fn flash_attention_policy(&self) -> llama_cpp_sys_2::llama_flash_attn_type {
self.context_params.flash_attn_type
}

/// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
Expand Down
3 changes: 0 additions & 3 deletions llama-cpp-2/src/model/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ impl LlamaModelParams {
/// ```
/// # use llama_cpp_2::model::params::LlamaModelParams;
/// let params = LlamaModelParams::default();
/// #[cfg(not(target_os = "macos"))]
/// assert_eq!(params.n_gpu_layers(), 0, "n_gpu_layers should be 0");
/// #[cfg(target_os = "macos")]
/// assert_eq!(params.n_gpu_layers(), 999, "n_gpu_layers should be 999");
/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");
Expand Down
2 changes: 1 addition & 1 deletion llama-cpp-sys-2/llama.cpp
Submodule llama.cpp updated 431 files
Loading