Skip to content

feat: Add extended sampling API with candidate token lists #14612 #14765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,109 @@ To learn more about model quantization, [read this documentation](tools/quantize
- Make sure to read this: [Inference at the edge](https://github.com/ggml-org/llama.cpp/discussions/205)
- A bit of backstory for those who are interested: [Changelog podcast](https://changelog.com/podcast/532)

## Extended Sampling API

The `llama.cpp` library provides an extended sampling API that allows developers to access detailed information about the sampling process, including candidate tokens and their probabilities. This feature is particularly useful for debugging, analysis, and building applications that need insight into the model's decision-making process.

### New API Functions

#### `llama_sampler_sample_with_candidates`

```c
int32_t llama_sampler_sample_with_candidates(
struct llama_sampler * smpl,
struct llama_context * ctx,
int32_t idx,
size_t max_candidates,
struct llama_sampling_result * result
);
```

This function extends the standard `llama_sampler_sample` by returning detailed information about the sampling process:

- **Parameters:**
- `smpl`: The sampler instance
- `ctx`: The context containing the model state
- `idx`: Index of the output to sample from (typically -1 for the last output)
- `max_candidates`: Maximum number of candidate tokens to return (0 for all candidates)
- `result`: Pointer to the result structure

- **Returns:** 0 on success, -1 if result is null, -2 if logits are invalid

#### `llama_sampling_result` Structure

```c
typedef struct llama_sampling_result {
llama_token selected_token; // The selected token ID
float selected_logit; // Logit value of the selected token
float selected_prob; // Probability of the selected token
bool is_selected; // True if a token was successfully selected
llama_token_data_array candidates; // Array of candidate tokens and their probabilities
} llama_sampling_result;
```

#### `llama_sampling_result_free`

```c
void llama_sampling_result_free(struct llama_sampling_result * result);
```

Frees the memory allocated within a `llama_sampling_result` structure.

### Usage Example

```c
#include "llama.h"

// Initialize model and context
llama_model * model = llama_model_load_from_file("model.gguf", model_params);
llama_context * ctx = llama_init_from_model(model, ctx_params);

// Create sampler chain
llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40));
llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
llama_sampler_chain_add(smpl, llama_sampler_init_dist(42));

// Sample with candidates
struct llama_sampling_result result;
int ret = llama_sampler_sample_with_candidates(smpl, ctx, -1, 10, &result);

if (ret == 0 && result.is_selected) {
printf("Selected token: %d (Probability: %.3f)\n",
result.selected_token, result.selected_prob);

printf("Top candidates:\n");
for (size_t i = 0; i < result.candidates.size; i++) {
const llama_token_data * candidate = &result.candidates.data[i];
printf(" %zu. Token %d (Probability: %.3f, Logit: %.3f)\n",
i + 1, candidate->id, candidate->p, candidate->logit);
}
}

// Clean up
llama_sampling_result_free(&result);
llama_sampler_free(smpl);
llama_free(ctx);
llama_model_free(model);
```

### Applications

This extended API is particularly useful for:

- **Debugging sampling strategies**: Understanding why specific tokens were chosen
- **Analyzing model behavior**: Examining the probability distribution of candidate tokens
- **Building interactive applications**: Providing users with insight into the model's decision process
- **Quality control**: Verifying that the model is selecting reasonable candidates
- **Research and development**: Studying the relationship between different sampling parameters and token selection

### Implementation Details

The API integrates seamlessly with the existing sampling chain architecture in `llama.cpp`. It applies the same sampling steps (Top-K, Temperature, Distribution, etc.) as the standard sampling function but provides additional transparency into the process.

For more information about the implementation, see the source code in `src/llama-sampling.cpp` and the header definitions in `include/llama.h`.

## Other documentation

- [main (cli)](tools/main/README.md)
Expand Down
29 changes: 29 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@ extern "C" {
bool sorted;
} llama_token_data_array;

// Extended sampling result with candidate tokens and probabilities
typedef struct llama_sampling_result {
llama_token selected_token; // the selected token
float selected_logit; // logit of the selected token
float selected_prob; // probability of the selected token
bool is_selected; // flag indicating if a token was selected
llama_token_data_array candidates; // array of candidate tokens with their probabilities
} llama_sampling_result;

typedef bool (*llama_progress_callback)(float progress, void * user_data);

// Input data for llama_encode/llama_decode
Expand Down Expand Up @@ -1356,6 +1365,26 @@ extern "C" {
// Returns the sampled token
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);

/// @details Extended sampling function that returns detailed information about the sampling process
/// including the selected token and a list of candidate tokens with their probabilities.
/// @param smpl The sampler to use
/// @param ctx The context containing the model
/// @param idx The index of the output to sample from (-1 for the last token)
/// @param max_candidates Maximum number of candidate tokens to return (0 for all available)
/// @param result Pointer to the result structure to fill
/// @return 0 on success, negative value on error
LLAMA_API int32_t llama_sampler_sample_with_candidates(
struct llama_sampler * smpl,
struct llama_context * ctx,
int32_t idx,
size_t max_candidates,
struct llama_sampling_result * result
);

/// @details Free the memory allocated for candidate tokens in a sampling result
/// @param result The sampling result to free
LLAMA_API void llama_sampling_result_free(struct llama_sampling_result * result);

// TODO: extend in the future
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);

Expand Down
90 changes: 90 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,96 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
return token;
}

