Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor inference optimizations #1094

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.36]

### Changed

- Make the `inference_only` mode switchable.
- Simplify inference optimizations by
(1) using `eval()` to disable dropout instead of explicitly setting dropout modules to None;
(2) always using default value `inplace=False` for activation modules.

## [3.1.35]

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions sockeye/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand All @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.35'
__version__ = '3.1.36'
14 changes: 1 addition & 13 deletions sockeye/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -83,15 +83,3 @@ def copy(self, **kwargs):
for name, value in kwargs.items():
object.__setattr__(copy_obj, name, value)
return copy_obj

def disable_dropout(self):
"""
Sets the value of all float-valued attributes in this config (or any of its children) that contain 'dropout'
in their name to 0.0.
"""
for attr, val in self.__dict__.items():
if isinstance(val, Config):
val.disable_dropout()
elif 'dropout' in attr and isinstance(val, float):
logger.debug("Setting %s to 0.0", attr)
setattr(self, attr, 0.0)
24 changes: 17 additions & 7 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -93,6 +93,10 @@ def get_decoder(cls,
def __init__(self):
super().__init__()

@abstractmethod
def set_inference_only(self, inference_only: bool):
raise NotImplementedError()

@abstractmethod
def state_structure(self) -> str:
raise NotImplementedError()
Expand Down Expand Up @@ -147,7 +151,6 @@ def __init__(self,
Decoder.__init__(self)
pt.nn.Module.__init__(self)
self.config = config
self.inference_only = inference_only
self.pos_embedding = layers.PositionalEmbeddings(weight_type=self.config.positional_embedding_type,
num_embed=self.config.model_size,
max_seq_len=self.config.max_seq_len_target,
Expand All @@ -158,7 +161,7 @@ def __init__(self,

self.layers = pt.nn.ModuleList( # using ModuleList because we have additional inputs
transformer.TransformerDecoderBlock(config,
inference_only=self.inference_only,
inference_only=inference_only,
dtype=dtype,
clamp_to_dtype=clamp_to_dtype)
for _ in range(config.num_layers))
Expand All @@ -168,8 +171,16 @@ def __init__(self,
num_hidden=self.config.model_size,
dtype=dtype,
clamp_to_dtype=clamp_to_dtype)
if self.config.dropout_prepost > 0.0:
self.dropout = pt.nn.Dropout(p=self.config.dropout_prepost, inplace=inference_only)
self.dropout = pt.nn.Dropout(p=self.config.dropout_prepost)
self.set_inference_only(inference_only)

def set_inference_only(self, inference_only: bool):
"""
Set inference_only.
"""
self.inference_only = inference_only
for layer in self.layers:
layer.set_inference_only(inference_only)

def state_structure(self) -> str:
"""
Expand Down Expand Up @@ -279,8 +290,7 @@ def forward(self, step_input: pt.Tensor, states: List[pt.Tensor]) -> Tuple[pt.Te
# (length, batch_size, model_size)
target = target.transpose(1, 0)

if self.config.dropout_prepost > 0.0:
target = self.dropout(target)
target = self.dropout(target)

new_autoregr_states = [] # type: List[pt.Tensor]
for layer, layer_autoregr_state, layer_enc_att_kv in zip(self.layers, autoregr_states, enc_att_kv):
Expand Down
6 changes: 3 additions & 3 deletions sockeye/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self,
self.factor_embeds.append(factor_embed)
self.factor_combinations.append(fc.combine)

self.dropout = pt.nn.Dropout(p=self.config.dropout) if self.config.dropout > 0.0 else None
self.dropout = pt.nn.Dropout(p=self.config.dropout)

def forward(self, data: pt.Tensor) -> pt.Tensor:
primary_data = data[:, :, 0]
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self,
pt.nn.Module.__init__(self)
self.config = config

self.dropout = pt.nn.Dropout(p=config.dropout_prepost) if config.dropout_prepost > 0.0 else None
self.dropout = pt.nn.Dropout(p=config.dropout_prepost)

self.pos_embedding = layers.PositionalEmbeddings(weight_type=self.config.positional_embedding_type,
num_embed=self.config.model_size,
Expand Down
17 changes: 14 additions & 3 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -787,6 +787,8 @@ def __init__(self,
if strip_unknown_words:
self.strip_ids.add(self.unk_id)
self.models = models
for model in self.models:
model.eval()

# after models are loaded we ensured that they agree on max_input_length, max_output_length and batch size
# set a common max_output length for all models.
Expand Down Expand Up @@ -943,8 +945,7 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool =
batch = batch + [batch[0]] * rest

translator_inputs = [indexed_translator_input.translator_input for indexed_translator_input in batch]
with pt.inference_mode():
batch_translations = self._translate_np(*self._get_inference_input(translator_inputs))
batch_translations = self._translate_batch(translator_inputs)

# truncate to remove filler translations
if fill_up_batches and rest > 0:
Expand Down Expand Up @@ -988,6 +989,16 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool =

return results

def _translate_batch(self, translator_inputs: List[TranslatorInput]) -> List[Translation]:
"""
Translate a batch of inputs.

:param translator_inputs: List of TranslatorInputs.
:return: List of Translation.
"""
with pt.inference_mode():
return self._translate_np(*self._get_inference_input(translator_inputs))

def _get_inference_input(self,
trans_inputs: List[TranslatorInput]) -> Tuple[pt.Tensor,
pt.Tensor,
Expand Down
35 changes: 27 additions & 8 deletions sockeye/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -26,12 +26,12 @@
logger = logging.getLogger(__name__)


def get_activation(act_type: str, inplace: bool = False) -> pt.nn.Module:
def get_activation(act_type: str) -> pt.nn.Module:
if act_type == C.SWISH1:
return pt.nn.SiLU(inplace=inplace)
return pt.nn.SiLU()
if act_type == C.GELU:
return pt.nn.GELU()
return pt.nn.ReLU(inplace=inplace)
return pt.nn.ReLU()


class LHUC(pt.nn.Module):
Expand Down Expand Up @@ -287,7 +287,7 @@ class DotAttentionCell(pt.nn.Module):

def __init__(self, dropout: float = 0.0, heads: int = 1) -> None:
super().__init__()
self.dropout = pt.nn.Dropout(p=dropout) if dropout > 0.0 else None
self.dropout = pt.nn.Dropout(p=dropout)
self.heads = heads

def forward(self,
Expand Down Expand Up @@ -420,6 +420,13 @@ def get_state_shape(self, batch_size) -> Tuple:
"""
raise NotImplementedError

@abstractmethod
def set_inference_only(self, inference_only: bool):
"""
Set inference_only.
"""
raise NotImplementedError

@abstractmethod
def forward(self, inputs: pt.Tensor, previous_states: pt.Tensor, *args) -> Tuple:
"""
Expand Down Expand Up @@ -461,6 +468,12 @@ def __init__(self,
# Interleaved format is used for inference, non-interleaved format is used for fused MHA in training.
self.kv_interleaved = False

def set_inference_only(self, inference_only: bool):
"""
Set inference_only. Not needed for MultiHeadSelfAttention.
"""
raise NotImplementedError

def separate_kv(self):
""" write kv input projection parameters in non-interleaved format (compatible with F.multi_head_attention) """
assert self.kv_interleaved
Expand Down Expand Up @@ -799,11 +812,9 @@ def __init__(self,
clamp_to_dtype: bool = False,) -> None:
super().__init__()
self.model_size = model_size
self.inference_only = inference_only
self.clamp_to_dtype = clamp_to_dtype

self.cell_state_transform = self._inference_cell_state_transform \
if inference_only else self._training_cell_state_transform
self.set_inference_only(inference_only)

self.forget_gate = pt.nn.Linear(in_features=model_size, out_features=model_size, bias=True, dtype=dtype)
self.forget_gate_act = pt.nn.Sigmoid()
Expand All @@ -812,6 +823,14 @@ def __init__(self,

self.relu = pt.nn.ReLU(inplace=False) # inplace=False because we need to non-activated data as well

def set_inference_only(self, inference_only: bool):
"""
Set inference_only.
"""
self.inference_only = inference_only
xingniu marked this conversation as resolved.
Show resolved Hide resolved
self.cell_state_transform = self._inference_cell_state_transform \
if inference_only else self._training_cell_state_transform

@property
def num_state_tensors(self) -> int:
""" Number of state tensors returned by the layer """
Expand Down
38 changes: 17 additions & 21 deletions sockeye/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -111,7 +111,6 @@ def __init__(self,
super().__init__()
self.config = copy.deepcopy(config)
self.dtype = utils.get_torch_dtype(config.dtype)
self.inference_only = inference_only
self.clamp_to_dtype = clamp_to_dtype
logger.info("%s", self.config)
self.train_decoder_only = train_decoder_only
Expand Down Expand Up @@ -144,10 +143,10 @@ def __init__(self,
vocab_size=self.config.vocab_target_size,
weight=output_weight,
dtype=self.dtype)
if self.inference_only:
# Running this layer scripted with a newly initialized model can
# cause an overflow error.
self.output_layer = pt.jit.script(self.output_layer)
self.output_layer_module_cached = self.output_layer
# Running this layer scripted with a newly initialized model can cause an overflow error.
self.output_layer_script_cached = pt.jit.script(self.output_layer_module_cached)
self.set_inference_only(inference_only)

self.factor_output_layers = pt.nn.ModuleList()
# Optional target factor output layers
Expand Down Expand Up @@ -189,6 +188,14 @@ def __init__(self,

self.knn : Optional[layers.KNN] = None

def set_inference_only(self, inference_only: bool):
"""
Turn inference_only optimization on or off.
"""
self.inference_only = inference_only
self.output_layer = self.output_layer_script_cached if self.inference_only else \
self.output_layer_module_cached
self.decoder.set_inference_only(self.inference_only)

def cast(self, dtype: Union[pt.dtype, str]):
dtype = utils.get_torch_dtype(dtype)
Expand Down Expand Up @@ -417,7 +424,8 @@ def save_parameters(self, fname: str):
# filter their names from the state dictionary to avoid saving redundant
# copies of their parameters. Copies can also cause errors at loadtime
# if the traced modules do not yet exist.
filtered_state_dict = {name: param for (name, param) in self.state_dict().items() if 'traced' not in name}
filtered_state_dict = {name: param for (name, param) in self.state_dict().items()
if 'traced' not in name and 'cached' not in name}
pt.save(filtered_state_dict, fname)
self.apply(layers.separate_kv)
logging.info('Saved params/state_dict to "%s"', fname)
Expand Down Expand Up @@ -445,12 +453,12 @@ def load_parameters(self,
missing, unexpected = self.load_state_dict(state_dict, strict=False)
# Earlier versions of Sockeye may have saved parameters for traced
# modules. These parameters can be safely ignored.
unexpected = [key for key in unexpected if 'traced' not in key]
unexpected = [key for key in unexpected if 'traced' not in key and 'cached' not in key]
# We also ignore cases where traced modules exist and appear to be
# missing parameters. These modules actually use the same parameters as
# their original non-traced versions so there are no separate parameters
# to load.
missing = [key for key in missing if 'traced' not in key]
missing = [key for key in missing if 'traced' not in key and 'cached' not in key]
if not allow_missing:
utils.check_condition(not missing, f"missing keys: {missing}")
if not ignore_extra:
Expand Down Expand Up @@ -706,7 +714,6 @@ def load_model(model_folder: str,
inference_only: bool = False,
train_decoder_only: bool = False,
allow_missing: bool = False,
set_grad_req_null: bool = True,
forward_pass_cache_size: int = 0,
knn_index: Optional[str] = None) -> Tuple[SockeyeModel, List[vocab.Vocab], List[vocab.Vocab]]:
"""
Expand All @@ -721,7 +728,6 @@ def load_model(model_folder: str,
:param train_decoder_only: Training will only update the decoder. Disable
autograd for encoder and embeddings to save memory.
:param allow_missing: Allow missing parameters in the loaded model.
:param set_grad_req_null: Set grad_req to null for model parameters.
:param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
:param knn_index: Optional path to a folder containing a KNN model index.
:return: List of models, source vocabularies, target vocabularies.
Expand All @@ -733,10 +739,6 @@ def load_model(model_folder: str,
utils.check_version(model_version)
model_config = SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))

if inference_only:
logger.info("Disabling dropout layers for performance reasons")
model_config.disable_dropout()

if checkpoint is None:
params_fname = os.path.join(model_folder, C.PARAMS_BEST_NAME)
else:
Expand All @@ -755,9 +757,6 @@ def load_model(model_folder: str,

model.to(device)

if set_grad_req_null:
model.eval()

if dtype is None:
logger.info("Model dtype: %s" % model.dtype)
else:
Expand All @@ -783,7 +782,6 @@ def load_models(device: pt.device,
inference_only: bool = False,
train_decoder_only: bool = False,
allow_missing: bool = False,
set_grad_req_null: bool = True,
forward_pass_cache_size: int = 0,
knn_index: Optional[str] = None) -> Tuple[List[SockeyeModel],
List[vocab.Vocab],
Expand All @@ -800,7 +798,6 @@ def load_models(device: pt.device,
:param train_decoder_only: Training will only update the decoder. Disable
autograd for encoder and embeddings to save memory.
:param allow_missing: Allow missing parameters in the loaded models.
:param set_grad_req_null: Set grad_req to null for model parameters.
:param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
:param knn_index: Optional path to a folder containing a KNN model index.
:return: List of models, source vocabulary, target vocabulary, source factor vocabularies.
Expand All @@ -825,7 +822,6 @@ def load_models(device: pt.device,
inference_only=inference_only,
train_decoder_only=train_decoder_only,
allow_missing=allow_missing,
set_grad_req_null=set_grad_req_null,
forward_pass_cache_size=forward_pass_cache_size,
knn_index=knn_index)
models.append(model)
Expand Down
Loading
Loading