-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TP + DP training error #2394
Comments
|
Add the following code after line448
then
|
Thanks for the additional info, but I can't reproduce your error. Instead, I get:
Note that Moreover, you haven't indicated your PyTorch version and still haven't shared your full stack trace. |
Sorry, please install the latest version of transformers directly by |
Okay, I got it working the the latest transformers install, I also had to change this line from your snippet: - model.load_adapter(lora_config)
+ model.add_adapter(lora_config) After that, I could reproduce the error. Digging into this, I saw that transformers creates a {
'layers.*.self_attn.q_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303a10>,
'layers.*.self_attn.k_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303ad0>,
'layers.*.self_attn.v_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303bd0>,
'layers.*.self_attn.o_proj': <torch.distributed.tensor.parallel.style.RowwiseParallel object at 0x71b993303cd0>,
'layers.*.mlp.gate_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303d50>,
'layers.*.mlp.up_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303e50>,
'layers.*.mlp.down_proj': <torch.distributed.tensor.parallel.style.RowwiseParallel object at 0x71b993303f50>
} AFAICT, this plan indicates which layers to target for tensor parallelism and it is determined on the base model. Therefore, it targets layers such as I don't have a solution for this. Some ideas that come to mind:
I'm not very knowledgeable about TP and the transformers integration. Maybe @kwen2501 or @Cyrilvallez can help. |
System Info
peft: 0.14.1.dev0
transformers: 4.50.dev0
accelerate: 1.4.0.dev0
python: 3.11
linux
Who can help?
No response
Information
Tasks
examples
folderReproduction
After adding the LoRA module to the model, an error occurred:
NotImplementederror: ColwiseParallel currently only support nn.linear and nn.embedding
Expected behavior
lora module training with TP
The text was updated successfully, but these errors were encountered: