Skip to content

Commit

Permalink
chore: linter 1/2
Browse files Browse the repository at this point in the history
  • Loading branch information
b4yuan committed Aug 20, 2024
1 parent 68823cf commit 3c97c43
Showing 1 changed file with 99 additions and 84 deletions.
183 changes: 99 additions & 84 deletions packages/ecog2vec/ecog2vec/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,64 +43,64 @@ class NeuralDataGenerator():
def __init__(self, nwb_dir, patient):

self.patient = patient

file_list = os.listdir(nwb_dir)
self.nwb_dir = nwb_dir
self.nwb_files = [file
for file in file_list
self.nwb_files = [file
for file in file_list
if file.startswith(f"{patient}")]
self.target_sr = 100

self.bad_electrodes = []
self.good_electrodes = list(np.arange(256))

self.high_gamma_min = 70
self.high_gamma_max = 199

# Bad electrodes are 1-indexed!

if patient == 'EFC400':
self.electrode_name = 'R256GridElectrode electrodes'
self.grid_size = np.array([16, 16])
self.bad_electrodes = [x - 1 for x in [1, 2, 33, 50, 54, 64,
self.bad_electrodes = [x - 1 for x in [1, 2, 33, 50, 54, 64,
128, 129, 193, 194, 256]]
self.blocks_ID_mocha = [3, 23, 72]

elif patient == 'EFC401':
self.electrode_name = 'L256GridElectrode electrodes'
self.grid_size = np.array([16, 16])
self.bad_electrodes = [x - 1 for x in [1, 2, 63, 64, 65, 127,
143, 193, 194, 195, 196,
235, 239, 243, 252, 254,
self.bad_electrodes = [x - 1 for x in [1, 2, 63, 64, 65, 127,
143, 193, 194, 195, 196,
235, 239, 243, 252, 254,
255, 256]]
self.blocks_ID_mocha = [4, 41, 57, 61, 66, 69, 73, 77, 83, 87]

elif patient == "EFC402":
self.electrode_name = 'InferiorGrid electrodes'
self.grid_size = np.array([8, 16])
self.bad_electrodes = [x - 1 for x in list(range(129, 257))]
self.blocks_ID_demo2 = [4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17,
18, 19, 25, 26, 27, 33, 34, 35, 44, 45,
self.blocks_ID_demo2 = [4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17,
18, 19, 25, 26, 27, 33, 34, 35, 44, 45,
46, 47, 48, 49, 58, 59, 60]

elif patient == 'EFC403':
self.electrode_name = 'Grid electrodes'
self.grid_size = np.array([16, 16])
self.bad_electrodes = [x - 1 for x in [129, 130, 131, 132, 133,
self.bad_electrodes = [x - 1 for x in [129, 130, 131, 132, 133,
134, 135, 136, 137, 138,
139, 140, 141, 142, 143,
139, 140, 141, 142, 143,
144, 145, 146, 147, 148,
149, 161, 162, 163, 164,
149, 161, 162, 163, 164,
165, 166, 167, 168, 169,
170, 171, 172, 173, 174,
170, 171, 172, 173, 174,
175, 176, 177, 178, 179,
180, 181]]
self.blocks_ID_demo2 = [4, 7, 10, 13, 19, 20, 21, 28, 35, 39,
52, 53, 54, 55, 56, 59, 60, 61, 62, 63,
64, 70, 73, 74, 75, 76, 77, 83, 92, 93,
94, 95, 97, 98, 99, 100, 101, 108, 109,
self.blocks_ID_demo2 = [4, 7, 10, 13, 19, 20, 21, 28, 35, 39,
52, 53, 54, 55, 56, 59, 60, 61, 62, 63,
64, 70, 73, 74, 75, 76, 77, 83, 92, 93,
94, 95, 97, 98, 99, 100, 101, 108, 109,
110, 111, 112, 113, 114, 115]

else:
self.electrode_name = None
self.grid_size = None
Expand All @@ -110,8 +110,8 @@ def __init__(self, nwb_dir, patient):

self.config = None

def make_data(self,

def make_data(self,
chopped_sentence_dir=None,
sentence_dir=None,
chopped_recording_dir=None,
Expand All @@ -135,11 +135,12 @@ def make_data(self,
(None)
"""
all_example_dict = [] # not maintained at the moment; stores ALL example dicts

block_pattern = re.compile(r'B(\d+)')

if BPR is None:
raise ValueError("Please specify whether to use common average reference or bipolar referencing")
raise ValueError("Please specify whether to use \
common average reference or bipolar referencing")

if self.config is None:
self.config = {
Expand All @@ -148,23 +149,23 @@ def make_data(self,
'target sampling rate': None,
'grid size': self.grid_size
}

for file in self.nwb_files:

create_training_data = True

match = block_pattern.search(file)
block = int(match.group(1))

if self.patient == 'EFC400' or self.patient == 'EFC401':
if block in self.blocks_ID_mocha:
create_training_data = False
elif self.patient == 'EFC402' or self.patient == 'EFC403':
if block in self.blocks_ID_demo2:
create_training_data = False

path = os.path.join(self.nwb_dir, file)

io = NWBHDF5IO(path, load_namespaces=True, mode='r')
nwbfile = io.read()

Expand All @@ -173,13 +174,13 @@ def make_data(self,
with NeuralDataProcessor(
nwb_path=path, config=self.config, WRITE=False
) as processor:

# Grab the electrode table and sampling rate,
# and then process the raw ECoG data.
# and then process the raw ECoG data.

electrode_table = nwbfile.acquisition["ElectricalSeries"].\
electrodes.table[:]

self.nwb_sr = nwbfile.acquisition["ElectricalSeries"].\
rate

Expand All @@ -197,22 +198,23 @@ def make_data(self,

nwbfile_electrodes = processor.nwb_file.processing['ecephys'].\
data_interfaces['LFP'].\
electrical_series[f'high gamma ({list(self.config["referencing"])[0]})'].\
electrical_series[f'high gamma \
({list(self.config["referencing"])[0]})'].\
data[()][:, self.good_electrodes]

print(f"Number of good electrodes in {file}: {nwbfile_electrodes.shape[1]}")
print(f"Number of good electrodes in {file}: {nwbfile_electrodes.shape[1]}")

# Begin building the WAVE files for wav2vec training
# and evaluation.

# Starts/stops for each intrablock trial.
starts = [int(start)
for start
starts = [int(start)
for start
in list(nwbfile.trials[:]["start_time"] * self.nwb_sr)]
stops = [int(start)
for start
in list(nwbfile.trials[:]["stop_time"] * self.nwb_sr)]

# Manage the speaking segments only... as an option .
# Training data for wav2vec as speaking segments only
# will be saved in the `chopped_sentence_dir` directory.
Expand All @@ -222,22 +224,23 @@ def make_data(self,
for start, stop in zip(starts, stops):
speaking_segment = nwbfile_electrodes[start:stop,:]
all_speaking_segments.append(speaking_segment)

if sentence_dir:
file_name = f'{sentence_dir}/{file}_{i}.wav'
sf.write(file_name,
sf.write(file_name,
speaking_segment, 16000, subtype='FLOAT')

i = i + 1

concatenated_speaking_segments = np.concatenate(all_speaking_segments, axis=0)

# Training data: speaking segments only
if create_training_data and chopped_sentence_dir:
num_full_chunks = len(concatenated_speaking_segments) // chunk_length
# last_chunk_size = len(nwbfile_electrodes) % chunk_size

full_chunks = np.split(concatenated_speaking_segments[:num_full_chunks * chunk_length], num_full_chunks)
full_chunks = np.split(concatenated_speaking_segments[:num_full_chunks * chunk_length],
num_full_chunks)
last_chunk = concatenated_speaking_segments[num_full_chunks * chunk_length:]

chunks = full_chunks # + [last_chunk] omit the last non-100000 chunk
Expand All @@ -247,54 +250,61 @@ def make_data(self,
file_name = f'{chopped_sentence_dir}/{file}_{i}.wav'
sf.write(file_name, chunk, 16000, subtype='FLOAT')

print(f'Out of distribution block. Number of chopped chunks w/o intertrial silences of length {chunk_length} added to training data: {num_full_chunks}')


print(f'Out of distribution block. \
Number of chopped chunks w/o intertrial silences \
of length {chunk_length} added to training data: {num_full_chunks}')


# Training data: silences included
if create_training_data and chopped_recording_dir:

_nwbfile_electrodes = nwbfile_electrodes # [starts[0]:stops[-1],:] # remove starting/end silences
_nwbfile_electrodes = nwbfile_electrodes # [starts[0]:stops[-1],:]
num_full_chunks = len(_nwbfile_electrodes) // chunk_length
# last_chunk_size = len(_nwbfile_electrodes) % chunk_size

if num_full_chunks != 0:

full_chunks = np.split(_nwbfile_electrodes[:num_full_chunks * chunk_length], num_full_chunks)
full_chunks = np.split(_nwbfile_electrodes[:num_full_chunks * chunk_length],
num_full_chunks)
last_chunk = _nwbfile_electrodes[num_full_chunks * chunk_length:]

chunks = full_chunks # + [last_chunk] omit the last non-100000 chunk

# Checking lengths here
# for chunk in chunks:
# print(chunk.shape)
# print(last_chunk.shape)

# Loop through the chunks and save them as WAV files
for i, chunk in enumerate(chunks):
file_name = f'{chopped_recording_dir}/{file}_{i}.wav' # CHANGE FOR EACH SUBJECT
sf.write(file_name, chunk, 16000, subtype='FLOAT') # adjust as needed
file_name = f'{chopped_recording_dir}/{file}_{i}.wav'
sf.write(file_name, chunk, 16000, subtype='FLOAT')

print(f'Out of distribution block. \
Number of chopped chunks w/ intertrial silences \
of length {chunk_length} added to training data: {num_full_chunks}')

print(f'Out of distribution block. Number of chopped chunks w/ intertrial silences of length {chunk_length} added to training data: {num_full_chunks}')

if full_recording_dir:
file_name = f'{full_recording_dir}/{file}.wav'
sf.write(file_name, nwbfile_electrodes, 16000, subtype='FLOAT')

print('Full recording saved as a WAVE file.')

if (ecog_tfrecords_dir and
if (ecog_tfrecords_dir and
((self.patient in ('EFC402', 'EFC403') and (block in self.blocks_ID_demo2) or
(self.patient in ('EFC400', 'EFC401') and (block in self.blocks_ID_mocha))))):

# Create TFRecords for the ECoG data

high_gamma = downsample(nwbfile_electrodes,
self.nwb_sr,
self.target_sr,
high_gamma = downsample(nwbfile_electrodes,
self.nwb_sr,
self.target_sr,
'NWB',
ZSCORE=True)

phoneme_transcriptions = nwbfile.processing['behavior'].data_interfaces['BehavioralEpochs'].interval_series #['phoneme transcription'].timestamps[:]
phoneme_transcriptions = nwbfile.processing['behavior'].\
data_interfaces['BehavioralEpochs'].\
interval_series

token_type = 'word_sequence'

Expand Down Expand Up @@ -327,44 +337,44 @@ def make_data(self,

i0 = np.rint(self.target_sr * t0).astype(int)
iF = np.rint(self.target_sr * tF).astype(int)

# ECOG (C) SEQUENCE
c = high_gamma[i0:iF,:]
# print(c.shape)
# plt.plot(c[:,0])
# break

nsamples = c.shape[0]

# TEXT SEQUENCE
speech_string = trial['transcription'].values[0]
text_sequence = sentence_tokenize(speech_string.split(' ')) # , 'text_sequence')
# AUDIO SEQUENCE
text_sequence = sentence_tokenize(speech_string.split(' '))

# AUDIO SEQUENCE
audio_sequence = []

# PHONEME SEQUENCE

M = iF - i0
max_seconds = max_seconds_dict.get(token_type) # , 0.2) # i don't think this 0.2 default is necessary for the scope of this

max_seconds = max_seconds_dict.get(token_type)
max_samples = int(np.floor(self.target_sr * max_seconds))
max_length = min(M, max_samples)

phoneme_array = transcription_to_array(
t0, tF, phoneme_onset_times, phoneme_offset_times,
phoneme_transcript, max_length, self.target_sr
phoneme_transcript, max_length, self.target_sr
)

phoneme_sequence = [ph.encode('utf-8') for ph in phoneme_array]

if len(phoneme_sequence) != nsamples:
if len(phoneme_sequence) > nsamples:
phoneme_sequence = [phoneme_sequence[i] for i in range(nsamples)]
else:
for i in range(nsamples - len(phoneme_sequence)):
phoneme_sequence.append(phoneme_sequence[len(phoneme_sequence) - 1])

print('\n------------------------')
print(f'For sentence {index}: ')
print(c[0:5,0:5])
Expand All @@ -374,17 +384,21 @@ def make_data(self,
print(f'Length of phoneme sequence: {len(phoneme_sequence)}')
print(phoneme_sequence)
print('------------------------\n')

example_dicts.append({'ecog_sequence': c, 'text_sequence': text_sequence, 'audio_sequence': [], 'phoneme_sequence': phoneme_sequence,})

example_dicts.append({'ecog_sequence': c,
'text_sequence': text_sequence,
'audio_sequence': [],
'phoneme_sequence': phoneme_sequence,})

# all_example_dict.extend(example_dicts)
# print(len(example_dicts))
# print(len(all_example_dict))
write_to_Protobuf(f'{ecog_tfrecords_dir}/{self.patient}_B{block}.tfrecord', example_dicts)
write_to_Protobuf(f'{ecog_tfrecords_dir}/{self.patient}_B{block}.tfrecord',
example_dicts)

print('In distribution block. TFRecords created.')

except Exception as e:
except Exception as e:
print(f"An error occured and block {path} is not inluded in the wav2vec training data: {e}")

io.close()
Expand Down Expand Up @@ -435,7 +449,8 @@ def transcription_to_array(trial_t0, trial_tF, onset_times, offset_times, transc
# print('exactly one phoneme:', np.all(np.sum(indices, 0) == 1))
assert np.all(np.sum(indices, 0) < 2)
except:
pdb.set_trace()
# pdb.set_trace()
pass

# ...but there can be locations with *zero* phonemes; assume 'pau' here
transcript = np.insert(transcript, 0, 'pau')
Expand Down

0 comments on commit 3c97c43

Please sign in to comment.