Skip to content

Commit 016f59d

Browse files
committed
update: add new interface based on new gemma.cpp
1 parent 7ece21c commit 016f59d

File tree

6 files changed

+116
-213
lines changed

6 files changed

+116
-213
lines changed

.clang-format

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Language: Cpp
2+
BasedOnStyle: Google

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
1111
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
1212
FetchContent_MakeAvailable(sentencepiece)
1313

14-
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
14+
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG 8fb44ed6dd123f63dca95c20c561e8ca1de511d7)
1515
FetchContent_MakeAvailable(gemma)
1616

1717
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
@@ -30,3 +30,5 @@ FetchContent_GetProperties(gemma)
3030
FetchContent_GetProperties(sentencepiece)
3131
target_include_directories(pygemma PRIVATE ${gemma_SOURCE_DIR})
3232
target_include_directories(pygemma PRIVATE ${sentencepiece_SOURCE_DIR})
33+
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
34+
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from setuptools import setup, find_packages, Extension
55
from setuptools.command.build_ext import build_ext
6+
import platform
67

78

89
class CMakeExtension(Extension):
@@ -39,7 +40,7 @@ def build_extension(self, ext):
3940
"--",
4041
"-j",
4142
"12",
42-
] # Specifies the number of jobs to run simultaneously
43+
]
4344

4445
if not os.path.exists(self.build_temp):
4546
os.makedirs(self.build_temp)
@@ -58,7 +59,7 @@ def build_extension(self, ext):
5859
version="0.1.2",
5960
author="Nam Tran",
6061
author_email="[email protected]",
61-
description="A Python package with a C++ backend using gemma.",
62+
description="A Python package with a C++ backend using gemma.cpp",
6263
long_description="""
6364
This package provides Python bindings to a C++ library using pybind11.
6465
""",

src/gemma_binding.cpp

Lines changed: 51 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,9 @@
11
#include <pybind11/pybind11.h>
22
#include <pybind11/stl.h>
33

4-
// Command line text interface to gemma.
5-
6-
#include <ctime>
7-
#include <iostream>
8-
#include <random>
9-
#include <string>
10-
#include <thread> // NOLINT
11-
#include <vector>
12-
13-
// copybara:import_next_line:gemma_cpp
14-
#include "compression/compress.h"
15-
// copybara:end
16-
// copybara:import_next_line:gemma_cpp
17-
#include "gemma.h" // Gemma
18-
// copybara:end
19-
// copybara:import_next_line:gemma_cpp
20-
#include "util/app.h"
21-
// copybara:end
22-
// copybara:import_next_line:gemma_cpp
23-
#include "util/args.h" // HasHelp
24-
// copybara:end
25-
#include "hwy/base.h"
26-
#include "hwy/contrib/thread_pool/thread_pool.h"
27-
#include "hwy/highway.h"
28-
#include "hwy/per_target.h"
29-
#include "hwy/profiler.h"
30-
#include "hwy/timer.h"
31-
4+
#include "gemma_binding.h"
325
namespace py = pybind11;
336

34-
namespace gcpp {
35-
367
static constexpr std::string_view kAsciiArtBanner =
378
" __ _ ___ _ __ ___ _ __ ___ __ _ ___ _ __ _ __\n"
389
" / _` |/ _ \\ '_ ` _ \\| '_ ` _ \\ / _` | / __| '_ \\| '_ \\\n"
@@ -211,35 +182,51 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
211182
<< "command line flag.\n";
212183
}
213184

