Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in masked index generation? #39

Open
tekinek opened this issue Jan 3, 2024 · 9 comments
Open

Possible bug in masked index generation? #39

tekinek opened this issue Jan 3, 2024 · 9 comments

Comments

@tekinek
Copy link

tekinek commented Jan 3, 2024

masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())

Hi, why the masked_index is extended for 15% of tokens? If I understand correctly, the extention should be placed inside the else statement at line # 80, right?

@jav-ed
Copy link

jav-ed commented Jan 4, 2024

A few days I ago, I was wondering about the very same thing. We would only want the masking to be registered as masked, when the tokens are either masked or modified in other way. The current code is as follows:

phoneme_list = ''.join(phonemes)
masked_index = []
for z in zip(phonemes, input_ids):
    z = list(z)
    
    words.extend([z[1]] * len(z[0]))
    words.append(self.word_separator)
    labels += z[0] + " "

    if np.random.rand() < self.word_mask_prob:
        if np.random.rand() < self.replace_prob:
            if np.random.rand() < (self.phoneme_mask_prob / self.replace_prob): 
                phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))])  # randomized
            else:
                phoneme += z[0]
        else:
            phoneme += self.token_mask * len(z[0]) # masked
            
        masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())
    else:
        phoneme += z[0] 

    phoneme += self.token_separator

From what I think, the original goal probably was:

85% of the time: keep original
12.5% of the time: special phoneme mask
1.5% of the time: random (from the available) phoneme mask
1.5% of the time: were to mask the token but kept the original phoneme

--> 86.5% of the time: keep original
12% of the time: special phoneme mask
1.5% of the time: random (from the available) phoneme mask

However, the code as provided above would also try to mask the 1.5% of the times when the token could be masked, but actually kept as the original phoneme (the second mentioning of the 1.5% in the list above). Even though masking 1.5% incorrectly maybe has neglectable impact on the performance, consider the following correction suggestion:

phoneme_list = ''.join(phonemes)
masked_index = []
for z in zip(phonemes, input_ids):
  z = list(z)
  
  words.extend([z[1]] * len(z[0]))
  words.append(self.word_separator)
  labels += z[0] + " "

  if np.random.rand() < self.word_mask_prob:
      if np.random.rand() < self.replace_prob:
          if np.random.rand() < (self.phoneme_mask_prob / self.replace_prob): 
              phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))])  # randomized
              
              # added here
              masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())
              
          else:
              phoneme += z[0]
      else:
          phoneme += self.token_mask * len(z[0]) # masked
          
          # added here
          masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())
          
      # removed here
      # masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist())
      
  else:
      phoneme += z[0] 

  phoneme += self.token_separator

@tekinek
Copy link
Author

tekinek commented Jan 5, 2024

Yes, here is the distribution. For each word (token) in a given sample:

1   85%: keep original
2   15%: 
3        - 80%: whole word masking (e.g. nice —> MMMM)  
4        - 20%:
5              - 50%: random replacement of every phoneme in it (e.g. nice —> csoe)
6              - 50%: keep original

I think, the masked index should be registered for case 3 and 5. But the currect implemetaion covers case 2, which I think is a bug.

@yl4579
Copy link
Owner

yl4579 commented Jan 6, 2024

Thanks for your question. This was intentional. The masked indices are used for loss calculation here: https://github.com/yl4579/PL-BERT/blob/main/train.ipynb (see the if len(_masked_indices) > 0: line), so the masked token also includes unchanged tokens so the model is trained to (1.5% of the times) to reproduce the exact input tokens guided by the loss. If we don't include this, the model will not be able to learn to keep the original tokens if the tokens are unmasked (like during actual time when you use it in TTS fine-tuning).

@tekinek
Copy link
Author

tekinek commented Jan 7, 2024

Thanks for your clarification. I have trained PL-bert for my language and tried to evaluate it by asking it to predict masked/unmasked tokens and phonemes. In most cases it's prediction make sense, but it fails at predicting "space" between words (which is used as token seperator in this repo). With differenct checkpoints it predicts space as random phoneme, but never the space itself, even when the space is not masked for input.

So I further trained the model in addition to masking the token separator with 15% of chance. Now the model can predict the space.

                phoneme += self.token_separator
                if np.random.rand() < self.phoneme_mask_prob:
                    masked_index.extend((np.arange(len(phoneme) - 1, len(phoneme))).tolist())

@jav-ed
Copy link

jav-ed commented Jan 7, 2024

