Skip to content

Commit

Permalink
chore: linter 2/2
Browse files Browse the repository at this point in the history
  • Loading branch information
b4yuan committed Aug 20, 2024
1 parent 3c97c43 commit 006c460
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
39 changes: 22 additions & 17 deletions packages/ecog2vec/ecog2vec/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import re
from pynwb import NWBHDF5IO
import numpy as np
import soundfile as sf
import os
import scipy
from scipy.fft import fft, ifft, fftfreq, rfftfreq, rfft, irfft
from scipy.signal import butter, lfilter, filtfilt, hilbert
import re
# from ripple2nwb.neural_processing import NeuralDataProcessor
# from prepype import NeuralDataProcessor
from prepype.neural_processing import NeuralDataProcessor, downsample, downsample_NWB
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(self, nwb_dir, patient):
self.nwb_files = [file
for file in file_list
if file.startswith(f"{patient}")]
self.nwb_sr = None
self.target_sr = 100

self.bad_electrodes = []
Expand Down Expand Up @@ -184,7 +185,7 @@ def make_data(self,
self.nwb_sr = nwbfile.acquisition["ElectricalSeries"].\
rate

# indices = np.where(electrode_table["group_name"] ==
# indices = np.where(electrode_table["group_name"] ==
# self.electrode_name
# )[0]

Expand All @@ -197,10 +198,10 @@ def make_data(self,
print('High gamma extraction done.')

nwbfile_electrodes = processor.nwb_file.processing['ecephys'].\
data_interfaces['LFP'].\
electrical_series[f'high gamma \
({list(self.config["referencing"])[0]})'].\
data[()][:, self.good_electrodes]
data_interfaces['LFP'].\
electrical_series[f'high gamma \
({list(self.config["referencing"])[0]})'].\
data[()][:, self.good_electrodes]

print(f"Number of good electrodes in {file}: {nwbfile_electrodes.shape[1]}")

Expand All @@ -215,10 +216,10 @@ def make_data(self,
for start
in list(nwbfile.trials[:]["stop_time"] * self.nwb_sr)]

# Manage the speaking segments only... as an option .
# Manage the speaking segments only... as an option.
# Training data for wav2vec as speaking segments only
# will be saved in the `chopped_sentence_dir` directory.
# This block also saves the individual sentences.
# will be saved in the `chopped_sentence_dir` directory.
# This block also saves the individual sentences.
i = 0
all_speaking_segments = []
for start, stop in zip(starts, stops):
Expand All @@ -232,7 +233,8 @@ def make_data(self,

i = i + 1

concatenated_speaking_segments = np.concatenate(all_speaking_segments, axis=0)
concatenated_speaking_segments = np.concatenate(all_speaking_segments,
axis=0)

# Training data: speaking segments only
if create_training_data and chopped_sentence_dir:
Expand All @@ -258,7 +260,7 @@ def make_data(self,
# Training data: silences included
if create_training_data and chopped_recording_dir:

_nwbfile_electrodes = nwbfile_electrodes # [starts[0]:stops[-1],:]
_nwbfile_electrodes = nwbfile_electrodes
num_full_chunks = len(_nwbfile_electrodes) // chunk_length
# last_chunk_size = len(_nwbfile_electrodes) % chunk_size

Expand Down Expand Up @@ -291,8 +293,8 @@ def make_data(self,
print('Full recording saved as a WAVE file.')

if (ecog_tfrecords_dir and
((self.patient in ('EFC402', 'EFC403') and (block in self.blocks_ID_demo2) or
(self.patient in ('EFC400', 'EFC401') and (block in self.blocks_ID_mocha))))):
((self.patient in {'EFC402', 'EFC403'} and (block in self.blocks_ID_demo2) or
(self.patient in {'EFC400', 'EFC401'} and (block in self.blocks_ID_mocha))))):

# Create TFRecords for the ECoG data

Expand Down Expand Up @@ -399,7 +401,9 @@ def make_data(self,
print('In distribution block. TFRecords created.')

except Exception as e:
print(f"An error occured and block {path} is not inluded in the wav2vec training data: {e}")
print(f"An error occured \
and block {path} is not inluded \
in the wav2vec training data: {e}")

io.close()

Expand All @@ -425,7 +429,8 @@ def write_to_Protobuf(path, example_dicts):
feature_example = tfh.make_feature_example(example_dict)
writer.write(feature_example.SerializeToString())

def transcription_to_array(trial_t0, trial_tF, onset_times, offset_times, transcription, max_length, sampling_rate):
def transcription_to_array(trial_t0, trial_tF, onset_times, offset_times,
transcription, max_length, sampling_rate):

# if the transcription is missing (e.g. for covert trials)
if transcription is None:
Expand Down Expand Up @@ -456,4 +461,4 @@ def transcription_to_array(trial_t0, trial_tF, onset_times, offset_times, transc
transcript = np.insert(transcript, 0, 'pau')
indices = np.sum(indices*(np.arange(1, len(transcript))[:, None]), 0)

return transcript[indices]
return transcript[indices]
2 changes: 1 addition & 1 deletion packages/ecog2vec/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
install_requires=[
# Add any other required packages here
],
)
)

0 comments on commit 006c460

Please sign in to comment.