-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add mpnet model family #405
base: main
Are you sure you want to change the base?
Conversation
lib/bumblebee/text/mpnet.ex
Outdated
@@ -0,0 +1,458 @@ | |||
defmodule Bumblebee.Text.MPNet do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note this was copied from the Bert implementation. A few adjustments in the
options
but that was about it
test/bumblebee/text/mpnet_test.exs
Outdated
|
||
test ":for_masked_language_modeling" do | ||
assert {:ok, %{model: model, params: params, spec: spec}} = | ||
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MPNetForMaskedLM"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/bumblebee/text/mpnet_test.exs
Outdated
assert_all_close( | ||
outputs.hidden_state[[.., 1..3, 1..3]], | ||
Nx.tensor([ | ||
[[-0.2331, 1.7817, 1.1736], [-1.1001, 1.3922, -0.3391], [0.0408, 0.8677, -0.0779]] | ||
]) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to compare against the reference values from hf/transformers:
from transformers import MPNetModel
import torch
model = MPNetModel.from_pretrained("hf-internal-testing/tiny-random-MPNetModel")
inputs = {
"input_ids": torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
print(outputs.last_hidden_state[:, 1:4, 1:4])
#=> torch.Size([1, 10, 64])
#=> tensor([[[ 0.0033, -0.2547, 0.4954],
#=> [-1.5348, -1.5433, 0.4846],
#=> [ 0.7795, -0.3995, -0.9499]]], grad_fn=<SliceBackward0>)
I believe there are a few differences between MPNet and BERT, so we need to align the implementation accordingly. In particular, by a quick look some layer names differ, for example key->k
, value->v
, query->q
, so we need to update the layer mapping as well :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a ton! Regarding your last comment, are you looking here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! And the Bert implementation is https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py, which may be helpful for differences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok made another round of improvements! thx for the direction.
lib/bumblebee/text/mpnet.ex
Outdated
@@ -0,0 +1,458 @@ | |||
defmodule Bumblebee.Text.MPNet do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: I think we want the name to be MpNet
to align with our naming conventions. Basically, in acronyms we capitalize only the first letter, as in BERT -> Bert
, RoBERTa
-> Roberta
. And we capitalize on each word, such as ResNet
, ConvNext
. We do this, because the reference names are often arbitrarily capitalized, and it's not ergonomic for library users to know the exact capitalization.
https://arxiv.org/pdf/2004.09297
Huggingface cards
https://huggingface.co/microsoft/mpnet-base
https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1