-
Notifications
You must be signed in to change notification settings - Fork 34
/
base.py
213 lines (190 loc) · 7.93 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import sounddevice as sd
import re
import queue
import threading
from typing import Callable
import numpy as np
import pysbd
from dotenv import load_dotenv
load_dotenv()
def remove_words_in_brackets_and_spaces(text):
"""
:param text: input text
:return: input text with the extra spaces and words in brackets removed. (e.g. [USER])
"""
pattern = r"\s*\[.*?\]\s*"
cleaned_text = re.sub(pattern, " ", text)
cleaned_text = cleaned_text.strip()
return cleaned_text
class BaseMouth:
def __init__(self, sample_rate: int, player=sd, wait=True, logger=None):
"""
Initializes the BaseMouth class.
:param sample_rate: The sample rate of the audio.
:param player: The audio player object. Defaults to sounddevice.
:param wait: Whether to wait for the audio to finish playing. Defaults to True.
"""
self.sample_rate = sample_rate
self.interrupted = ""
self.player = player
self.seg = pysbd.Segmenter(language="en", clean=True)
self.wait = wait
self.logger = logger
def run_tts(self, text: str) -> np.ndarray:
"""
:param text: The text to synthesize speech for
:return: audio numpy array for sounddevice
"""
raise NotImplementedError("This method should be implemented by the subclass")
def say_text(self, text: str):
"""
calls run_tts and plays the audio using the player.
:param text: The text to synthesize speech for
"""
output = self.run_tts(text)
self.player.play(output, samplerate=self.sample_rate)
self.player.wait()
def say(self, audio_queue: queue.Queue, listen_interruption_func: Callable):
"""
Plays the audios in the queue using the player. Stops if interruption occurred.
:param audio_queue: The queue where the audio is stored for it to be played
:param listen_interruption_func: callable function from the ear class.
"""
self.interrupted = ""
while True:
output, text = audio_queue.get()
if output is None:
self.player.wait() # wait for the last audio to finish
break
# get the duration of audio
duration = len(output) / self.sample_rate
self._log_event("playing audio", "TTS", f"{duration} seconds")
self.player.play(output, samplerate=self.sample_rate)
interruption = listen_interruption_func(duration)
if interruption:
self._log_event("audio interrupted", f"TTS")
self.player.stop()
self.interrupted = (interruption, text)
break
else:
if self.wait:
self.player.wait() # No need for wait here
def say_multiple(self, text: str, listen_interruption_func: Callable):
"""
Splits the text into sentences. Then plays the sentences one by one
using run_tts() and say()
:param text: Input text to synthesize
:param listen_interruption_func: callable function from the ear class
"""
sentences = self.seg.segment(text)
print(sentences)
audio_queue = queue.Queue()
say_thread = threading.Thread(
target=self.say, args=(audio_queue, listen_interruption_func)
)
say_thread.start()
for sentence in sentences:
output = self.run_tts(sentence)
audio_queue.put((output, sentence))
if self.interrupted:
break
audio_queue.put((None, ""))
say_thread.join()
def _handle_interruption(self, responses_list, interrupt_queue):
interrupt_transcription, interrupt_text = self.interrupted
self._log_event("interruption detected", "TTS", interrupt_transcription)
idx = responses_list.index(interrupt_text)
assert (
idx != -1
), "Interrupted text not found in responses list. This should not happen. Raise an issue."
responses_list = responses_list[:idx] + ["..."]
interrupt_queue.put(interrupt_transcription)
return responses_list
def _get_all_text(self, text_queue):
text = text_queue.get()
while not text_queue.empty():
new_text = text_queue.get()
if new_text is not None:
text += new_text
else:
text_queue.put(None)
break
return text
def _log_event(self, event: str, details: str, further: str = ""):
if self.logger:
self.logger.info(
event, extra={"details": details, "further": f'"{further}"'}
)
def say_multiple_stream(
self,
text_queue: queue.Queue,
listen_interruption_func: Callable,
interrupt_queue: queue.Queue,
audio_queue: queue.Queue = None,
):
"""
Receives text from the text_queue. As soon as a sentence is made run_tts is called to
synthesize its speech and sent to the audio_queue for it to be played.
:param text_queue: The queue where the llm adds the predicted tokens
:param listen_interruption_func: callable function from the ear class
:param interrupt_queue: The queue where True is put when interruption occurred.
:param audio_queue: The queue where the audio to be played is placed
"""
response = ""
all_response = []
interrupt_text_list = []
if audio_queue is None:
audio_queue = queue.Queue()
say_thread = threading.Thread(
target=self.say, args=(audio_queue, listen_interruption_func)
)
self._log_event("audio play thread started", "TTS")
say_thread.start()
text = ""
while text is not None:
self._log_event("getting all text", "TTS")
text = self._get_all_text(text_queue)
self._log_event("all text received", "TTS")
if text is None:
self._log_event("Stream ended", "TTS")
sentence = response
else:
response += text
self._log_event("segmenting text", "TTS", response)
sentences = self.seg.segment(response)
# if there are multiple sentences we split and play the first one
if len(sentences) > 1:
self._log_event("multiple sentences detected", "TTS")
sentence = sentences[0]
response = " ".join([s for s in sentences[1:] if s != "."])
else:
self._log_event("single sentence detected", "TTS")
continue
if sentence.strip() != "":
self._log_event("cleaning sentence", "TTS")
clean_sentence = remove_words_in_brackets_and_spaces(sentence).strip()
if (
clean_sentence.strip() != ""
): # sentence only contains words in brackets
self._log_event("running tts", "TTS", clean_sentence)
output = self.run_tts(clean_sentence)
self._log_event("tts output received", "TTS")
audio_queue.put((output, clean_sentence))
interrupt_text_list.append(clean_sentence)
all_response.append(sentence)
# if interruption occurred, handle it
if self.interrupted:
all_response = self._handle_interruption(
interrupt_text_list, interrupt_queue
)
self.interrupted = ""
break
audio_queue.put((None, ""))
say_thread.join()
self._log_event("audio play thread ended", "TTS")
if self.interrupted:
all_response = self._handle_interruption(
interrupt_text_list, interrupt_queue
)
text_queue.queue.clear()
text_queue.put(" ".join(all_response))