From b9b1d71d3e1c9e9fe08d3936cb0d4d3311882529 Mon Sep 17 00:00:00 2001 From: hitsmaxft Date: Sun, 4 Feb 2024 14:49:05 +0800 Subject: [PATCH] support stream output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes in Cargo.toml: • Updated tokio from 1.0 to 1.0-beta.1. • Added google-generative-ai-rs as a dependency with version 0.1.7. Changes in src/main.rs: • Added the shellexpand crate. • Made config_path configurable via command-line argument --config-file. • Added a --stream command-line argument to indicate whether the response should be streamed. Changes in gemini-cli-example.toml: • Added an example config with google-generative-ai-rs model specified. • Set template to Context: {system}\nMessages:\nAuthor: User\nContent: {prompt}. Commit message: Update Tokio, add support for Google Generative AI RS, and configurable streaming --- Cargo.lock | 424 ++++++++++++++++++++++++++++++++++++---- Cargo.toml | 9 +- gemini-cli-example.toml | 12 ++ src/main.rs | 151 ++++++++------ 4 files changed, 496 insertions(+), 100 deletions(-) create mode 100644 gemini-cli-example.toml diff --git a/Cargo.lock b/Cargo.lock index 3985183..ff54d35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,7 +103,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", ] [[package]] @@ -234,6 +234,19 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.52.0", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -250,12 +263,45 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "either" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "encoding_rs" version = "0.8.33" @@ -275,19 +321,6 @@ dependencies = [ "regex", ] -[[package]] -name = "env_logger" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" -dependencies = [ - "humantime", - "is-terminal", - "log", - "regex", - "termcolor", -] - [[package]] name = "env_logger" version = "0.11.1" @@ -409,7 +442,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", ] [[package]] @@ -473,15 +506,17 @@ version = "0.1.0" dependencies = [ "anyhow", "clap", - "env_logger 0.11.1", + "env_logger", "futures", "google-generative-ai-rs", "log", "serde", "serde_json", + "shellexpand", "thiserror", "tokio", - "toml", + "tokio-stream", + "toml 0.5.11", ] [[package]] @@ -503,14 +538,14 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "google-generative-ai-rs" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1675b2fba1151aa52e96f73e861d3563a5b79abfc8ffcb2f1c09ea41fa371c9" +version = "0.1.8" +source = "git+https://github.com/hitsmaxft/google-generative-ai-rs.git?branch=streaming#794f924728575bf5509ede574ccd6a948aff1742" dependencies = [ - "env_logger 0.10.2", + "env_logger", "futures", "gcp_auth", "log", + "pin", "reqwest", "reqwest-streams", "serde", @@ -549,6 +584,12 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.1.19" @@ -718,6 +759,28 @@ dependencies = [ "hashbrown 0.14.3", ] +[[package]] +name = "indicatif" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + [[package]] name = "ipnet" version = "2.9.0" @@ -725,14 +788,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] -name = "is-terminal" -version = "0.4.10" +name = "itertools" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" dependencies = [ - "hermit-abi 0.3.4", - "rustix", - "windows-sys 0.52.0", + "either", ] [[package]] @@ -750,6 +811,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "json" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "078e285eafdfb6c4b434e0d31e8cfcb5115b651496faca5749b88fafd4f23bfd" + [[package]] name = "lazy_static" version = "1.4.0" @@ -762,6 +829,17 @@ version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +[[package]] +name = "libredox" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" +dependencies = [ + "bitflags 2.4.2", + "libc", + "redox_syscall", +] + [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -784,6 +862,15 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "memchr" version = "2.7.1" @@ -834,6 +921,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-traits" version = "0.2.17" @@ -853,6 +950,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.32.2" @@ -891,7 +994,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", ] [[package]] @@ -912,12 +1015,24 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "os_str_bytes" version = "6.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -947,6 +1062,35 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b306f2297baf731f5b5eb28a015190c57d9b4ccd4aa91b02c84ec6fe27ffb911" +dependencies = [ + "async-trait", + "bytes", + "clap", + "futures", + "home", + "indicatif", + "itertools", + "json", + "pin-project", + "reqwest", + "serde", + "serde_json", + "snafu", + "strfmt", + "tokio", + "toml 0.4.10", + "tracing", + "tracing-subscriber", + "tracing-tree", + "unicode-segmentation", + "url", +] + [[package]] name = "pin-project" version = "1.1.4" @@ -964,7 +1108,7 @@ checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", ] [[package]] @@ -985,6 +1129,12 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "proc-macro2" version = "1.0.78" @@ -1012,6 +1162,17 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_users" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "regex" version = "1.10.3" @@ -1020,8 +1181,17 @@ checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.5", + "regex-syntax 0.8.2", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -1032,9 +1202,15 @@ checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.2", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.2" @@ -1246,7 +1422,7 @@ checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", ] [[package]] @@ -1272,6 +1448,24 @@ dependencies = [ "serde", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shellexpand" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da03fa3b94cc19e3ebfc88c4229c49d8f08cdbd1228870a45f0ffdf84988e14b" +dependencies = [ + "dirs", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -1296,6 +1490,29 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "backtrace", + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "socket2" version = "0.5.5" @@ -1312,12 +1529,29 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "strfmt" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b278b244ef7aa5852b277f52dd0c6cac3a109919e1f6d699adde63251227a30f" + [[package]] name = "strsim" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.48" @@ -1395,7 +1629,17 @@ checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", +] + +[[package]] +name = "thread_local" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +dependencies = [ + "cfg-if", + "once_cell", ] [[package]] @@ -1440,7 +1684,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", ] [[package]] @@ -1463,6 +1707,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.10" @@ -1477,6 +1732,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "toml" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758664fc71a3a69038656bee8b6be6477d2a6c315a6b81f7081f591bffa4111f" +dependencies = [ + "serde", +] + [[package]] name = "toml" version = "0.5.11" @@ -1511,7 +1775,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", ] [[package]] @@ -1521,6 +1785,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", ] [[package]] @@ -1533,6 +1798,71 @@ dependencies = [ "tracing", ] +[[package]] +name = "tracing-log" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f751112709b4e791d8ce53e32c4ed2d353565a795ce84da2285393f41557bdf2" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-serde", +] + +[[package]] +name = "tracing-tree" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ec6adcab41b1391b08a308cc6302b79f8095d1673f6947c2dc65ffb028b0b2d" +dependencies = [ + "nu-ansi-term", + "tracing-core", + "tracing-log 0.1.4", + "tracing-subscriber", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -1560,6 +1890,18 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-segmentation" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" + +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + [[package]] name = "untrusted" version = "0.9.0" @@ -1583,6 +1925,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15" @@ -1625,7 +1973,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.48", "wasm-bindgen-shared", ] @@ -1659,7 +2007,7 @@ checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.48", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 107cc31..c9ee9d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,17 @@ edition = "2021" clap = "3.1.0" serde = { version = "1.0", features = ["derive"] } toml = "0.5.8" -tokio = { version = "1.0", features = ["full"] } +tokio = { version = "1.0", features = ["full", "time"] } anyhow = "1.0" thiserror = "1.0" -google-generative-ai-rs = "0.1.7" log = "0.4.20" serde_json = "1.0.112" env_logger = "0.11.1" futures = "0.3.30" +shellexpand = "3.1.0" +tokio-stream = "0.1.14" +[dependencies.google-generative-ai-rs] + +git = "https://github.com/hitsmaxft/google-generative-ai-rs.git" +branch = "streaming" diff --git a/gemini-cli-example.toml b/gemini-cli-example.toml new file mode 100644 index 0000000..c709d9b --- /dev/null +++ b/gemini-cli-example.toml @@ -0,0 +1,12 @@ +token= "" + +model = "gemini-pro" +template = "Context: {system}\nMessages:\nAuthor: User\nContent: {prompt}" +stream=true +markdown=true + +generation_config={} + +[code] +model = "code-bison" + diff --git a/src/main.rs b/src/main.rs index 3a7b405..16cbe53 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,12 @@ -use futures::stream::{self, StreamExt}; +extern crate shellexpand; // 1.0.0 + +use log::info; use clap::{App, Arg, ArgMatches}; use env_logger::Env; -use tokio::io::{self, AsyncWriteExt}; +use google_generative_ai_rs::v1::gemini::response::{Candidate, GeminiResponse}; +use std::io::{stdout, Write}; use serde::{Deserialize, Serialize}; -//use std::env; use google_generative_ai_rs::v1::{ api::Client, @@ -17,14 +19,53 @@ struct Config { generation_config: std::collections::HashMap, } -fn read_config(file_path: &str) -> Result> { - let contents = std::fs::read_to_string(file_path)?; +async fn read_config(input: &str) -> Result> { + + let real_path: &str = &shellexpand::tilde(input); + + info!("final config file path is {}", real_path); + let contents = tokio::fs::read_to_string(real_path).await?; let config: Config = toml::from_str(&contents)?; Ok(config) } -#[tokio::main] +async fn output_response(gemini: &GeminiResponse) -> String { + + if gemini.candidates.len() ==0 { + return "".to_string(); + } + + let first_candi:&Candidate = &gemini.candidates[0]; + + if first_candi.content.parts.len() == 0 { + return "".to_string(); + } + + let first_part : &Part = &first_candi.content.parts[0]; + let may_text: &Option = &first_part.text; + + match may_text { + Some(text ) => { + let mut lock = stdout().lock(); + let _ = write!(lock, "{}", text); + "".to_string() + } + _ => "".to_string(), + } +} + async fn run(matches: ArgMatches) -> Result<(), Box> { + + + let env = Env::default() + .filter_or("MY_LOG_LEVEL", match matches.contains_id("verbose") { + true => "info", + _ => "warn", }) + .write_style_or("MY_LOG_STYLE", "always"); + + env_logger::init_from_env(env); + + // Parse command-line arguments let prompt = matches.value_of("prompt").unwrap_or_else(|| { eprintln!("No prompt provided. Please use --prompt to specify the prompt."); @@ -33,14 +74,10 @@ async fn run(matches: ArgMatches) -> Result<(), Box> { let config_path = matches .value_of("config-file") - .unwrap_or("~/.config/gemini.toml"); - - let is_stream = match matches.value_of("stream") { - Some("true") => true, - _ => false, - }; + .unwrap_or("~/.config/gemini-cli.toml"); - let config = read_config(config_path)?; + let is_stream = matches.contains_id("stream"); + let config = read_config(config_path).await?; let token = matches .value_of("token") @@ -48,7 +85,7 @@ async fn run(matches: ArgMatches) -> Result<(), Box> { .expect("No token provided. Please use --token or configure in the TOML file."); let client = match is_stream { - true => Client::new_from_model_reponse_type( + true => Client::new_from_model_response_type( google_generative_ai_rs::v1::gemini::Model::GeminiPro, token.to_string(), google_generative_ai_rs::v1::gemini::ResponseType::StreamGenerateContent, @@ -72,47 +109,35 @@ async fn run(matches: ArgMatches) -> Result<(), Box> { tools: vec![], safety_settings: vec![], + //TODO read from config generation_config: None, }; let response = client.post(30, &txt_request).await?; - match is_stream { - true => match response.streamed() { - Some(gemini) => { - let stream_iter = stream::iter(&gemini.streamed_candidates); - stream_iter.then(|gemini| async move { - match &(gemini.candidates[0].content.parts[0].text) { - Some(text) => { - print!("{}", text.to_string()); - let _ =io::stdout().flush().await; - "" - } - _ => "", - } - }).collect::().await; + if is_stream { + info!("streaming output"); + if let Some(stream_response) = response.streamed() { + if let Some(json_stream) = stream_response.response_stream { + Client::for_each_async(json_stream, move |gr:GeminiResponse| async move { + output_response(&gr).await; + }).await } - , - _ => (), - }, - _ => match response.rest() { - Some(gemini) => match &(gemini.candidates[0].content.parts[0].text) { - Some(text) => print!("{}", text.to_string()), - _ => (), - }, - _ => (), - }, + } + } else { + if let Some(gemini) = response.rest() { + if let Some(text) = &gemini.candidates.get(0).and_then(|c| c.content.parts.get(0).and_then(|p| p.text.as_ref())) { + print!("{}", text); + } + } } Ok(()) } -fn main() -> Result<(), Box> { - let env = Env::default() - .filter_or("MY_LOG_LEVEL", "warn") - .write_style_or("MY_LOG_STYLE", "always"); - - env_logger::init_from_env(env); +#[tokio::main] +async fn main() -> Result<(), Box> { + // Define the interval duration let matches = App::new("Gemini CLI") .version("0.1.0") @@ -120,32 +145,38 @@ fn main() -> Result<(), Box> { .about("Interacts with the Gemini model") .arg( Arg::with_name("prompt") - .long("prompt") - .value_name("PROMPT") - .help("Sets the prompt for the Gemini model") - .takes_value(true), + .long("prompt") + .value_name("PROMPT") + .help("Sets the prompt for the Gemini model") + .takes_value(true), + ) + .arg( + Arg::with_name("verbose") + .short('v') + .long("verbose") + .help("output more logs"), ) .arg( Arg::with_name("stream") - .long("stream") - .help("Streams the response from the model"), + .long("stream") + .help("Streams the response from the model"), ) .arg( Arg::with_name("config-file") - .short('f') - .long("config-file") - .value_name("FILE") - .help("Specify a custom TOML file for configuration") - .takes_value(true), + .short('f') + .long("config-file") + .value_name("FILE") + .help("Specify a custom TOML file for configuration") + .takes_value(true), ) .arg( Arg::with_name("token") - .long("token") - .value_name("TOKEN") - .help("Specify the API token directly") - .takes_value(true), + .long("token") + .value_name("TOKEN") + .help("Specify the API token directly") + .takes_value(true), ) .get_matches(); - run(matches) + run(matches).await }