Skip to content

Commit

Permalink
[docs] Appendix E (#287)
Browse files Browse the repository at this point in the history
* LoRALayer

* fix error in import statement for LoRALayer

* LinearWithLoRA

* sm nit

* MultiHeadAttentionWithLoRA

* FeedForwardWithLoRA

* TransformerBlockWithLoRA

* GPTModelWithLoRA
  • Loading branch information
nerdai authored Jan 23, 2025
1 parent ee66b09 commit 78dd5f0
Showing 1 changed file with 117 additions and 1 deletion.
118 changes: 117 additions & 1 deletion src/listings/apdx_e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,22 @@ pub struct LoRALayer {
}

impl LoRALayer {
/// Creates a new `LoRALayer`
///
/// ```rust
/// use candle_core::{Device, DType};
/// use candle_nn::{VarBuilder, VarMap};
/// use llms_from_scratch_rs::listings::apdx_e::LoRALayer;
///
/// let dev = Device::cuda_if_available(0).unwrap();
/// let varmap = VarMap::new();
/// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
///
/// let alpha = 0.5_f64;
/// let rank = 3_usize;
/// let (d_in, d_out) = (20_usize, 30_usize);
/// let lora_layer = LoRALayer::new(d_in, d_out, rank, alpha, vb).unwrap();
/// ```
#[allow(non_snake_case)]
pub fn new(
in_dim: usize,
Expand Down Expand Up @@ -170,6 +186,25 @@ pub struct LinearWithLoRA {
}

impl LinearWithLoRA {
/// Creates a new `LinearWithLoRA` from `Linear`
///
/// ```rust
/// use candle_core::{Device, DType};
/// use candle_nn::{Linear, VarBuilder, VarMap};
/// use llms_from_scratch_rs::listings::ch04::Config;
/// use llms_from_scratch_rs::listings::apdx_e::LinearWithLoRA;
///
/// let dev = Device::cuda_if_available(0).unwrap();
/// let varmap = VarMap::new();
/// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
///
/// let cfg = Config::gpt_sm_test();
/// let linear = candle_nn::linear(cfg.emb_dim, cfg.emb_dim, vb.pp("linear")).unwrap();
///
/// let alpha = 0.5_f64;
/// let rank = 3_usize;
/// let lora_with_linear = LinearWithLoRA::from_linear(linear, rank, alpha, vb.pp("linear")).unwrap();
/// ```
pub fn from_linear(
linear: Linear,
rank: usize,
Expand All @@ -194,7 +229,7 @@ impl Module for LinearWithLoRA {
}

/// Function to replace all `Linear` layers with `LinearWithLoRA` in a given model
/// NOTE: this won't work for Candle
/// NOTE: this won't work for Candle and is thus a no-op.
/// Need to impl all the modules `XXXWithLoRA` and probably impl the `From` trait
#[allow(unused_variables)]
pub fn replace_linear_with_lora(
Expand All @@ -205,6 +240,7 @@ pub fn replace_linear_with_lora(
varmap: &VarMap,
vb: VarBuilder<'_>,
) -> Result<()> {
// no-op
Ok(())
}

Expand All @@ -226,6 +262,25 @@ pub struct MultiHeadAttentionWithLoRA {
}

impl MultiHeadAttentionWithLoRA {
/// Creates a new `MultiHeadAttentionWithLoRA` from `MultiHeadAttention`
///
/// ```rust
/// use candle_core::{Device, DType};
/// use candle_nn::{VarBuilder, VarMap};
/// use llms_from_scratch_rs::listings::ch03::MultiHeadAttention;
/// use llms_from_scratch_rs::listings::apdx_e::MultiHeadAttentionWithLoRA;
///
/// let dev = Device::cuda_if_available(0).unwrap();
/// let varmap = VarMap::new();
/// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
///
/// let (d_in, d_out, num_heads) = (3_usize, 6_usize, 2_usize);
/// let mha = MultiHeadAttention::new(d_in, d_out, 0.5_f32, num_heads, false, vb.pp("attn")).unwrap();
///
/// let alpha = 0.5_f64;
/// let rank = 3_usize;
/// let mha_with_lora = MultiHeadAttentionWithLoRA::from_mha(mha, rank, alpha, vb.pp("attn")).unwrap();
/// ```
pub fn from_mha(
mha: MultiHeadAttention,
rank: usize,
Expand Down Expand Up @@ -374,6 +429,24 @@ pub struct FeedForwardWithLoRA {
}

impl FeedForwardWithLoRA {
/// Creates a new `FeedForwardWithLoRA` from `FeedForward`
///
/// ```rust
/// use candle_core::{Device, DType};
/// use candle_nn::{VarBuilder, VarMap};
/// use llms_from_scratch_rs::listings::ch04::{Config, FeedForward};
/// use llms_from_scratch_rs::listings::apdx_e::FeedForwardWithLoRA;
///
/// let dev = Device::cuda_if_available(0).unwrap();
/// let varmap = VarMap::new();
/// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
///
/// let ff = FeedForward::new(Config::gpt_sm_test(), vb.pp("ff")).unwrap();
///
/// let alpha = 0.5_f64;
/// let rank = 3_usize;
/// let ff_with_lora = FeedForwardWithLoRA::from_ff(ff, rank, alpha, vb.pp("ff")).unwrap();
/// ```
pub fn from_ff(ff: FeedForward, rank: usize, alpha: f64, vb: VarBuilder<'_>) -> Result<Self> {
let mut iter = ff.layers().iter();

Expand Down Expand Up @@ -444,6 +517,30 @@ pub struct TransformerBlockWithLoRA {
}

impl TransformerBlockWithLoRA {
/// Creates a new `TransformerBlockWithLoRA` from `TransformerBlock`
///
/// ```rust
/// use candle_core::{Device, DType};
/// use candle_nn::{VarBuilder, VarMap};
/// use llms_from_scratch_rs::listings::ch04::{Config, TransformerBlock};
/// use llms_from_scratch_rs::listings::apdx_e::TransformerBlockWithLoRA;
///
/// let dev = Device::cuda_if_available(0).unwrap();
/// let varmap = VarMap::new();
/// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
///
/// let cfg = Config::gpt_sm_test();
/// let transformer_block = TransformerBlock::new(cfg, vb.pp("transformer")).unwrap();
///
/// let alpha = 0.5_f64;
/// let rank = 3_usize;
/// let transformer_block_with_lora = TransformerBlockWithLoRA::from_trf_block(
/// transformer_block,
/// rank,
/// alpha,
/// vb.pp("transformer"),
/// ).unwrap();
/// ```
pub fn from_trf_block(
trf_block: TransformerBlock,
rank: usize,
Expand Down Expand Up @@ -553,6 +650,25 @@ pub struct GPTModelWithLoRA {
}

impl GPTModelWithLoRA {
/// Creates a new `GPTModelWithLoRA` from `GPTModel`
///
/// ```rust
/// use candle_core::{Device, DType};
/// use candle_nn::{VarBuilder, VarMap};
/// use llms_from_scratch_rs::listings::ch04::{Config, GPTModel};
/// use llms_from_scratch_rs::listings::apdx_e::GPTModelWithLoRA;
///
/// let dev = Device::cuda_if_available(0).unwrap();
/// let varmap = VarMap::new();
/// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
///
/// let cfg = Config::gpt_sm_test();
/// let model = GPTModel::new(cfg, vb.pp("model")).unwrap();
///
/// let alpha = 0.5_f64;
/// let rank = 3_usize;
/// let model_with_lora = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model")).unwrap();
/// ```
pub fn from_gpt_model(
gpt: GPTModel,
rank: usize,
Expand Down

0 comments on commit 78dd5f0

Please sign in to comment.