Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
p1 = std::numeric_limits<llama_pos>::max();
}

// models like Mamba or RWKV can't have a state partially erased
// models like Mamba or RWKV can't have a state partially erased at the end
// of the sequence because their state isn't preserved for previous tokens
if (seq_id >= (int64_t) size) {
// could be fatal
return false;
Expand All @@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
int32_t & tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
// partial intersection is invalid if it includes the final pos
if ((0 < p0 && p0 <= cell.pos && p1 > cell.pos)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ((0 < p0 && p0 <= cell.pos && p1 > cell.pos)) {
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we check strictly larger than 0 rather than 0 <= p0?

//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
Expand Down
5 changes: 4 additions & 1 deletion tools/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,10 @@ int main(int argc, char ** argv) {
}

// remove any "future" tokens that we might have inherited from the previous session
llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1);
if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) {
LOG_INF("%s: unable to resuse common prefix\n", __func__);
llama_memory_seq_rm(mem, -1, -1, -1);
}
}

LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
Expand Down
Loading