Skip to content

Commit

Permalink
Turn debug flag off
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Oct 17, 2024
1 parent af13eef commit bd59528
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Iterator
import torch
import syncode.common as common
from transformers import LogitsProcessor, PreTrainedTokenizer
Expand All @@ -8,6 +7,8 @@
from syncode.dfa_mask_store import DFAMaskStore
from syncode.parsers.grammars import Grammar

# Set to True for debugging
DEBUG = False

class SyncodeLogitsProcessor(LogitsProcessor):
"""
Expand Down Expand Up @@ -135,7 +136,6 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# start_from is used for choosing where the parsing should start
debug = True
partial_codes = self._get_partial_codes(input_ids)

for idx, partial_code in enumerate(partial_codes):
Expand All @@ -152,7 +152,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

accept_mask = self.dfa_mask_store.get_accept_mask(r, logger=self.logger)

if debug:
if DEBUG:
self._log_current_status(partial_code, r)
greedy_token = self.tokenizer.decode(scores[idx].argmax(dim=-1))

Expand All @@ -167,7 +167,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
self._log_current_status(partial_code, r)

# For debugging - remove later
if debug: self._debug_greedy(scores, idx, partial_code, r, greedy_token)
if DEBUG: self._debug_greedy(scores, idx, partial_code, r, greedy_token)

return scores

Expand Down

0 comments on commit bd59528

Please sign in to comment.