-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add OSF continual learning example #2897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add OSF continual learning example #2897
Conversation
|
@githubnemo I have added the continual learning example as requested. Could you please review this PR? The example demonstrates OSF on 3 sequential tasks (ScienceQA, NumGLUE, FOMC) with progressive rank allocation and compares against full fine-tuning baseline. |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks, this looks very nice!
Just a quick review for now, tried to run the code and got this exception:
$ python osf_continual_learning.py --model_name meta-llama/Llama-3.2-1B-Instruct --run_baseline
================================================================================
TRAINING WITH OSF (Orthogonal Subspace Fine-tuning)
================================================================================
Loading datasets...
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 9871.85 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 23863.81 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 17309.66 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 10534.10 examples/s]
================================================================================
TASK 1: ScienceQA
Effective Rank: 0.3 (preserving 30%)
================================================================================
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
Training on ScienceQA...
{'loss': 6.7117, 'grad_norm': 220.0, 'learning_rate': 4.296875e-06, 'epoch': 0.31}
{'loss': 1.8609, 'grad_norm': 136.0, 'learning_rate': 3.5156250000000003e-06, 'epoch': 0.62}
{'loss': 1.2447, 'grad_norm': 131.0, 'learning_rate': 2.7343750000000004e-06, 'epoch': 0.94}
{'loss': 1.1187, 'grad_norm': 118.5, 'learning_rate': 1.953125e-06, 'epoch': 1.25}
{'loss': 1.0351, 'grad_norm': 117.0, 'learning_rate': 1.1718750000000001e-06, 'epoch': 1.56}
{'loss': 1.0219, 'grad_norm': 117.0, 'learning_rate': 3.90625e-07, 'epoch': 1.88}
{'train_runtime': 22.6765, 'train_samples_per_second': 88.197, 'train_steps_per_second': 2.822, 'train_loss': 2.0924081951379776, 'epoch': 2.0}
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:22<00:00, 2.82it/s]
Evaluating on all tasks after training on ScienceQA:
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:01<00:00, 20.69it/s]
Traceback (most recent call last):
File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 694, in <module>
main()
File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 627, in main
osf_history = train_with_osf(
^^^^^^^^^^^^^^^
File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 364, in train_with_osf
loss, accuracy = evaluate_model(
^^^^^^^^^^^^^^^
File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 199, in evaluate_model
loss = results["eval_loss"]
~~~~~~~^^^^^^^^^^^^^
KeyError: 'eval_loss'
Maybe that's on my side, I'm investigating.
|
|
||
| - [OSF Documentation](../../docs/source/package_reference/osf.md) | ||
| - [PEFT Documentation](https://huggingface.co/docs/peft) | ||
| - [Original Paper](https://arxiv.org/abs/2504.07097) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| - [Original Paper](https://arxiv.org/abs/2504.07097) | |
| - [Original Paper](https://huggingface.co/papers/2504.07097) |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the culprit, commented the solution above.
I ran the example with Llama-3.2 1B and got these results:
Final full fine-tuning model saved to ./osf_continual_learning_outputs/full_final
================================================================================
RESULTS COMPARISON: OSF vs Full Fine-tuning
================================================================================
--------------------------------------------------------------------------------
DETAILED RESULTS (Accuracy %)
--------------------------------------------------------------------------------
Task After Task OSF Acc % Full FT Acc % Difference
--------------------------------------------------------------------------------
ScienceQA ScienceQA 41.00 45.00 -4.00
ScienceQA NumGLUE 44.50 76.50 -32.00
ScienceQA FOMC 51.50 81.50 -30.00
NumGLUE NumGLUE 22.00 48.50 -26.50
NumGLUE FOMC 17.00 50.50 -33.50
FOMC FOMC 28.77 28.77 +0.00
================================================================================
SUMMARY METRICS
================================================================================
1. Average Accuracy Across All 3 Tasks (After Final Task):
OSF: 32.42%
Full FT: 53.59%
Difference: -21.17% (Full FT better)
2. Average Forgetting (Task 1 & 2):
Forgetting = Final Accuracy - Initial Accuracy (negative is worse)
ScienceQA:
OSF: +10.50% (initial: 41.00% → final: 51.50%)
Full FT: +36.50% (initial: 45.00% → final: 81.50%)
Difference: -26.00% (Full FT better)
NumGLUE:
OSF: -5.00% (initial: 22.00% → final: 17.00%)
Full FT: +2.00% (initial: 48.50% → final: 50.50%)
Difference: -7.00% (Full FT better)
Average Forgetting:
OSF: +2.75%
Full FT: +19.25%
Difference: -16.50% (Full FT better)
I'm not sure if that's expected, the effective rank is probably smaller since there's probably a difference in hidden dimensions.
When I run the same experiment with --learning_rate=5e-5 I get the following:
================================================================================
RESULTS COMPARISON: OSF vs Full Fine-tuning
================================================================================
--------------------------------------------------------------------------------
DETAILED RESULTS (Accuracy %)
--------------------------------------------------------------------------------
Task After Task OSF Acc % Full FT Acc % Difference
--------------------------------------------------------------------------------
ScienceQA ScienceQA 100.00 39.50 +60.50 (OSF better)
ScienceQA NumGLUE 99.50 100.00 -0.50
ScienceQA FOMC 100.00 39.50 +60.50 (OSF better)
NumGLUE NumGLUE 54.50 17.00 +37.50 (OSF better)
NumGLUE FOMC 53.50 55.00 -1.50
FOMC FOMC 28.77 28.77 +0.00
================================================================================
SUMMARY METRICS
================================================================================
1. Average Accuracy Across All 3 Tasks (After Final Task):
OSF: 60.76%
Full FT: 41.09%
Difference: +19.67% (OSF better)
2. Average Forgetting (Task 1 & 2):
Forgetting = Final Accuracy - Initial Accuracy (negative is worse)
ScienceQA:
OSF: +0.00% (initial: 100.00% → final: 100.00%)
Full FT: +0.00% (initial: 39.50% → final: 39.50%)
Difference: +0.00% (Full FT better)
NumGLUE:
OSF: -1.00% (initial: 54.50% → final: 53.50%)
Full FT: +38.00% (initial: 17.00% → final: 55.00%)
Difference: -39.00% (Full FT better)
Average Forgetting:
OSF: -0.50%
Full FT: +19.00%
Difference: -19.50% (Full FT better)
Is this expected? This looks a bit off.
| trainer = Trainer( | ||
| model=model, | ||
| data_collator=data_collator, | ||
| eval_dataset=eval_dataset, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| eval_dataset=eval_dataset, | |
| eval_dataset=eval_dataset, | |
| args=TrainingArguments( | |
| label_names=["labels"], | |
| ), |
We need this to get an eval loss for PEFT models. See also: #1120 (comment)
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| base_model = AutoModelForCausalLM.from_pretrained( | ||
| model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True | |
| model_name, torch_dtype=torch.bfloat16, device_map="auto", |
this isn't a requirement or is it?
Summary
Adds runnable continual learning example for OSF method as requested in PR #2685 .
Changes
examples/orthogonal_subspace_learning/Results (2 epochs per task)
Files
osf_continual_learning.py- Main example script with OSF and baseline trainingutils.py- Dataset loading and formatting utilities for 3 tasksREADME.md- Comprehensive documentation with usage examplesImplementation Details
Addresses feedback from #2685