-
Notifications
You must be signed in to change notification settings - Fork 12.5k
cuda : implement bf16 cpy ops and enable bf16 cont #14763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally speaking I am not a fan of how the float conversions are being done currently. I think the code could be deduplicated significantly by unconditionally casting half
, nv_bfloat16
, and float
to float
and then simply using that float
value to set the destination. I would appreciate it if you were to do this in this PR, otherwise I'll keep it as one of the tasks to hand out when people ask me for a good first issue to work on.
* origin/master: (49 commits) ci : correct label refactor->refactoring (ggml-org#14832) CUDA: fix quantized KV cache + multiple sequences (ggml-org#14822) tests : add non-cont K,V FA tests memory : handle saving/loading null layers in recurrent memory (ggml-org#14675) ggml: fix loongarch quantize_row_q8_1 error (ggml-org#14827) CANN: weight format to NZ for Ascend310P3 (ggml-org#14407) CUDA: add fused rms norm (ggml-org#14800) ggml : model card yaml tab->2xspace (ggml-org#14819) vulkan: fix rms_norm_mul to handle broadcasting dim0 (ggml-org#14817) llama : add model type detection for rwkv7 7B&14B (ggml-org#14816) imatrix: add option to display importance score statistics for a given imatrix file (ggml-org#12718) Mtmd: add a way to select device for vision encoder (ggml-org#14236) cuda : implement bf16 cpy ops and enable bf16 cont (ggml-org#14763) opencl: remove unreachable `return` (ggml-org#14806) server : allow setting `--reverse-prompt` arg (ggml-org#14799) cuda: remove linking to cublasLt (ggml-org#14790) opencl: fix `im2col` when `KW!=KH` (ggml-org#14803) opencl: add conv2d kernel (ggml-org#14403) sycl: Fix im2col (ggml-org#14797) kleidiai: add support for get_rows (ggml-org#14676) ...
* implement bf16 cpy ops and enable bf16 cont * deduplicate copy functions * deduplicate checks
Implemented missing BF16 CPY ops and enabled CONT op for BF16.
Tests before
Tests after
Also fixed a cut'n'paste error for F16->F16 in
ggml_cuda_cpy_fn
and deduplicated all copy functions.