214-
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
185+
void GemmaWrapper::loadModel(const std::vector<std::string> &args) {
186+
int argc = args.size() + 1; // +1 for the program name
187+
std::vector<char *> argv_vec;
188+
argv_vec.reserve(argc);
189+
argv_vec.push_back(const_cast<char *>("pygemma"));
190+
for (const auto &arg : args)
191+
{
192+
argv_vec.push_back(const_cast<char *>(arg.c_str()));
193+
}
194+
195+
char **argv = argv_vec.data();
196+
197+
this->m_loader = gcpp::LoaderArgs(argc, argv);
198+
this->m_inference = gcpp::InferenceArgs(argc, argv);
199+
this->m_app = gcpp::AppArgs(argc, argv);
200+
215201
PROFILER_ZONE("Run.misc");
216202

217203
hwy::ThreadPool inner_pool(0);
218-
hwy::ThreadPool pool(app.num_threads);
204+
hwy::ThreadPool pool(this->m_app.num_threads);
219205
// For many-core, pinning threads to cores helps.
220-
if (app.num_threads > 10) {
221-
PinThreadToCore(app.num_threads - 1); // Main thread
206+
if (this->m_app.num_threads > 10) {
207+
PinThreadToCore(this->m_app.num_threads - 1); // Main thread
222208

223209
pool.Run(0, pool.NumThreads(),
224210
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
225211
}
226212

227-
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
228-
loader.ModelType(), pool);
229-
230-
auto kv_cache = CreateKVCache(loader.ModelType());
213+
if (!this->m_model) {
214+
this->m_model.reset(new gcpp::Gemma(this->m_loader.tokenizer, this->m_loader.compressed_weights, this->m_loader.ModelType(), pool));
215+
}
216+
// auto kvcache = CreateKVCache(loader.ModelType());
217+
this->m_kvcache = CreateKVCache(this->m_loader.ModelType());
231218

232-
if (const char* error = inference.Validate()) {
233-
ShowHelp(loader, inference, app);
219+
if (const char* error = this->m_inference.Validate()) {
220+
ShowHelp(this->m_loader, this->m_inference, this->m_app);
234221
HWY_ABORT("\nInvalid args: %s", error);
235222
}
236223

237-
if (app.verbosity >= 1) {
224+
if (this->m_app.verbosity >= 1) {
238225
const std::string instructions =
239226
"*Usage*\n"
240227
" Enter an instruction and press enter (%C resets conversation, "
241228
"%Q quits).\n" +
242-
(inference.multiturn == 0
229+
(this->m_inference.multiturn == 0
243230
? std::string(" Since multiturn is set to 0, conversation will "
244231
"automatically reset every turn.\n\n")
245232
: "\n") +
@@ -252,153 +239,35 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
252239

253240
std::cout << "\033[2J\033[1;1H" // clear screen
254241
<< kAsciiArtBanner << "\n\n";
255-
ShowConfig(loader, inference, app);
242+
ShowConfig(this->m_loader, this->m_inference, this->m_app);
256243
std::cout << "\n" << instructions << "\n";
257244
}
258-
259-
ReplGemma(
260-
model, kv_cache, pool, inner_pool, inference, app.verbosity,
261-
/*accept_token=*/[](int) { return true; }, app.eot_line);
262245
}
263246

264-
// std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool,
265-
// hwy::ThreadPool &inner_pool, const InferenceArgs &args,
266-
// int verbosity, const gcpp::AcceptFunc &accept_token,
267-
// std::string &prompt_string)
268-
// {
269-
// std::string generated_text;
270-
// // Seed the random number generator
271-
// std::random_device rd;
272-
// std::mt19937 gen(rd());
273-
// int prompt_size{};
274-
// if (model.model_training == ModelTraining::GEMMA_IT)
275-
// {
276-
// // For instruction-tuned models: add control tokens.
277-
// prompt_string = "<start_of_turn>user\n" + prompt_string +
278-
// "<end_of_turn>\n<start_of_turn>model\n";
279-
// }
280-
// // Encode the prompt string into tokens
281-
// std::vector<int> prompt;
282-
// HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok());
283-
// // Placeholder for generated token IDs
284-
// std::vector<int> generated_tokens;
285-
// // Define lambda for token decoding
286-
// StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool {
287-
// generated_tokens.push_back(token);
288-
// return true; // Continue generating
289-
// };
290-
// // Decode tokens
291-
// prompt_size = prompt.size();
292-
// GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token,
293-
// accept_token, gen, verbosity);
294-
// HWY_ASSERT(model.Tokenizer()->Decode(generated_tokens, &generated_text).ok());
295-
// generated_text = generated_text.substr(prompt_string.size());
296-
297-
// return generated_text;
298-
// }
299-
300-
// std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app,
301-
// std::string &prompt_string){
302-
// hwy::ThreadPool inner_pool(0);
303-
// hwy::ThreadPool pool(app.num_threads);
304-
// if (app.num_threads > 10)
305-
// {
306-
// PinThreadToCore(app.num_threads - 1); // Main thread
307-
308-
// pool.Run(0, pool.NumThreads(),
309-
// [](uint64_t /*task*/, size_t thread)
310-
// { PinThreadToCore(thread); });
311-
// }
312-
// gcpp::Gemma model(loader, pool);
313-
// return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int)
314-
// { return true; }, prompt_string);
315-
316-
// }
317-
318-
} // namespace gcpp
319-
320-
void chat_base(int argc, char **argv)
321-
{
322-
{
323-
PROFILER_ZONE("Startup.misc");
324-
325-
gcpp::LoaderArgs loader(argc, argv);
326-
gcpp::InferenceArgs inference(argc, argv);
327-
gcpp::AppArgs app(argc, argv);
328-
329-
if (gcpp::HasHelp(argc, argv))
330-
{
331-
ShowHelp(loader, inference, app);
332-
// return 0;
333-
}
334-
335-
if (const char *error = loader.Validate())
336-
{
337-
ShowHelp(loader, inference, app);
338-
HWY_ABORT("\nInvalid args: %s", error);
339-
}
340-
341-
gcpp::Run(loader, inference, app);
342-
}
343-
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
344-
// return 1;
345-
}
346-
// std::string completion_base(int argc, char **argv)
347-
// {
348-
// gcpp::LoaderArgs loader(argc, argv);
349-
// gcpp::InferenceArgs inference(argc, argv);
350-
// gcpp::AppArgs app(argc, argv);
351-
// std::string prompt_string = argv[argc-1];
352-
// return gcpp::completion(loader, inference, app, prompt_string);
353-
// }
354-
// std::string completion_base_wrapper(const std::vector<std::string> &args,std::string &prompt_string)
355-
// {
356-
// int argc = args.size() + 2; // +1 for the program name
357-
// std::vector<char *> argv_vec;
358-
// argv_vec.reserve(argc);
359-
360-
// argv_vec.push_back(const_cast<char *>("pygemma"));
361-
362-
// for (const auto &arg : args)
363-
// {
364-
// argv_vec.push_back(const_cast<char *>(arg.c_str()));
365-
// }
366-
// argv_vec.push_back(const_cast<char *>(prompt_string.c_str()));
367-
// char **argv = argv_vec.data();
368-
// return completion_base(argc, argv);
369-
// }
370-
void show_help_wrapper()
371-
{
372-
// Assuming ShowHelp does not critically depend on argv content
373-
gcpp::LoaderArgs loader(0, nullptr);
374-
gcpp::InferenceArgs inference(0, nullptr);
375-
gcpp::AppArgs app(0, nullptr);
376-
377-
ShowHelp(loader, inference, app);
247+
void GemmaWrapper::showConfig() {
248+
ShowConfig(this->m_loader,this->m_inference, this->m_app);
378249
}
379250

380-
std::string chat_base_wrapper(const std::vector<std::string> &args)
381-
{
382-
int argc = args.size() + 1; // +1 for the program name
383-
std::vector<char *> argv_vec;
384-
argv_vec.reserve(argc);
385-
argv_vec.push_back(const_cast<char *>("pygemma"));
386-
387-
for (const auto &arg : args)
388-
{
389-
argv_vec.push_back(const_cast<char *>(arg.c_str()));
390-
}
391-
392-
char **argv = argv_vec.data();
393-
394-
chat_base(argc, argv);
251+
void GemmaWrapper::showHelp() {
252+
ShowHelp(this->m_loader,this->m_inference, this->m_app);
395253
}
396254

397255

398-
PYBIND11_MODULE(pygemma, m)
399-
{
400-
m.doc() = "Pybind11 integration for chat_base function";
401-
m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments");
402-
m.def("show_help", &show_help_wrapper, "A wrapper for show_help function");
403-
// m.def("completion", &completion_base_wrapper, "A wrapper for inference function");
256+
PYBIND11_MODULE(pygemma, m) {
257+
py::class_<GemmaWrapper>(m, "Gemma")
258+
.def(py::init<>())
259+
.def("show_config", &GemmaWrapper::showConfig)
260+
.def("show_help", &GemmaWrapper::showHelp)
261+
.def("load_model", [](GemmaWrapper &self,
262+
const std::string &tokenizer,
263+
const std::string &compressed_weights,
264+
const std::string &model) {
265+
std::vector<std::string> args = {
266+
"--tokenizer", tokenizer,
267+
"--compressed_weights", compressed_weights,
268+
"--model", model
269+
};
270+
self.loadModel(args); // Assuming GemmaWrapper::loadModel accepts std::vector<std::string>
271+
}, py::arg("tokenizer"), py::arg("compressed_weights"), py::arg("model"))
272+
.def("completion", &GemmaWrapper::completionPrompt);
404273
}

src/gemma_binding.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
// Command line text interface to gemma.
3+
4+
#include <ctime>
5+
#include <iostream>
6+
#include <random>
7+
#include <string>
8+
#include <thread> // NOLINT
9+
#include <vector>
10+
11+
// copybara:import_next_line:gemma_cpp
12+
#include "compression/compress.h"
13+
// copybara:end
14+
// copybara:import_next_line:gemma_cpp
15+
#include "gemma.h" // Gemma
16+
// copybara:end
17+
// copybara:import_next_line:gemma_cpp
18+
#include "util/app.h"
19+
// copybara:end
20+
// copybara:import_next_line:gemma_cpp
21+
#include "util/args.h" // HasHelp
22+
// copybara:end
23+
#include "hwy/base.h"
24+
#include "hwy/contrib/thread_pool/thread_pool.h"
25+
#include "hwy/highway.h"
26+
#include "hwy/per_target.h"
27+
#include "hwy/profiler.h"
28+
#include "hwy/timer.h"
29+
30+
using namespace gcpp;
31+
32+
class GemmaWrapper {
33+
public:
34+
// GemmaWrapper();
35+
void loadModel(const std::vector<std::string> &args); // Consider exception safety
36+
void showConfig();
37+
void showHelp();
38+
std::string completionPrompt();
39+
40+
private:
41+
gcpp::LoaderArgs m_loader = gcpp::LoaderArgs(0, nullptr);
42+
gcpp::InferenceArgs m_inference = gcpp::InferenceArgs(0, nullptr);
43+
gcpp::AppArgs m_app = gcpp::AppArgs(0, nullptr);
44+
std::unique_ptr<gcpp::Gemma> m_model;
45+
KVCache m_kvcache;
46+
};

0 commit comments

Comments
 (0)