Skip to content

Commit

Permalink
improve 2dmse; filter duplicate
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Feb 29, 2024
1 parent 05a2c59 commit 347bbb6
Showing 1 changed file with 66 additions and 40 deletions.
106 changes: 66 additions & 40 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import sys
import json
import math
import random
from functools import partial
from typing import Any, Dict, Optional, List, Union, Tuple, Callable
Expand Down Expand Up @@ -581,12 +582,13 @@ class AngleDataCollator:
:param padding: Union[bool, str, PaddingStrategy], padding strategy
:param max_length: Optional[int], max length
:param return_tensors: str
:param filter_duplicate: bool. Whether filter duplicate data
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = 'longest'
max_length: Optional[int] = None
return_tensors: str = "pt"
filter_duplicate: bool = True

def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str, torch.Tensor]:
if return_tensors is None:
Expand All @@ -595,6 +597,7 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str
end_with_eos = features[0]['extra']['end_with_eos']

new_features = []
duplicate_set = set()
for feature in features:
seperate_ids = feature['seperate_ids']
input_ids = feature['input_ids']
Expand All @@ -609,26 +612,41 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str

max_seperate_id = max(seperate_ids)
prev_start_idx = 0
current_features = []
is_duplicate = False
for seperate_id in range(1, max_seperate_id + 1):
start_idx = seperate_ids.index(seperate_id)

new_feature = {}
new_feature['input_ids'] = input_ids[prev_start_idx:start_idx]
new_input_ids = input_ids[prev_start_idx:start_idx]
if tuple(new_input_ids) in duplicate_set:
is_duplicate = True
if self.filter_duplicate:
break
duplicate_set.add(tuple(new_input_ids))
new_feature['input_ids'] = new_input_ids
new_feature['attention_mask'] = attention_mask[prev_start_idx:start_idx]
if has_token_type_ids:
new_feature['token_type_ids'] = token_type_ids[prev_start_idx:start_idx]
new_feature['labels'] = feature['labels']
new_features.append(new_feature)
current_features.append(new_feature)
prev_start_idx = start_idx

# last
new_feature = {}
new_feature['input_ids'] = input_ids[prev_start_idx:]
new_input_ids = input_ids[prev_start_idx:]
if tuple(new_input_ids) in duplicate_set:
is_duplicate = True
duplicate_set.add(tuple(new_input_ids))
new_feature['input_ids'] = new_input_ids
new_feature['attention_mask'] = attention_mask[prev_start_idx:]
if has_token_type_ids:
new_feature['token_type_ids'] = token_type_ids[prev_start_idx:]
new_feature['labels'] = feature['labels']
new_features.append(new_feature)
current_features.append(new_feature)

if self.filter_duplicate and is_duplicate:
continue
new_features += current_features

