Skip to content

Conversation

@NikhilNayak-debug
Copy link
Contributor

Summary

Adds runnable continual learning example for OSF method as requested in PR #2685 .

Changes

  • New example in examples/orthogonal_subspace_learning/
  • Demonstrates OSF preventing catastrophic forgetting on 3 sequential tasks (ScienceQA, NumGLUE, FOMC)
  • Includes full fine-tuning baseline for comparison
  • Progressive capacity allocation: 70% trainable (Task 1) → 50% (Task 2) → 30% (Task 3)
  • Tracks accuracy and backward transfer metrics

Results (2 epochs per task)

  • OSF: 53.42% average accuracy, +30.25% backward transfer
  • Full FT: 46.26% average accuracy, -6.00% forgetting
  • OSF prevents catastrophic forgetting and enables positive backward transfer

Files

  • osf_continual_learning.py - Main example script with OSF and baseline training
  • utils.py - Dataset loading and formatting utilities for 3 tasks
  • README.md - Comprehensive documentation with usage examples

Implementation Details

  • Uses meta-llama/Llama-3.1-8B-Instruct by default
  • Learning rate: 5e-6, batch size: 32
  • Progressive effective_rank allocation (0.3 → 0.5 → 0.7)

Addresses feedback from #2685

@NikhilNayak-debug
Copy link
Contributor Author

@githubnemo I have added the continual learning example as requested. Could you please review this PR?

The example demonstrates OSF on 3 sequential tasks (ScienceQA, NumGLUE, FOMC) with progressive rank allocation and compares against full fine-tuning baseline.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks, this looks very nice!

Just a quick review for now, tried to run the code and got this exception:

$ python osf_continual_learning.py --model_name meta-llama/Llama-3.2-1B-Instruct --run_baseline

================================================================================
TRAINING WITH OSF (Orthogonal Subspace Fine-tuning)
================================================================================

Loading datasets...
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 9871.85 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 23863.81 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 17309.66 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 10534.10 examples/s]

================================================================================
TASK 1: ScienceQA
Effective Rank: 0.3 (preserving 30%)
================================================================================
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.

Training on ScienceQA...
{'loss': 6.7117, 'grad_norm': 220.0, 'learning_rate': 4.296875e-06, 'epoch': 0.31}                                                                                                                                   
{'loss': 1.8609, 'grad_norm': 136.0, 'learning_rate': 3.5156250000000003e-06, 'epoch': 0.62}                                                                                                                         
{'loss': 1.2447, 'grad_norm': 131.0, 'learning_rate': 2.7343750000000004e-06, 'epoch': 0.94}                                                                                                                         
{'loss': 1.1187, 'grad_norm': 118.5, 'learning_rate': 1.953125e-06, 'epoch': 1.25}                                                                                                                                   
{'loss': 1.0351, 'grad_norm': 117.0, 'learning_rate': 1.1718750000000001e-06, 'epoch': 1.56}                                                                                                                         
{'loss': 1.0219, 'grad_norm': 117.0, 'learning_rate': 3.90625e-07, 'epoch': 1.88}                                                                                                                                    
{'train_runtime': 22.6765, 'train_samples_per_second': 88.197, 'train_steps_per_second': 2.822, 'train_loss': 2.0924081951379776, 'epoch': 2.0}                                                                      
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:22<00:00,  2.82it/s]

