Skip to content

Commit

Permalink
rename: padding_strategy -> padding_side & support specifying padding
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Aug 23, 2024
1 parent a9721ca commit 4224d13
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ def check_llm(model_name_or_path: str, llm_regex_patterns: List[str] = None) ->
def get_pooling(outputs: torch.Tensor,
inputs: Dict,
pooling_strategy: str,
padding_strategy: str) -> torch.Tensor:
padding_side: str) -> torch.Tensor:
""" Pooling the model outputs.
:param outputs: torch.Tensor. Model outputs (without pooling)
:param inputs: Dict. Model inputs
:param pooling_strategy: str. Pooling strategy ['cls', 'cls_avg', 'cls_max', 'last', 'avg', 'max', 'all', index]
:param padding_strategy: str. Padding strategy of tokenizers (`left` or `right`).
:param padding_side: str. Padding strategy of tokenizers (`left` or `right`).
It can be obtained by `tokenizer.padding_side`.
"""
if pooling_strategy == 'cls':
Expand All @@ -269,7 +269,7 @@ def get_pooling(outputs: torch.Tensor,
outputs = (outputs[:, 0] + maximum) / 2.0
elif pooling_strategy == 'last':
batch_size = inputs['input_ids'].shape[0]
sequence_lengths = -1 if padding_strategy == 'left' else inputs["attention_mask"].sum(dim=1) - 1
sequence_lengths = -1 if padding_side == 'left' else inputs["attention_mask"].sum(dim=1) - 1
outputs = outputs[torch.arange(batch_size, device=outputs.device), sequence_lengths]
elif pooling_strategy == 'avg':
outputs = torch.sum(
Expand Down Expand Up @@ -691,16 +691,16 @@ class Pooler:
:param model: PreTrainedModel
:param pooling_strategy: Optional[str]. Currently support [`cls`, `last`, `avg`, `cls_avg`, `max`]. Default None.
:param padding_strategy: Optional[str]. `left` or `right`. Default None.
:param padding_side: Optional[str]. `left` or `right`. Default None.
:param is_llm: bool. Default False
"""
def __init__(self,
model: PreTrainedModel,
pooling_strategy: Optional[Union[int, str]] = None,
padding_strategy: Optional[str] = None):
padding_side: Optional[str] = None):
self.model = model
self.pooling_strategy = pooling_strategy
self.padding_strategy = padding_strategy
self.padding_side = padding_side

def __call__(self,
inputs: Dict,
Expand Down Expand Up @@ -728,7 +728,7 @@ def __call__(self,
outputs = all_layer_outputs[layer_index]
outputs = get_pooling(outputs, inputs,
pooling_strategy or self.pooling_strategy,
padding_strategy=self.padding_strategy)
padding_side=self.padding_side)
n_dim = len(outputs.shape)
if embedding_start is not None:
if n_dim == 2:
Expand Down Expand Up @@ -787,7 +787,7 @@ def __init__(self,
self.teacher_pooler = Pooler(
teacher_backbone,
pooling_strategy=self.teacher_pooling_strategy,
padding_strategy=self.pooler.padding_strategy)
padding_side=self.pooler.padding_side)
logger.info(f'Train with teacher={teacher_name_or_path}')

def compute_distillation_loss(self,
Expand Down Expand Up @@ -840,7 +840,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
all_outputs = all_layer_outputs[-1]
outputs = get_pooling(all_outputs, inputs,
self.pooler.pooling_strategy,
self.pooler.padding_strategy)
self.pooler.padding_side)
loss = self.loss_fct(labels, outputs)
if self.teacher_name_or_path is not None:
with torch.no_grad():
Expand Down Expand Up @@ -910,7 +910,7 @@ def compute_student_loss(self,
all_layer_outputs: torch.Tensor,
labels: torch.Tensor,
pooling_strategy: str,
padding_strategy: str) -> torch.Tensor:
padding_side: str) -> torch.Tensor:
loss = 0.
compression_loss = 0.
for i in range(self.n_layers - 1):
Expand All @@ -919,7 +919,7 @@ def compute_student_loss(self,
student_outputs = get_pooling(all_student_outputs,
inputs,
pooling_strategy,
padding_strategy)
padding_side)

slimmed_outputs = student_outputs[:, :self.ese_compression_size]
loss += self.loss_fct(labels, slimmed_outputs) / division
Expand Down Expand Up @@ -951,7 +951,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
all_teacher_outputs = all_layer_outputs[-1]
teacher_outputs = get_pooling(all_teacher_outputs, inputs,
self.pooler.pooling_strategy,
self.pooler.padding_strategy)
self.pooler.padding_side)

loss = self.loss_fct(labels, teacher_outputs)

Expand All @@ -970,7 +970,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
all_layer_outputs,
labels,
self.pooler.pooling_strategy,
self.pooler.padding_strategy,
self.pooler.padding_side,
)

# alignment loss
Expand Down Expand Up @@ -1301,7 +1301,7 @@ def __init__(self,
self.pooler = Pooler(
self.backbone,
pooling_strategy=self.pooling_strategy,
padding_strategy=self.tokenizer.padding_side)
padding_side=self.tokenizer.padding_side)

self.__cfg = {
'model_name_or_path': model_name_or_path,
Expand Down Expand Up @@ -1432,7 +1432,8 @@ def fit(self,
push_to_hub: bool = False,
hub_model_id: Optional[str] = None,
hub_private_repo: bool = True,
coword_random_mask_rate: float = 0.):
coword_random_mask_rate: float = 0.,
padding: str = 'longest'):
"""
Fit using AnglE.
Expand All @@ -1459,7 +1460,8 @@ def fit(self,
:param push_to_hub: bool, whether push to hub.
:param hub_model_id: Optional[str], hub model id.
:param hub_private_repo: bool, whether push to private repo.
:param coword_random_mask_rate: float, random mask common token rate. Default 0..
:param coword_random_mask_rate: float, random mask common token rate. Default 0.
:param padding: str, padding strategy of tokenizer. Default 'longest'.
""" # NOQA
if output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -1541,6 +1543,7 @@ def fit(self,
callbacks=callbacks,
data_collator=AngleDataCollator(
self.tokenizer,
padding=padding,
return_tensors="pt",
max_length=self.max_length,
filter_duplicate=filter_duplicate,
Expand Down Expand Up @@ -1592,7 +1595,8 @@ def encode(self,
embedding_size: Optional[int] = None,
device: Optional[Any] = None,
prompt: Optional[str] = None,
normalize_embedding: bool = False):
normalize_embedding: bool = False,
padding: str = 'longest'):
"""
encode texts.
Expand All @@ -1605,6 +1609,7 @@ def encode(self,
:param device: Optional[Any]. Default None.
:param prompt: Optional[str]. Default None.
:param normalize_embedding: bool. Default False.
:param padding: str. Padding strategy of tokenizer. Default 'longest'.
"""
self.backbone.eval()

Expand All @@ -1628,11 +1633,11 @@ def encode(self,
max_length=max_length or self.max_length,
truncation=True)
tok['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in tok['input_ids']]
tok = self.tokenizer.pad(tok, padding=True, return_attention_mask=True, return_tensors='pt')
tok = self.tokenizer.pad(tok, padding=padding, return_attention_mask=True, return_tensors='pt')
else:
tok = self.tokenizer(
inputs,
padding='longest',
padding=padding,
max_length=max_length or self.max_length,
truncation=True,
return_tensors='pt')
Expand Down

0 comments on commit 4224d13

Please sign in to comment.