int32_t llama_sampler_sample_with_candidates(
struct llama_sampler * smpl,
struct llama_context * ctx,
int32_t idx,
size_t max_candidates,
struct llama_sampling_result * result
) {
if (!result) {
return -1; // Invalid result pointer
}

const auto * logits = llama_get_logits_ith(ctx, idx);
if (!logits) {
return -2; // Invalid logits
}

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);

// Initialize result structure
result->selected_token = LLAMA_TOKEN_NULL;
result->selected_logit = 0.0f;
result->selected_prob = 0.0f;
result->is_selected = false;
result->candidates.data = nullptr;
result->candidates.size = 0;
result->candidates.selected = -1;
result->candidates.sorted = false;

// Create candidate tokens array
std::vector<llama_token_data> cur;
cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}

llama_token_data_array cur_p = {
/* .data = */ cur.data(),
/* .size = */ cur.size(),
/* .selected = */ -1,
/* .sorted = */ false,
};

// Apply sampling
llama_sampler_apply(smpl, &cur_p);

// Check if a token was selected
if (cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size) {
result->selected_token = cur_p.data[cur_p.selected].id;
result->selected_logit = cur_p.data[cur_p.selected].logit;
result->selected_prob = cur_p.data[cur_p.selected].p;
result->is_selected = true;

// Accept the selected token
llama_sampler_accept(smpl, result->selected_token);
}

// Determine how many candidates to return
size_t num_candidates = cur_p.size;
if (max_candidates > 0 && max_candidates < num_candidates) {
num_candidates = max_candidates;
}

// Allocate and copy candidate data
if (num_candidates > 0) {
result->candidates.data = new llama_token_data[num_candidates];
result->candidates.size = num_candidates;
result->candidates.selected = cur_p.selected;
result->candidates.sorted = cur_p.sorted;

// Copy the top candidates (they should already be sorted by the sampler)
for (size_t i = 0; i < num_candidates; i++) {
result->candidates.data[i] = cur_p.data[i];
}
}

return 0; // Success
}

void llama_sampling_result_free(struct llama_sampling_result * result) {
if (result && result->candidates.data) {
delete[] result->candidates.data;
result->candidates.data = nullptr;
result->candidates.size = 0;
result->candidates.selected = -1;
result->candidates.sorted = false;
}
}

// sampler chain

static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
Expand Down