Evaluating on all tasks after training on ScienceQA:
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:01<00:00, 20.69it/s]
Traceback (most recent call last):
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 694, in <module>
    main()
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 627, in main
    osf_history = train_with_osf(
                  ^^^^^^^^^^^^^^^
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 364, in train_with_osf
    loss, accuracy = evaluate_model(
                     ^^^^^^^^^^^^^^^
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 199, in evaluate_model
    loss = results["eval_loss"]
           ~~~~~~~^^^^^^^^^^^^^
KeyError: 'eval_loss'

Maybe that's on my side, I'm investigating.


- [OSF Documentation](../../docs/source/package_reference/osf.md)
- [PEFT Documentation](https://huggingface.co/docs/peft)
- [Original Paper](https://arxiv.org/abs/2504.07097)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- [Original Paper](https://arxiv.org/abs/2504.07097)
- [Original Paper](https://huggingface.co/papers/2504.07097)

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the culprit, commented the solution above.

I ran the example with Llama-3.2 1B and got these results:

Final full fine-tuning model saved to ./osf_continual_learning_outputs/full_final

================================================================================
RESULTS COMPARISON: OSF vs Full Fine-tuning
================================================================================

--------------------------------------------------------------------------------
DETAILED RESULTS (Accuracy %)
--------------------------------------------------------------------------------
Task            After Task      OSF Acc %       Full FT Acc %   Difference     
--------------------------------------------------------------------------------
ScienceQA       ScienceQA       41.00           45.00                     -4.00
ScienceQA       NumGLUE         44.50           76.50                    -32.00
ScienceQA       FOMC            51.50           81.50                    -30.00
NumGLUE         NumGLUE         22.00           48.50                    -26.50
NumGLUE         FOMC            17.00           50.50                    -33.50
FOMC            FOMC            28.77           28.77                     +0.00

================================================================================
SUMMARY METRICS
================================================================================

1. Average Accuracy Across All 3 Tasks (After Final Task):
   OSF:     32.42%
   Full FT: 53.59%
   Difference: -21.17% (Full FT better)

2. Average Forgetting (Task 1 & 2):
   Forgetting = Final Accuracy - Initial Accuracy (negative is worse)

   ScienceQA:
     OSF:     +10.50% (initial: 41.00% → final: 51.50%)
     Full FT: +36.50% (initial: 45.00% → final: 81.50%)
     Difference: -26.00% (Full FT better)

   NumGLUE:
     OSF:     -5.00% (initial: 22.00% → final: 17.00%)
     Full FT: +2.00% (initial: 48.50% → final: 50.50%)
     Difference: -7.00% (Full FT better)

   Average Forgetting:
     OSF:     +2.75%
     Full FT: +19.25%
     Difference: -16.50% (Full FT better)

I'm not sure if that's expected, the effective rank is probably smaller since there's probably a difference in hidden dimensions.

When I run the same experiment with --learning_rate=5e-5 I get the following:

================================================================================
RESULTS COMPARISON: OSF vs Full Fine-tuning
================================================================================

--------------------------------------------------------------------------------
DETAILED RESULTS (Accuracy %)
--------------------------------------------------------------------------------
Task            After Task      OSF Acc %       Full FT Acc %   Difference     
--------------------------------------------------------------------------------
ScienceQA       ScienceQA       100.00          39.50                    +60.50  (OSF better)
ScienceQA       NumGLUE         99.50           100.00                    -0.50
ScienceQA       FOMC            100.00          39.50                    +60.50  (OSF better)
NumGLUE         NumGLUE         54.50           17.00                    +37.50  (OSF better)
NumGLUE         FOMC            53.50           55.00                     -1.50
FOMC            FOMC            28.77           28.77                     +0.00

================================================================================
SUMMARY METRICS
================================================================================

1. Average Accuracy Across All 3 Tasks (After Final Task):
   OSF:     60.76%
   Full FT: 41.09%
   Difference: +19.67% (OSF better)

2. Average Forgetting (Task 1 & 2):
   Forgetting = Final Accuracy - Initial Accuracy (negative is worse)

   ScienceQA:
     OSF:     +0.00% (initial: 100.00% → final: 100.00%)
     Full FT: +0.00% (initial: 39.50% → final: 39.50%)
     Difference: +0.00% (Full FT better)

   NumGLUE:
     OSF:     -1.00% (initial: 54.50% → final: 53.50%)
     Full FT: +38.00% (initial: 17.00% → final: 55.00%)
     Difference: -39.00% (Full FT better)

   Average Forgetting:
     OSF:     -0.50%
     Full FT: +19.00%
     Difference: -19.50% (Full FT better)

Is this expected? This looks a bit off.

trainer = Trainer(
model=model,
data_collator=data_collator,
eval_dataset=eval_dataset,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eval_dataset=eval_dataset,
eval_dataset=eval_dataset,
args=TrainingArguments(
label_names=["labels"],
),

We need this to get an eval loss for PEFT models. See also: #1120 (comment)

tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
model_name, torch_dtype=torch.bfloat16, device_map="auto",

this isn't a requirement or is it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants