diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h index e5dc1d911e9d4..4280056fa97b7 100644 --- a/ggml/include/ggml-cann.h +++ b/ggml/include/ggml-cann.h @@ -22,6 +22,9 @@ #pragma once +#define GGML_COMMON_DECL_C + +#include "../src/ggml-common.h" #include "ggml-backend.h" #include "ggml.h" @@ -29,35 +32,11 @@ extern "C" { #endif -/** - * @def GGML_CANN_NAME - * @brief Define for the name of the CANN backend. - */ -#define GGML_CANN_NAME "CANN" - /** * @brief Maximum number of CANN devices supported. */ #define GGML_CANN_MAX_DEVICES 16 -/** - * @brief Structure for QK4_0 data format. - */ -#define QK4_0 32 -typedef struct { - uint16_t d; /**< Delta */ - uint8_t qs[QK4_0 / 2]; /**< Nibbles / quants */ -} block_q4_0; - -/** - * @brief Structure for QK8_0 data format. - */ -#define QK8_0 32 -typedef struct { - uint16_t d; /**< Delta */ - int8_t qs[QK8_0]; /**< Quants */ -} block_q8_0; - /** * @brief Initializes the CANN backend for a specified device. * @@ -133,16 +112,6 @@ GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, size_t* free, size_t* total); -/** - * @brief Initializes resources required by the CANN backend. - */ -void ggml_cann_backend_init(void); - -/** - * @brief Frees resources used by the CANN backend. - */ -void ggml_cann_backend_free(void); - #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp index 15465bc422c60..58278b7139098 100644 --- a/ggml/src/ggml-cann.cpp +++ b/ggml/src/ggml-cann.cpp @@ -96,7 +96,7 @@ static ggml_cann_device_info ggml_cann_init() { aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count); if (err != ACL_SUCCESS) { - fprintf(stderr, "%s: failed to initialize " GGML_CANN_NAME ": %s\n", + fprintf(stderr, "%s: failed to initialize CANN: %s\n", __func__, aclGetRecentErrMsg()); return info; } @@ -464,7 +464,6 @@ struct ggml_backend_cann_buffer_context { int32_t device; ///< The device ID associated with this buffer context. void* dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer. - std::string name; ///< Name of the buffer context. /** * @brief Constructor to initialize the CANN buffer context. @@ -474,8 +473,7 @@ struct ggml_backend_cann_buffer_context { */ ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr) : device(device), - dev_ptr(dev_ptr), - name(GGML_CANN_NAME + std::to_string(device)) {} + dev_ptr(dev_ptr) {} /** * @brief Destructor to free the device memory allocated for the buffer. @@ -495,9 +493,9 @@ struct ggml_backend_cann_buffer_context { GGML_CALL static const char* ggml_backend_cann_buffer_get_name( ggml_backend_buffer_t buffer) { - ggml_backend_cann_buffer_context* ctx = - (ggml_backend_cann_buffer_context*)buffer->context; - return ctx->name.c_str(); + return "CANN"; + + GGML_UNUSED(buffer); } /** @@ -1004,10 +1002,9 @@ struct ggml_backend_cann_buffer_type_context { */ GGML_CALL static const char* ggml_backend_cann_buffer_type_name( ggml_backend_buffer_type_t buft) { - ggml_backend_cann_buffer_type_context* ctx = - (ggml_backend_cann_buffer_type_context*)buft->context; + return "CANN"; - return ctx->name.c_str(); + GGML_UNUSED(buft); } /** @@ -1151,8 +1148,8 @@ ggml_backend_cann_buffer_type(int32_t device) { ggml_backend_cann_buffer_types[i] = { /* .iface = */ ggml_backend_cann_buffer_type_interface, /* .context = */ - new ggml_backend_cann_buffer_type_context{ - i, GGML_CANN_NAME + std::to_string(i)}, + new ggml_backend_cann_buffer_type_context{ + i, "CANN" + std::to_string(i)}, }; } ggml_backend_cann_buffer_type_initialized = true; @@ -1344,6 +1341,12 @@ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) { (ggml_backend_cann_context*)backend->context; ACL_CHECK(aclrtSynchronizeDevice()); ACL_CHECK(aclrtResetDevice(cann_ctx->device)); + + // finalize when last backend freed. + if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) { + ACL_CHECK(aclFinalize()); + } + delete cann_ctx; delete backend; } @@ -1703,15 +1706,9 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { */ GGML_CALL static bool ggml_backend_cann_supports_buft( ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (ggml_backend_buft_is_cann(buft)) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - ggml_backend_cann_buffer_type_context* buft_ctx = - (ggml_backend_cann_buffer_type_context*)buft->context; - return buft_ctx->device == cann_ctx->device; - } + return buft->iface.get_name == ggml_backend_cann_buffer_type_name; - return false; + GGML_UNUSED(backend); } /** @@ -1870,6 +1867,7 @@ static ggml_guid_t ggml_backend_cann_guid() { } GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) { + aclInit(nullptr); if (device < 0 || device >= ggml_backend_cann_get_device_count()) { fprintf(stderr, "%s: error: invalid device %d\n", __func__, device); return nullptr; @@ -1945,19 +1943,14 @@ extern "C" GGML_CALL int ggml_backend_cann_reg_devices(); * @return int The number of CANN devices registered. */ GGML_CALL int ggml_backend_cann_reg_devices() { - aclInit(nullptr); uint32_t device_count = ggml_backend_cann_get_device_count(); // initialization for (uint32_t i = 0; i < device_count; i++) { char name[128]; - snprintf(name, sizeof(name), "%s%d", GGML_CANN_NAME, i); + snprintf(name, sizeof(name), "CANN%d", i); ggml_backend_register(name, ggml_backend_reg_cann_init, ggml_backend_cann_buffer_type(i), (void*)(intptr_t)i); } return device_count; } - -void ggml_cann_backend_init(void) { ACL_CHECK(aclInit(nullptr)); } - -void ggml_cann_backend_free(void) { ACL_CHECK(aclFinalize()); } diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 0989fc058cf3f..e6a5701075f02 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -221,7 +221,7 @@ struct ggml_backend_cann_context { * @param device Device ID. */ explicit ggml_backend_cann_context(int device) - : device(device), name(GGML_CANN_NAME + std::to_string(device)) {} + : device(device), name("CANN" + std::to_string(device)) {} /** * @brief Destructor for cleaning up resources. diff --git a/src/llama.cpp b/src/llama.cpp index 0d7de57816169..13c5f10c56c86 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18916,10 +18916,6 @@ void llama_backend_init(void) { struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); } - -#if defined(GGML_USE_CANN) - ggml_cann_backend_init(); -#endif } void llama_numa_init(enum ggml_numa_strategy numa) { @@ -18929,10 +18925,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) { } void llama_backend_free(void) { -#if defined(GGML_USE_CANN) - ggml_cann_backend_free(); -#endif - ggml_quantize_free(); }