Skip to content

Commit

Permalink
Use fault tolerant symlink during training to handle temporary file s…
Browse files Browse the repository at this point in the history
…ystem inconsistencies. (#1093)
  • Loading branch information
mjdenkowski authored Aug 10, 2023
1 parent 4c30942 commit 8753d95
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 5 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.35]

### Fixed

- Use fault tolerant symlink wrapper function during training to handle cases of temporary inconsistency in distributed file systems.
- Update DeepSpeed requirement file to specify version (`deepspeed==0.6.5`).

## [3.1.34]

### Fixed
Expand All @@ -19,7 +26,7 @@ Each version section may have subsections for: _Added_, _Changed_, _Removed_, _D

## [3.1.33]

### Fixed
### Fixed
- Two small fixes to SampleK. Before the device was not set correctly leading to issues when running sampling on GPUs. Furthermore, SampleK did not return the top-k values correctly.

## [3.1.32]
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.deepspeed.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
deepspeed
deepspeed==0.6.5
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.34'
__version__ = '3.1.35'
4 changes: 2 additions & 2 deletions sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _update_best_params(self):
actual_best_params_fname = C.PARAMS_NAME % self.state.best_checkpoint
if os.path.lexists(self.best_params_fname):
os.remove(self.best_params_fname)
os.symlink(actual_best_params_fname, self.best_params_fname)
utils.fault_tolerant_symlink(actual_best_params_fname, self.best_params_fname)
logger.info("'%s' now points to '%s'", self.best_params_fname, actual_best_params_fname)

def _save_params(self, use_checkpoint: bool = False):
Expand Down Expand Up @@ -709,7 +709,7 @@ def _save_training_state(self, train_iter: data_io.BaseParallelSampleIter):
params_file = os.path.join(training_state_dirname, C.TRAINING_STATE_PARAMS_NAME)
if os.path.exists(params_file):
os.unlink(params_file)
os.symlink(os.path.join("..", params_base_fname), params_file)
utils.fault_tolerant_symlink(os.path.join("..", params_base_fname), params_file)

# (2) Optimizer state
opt_state_fname = os.path.join(training_state_dirname, C.OPT_STATE_LAST)
Expand Down
27 changes: 27 additions & 0 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pprint
import random
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from itertools import starmap
Expand Down Expand Up @@ -817,3 +818,29 @@ def init_device(args: argparse.Namespace) -> pt.device:
logger.info('CUDA: allow tf32 (float32 but with 10 bits precision)')

return device


def fault_tolerant_symlink(src: str, dst: str, max_retries: int = 6):
"""
Attempt to create a symbolic link from source to destination. If a
FileExistsError is raised, assume a distributed filesystem is currently
synchronizing, wait, and retry. If the maximum number of retries is
exceeded, raise an error.
:param src: Source file.
:param dst: Destination file.
:param max_retries: Maximum number of retries.
"""
retries = 0
while True:
try:
os.symlink(src, dst)
return
except FileExistsError as error:
if retries >= max_retries:
break
wait_time = 2**retries
logger.warn(f'Error detected when calling symlink: {error}. Retrying in {wait_time} seconds.')
time.sleep(wait_time)
retries += 1
raise OSError(f'Max retries exceeded when attempting to create symlink: \'{src}\' -> \'{dst}\'')
18 changes: 18 additions & 0 deletions test/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import re
from tempfile import TemporaryDirectory
import unittest

import numpy as np
import pytest
Expand Down Expand Up @@ -428,3 +429,20 @@ def test_rerank_hypotheses_isometric(hypothesis, hypothesis_score, source, metri
def test_update_dict_with_prefix_kv(dest, prefix_kv, expected):
utils.update_dict_with_prefix_kv(dest, prefix_kv)
assert dest == expected


@unittest.mock.patch('time.sleep')
def test_fault_tolerant_symlink(mock_sleep):
with TemporaryDirectory() as temp:
src_fname = os.path.join(temp, 'src')
dst_fname = os.path.join(temp, 'dst')
_touch_file(src_fname, compressed=False, empty=False)
# First symlink succeeds
utils.fault_tolerant_symlink(src_fname, dst_fname)
# Second symlink fails after retries (file exists)
with pytest.raises(OSError):
utils.fault_tolerant_symlink(src_fname, dst_fname)
assert mock_sleep.called
# Same data read from source and destination
with utils.smart_open(src_fname) as src_in, utils.smart_open(dst_fname) as dst_in:
assert src_in.readlines() == dst_in.readlines()

0 comments on commit 8753d95

Please sign in to comment.