Skip to content

Commit

Permalink
Add new LLM module and a prompt generator
Browse files Browse the repository at this point in the history
  • Loading branch information
hitsmaxft committed Feb 7, 2024
1 parent 70ba797 commit f67222c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 24 deletions.
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@

pub mod cli;
pub mod cli;
pub mod llm;
35 changes: 35 additions & 0 deletions src/llm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use google_generative_ai_rs::v1::api::PostResult;
use google_generative_ai_rs::v1::gemini::{Role, Content, Part};
use google_generative_ai_rs::v1::gemini::request::Request;
use google_generative_ai_rs::v1::api::Client;
use google_generative_ai_rs::v1::errors::GoogleAPIError;

pub struct LLMRequest<'a> {
pub stream: bool,
pub rich: bool,
pub token: &'a str,
pub prompt: Option<String>,

}

pub async fn request(client: Client, req: LLMRequest<'_>) -> Result<PostResult, GoogleAPIError> {
let txt_request = Request {
contents: vec![Content {
role: Role::User,
parts: vec![Part {
text: req.prompt,
inline_data: None,
file_data: None,
video_metadata: None,
}],
}],

tools: vec![],
safety_settings: vec![],
//TODO read from config
generation_config: None,
};

return client.post(30, &txt_request).await;

}
32 changes: 9 additions & 23 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
use gemini_pro_cli::cli;
use clap::ArgMatches;
use env_logger::Env;
use gemini_pro_cli::llm;
use google_generative_ai_rs::v1::gemini::response::GeminiResponse;
use google_generative_ai_rs::v1::gemini::request::Request;
use google_generative_ai_rs::v1::gemini::Content;
use google_generative_ai_rs::v1::gemini::Role;
use google_generative_ai_rs::v1::gemini::Part;
use log::info;
use std::io::{stdin, Read};

Expand Down Expand Up @@ -66,24 +63,12 @@ async fn run(matches: ArgMatches) -> Result<(), Box<dyn std::error::Error>> {
),
};

let txt_request = Request {
contents: vec![Content {
role: Role::User,
parts: vec![Part {
text: Some(prompt),
inline_data: None,
file_data: None,
video_metadata: None,
}],
}],

tools: vec![],
safety_settings: vec![],
//TODO read from config
generation_config: None,
};

let response = client.post(30, &txt_request).await?;
let response = llm::request(client, llm::LLMRequest {
stream : is_stream,
rich : is_rich,
token : token,
prompt : Some(prompt),
}).await?;

if is_stream {
info!("streaming output");
Expand All @@ -99,7 +84,8 @@ async fn run(matches: ArgMatches) -> Result<(), Box<dyn std::error::Error>> {
if let Some(text) = &gemini
.candidates
.first()
.and_then(|c| c.content.parts.first().and_then(|p| p.text.as_ref()))
.and_then(|c| c.content.parts.first()
.and_then(|p| p.text.as_ref()))
{
if is_rich {
termimad::print_inline(text);
Expand Down

0 comments on commit f67222c

Please sign in to comment.