Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MQA Implementation for 2B models #114

Merged
merged 6 commits into from
Mar 22, 2024
Merged

Conversation

ufownl
Copy link
Contributor

@ufownl ufownl commented Mar 20, 2024

This PR implements "Multi-Query Attention" for the 2B models and modifies vocabulary size to be the same as gemma_pytorch (mentioned in #103). It works fine with weights converted from gemma_pytorch but will lead to the original gemma.cpp weights are unusable.

It needs more testing, and I'll use it to test the fine-tuned weights.

Copy link
Collaborator

@austinvhuang austinvhuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much, MQA is one of the most important low-hanging fruit to implement right now! Looks pretty good overall, have a look at the comment about avoiding branching.

Tagging @pculliton to check the model exporting + vocab size change and @jan-wassenberg on any perf suggestions.

@ufownl ufownl changed the base branch from experimental to dev March 20, 2024 15:51
Copy link
Member

@jan-wassenberg jan-wassenberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you :) Some small suggestions:

Copy link
Collaborator

@austinvhuang austinvhuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, if the performance looks good/better (I'm curious how much) and generation looks correct + @jan-wassenberg LGTMs can probably move forward with merging to dev.

@ufownl
Copy link
Contributor Author

ufownl commented Mar 21, 2024

I tested the weights converted from gemma_pytorch (2b-it and 7b-it) and the generation looks fine.

Copy link
Member

@jan-wassenberg jan-wassenberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice use of lambdas! Thanks for making the change.

@austinvhuang austinvhuang added the copybara-import Trigger Copybara for merging pull requests label Mar 22, 2024
@copybara-service copybara-service bot merged commit fcf5c1a into google:dev Mar 22, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
copybara-import Trigger Copybara for merging pull requests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants