Skip to content

Commit

Permalink
beam_search fix for running with torch.use_deterministic_algorithms(T…
Browse files Browse the repository at this point in the history
…rue) (#1096)
  • Loading branch information
Jehovan authored Sep 7, 2023
1 parent a59e7a2 commit 2d80b2a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.37]

### Fixed

- Fixed beam_search for running with torch.use_deterministic_algorithms(True)

## [3.1.36]

### Changed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.36'
__version__ = '3.1.37'
2 changes: 1 addition & 1 deletion sockeye/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ def forward(self,
# locations of each batch item when first dimension is (batch * beam)
batch_indices = pt.arange(0, batch_size * self.beam_size, self.beam_size, dtype=pt.int64, device=self.device)
first_step_mask = pt.full((batch_size * self.beam_size, 1), fill_value=np.inf, device=self.device, dtype=self.dtype)
first_step_mask[batch_indices] = 0.0
first_step_mask[batch_indices] = pt.full((batch_size, 1), fill_value=0.0, device=self.device, dtype=self.dtype)
if target_prefix is not None:
first_step_mask = utils.adjust_first_step_masking(target_prefix, first_step_mask)

Expand Down

0 comments on commit 2d80b2a

Please sign in to comment.