@tekinek
When it comes to the distribution, here are some snippets of my code. Note that the code was not intended for publication, the comments were just for my own personal understanding. Thus, please overlook spelling mistakes and similar mistakes. The main point (probability distribution) should still be extractable from the comments made, inshallah

      # 85% of the time: keep original
      # 12.5% of the time: special phoneme mask
      # 1.5% of the time: random (from the available) phoneme mask
      # 1.5% of the time: were to mask the token but kept the original phoneme
      
      # --> 86.5% of the time: keep original
      # 12% of the time: special phoneme mask
      # 1.5% of the time: random (from the available) phoneme mask
      

      # word_mask_prob= 0.15
      # np.random.rand() --> random number between [0;1]
      # for less than 15% of the time or ~ 15% of the time
      if np.random.rand() < self.word_mask_prob:
          
          # replace_prob=0.2
          # now for less than 20% of the time or ~ 20% of the time
          # for 0.15 * 0.2 = 0.03 of the time
          # for ~ 3% of the time
          if np.random.rand() < self.replace_prob:
              
              
              # ------------------ random replacement ------------------ #
              # phoneme_mask_prob=0.1, replace_prob=0.2
              # 0.1/0.2 = 0.5 
              # 0.03 * 0.5 = 0.015 = 1.5% of the time replace the masked phoeneme with a random phoeneme
              if np.random.rand() < (self.phoneme_mask_prob / self.replace_prob): 
                  
                  # np.random.randint(0, len(phoneme_list)) = get an intger between [0, len(phoneme_list)],
                  # note that np.random.randint(a, b) will return actually return a int between a and (b-1) and not b
                  
                  # for the len of current phonemes - choose some random phonemes fromt he available phenemes and add them to the colelciton phoenme string, that is phoneme
                  phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))
                                      ])  # randomized
                  
                  masked_index.extend((np.arange(len(phoneme) - len(z[0]), 
                                         len(phoneme))).tolist())
                  
              # --------------------- take original -------------------- #
              # considered for masking but, kept original:
              # 0.03 *(1 - 0.5) = 0.03 * 0.5 = 0.015 = 1.5%
              else:
                  phoneme += z[0]
               
          # ------------------- special token masking ------------------ #
          # for ~ 0.15 * (1- 0.2 = 0.8) = 0.15 * 0.8 = 0.12 of the time
          # for ~ 12% of the time special mask token
          else:
              
              # add masking tikN = "M" to the phoneme string
              phoneme += self.token_mask * len(z[0]) # masked
              
              masked_index.extend((np.arange(len(phoneme) - len(z[0]), 
                                         len(phoneme))).tolist())
             

          ## Mofication made here, this line of code, should only be exuted when actually masking occurs and not when the original is taken, for 1.5% of cases this code would be executed, while no masking is actually in place
          # masked_index.extend((np.arange(len(phoneme) - len(z[0]), 
          #                                len(phoneme))).tolist()
          #                     )
         
      # -------------------- keep original phoenmes -------------------- #
      # for ~ 1 - 0.15 = 0.85 of the time --> do not mask and keep the original phoneme
      else:
          phoneme += z[0] 

      phoneme += self.token_separator


  # count phonemes in the ful lphoneme collection string
  mel_length = len(phoneme)

@jav-ed
Copy link

jav-ed commented Jan 7, 2024

Please feel free to correct me, if I made a mistake in the probability distribution calculations

@tekinek
Copy link
Author

tekinek commented Jan 7, 2024

@jav-ed your calculation looks correct. However, as @yl4579 clarified, we don't need to change the distribution. But I still suggest you to add the following line right after the last phoneme += self.token_separator in your code, which means mask the token saperator (space), if your langauge uses space between words.

if np.random.rand() < self.phoneme_mask_prob:
                    masked_index.extend((np.arange(len(phoneme) - 1, len(phoneme))).tolist())

@yl4579
Copy link
Owner

yl4579 commented Jan 8, 2024

@tekinek The token separator doesn't need to be predicted because it has a one-to-one correspondence between the grapheme and phoneme (i.e., the space token in the phoneme domain always corresponds to the word separator token in the grapheme domain). Even though the linear projection head fails at predicting this specific token, it won't affect the downstream task because a white space phoneme token means exactly word separator.

@jav-ed
Copy link

jav-ed commented Jan 9, 2024

@tekinek thank you, yes, now I see why there is an advantage in not inserting the mask_index inside the else condition. Just for anybody else, who might not understand it immediately. Basically, @yl4579 explained, that he is tricking the Bert model on purpose. He makes the model believe that something is masked, while it is not masked. Through this implementation, the model is supposed to know that some tokens are already correct and, thus, shall not be replaced.

@yl4579 thank you for your explanation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants