Skip to content

Commit

Permalink
Make summarizing run in parallel.
Browse files Browse the repository at this point in the history
  • Loading branch information
zensh committed Sep 4, 2023
1 parent ea8171a commit eed5b16
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "jarvis"
version = "0.11.2"
version = "0.11.3"
edition = "2021"
rust-version = "1.64"
description = ""
Expand Down
6 changes: 3 additions & 3 deletions crates/axum-web/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ pub struct ReqContext {
}

impl ReqContext {
pub fn new(rid: &str, user: xid::Id, rating: i8) -> Self {
pub fn new(rid: String, user: xid::Id, rating: i8) -> Self {
Self {
rid: rid.to_string(),
rid,
user,
rating,
unix_ms: unix_ms(),
Expand Down Expand Up @@ -63,7 +63,7 @@ pub async fn middleware<B>(mut req: Request<B>, next: Next<B>) -> Response {

let uid = xid::Id::from_str(&user).unwrap_or_default();

let ctx = Arc::new(ReqContext::new(&rid, uid, rating));
let ctx = Arc::new(ReqContext::new(rid.clone(), uid, rating));
req.extensions_mut().insert(ctx.clone());

let res = next.run(req).await;
Expand Down
2 changes: 1 addition & 1 deletion src/api/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async fn embedding(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)
let mut total_tokens: i32 = 0;
let mut progress = 0usize;
for unit_group in content {
let ctx = ReqContext::new(&rid, user, 0);
let ctx = ReqContext::new(rid.clone(), user, 0);
let embedding_input: Vec<String> = unit_group
.iter()
.map(|unit| unit.to_embedding_string())
Expand Down
2 changes: 1 addition & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ static SECTION_SEPARATOR: &str = "------";

// gpt-35-turbo, 4096
static SUMMARIZE_SECTION_TOKENS: usize = 2400;
static SUMMARIZE_HIGH_TOKENS: usize = 3000;
pub(crate) static SUMMARIZE_HIGH_TOKENS: usize = 3000;

// text-embedding-ada-002, 8191
// https://community.openai.com/t/embedding-text-length-vs-accuracy/96564
Expand Down
121 changes: 101 additions & 20 deletions src/api/summarizing.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use axum::{extract::State, Extension};
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Instant};
use tokio::sync::{mpsc, Semaphore};
use validator::Validate;

use axum_web::context::{unix_ms, ReqContext};
use axum_web::erring::{HTTPError, SuccessResponse};
use axum_web::object::{cbor_from_slice, PackObject};
use scylla_orm::ColumnsMap;

use crate::api::{AppState, TEContentList, TEOutput, TEParams, TESegmenter};
use crate::api::{AppState, TEContentList, TEOutput, TEParams, TESegmenter, SUMMARIZE_HIGH_TOKENS};
use crate::db;
use crate::lang::Language;
use crate::openai;
Expand Down Expand Up @@ -173,7 +174,7 @@ async fn summarize(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)

log::info!(target: "summarizing",
action = "start_job",
rid = rid,
rid = rid.clone(),
user = user.to_string(),
gid = te.gid.to_string(),
cid = te.cid.to_string(),
Expand All @@ -187,23 +188,39 @@ async fn summarize(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)
let mut total_tokens: usize = 0;
let mut doc = db::Summarizing::with_pk(te.gid, te.cid, te.language, te.version);

let mut output = String::new();
let mut progress = 0usize;
if content.len() == 1 && tokenizer::tokens_len(&content[0]) < 100 {
output = content[0].to_owned();
let output = if pieces == 1 && tokenizer::tokens_len(&content[0]) < 100 {
content[0].to_owned()
} else {
for c in content {
let ctx = ReqContext::new(&rid, user, 0);
let text = if output.is_empty() {
c.to_owned()
} else {
output.clone() + "\n" + &c
};

let res = app.ai.summarize(&ctx, te.language.to_name(), &text).await;
let semaphore = Arc::new(Semaphore::new(5));
let (tx, mut rx) =
mpsc::channel::<(usize, ReqContext, Result<(u32, String), HTTPError>)>(pieces);

for (i, text) in content.into_iter().enumerate() {
let rid = rid.clone();
let user = user;
let app = app.clone();
let lang = te.language.to_name();
let tx = tx.clone();
let sem = semaphore.clone();
tokio::spawn(async move {
let permit = sem.acquire().await.unwrap();
let ctx = ReqContext::new(rid, user, 0);
let res = app.ai.summarize(&ctx, lang, &text).await;
let _ = tx.send((i, ctx, res)).await;
drop(permit)
});
}

let mut res_list: Vec<String> = Vec::with_capacity(pieces);
res_list.resize(pieces, "".to_string());

while let Some((i, ctx, res)) = rx.recv().await {
let ai_elapsed = ctx.start.elapsed().as_millis() as u64;
let kv = ctx.get_kv().await;
if let Err(err) = res {
semaphore.close();

let mut cols = ColumnsMap::with_capacity(2);
cols.set_as("updated_at", &(unix_ms() as i64));
cols.set_as("error", &err.to_string());
Expand All @@ -215,6 +232,7 @@ async fn summarize(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)
cid = te.cid.to_string(),
language = te.language.to_639_3().to_string(),
elapsed = ai_elapsed,
piece_at = i,
kv = log::as_serde!(kv);
"{}", err.to_string(),
);
Expand All @@ -225,11 +243,11 @@ async fn summarize(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)
let used_tokens = res.0 as usize;
total_tokens += used_tokens;
progress += 1;
output = res.1;
res_list[i] = res.1;

let mut cols = ColumnsMap::with_capacity(3);
cols.set_as("updated_at", &(unix_ms() as i64));
cols.set_as("progress", &((progress * 100 / pieces) as i8));
cols.set_as("progress", &((progress * 100 / pieces + 1) as i8));
cols.set_as("tokens", &(total_tokens as i32));
let _ = doc.upsert_fields(&app.scylla, cols).await;

Expand All @@ -241,11 +259,74 @@ async fn summarize(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)
tokens = used_tokens,
total_elapsed = start.elapsed().as_millis(),
total_tokens = total_tokens,
piece_at = i,
kv = log::as_serde!(kv);
"{}/{}", progress, pieces,
"{}/{}", progress, pieces+1,
);
}
}

// summarize all pieces
let mut tokens_list: Vec<usize> =
res_list.iter().map(|s| tokenizer::tokens_len(s)).collect();
while tokens_list.len() > 2 && tokens_list.iter().sum::<usize>() > SUMMARIZE_HIGH_TOKENS {
let i = tokens_list.len() / 2 + 1;
// ignore pieces in middle.
res_list.remove(i);
tokens_list.remove(i);
}

let ctx = ReqContext::new(rid.clone(), user, 0);
let res = app
.ai
.summarize(&ctx, te.language.to_name(), &res_list.join("\n"))
.await;
let ai_elapsed = ctx.start.elapsed().as_millis() as u64;
let kv = ctx.get_kv().await;
if let Err(err) = res {
let mut cols = ColumnsMap::with_capacity(2);
cols.set_as("updated_at", &(unix_ms() as i64));
cols.set_as("error", &err.to_string());
let _ = doc.upsert_fields(&app.scylla, cols).await;

log::error!(target: "summarizing",
action = "call_openai",
rid = ctx.rid,
cid = te.cid.to_string(),
language = te.language.to_639_3().to_string(),
elapsed = ai_elapsed,
piece_at = pieces + 1,
kv = log::as_serde!(kv);
"{}", err.to_string(),
);
return;
}

let res = res.unwrap();
let used_tokens = res.0 as usize;
total_tokens += used_tokens;
progress += 1;

let mut cols = ColumnsMap::with_capacity(3);
cols.set_as("updated_at", &(unix_ms() as i64));
cols.set_as("progress", &100i8);
cols.set_as("tokens", &(total_tokens as i32));
let _ = doc.upsert_fields(&app.scylla, cols).await;

log::info!(target: "summarizing",
action = "call_openai",
rid = ctx.rid,
cid = te.cid.to_string(),
elapsed = ai_elapsed,
tokens = used_tokens,
total_elapsed = start.elapsed().as_millis(),
total_tokens = total_tokens,
piece_at = progress,
kv = log::as_serde!(kv);
"{}/{}", progress, pieces+1,
);

res.1
};

// save target lang doc to db
let mut cols = ColumnsMap::with_capacity(5);
Expand All @@ -260,7 +341,7 @@ async fn summarize(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)
Err(err) => {
log::error!(target: "summarizing",
action = "to_scylla",
rid = &rid,
rid = rid.clone(),
cid = te.cid.to_string(),
elapsed = start.elapsed().as_millis() as u64 - elapsed,
summary_length = output.len();
Expand All @@ -270,7 +351,7 @@ async fn summarize(app: Arc<AppState>, rid: String, user: xid::Id, te: TEParams)
Ok(_) => {
log::info!(target: "summarizing",
action = "to_scylla",
rid = &rid,
rid = rid.clone(),
cid = te.cid.to_string(),
elapsed = start.elapsed().as_millis() as u64 - elapsed,
summary_length = output.len();
Expand Down
2 changes: 1 addition & 1 deletion src/api/translating.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ async fn translate(

let mut progress = 0usize;
for unit in content {
let ctx = ReqContext::new(&rid, user, 0);
let ctx = ReqContext::new(rid.clone(), user, 0);
let res = app
.ai
.translate(
Expand Down

0 comments on commit eed5b16

Please sign in to comment.