# remove features
del features
Expand Down Expand Up @@ -685,13 +703,17 @@ def __init__(self,
self.padding_strategy = padding_strategy
self.is_llm = is_llm

def __call__(self, inputs: Dict, layer_index: int = -1, embedding_size: Optional[int] = None) -> torch.Tensor:
def __call__(self, inputs: Dict, layer_index: int = -1, embedding_size: Optional[int] = None,
return_all_layer_outputs: bool = False) -> torch.Tensor:
"""
:param inputs: Dict. Model inputs.
:param layer_index: int. Get embeddings from specific layer.
:param embedding_size: int. Set embedding size for sentence embeddings for 2DMSE models.
"""
outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states[layer_index]
all_layer_outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states
if return_all_layer_outputs:
return all_layer_outputs
outputs = all_layer_outputs[layer_index]
if self.is_llm:
batch_size = inputs['input_ids'].shape[0]
sequence_lengths = -1 if self.padding_strategy == 'left' else inputs["attention_mask"].sum(dim=1) - 1
Expand Down Expand Up @@ -802,46 +824,48 @@ def __init__(self,
self.tdmse_student_lambda = tdmse_student_lambda
self.apply_tdmse_kl = apply_tdmse_kl
self.n_layers = self.pooler.model.config.num_hidden_layers
self.tdmse_hidden_sizes = get_geometric_hidden_sizes(base=8, max_hidden=self.pooler.model.config.hidden_size)
self.hidden_size = self.pooler.model.config.hidden_size
self.tdmse_hidden_sizes = get_geometric_hidden_sizes(base=8, max_hidden=self.hidden_size)
self.kl_loss_fct = nn.KLDivLoss(reduction='batchmean')
logger.info('Train 2DMSE!')
logger.info('Train with 2DMSE!')

def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels", None)
# layer
sample_layer = random.randint(1, self.n_layers - 1)
if self.fixed_teacher_name_or_path is not None:
all_teacher_outputs = self.pooler(inputs, layer_index=-1)
teacher_outputs = get_pooling(all_teacher_outputs, inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)
all_student_outputs = self.pooler(inputs, layer_index=sample_layer)
student_outputs = get_pooling(all_student_outputs, inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)
else:
teacher_outputs = self.pooler(inputs, layer_index=-1)
student_outputs = self.pooler(inputs, layer_index=sample_layer)

kl_outputs = teacher_outputs
pooling_strategy = (self.alignment_pooling_strategy
if self.pooler.pooling_strategy == 'all'
else self.pooler.pooling_strategy)
all_layer_outputs = self.pooler(inputs, layer_index=-1, return_all_layer_outputs=True)
all_teacher_outputs = all_layer_outputs[-1]
teacher_outputs = get_pooling(all_teacher_outputs, inputs,
pooling_strategy,
self.pooler.padding_strategy)
all_student_outputs = all_layer_outputs[sample_layer]
student_outputs = get_pooling(all_student_outputs,
inputs,
pooling_strategy,
self.pooler.padding_strategy)

teacher_kl_outputs = teacher_outputs
if self.fixed_teacher_name_or_path is not None:
with torch.no_grad():
self.fixed_teacher_pooler.model = self.fixed_teacher_pooler.model.to(self.pooler.model.device)
all_fixed_outputs = self.fixed_teacher_pooler(inputs)
kl_outputs = get_pooling(all_fixed_outputs, inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)
teacher_kl_outputs = get_pooling(all_fixed_outputs,
inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)

teacher_loss = self.loss_fct(labels, teacher_outputs)
loss1 = self.tdmse_teacher_lambda * teacher_loss
if self.tdmse_student_lambda > 0:
student_loss = self.loss_fct(labels, student_outputs)
loss1 += self.tdmse_student_lambda * student_loss
loss1 = teacher_loss
student_loss = self.loss_fct(labels, student_outputs)
loss1 += student_loss / sample_layer
if self.apply_tdmse_kl and self.tdmse_student_lambda > 0:
kl_loss = self.kl_loss_fct(
F.log_softmax(student_outputs[:, None, :] / self.tdmse_kl_temperature, dim=-1),
F.softmax(kl_outputs[:, None, :] / self.tdmse_kl_temperature, dim=-1)
) * self.tdmse_kl_temperature**2
F.log_softmax(student_outputs / self.tdmse_kl_temperature, dim=-1),
F.softmax(teacher_kl_outputs / self.tdmse_kl_temperature, dim=-1)
) * self.tdmse_kl_temperature * math.log(2 + sample_layer)
loss1 += kl_loss

# feature
Expand All @@ -850,10 +874,10 @@ def compute_loss(self, model, inputs, return_outputs=False):
slimmed_student_outputs = student_outputs[:, :hidden_size]

slimmed_teacher_loss = self.loss_fct(labels, slimmed_teacher_outputs)
loss2 = self.tdmse_teacher_lambda * slimmed_teacher_loss
if self.tdmse_student_lambda > 0:
slimmed_student_loss = self.loss_fct(labels, slimmed_student_outputs)
loss2 += self.tdmse_student_lambda * slimmed_student_loss
loss2 = slimmed_teacher_loss
slimmed_student_loss = self.loss_fct(labels, slimmed_student_outputs)
loss2 += slimmed_student_loss / sample_layer

loss = loss1 + loss2

if self.fixed_teacher_name_or_path is not None:
Expand Down Expand Up @@ -1334,7 +1358,8 @@ def fit(self,
argument_kwargs: Optional[Dict] = None,
trainer_kwargs: Optional[Dict] = None,
loss_kwargs: Optional[Dict] = None,
apply_tdmse: bool = False):
apply_tdmse: bool = False,
filter_duplicate: bool = True):
"""
Fit using AnglE.
Expand Down Expand Up @@ -1412,7 +1437,7 @@ def fit(self,
),
callbacks=callbacks,
data_collator=AngleDataCollator(
self.tokenizer, return_tensors="pt", max_length=self.max_length
self.tokenizer, return_tensors="pt", max_length=self.max_length, filter_duplicate=filter_duplicate
),
**trainer_kwargs
)
Expand All @@ -1428,6 +1453,7 @@ def evaluate(self, data: Dataset, batch_size: int = 32, threshold: Optional[floa
self.tokenizer,
return_tensors="pt",
max_length=self.max_length,
filter_duplicate=False,
)
y_trues, y_preds = [], []
# for X, y in data.make_iter(random=False):
Expand Down

0 comments on commit 347bbb6

Please sign in to comment.