-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[bugfix] fix megatron seq_cls lora bridge #7054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bugfix] fix megatron seq_cls lora bridge #7054
Conversation
Summary of ChangesHello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a bug in the Megatron framework's LoRA integration, specifically for sequence classification tasks. The fix ensures that when model weights are saved in PEFT format, the configuration correctly reflects the 'SEQ_CLS' task type, which is essential for accurate model loading and subsequent operations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug in saving LoRA models for sequence classification tasks within the Megatron bridge. The fix correctly sets the task_type in the PEFT configuration, which is essential for the model to be loaded and used correctly for sequence classification. The change is correct and well-targeted. I have one suggestion to improve code maintainability by using an enum from the peft library instead of a hardcoded string.
| from swift.llm import get_multimodal_target_regex | ||
| peft_config = copy(mg_models[0].peft_config[self._adapter_name]) | ||
| if args.task_type == 'seq_cls': | ||
| peft_config.task_type = 'SEQ_CLS' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better maintainability and to avoid potential issues with string literals, it's recommended to use the TaskType enum from the peft library directly. This makes the code more robust to potential changes in the library.
To implement this, you would add from peft import TaskType with the other local imports and then change this line to:
peft_config.task_type = TaskType.SEQ_CLS|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a bugfix for the Megatron sequence classification LoRA bridge by correctly setting the task_type in the PEFT config during export. It also includes several improvements to the model conversion process, such as copying the args.json file for better portability and skipping precision tests when LoRA adapters are loaded to prevent errors.
My review identifies a potential race condition in swift/megatron/convert.py where a file operation is not guarded in a distributed context, which could lead to issues. I've also suggested a minor improvement in swift/megatron/export/export.py to make logging cleaner in a distributed environment. Overall, the changes are beneficial and address the intended issue.
swift/megatron/convert.py
Outdated
| if os.path.exists(args_path): | ||
| shutil.copy(args_path, os.path.join(args.output_dir, 'args.json')) | ||
| else: | ||
| args.save_args(args.output_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a distributed environment, all ranks will execute this block. shutil.copy is not safe to be called by multiple processes on the same file, which can lead to a race condition. args.save_args has an internal check for the master rank, but shutil.copy does not. You should guard this file operation with a rank check (e.g., if dist.get_rank() == 0:) to ensure it's only executed by a single process.
| if os.path.exists(args_path): | |
| shutil.copy(args_path, os.path.join(args.output_dir, 'args.json')) | |
| else: | |
| args.save_args(args.output_dir) | |
| if dist.get_rank() == 0: | |
| if os.path.exists(args_path): | |
| shutil.copy(args_path, os.path.join(args.output_dir, 'args.json')) | |
| else: | |
| args.save_args(args.output_dir) |
| shutil.copy(args_path, os.path.join(args.save, 'args.json')) | ||
| else: | ||
| args.save_args(args.save) | ||
| logger.info(f'Successfully saved HF model weights in `{args.save}`.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This log message will be printed by all ranks in a distributed environment, which can lead to cluttered logs. It's better to guard this with a rank check, similar to the log message on line 135 (logger.info_if(..., cond=is_last_rank())), to ensure the message is printed only once.
| logger.info(f'Successfully saved HF model weights in `{args.save}`.') | |
| logger.info_if(f'Successfully saved HF model weights in `{args.save}`.', cond=is_last_rank()) |
No description provided.