Skip to content

Commit

Permalink
Modify the code based on review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
hipudding committed Jul 16, 2024
1 parent 0da1e1f commit f50f090
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 69 deletions.
37 changes: 3 additions & 34 deletions ggml/include/ggml-cann.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,42 +22,21 @@

#pragma once

#define GGML_COMMON_DECL_C

#include "../src/ggml-common.h"
#include "ggml-backend.h"
#include "ggml.h"

#ifdef __cplusplus
extern "C" {
#endif

/**
* @def GGML_CANN_NAME
* @brief Define for the name of the CANN backend.
*/
#define GGML_CANN_NAME "CANN"

/**
* @brief Maximum number of CANN devices supported.
*/
#define GGML_CANN_MAX_DEVICES 16

/**
* @brief Structure for QK4_0 data format.
*/
#define QK4_0 32
typedef struct {
uint16_t d; /**< Delta */
uint8_t qs[QK4_0 / 2]; /**< Nibbles / quants */
} block_q4_0;

/**
* @brief Structure for QK8_0 data format.
*/
#define QK8_0 32
typedef struct {
uint16_t d; /**< Delta */
int8_t qs[QK8_0]; /**< Quants */
} block_q8_0;

/**
* @brief Initializes the CANN backend for a specified device.
*
Expand Down Expand Up @@ -133,16 +112,6 @@ GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
size_t* free,
size_t* total);

/**
* @brief Initializes resources required by the CANN backend.
*/
void ggml_cann_backend_init(void);

/**
* @brief Frees resources used by the CANN backend.
*/
void ggml_cann_backend_free(void);

#ifdef __cplusplus
}
#endif
45 changes: 19 additions & 26 deletions ggml/src/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ static ggml_cann_device_info ggml_cann_init() {
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);

if (err != ACL_SUCCESS) {
fprintf(stderr, "%s: failed to initialize " GGML_CANN_NAME ": %s\n",
fprintf(stderr, "%s: failed to initialize CANN: %s\n",
__func__, aclGetRecentErrMsg());
return info;
}
Expand Down Expand Up @@ -464,7 +464,6 @@ struct ggml_backend_cann_buffer_context {
int32_t device; ///< The device ID associated with this buffer context.
void* dev_ptr =
nullptr; ///< Pointer to the device memory allocated for the buffer.
std::string name; ///< Name of the buffer context.

/**
* @brief Constructor to initialize the CANN buffer context.
Expand All @@ -474,8 +473,7 @@ struct ggml_backend_cann_buffer_context {
*/
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
: device(device),
dev_ptr(dev_ptr),
name(GGML_CANN_NAME + std::to_string(device)) {}
dev_ptr(dev_ptr) {}

/**
* @brief Destructor to free the device memory allocated for the buffer.
Expand All @@ -495,9 +493,9 @@ struct ggml_backend_cann_buffer_context {

GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
ggml_backend_buffer_t buffer) {
ggml_backend_cann_buffer_context* ctx =
(ggml_backend_cann_buffer_context*)buffer->context;
return ctx->name.c_str();
return "CANN";

GGML_UNUSED(buffer);
}

/**
Expand Down Expand Up @@ -1004,10 +1002,9 @@ struct ggml_backend_cann_buffer_type_context {
*/
GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
ggml_backend_buffer_type_t buft) {
ggml_backend_cann_buffer_type_context* ctx =
(ggml_backend_cann_buffer_type_context*)buft->context;
return "CANN";

return ctx->name.c_str();
GGML_UNUSED(buft);
}

/**
Expand Down Expand Up @@ -1151,8 +1148,8 @@ ggml_backend_cann_buffer_type(int32_t device) {
ggml_backend_cann_buffer_types[i] = {
/* .iface = */ ggml_backend_cann_buffer_type_interface,
/* .context = */
new ggml_backend_cann_buffer_type_context{
i, GGML_CANN_NAME + std::to_string(i)},
new ggml_backend_cann_buffer_type_context{
i, "CANN" + std::to_string(i)},
};
}
ggml_backend_cann_buffer_type_initialized = true;
Expand Down Expand Up @@ -1344,6 +1341,12 @@ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
(ggml_backend_cann_context*)backend->context;
ACL_CHECK(aclrtSynchronizeDevice());
ACL_CHECK(aclrtResetDevice(cann_ctx->device));

// finalize when last backend freed.
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
ACL_CHECK(aclFinalize());
}

delete cann_ctx;
delete backend;
}
Expand Down Expand Up @@ -1703,15 +1706,9 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
*/
GGML_CALL static bool ggml_backend_cann_supports_buft(
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
if (ggml_backend_buft_is_cann(buft)) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_backend_cann_buffer_type_context* buft_ctx =
(ggml_backend_cann_buffer_type_context*)buft->context;
return buft_ctx->device == cann_ctx->device;
}
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;

return false;
GGML_UNUSED(backend);
}

/**
Expand Down Expand Up @@ -1870,6 +1867,7 @@ static ggml_guid_t ggml_backend_cann_guid() {
}

GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
aclInit(nullptr);
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
return nullptr;
Expand Down Expand Up @@ -1945,19 +1943,14 @@ extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
* @return int The number of CANN devices registered.
*/
GGML_CALL int ggml_backend_cann_reg_devices() {
aclInit(nullptr);
uint32_t device_count = ggml_backend_cann_get_device_count();
// initialization
for (uint32_t i = 0; i < device_count; i++) {
char name[128];
snprintf(name, sizeof(name), "%s%d", GGML_CANN_NAME, i);
snprintf(name, sizeof(name), "CANN%d", i);
ggml_backend_register(name, ggml_backend_reg_cann_init,
ggml_backend_cann_buffer_type(i),
(void*)(intptr_t)i);
}
return device_count;
}

void ggml_cann_backend_init(void) { ACL_CHECK(aclInit(nullptr)); }

void ggml_cann_backend_free(void) { ACL_CHECK(aclFinalize()); }
2 changes: 1 addition & 1 deletion ggml/src/ggml-cann/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ struct ggml_backend_cann_context {
* @param device Device ID.
*/
explicit ggml_backend_cann_context(int device)
: device(device), name(GGML_CANN_NAME + std::to_string(device)) {}
: device(device), name("CANN" + std::to_string(device)) {}

/**
* @brief Destructor for cleaning up resources.
Expand Down
8 changes: 0 additions & 8 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18916,10 +18916,6 @@ void llama_backend_init(void) {
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}

#if defined(GGML_USE_CANN)
ggml_cann_backend_init();
#endif
}

void llama_numa_init(enum ggml_numa_strategy numa) {
Expand All @@ -18929,10 +18925,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
}

void llama_backend_free(void) {
#if defined(GGML_USE_CANN)
ggml_cann_backend_free();
#endif

ggml_quantize_free();
}

Expand Down

0 comments on commit f50f090

Please sign in to comment.