Skip to content

Commit

Permalink
Skip sequences that are less than minimum sequence length (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtritt authored Aug 11, 2023
1 parent 8270e52 commit adb23bf
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/gtnet/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(self, window, step, vocab=None, padval=None, min_seq_len=100, devic
self.device = device

def encode(self, seq):
if len(seq) < self.min_seq_len:
raise ValueError(f"Minimum sequence length is {self.min_seq_len} - got {len(seq)}")
if seq.dtype == np.dtype('S1'):
seq = seq.view(np.uint8)
elif seq.dtype == np.dtype('U1'):
Expand Down Expand Up @@ -143,9 +145,14 @@ def readfiles(cls, encoder, fastas):
for fa in fastas:
logging.debug(f'loading {fa}')
for seqid, values in cls.readfile(fa):
batches = encoder.encode(values)
val = (fa, seqid, len(values), batches)
yield val
if len(values) < encoder.min_seq_len:
logging.warning((f"Skipping {seqid} from {fa} - length less than "
"minimum sequence length {encoder.min_seq}"))
yield (fa, seqid, len(values), torch.zeros((0, 0, 0), dtype=torch.uint8))
else:
batches = encoder.encode(values)
val = (fa, seqid, len(values), batches)
yield val


class SerialLoader(Loader):
Expand Down

0 comments on commit adb23bf

Please sign in to comment.