1
1
#include < pybind11/pybind11.h>
2
2
#include < pybind11/stl.h>
3
3
4
- // Command line text interface to gemma.
5
-
6
- #include < ctime>
7
- #include < iostream>
8
- #include < random>
9
- #include < string>
10
- #include < thread> // NOLINT
11
- #include < vector>
12
-
13
- // copybara:import_next_line:gemma_cpp
14
- #include " compression/compress.h"
15
- // copybara:end
16
- // copybara:import_next_line:gemma_cpp
17
- #include " gemma.h" // Gemma
18
- // copybara:end
19
- // copybara:import_next_line:gemma_cpp
20
- #include " util/app.h"
21
- // copybara:end
22
- // copybara:import_next_line:gemma_cpp
23
- #include " util/args.h" // HasHelp
24
- // copybara:end
25
- #include " hwy/base.h"
26
- #include " hwy/contrib/thread_pool/thread_pool.h"
27
- #include " hwy/highway.h"
28
- #include " hwy/per_target.h"
29
- #include " hwy/profiler.h"
30
- #include " hwy/timer.h"
31
-
4
+ #include " gemma_binding.h"
32
5
namespace py = pybind11;
33
6
34
- namespace gcpp {
35
-
36
7
static constexpr std::string_view kAsciiArtBanner =
37
8
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n "
38
9
" / _` |/ _ \\ '_ ` _ \\ | '_ ` _ \\ / _` | / __| '_ \\ | '_ \\\n "
@@ -211,35 +182,51 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
211
182
<< " command line flag.\n " ;
212
183
}
213
184
214
- void Run (LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
185
+ void GemmaWrapper::loadModel (const std::vector<std::string> &args) {
186
+ int argc = args.size () + 1 ; // +1 for the program name
187
+ std::vector<char *> argv_vec;
188
+ argv_vec.reserve (argc);
189
+ argv_vec.push_back (const_cast <char *>(" pygemma" ));
190
+ for (const auto &arg : args)
191
+ {
192
+ argv_vec.push_back (const_cast <char *>(arg.c_str ()));
193
+ }
194
+
195
+ char **argv = argv_vec.data ();
196
+
197
+ this ->m_loader = gcpp::LoaderArgs (argc, argv);
198
+ this ->m_inference = gcpp::InferenceArgs (argc, argv);
199
+ this ->m_app = gcpp::AppArgs (argc, argv);
200
+
215
201
PROFILER_ZONE (" Run.misc" );
216
202
217
203
hwy::ThreadPool inner_pool (0 );
218
- hwy::ThreadPool pool (app .num_threads );
204
+ hwy::ThreadPool pool (this -> m_app .num_threads );
219
205
// For many-core, pinning threads to cores helps.
220
- if (app .num_threads > 10 ) {
221
- PinThreadToCore (app .num_threads - 1 ); // Main thread
206
+ if (this -> m_app .num_threads > 10 ) {
207
+ PinThreadToCore (this -> m_app .num_threads - 1 ); // Main thread
222
208
223
209
pool.Run (0 , pool.NumThreads (),
224
210
[](uint64_t /* task*/ , size_t thread) { PinThreadToCore (thread); });
225
211
}
226
212
227
- gcpp::Gemma model (loader.tokenizer , loader.compressed_weights ,
228
- loader.ModelType (), pool);
229
-
230
- auto kv_cache = CreateKVCache (loader.ModelType ());
213
+ if (!this ->m_model ) {
214
+ this ->m_model .reset (new gcpp::Gemma (this ->m_loader .tokenizer , this ->m_loader .compressed_weights , this ->m_loader .ModelType (), pool));
215
+ }
216
+ // auto kvcache = CreateKVCache(loader.ModelType());
217
+ this ->m_kvcache = CreateKVCache (this ->m_loader .ModelType ());
231
218
232
- if (const char * error = inference .Validate ()) {
233
- ShowHelp (loader, inference, app );
219
+ if (const char * error = this -> m_inference .Validate ()) {
220
+ ShowHelp (this -> m_loader , this -> m_inference , this -> m_app );
234
221
HWY_ABORT (" \n Invalid args: %s" , error);
235
222
}
236
223
237
- if (app .verbosity >= 1 ) {
224
+ if (this -> m_app .verbosity >= 1 ) {
238
225
const std::string instructions =
239
226
" *Usage*\n "
240
227
" Enter an instruction and press enter (%C resets conversation, "
241
228
" %Q quits).\n " +
242
- (inference .multiturn == 0
229
+ (this -> m_inference .multiturn == 0
243
230
? std::string (" Since multiturn is set to 0, conversation will "
244
231
" automatically reset every turn.\n\n " )
245
232
: " \n " ) +
@@ -252,153 +239,35 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
252
239
253
240
std::cout << " \033 [2J\033 [1;1H" // clear screen
254
241
<< kAsciiArtBanner << " \n\n " ;
255
- ShowConfig (loader, inference, app );
242
+ ShowConfig (this -> m_loader , this -> m_inference , this -> m_app );
256
243
std::cout << " \n " << instructions << " \n " ;
257
244
}
258
-
259
- ReplGemma (
260
- model, kv_cache, pool, inner_pool, inference, app.verbosity ,
261
- /* accept_token=*/ [](int ) { return true ; }, app.eot_line );
262
245
}
263
246
264
- // std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool,
265
- // hwy::ThreadPool &inner_pool, const InferenceArgs &args,
266
- // int verbosity, const gcpp::AcceptFunc &accept_token,
267
- // std::string &prompt_string)
268
- // {
269
- // std::string generated_text;
270
- // // Seed the random number generator
271
- // std::random_device rd;
272
- // std::mt19937 gen(rd());
273
- // int prompt_size{};
274
- // if (model.model_training == ModelTraining::GEMMA_IT)
275
- // {
276
- // // For instruction-tuned models: add control tokens.
277
- // prompt_string = "<start_of_turn>user\n" + prompt_string +
278
- // "<end_of_turn>\n<start_of_turn>model\n";
279
- // }
280
- // // Encode the prompt string into tokens
281
- // std::vector<int> prompt;
282
- // HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
283
- // // Placeholder for generated token IDs
284
- // std::vector<int> generated_tokens;
285
- // // Define lambda for token decoding
286
- // StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool {
287
- // generated_tokens.push_back(token);
288
- // return true; // Continue generating
289
- // };
290
- // // Decode tokens
291
- // prompt_size = prompt.size();
292
- // GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token,
293
- // accept_token, gen, verbosity);
294
- // HWY_ASSERT(model.Tokenizer()->Decode(generated_tokens, &generated_text).ok());
295
- // generated_text = generated_text.substr(prompt_string.size());
296
-
297
- // return generated_text;
298
- // }
299
-
300
- // std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app,
301
- // std::string &prompt_string){
302
- // hwy::ThreadPool inner_pool(0);
303
- // hwy::ThreadPool pool(app.num_threads);
304
- // if (app.num_threads > 10)
305
- // {
306
- // PinThreadToCore(app.num_threads - 1); // Main thread
307
-
308
- // pool.Run(0, pool.NumThreads(),
309
- // [](uint64_t /*task*/, size_t thread)
310
- // { PinThreadToCore(thread); });
311
- // }
312
- // gcpp::Gemma model(loader, pool);
313
- // return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int)
314
- // { return true; }, prompt_string);
315
-
316
- // }
317
-
318
- } // namespace gcpp
319
-
320
- void chat_base (int argc, char **argv)
321
- {
322
- {
323
- PROFILER_ZONE (" Startup.misc" );
324
-
325
- gcpp::LoaderArgs loader (argc, argv);
326
- gcpp::InferenceArgs inference (argc, argv);
327
- gcpp::AppArgs app (argc, argv);
328
-
329
- if (gcpp::HasHelp (argc, argv))
330
- {
331
- ShowHelp (loader, inference, app);
332
- // return 0;
333
- }
334
-
335
- if (const char *error = loader.Validate ())
336
- {
337
- ShowHelp (loader, inference, app);
338
- HWY_ABORT (" \n Invalid args: %s" , error);
339
- }
340
-
341
- gcpp::Run (loader, inference, app);
342
- }
343
- PROFILER_PRINT_RESULTS (); // Must call outside the zone above.
344
- // return 1;
345
- }
346
- // std::string completion_base(int argc, char **argv)
347
- // {
348
- // gcpp::LoaderArgs loader(argc, argv);
349
- // gcpp::InferenceArgs inference(argc, argv);
350
- // gcpp::AppArgs app(argc, argv);
351
- // std::string prompt_string = argv[argc-1];
352
- // return gcpp::completion(loader, inference, app, prompt_string);
353
- // }
354
- // std::string completion_base_wrapper(const std::vector<std::string> &args,std::string &prompt_string)
355
- // {
356
- // int argc = args.size() + 2; // +1 for the program name
357
- // std::vector<char *> argv_vec;
358
- // argv_vec.reserve(argc);
359
-
360
- // argv_vec.push_back(const_cast<char *>("pygemma"));
361
-
362
- // for (const auto &arg : args)
363
- // {
364
- // argv_vec.push_back(const_cast<char *>(arg.c_str()));
365
- // }
366
- // argv_vec.push_back(const_cast<char *>(prompt_string.c_str()));
367
- // char **argv = argv_vec.data();
368
- // return completion_base(argc, argv);
369
- // }
370
- void show_help_wrapper ()
371
- {
372
- // Assuming ShowHelp does not critically depend on argv content
373
- gcpp::LoaderArgs loader (0 , nullptr );
374
- gcpp::InferenceArgs inference (0 , nullptr );
375
- gcpp::AppArgs app (0 , nullptr );
376
-
377
- ShowHelp (loader, inference, app);
247
+ void GemmaWrapper::showConfig () {
248
+ ShowConfig (this ->m_loader ,this ->m_inference , this ->m_app );
378
249
}
379
250
380
- std::string chat_base_wrapper (const std::vector<std::string> &args)
381
- {
382
- int argc = args.size () + 1 ; // +1 for the program name
383
- std::vector<char *> argv_vec;
384
- argv_vec.reserve (argc);
385
- argv_vec.push_back (const_cast <char *>(" pygemma" ));
386
-
387
- for (const auto &arg : args)
388
- {
389
- argv_vec.push_back (const_cast <char *>(arg.c_str ()));
390
- }
391
-
392
- char **argv = argv_vec.data ();
393
-
394
- chat_base (argc, argv);
251
+ void GemmaWrapper::showHelp () {
252
+ ShowHelp (this ->m_loader ,this ->m_inference , this ->m_app );
395
253
}
396
254
397
255
398
- PYBIND11_MODULE (pygemma, m)
399
- {
400
- m.doc () = " Pybind11 integration for chat_base function" ;
401
- m.def (" chat_base" , &chat_base_wrapper, " A wrapper for the chat_base function accepting Python list of strings as arguments" );
402
- m.def (" show_help" , &show_help_wrapper, " A wrapper for show_help function" );
403
- // m.def("completion", &completion_base_wrapper, "A wrapper for inference function");
256
+ PYBIND11_MODULE (pygemma, m) {
257
+ py::class_<GemmaWrapper>(m, " Gemma" )
258
+ .def (py::init<>())
259
+ .def (" show_config" , &GemmaWrapper::showConfig)
260
+ .def (" show_help" , &GemmaWrapper::showHelp)
261
+ .def (" load_model" , [](GemmaWrapper &self,
262
+ const std::string &tokenizer,
263
+ const std::string &compressed_weights,
264
+ const std::string &model) {
265
+ std::vector<std::string> args = {
266
+ " --tokenizer" , tokenizer,
267
+ " --compressed_weights" , compressed_weights,
268
+ " --model" , model
269
+ };
270
+ self.loadModel (args); // Assuming GemmaWrapper::loadModel accepts std::vector<std::string>
271
+ }, py::arg (" tokenizer" ), py::arg (" compressed_weights" ), py::arg (" model" ))
272
+ .def (" completion" , &GemmaWrapper::completionPrompt);
404
273
}
0 commit comments