-
Notifications
You must be signed in to change notification settings - Fork 2
/
tflite.html
163 lines (128 loc) · 7.57 KB
/
tflite.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width">
<title>Lyra V2 Web/Browser Demo - TFLite Runtime</title>
<script src="./enable-threads.js"></script>
</head>
<body>
<!-- <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script> -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf-tflite.js"></script>
<div>
<span>Model version:</span>
<select id="modelVersionSelectEl">
<option value="1.3.0">1.3.0</option>
<option value="1.2.0">1.2.0</option>
</select>
</div>
<br>
<div>
<button onclick="testEncoder(); this.parentElement.querySelectorAll('button').forEach(btn => btn.disabled=true);">Test Encoder</button>
<button onclick="testQuantizerEncoder(); this.parentElement.querySelectorAll('button').forEach(btn => btn.disabled=true);">Test Quantizer Encoder</button>
<button onclick="testQuantizerDecoder(); this.parentElement.querySelectorAll('button').forEach(btn => btn.disabled=true);">Test Quantizer Decoder</button>
<button onclick="testDecoder(); this.parentElement.querySelectorAll('button').forEach(btn => btn.disabled=true);">Test Decoder</button>
</div>
<p>See browser console for result. Note that this is a WIP - it isn't working yet!</p>
<script>
let encoderModel;
let decoderModel;
let quantizerEncoderModel;
let quantizerDecoderModel;
async function initModels(opts={}) {
if(opts.encoder && modelVersionSelectEl.value === "1.3.0" && !confirm('The encoder is not currently working - seems like an infinite loop or something and may crash your tab. Continue?')) {
throw new Error("User cancelled model init.");
}
console.log("Loading models...");
[
encoderModel,
quantizerEncoderModel,
quantizerDecoderModel,
decoderModel,
] = await Promise.all([
opts.encoder ? tflite.loadTFLiteModel(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/tflite/${modelVersionSelectEl.value}/soundstream_encoder.tflite`) : undefined,
opts.quantizerEncoder ? tflite.loadTFLiteModel(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/tflite/${modelVersionSelectEl.value}/quantizer_encoder.tflite`) : undefined,
opts.quantizerDecoder ? tflite.loadTFLiteModel(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/tflite/${modelVersionSelectEl.value}/quantizer_decoder.tflite`) : undefined,
opts.decoder ? tflite.loadTFLiteModel(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/tflite/${modelVersionSelectEl.value}/lyragan.tflite`) : undefined,
]);
console.log("Models finished loading.");
}
async function encode(audioSamples) { // `audioSamples` should be a 320-element Float32Array
if(!encoderModel) throw new Error("You need to init the encoder model first.");
let audioSamplesTensor = tf.tensor(audioSamples, [1,320], "float32");
let outputTensor = await encoderModel.predict({"serving_default_input_audio:0":audioSamplesTensor});
return outputTensor; // float32[1,1,64]
}
async function quantize(embedding, numQuantizers) { // `embedding` should be a tensor of type float32[1,1,64]. `numQuantizers` is an int32 scalar and can be either 16, 32, or 46. Explanation: https://github.com/google/lyra/issues/92#issuecomment-1265883697
if(!quantizerEncoderModel) throw new Error("You need to init the quantizer model first.");
let outputs = await quantizerEncoderModel.predict({"encode_input_frames:0":embedding, "encode_num_quantizers:0":numQuantizers}); // Need to work out how to select the encoder subgraph: https://github.com/tensorflow/tfjs/issues/6919
return outputs; // int32[46,1,1]
}
async function dequantize(bits) { // `bits` should be a tensor of type int32[46,1,1]
if(!quantizerDecoderModel) throw new Error("You need to init the quantizer model first.");
let outputs = await quantizerDecoderModel.predict({"decode_encoding_indices:0":bits});
return outputs; // float32[1,1,64]
}
async function decode(embedding) { // `embedding` should be a tensor of type float32[1,1,64]
if(!decoderModel) throw new Error("You need to init the decoder model first.");
let outputTensor = await decoderModel.predict({"serving_default_input_audio:0":embedding});
return outputTensor; // float32[1,320]
}
let audioCache = {};
async function getAudioSamples(url, chunkI) {
if(!audioCache[url]) {
// Need audio length to set max length of OfflineAudioContext - (hacky)
let audioDuration = await new Promise(r => {
let audio = new Audio();
audio.onloadedmetadata = () => r(audio.duration);
audio.src = url;
});
let audioData = await fetch(url).then(r => r.arrayBuffer());
let audioCtx = new OfflineAudioContext({sampleRate:16000, length:16000*audioDuration + 1000, numberOfChannels:1});
let decodedData = await audioCtx.decodeAudioData(audioData); // resampled to the AudioContext's sampling rate
let float32Data = decodedData.getChannelData(0); // Float32Array for channel 0
audioCache[url] = float32Data;
}
let audioChunk = audioCache[url].slice(chunkI*320, (chunkI+1)*320); // 320 samples at 16khz = 20ms
return audioChunk;
}
console.warn("CURRENTLY WE CAN ONLY TEST ONE MODEL AT A TIME DUE TO THIS ISSUE: https://github.com/tensorflow/tfjs/issues/6094#issuecomment-1267870990");
async function testEncoder() {
console.log("Testing encoder:");
await initModels({encoder:true});
let inputSamples = await getAudioSamples("https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/test.wav", 0);
console.log("INPUT:", inputSamples); // a 320-element Float32Array
let embedding = await encode(inputSamples);
console.log("OUTPUT:", embedding.dataSync()); // float32[1,1,64]
console.warn("Not sure why output is all zeros...");
}
async function testDecoder() {
console.log("Testing decoder:");
await initModels({decoder:true});
let inputEmbedding = tf.tensor(new Float32Array(64).fill(1), [1,1,64], "float32");
console.log("INPUT:", inputEmbedding);
let embedding = await decode(inputEmbedding);
console.log("OUTPUT:", embedding.dataSync()); // a 320-element Float32Array
console.warn("Possibly all zeros because of the strange input we're giving it. TODO: Once you get encoder working, use encoder's output as the input for this test.");
}
async function testQuantizerEncoder() {
console.log("Testing quantizer encoder:");
await initModels({quantizerEncoder:true});
let inputEmbedding = tf.tensor(new Float32Array(64).fill(1), [1,1,64], "float32");
console.log("INPUT:", inputEmbedding.dataSync());
let bits = await quantize(inputEmbedding, tf.scalar(46, "int32"));
console.log("OUTPUT:", bits.dataSync()); // int32[46,1,1]
}
async function testQuantizerDecoder() {
console.log("Testing quantizer decoder:");
await initModels({quantizerDecoder:true});
let bits = tf.tensor(new Int32Array(46).fill(1), [46,1,1], "int32");
console.log("INPUT:", bits.dataSync());
let outputEmbedding = await dequantize(bits);
console.log("OUTPUT:", outputEmbedding.dataSync()); // float32[1,1,64]
}
</script>
</body>
</html>