Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TP + DP training error #2394

Open
2 of 4 tasks
iMountTai opened this issue Feb 24, 2025 · 6 comments
Open
2 of 4 tasks

TP + DP training error #2394

iMountTai opened this issue Feb 24, 2025 · 6 comments

Comments

@iMountTai
Copy link

iMountTai commented Feb 24, 2025

System Info

peft: 0.14.1.dev0
transformers: 4.50.dev0
accelerate: 1.4.0.dev0
python: 3.11
linux

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

After adding the LoRA module to the model, an error occurred:
NotImplementederror: ColwiseParallel currently only support nn.linear and nn.embedding

Expected behavior

lora module training with TP

@iMountTai
Copy link
Author

iMountTai commented Feb 24, 2025

transformers 36296

@BenjaminBossan
Copy link
Member

  • What is your base model and layers are you targeting?
  • Can you show the full stack trace?
  • Ideally, you can share your training code.

@iMountTai
Copy link
Author

Add the following code after line448

    from peft import LoraConfig
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.0,
        target_modules=["q_proj", "k_proj"]
        )
    model.load_adapter(lora_config)
    model.enable_adapters()

then

torchrun --nnodes 1 --nproc_per_node 2 --master_port 27654 run_clm.py \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--do_train \
--do_eval \
--tp_size 2 \
--max_grad_norm -1 \
--output_dir /tmp/test-clm

@BenjaminBossan

@BenjaminBossan
Copy link
Member

Thanks for the additional info, but I can't reproduce your error. Instead, I get:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/name/work/forks/peft/notebooks/run_clm.py", line 671, in <module>
[rank0]:     main()
[rank0]:   File "/home/name/work/forks/peft/notebooks/run_clm.py", line 257, in main
[rank0]:     model_args, data_args, training_args = parser.parse_args_into_dataclasses()
[rank0]:                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/name/work/forks/transformers/src/transformers/hf_argparser.py", line 348, in parse_args_into_dataclasses
[rank0]:     raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
[rank0]: ValueError: Some specified arguments are not used by the HfArgumentParser: ['--tp_size', '2']
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/name/work/forks/peft/notebooks/run_clm.py", line 671, in <module>
[rank1]:     main()
[rank1]:   File "/home/name/work/forks/peft/notebooks/run_clm.py", line 257, in main
[rank1]:     model_args, data_args, training_args = parser.parse_args_into_dataclasses()
[rank1]:                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/name/work/forks/transformers/src/transformers/hf_argparser.py", line 348, in parse_args_into_dataclasses
[rank1]:     raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
[rank1]: ValueError: Some specified arguments are not used by the HfArgumentParser: ['--tp_size', '2']

Note that .devX versions are not fixed versions but may change over time. Ideally, you could use a release version of PEFT, accelerate, and transformers. Also, I noticed that your transformers version is quite old by now, could you please upgrade it to the latest version?

Moreover, you haven't indicated your PyTorch version and still haven't shared your full stack trace.

@iMountTai
Copy link
Author

iMountTai commented Feb 25, 2025

Sorry, please install the latest version of transformers directly by pip install git+https://github.com/huggingface/transformers.The same is true of others. My version information is wrong.

@BenjaminBossan
Copy link
Member

Okay, I got it working the the latest transformers install, I also had to change this line from your snippet:

-   model.load_adapter(lora_config)
+   model.add_adapter(lora_config)

After that, I could reproduce the error.

Digging into this, I saw that transformers creates a tp_plan for the model, I'm not exactly sure how it's determined:

{
'layers.*.self_attn.q_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303a10>, 
'layers.*.self_attn.k_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303ad0>, 
'layers.*.self_attn.v_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303bd0>, 
'layers.*.self_attn.o_proj': <torch.distributed.tensor.parallel.style.RowwiseParallel object at 0x71b993303cd0>, 
'layers.*.mlp.gate_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303d50>, 
'layers.*.mlp.up_proj': <torch.distributed.tensor.parallel.style.ColwiseParallel object at 0x71b993303e50>, 
'layers.*.mlp.down_proj': <torch.distributed.tensor.parallel.style.RowwiseParallel object at 0x71b993303f50>
}

AFAICT, this plan indicates which layers to target for tensor parallelism and it is determined on the base model. Therefore, it targets layers such as "layers.*.self_attn.q_proj". However, when we apply LoRA, the q_proj layer will be replaced by a lora.Linear layer from PEFT, which is probably why we get the error that the type is not supported. The original nn.Linear layer is wrapped by the lora.Linear layer and would be accessible via q_proj.base_layer.

I don't have a solution for this. Some ideas that come to mind:

  1. Being able for the user to pass the tp_plan manually. The user would need to figure out the correct layer names as explained above.
  2. The tp_plan creation would need to happen at a delayed stage, after add_adapter is called, and then automatically determine that it should target the base_layer.
  3. Instead of delaying the creation of tp_plan, perhaps it's sufficient to re-create after add_adapter is called.

I'm not very knowledgeable about TP and the transformers integration. Maybe @kwen2501 or @Cyrilvallez can help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants