Skip to content

Commit

Permalink
Add support for streaming APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
hitsmaxft committed Jan 28, 2024
1 parent 7035120 commit d59c4cf
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ google-generative-ai-rs = "0.1.7"
log = "0.4.20"
serde_json = "1.0.112"
env_logger = "0.11.1"
futures = "0.3.30"

34 changes: 22 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use futures::stream::{self, StreamExt};

use clap::{App, Arg, ArgMatches};
use env_logger::Env;
use log::{info, warn};
use tokio::io::{self, AsyncWriteExt};
use serde::{Deserialize, Serialize};
//use std::env;

use google_generative_ai_rs::v1::{
api::Client,
gemini::{request::Request, response::Candidate, Content, Part, Role},
gemini::{request::Request, Content, Part, Role},
};

#[derive(Serialize, Deserialize)]
Expand All @@ -33,7 +35,7 @@ async fn run(matches: ArgMatches) -> Result<(), Box<dyn std::error::Error>> {
.value_of("config-file")
.unwrap_or("~/.config/gemini.toml");

let stream = match matches.value_of("stream") {
let is_stream = match matches.value_of("stream") {
Some("true") => true,
_ => false,
};
Expand All @@ -45,7 +47,7 @@ async fn run(matches: ArgMatches) -> Result<(), Box<dyn std::error::Error>> {
.or_else(|| Some(config.token.as_str()))
.expect("No token provided. Please use --token or configure in the TOML file.");

let client = match stream {
let client = match is_stream {
true => Client::new_from_model_reponse_type(
google_generative_ai_rs::v1::gemini::Model::GeminiPro,
token.to_string(),
Expand Down Expand Up @@ -75,22 +77,30 @@ async fn run(matches: ArgMatches) -> Result<(), Box<dyn std::error::Error>> {

let response = client.post(30, &txt_request).await?;

match stream {
match is_stream {
true => match response.streamed() {
Some(mut gemini) => gemini.streamed_candidates.iter_mut().for_each(|gemini| {
Some(gemini) => {
let stream_iter = stream::iter(&gemini.streamed_candidates);
stream_iter.then(|gemini| async move {
match &(gemini.candidates[0].content.parts[0].text) {
Some(text) => print!("{}", text.to_string()),
_ => print!("{}", "text is empty"),
Some(text) => {
print!("{}", text.to_string());
let _ =io::stdout().flush().await;
""
}
_ => "",
}
}),
_ => print!("empty response"),
}).collect::<String>().await;
}
,
_ => (),
},
_ => match response.rest() {
Some(gemini) => match &(gemini.candidates[0].content.parts[0].text) {
Some(text) => print!("{}", text.to_string()),
_ => print!("{}", "text is empty"),
_ => (),
},
_ => print!("empty response"),
_ => (),
},
}

Expand Down

0 comments on commit d59c4cf

Please sign in to comment.