forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ui.js
95 lines (80 loc) · 2.97 KB
/
ui.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tfvis from '@tensorflow/tfjs-vis';
const statusElement = document.getElementById('status');
const messageElement = document.getElementById('message');
const imagesElement = document.getElementById('images');
export function isTraining() {
statusElement.innerText = 'Training...';
}
const lossArr = [];
export function trainingLog(loss, iteration) {
messageElement.innerText = `loss[${iteration}]: ${loss}`;
lossArr.push({x: iteration, y: loss});
const container = {name: 'Loss', tab: 'Training'};
const options = {
xLabel: 'Training Step',
yLavel: 'Loss',
};
const data = {values: lossArr, series: ['loss']};
tfvis.render.linechart(container, data, options);
}
export function showTestResults(batch, predictions, labels) {
statusElement.innerText = 'Testing...';
const testExamples = batch.xs.shape[0];
let totalCorrect = 0;
for (let i = 0; i < testExamples; i++) {
const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]);
const div = document.createElement('div');
div.className = 'pred-container';
const canvas = document.createElement('canvas');
draw(image.flatten(), canvas);
const pred = document.createElement('div');
const prediction = predictions[i];
const label = labels[i];
const correct = prediction === label;
if (correct) {
totalCorrect++;
}
pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`;
pred.innerText = `pred: ${prediction}`;
div.appendChild(pred);
div.appendChild(canvas);
imagesElement.appendChild(div);
}
const accuracy = 100 * totalCorrect / testExamples;
const displayStr =
`Accuracy: ${accuracy.toFixed(2)}% (${totalCorrect} / ${testExamples})`;
messageElement.innerText = `${displayStr}\n`;
console.log(displayStr);
}
export function draw(image, canvas) {
const [width, height] = [28, 28];
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
const imageData = new ImageData(width, height);
const data = image.dataSync();
for (let i = 0; i < height * width; ++i) {
const j = i * 4;
imageData.data[j + 0] = data[i] * 255;
imageData.data[j + 1] = data[i] * 255;
imageData.data[j + 2] = data[i] * 255;
imageData.data[j + 3] = 255;
}
ctx.putImageData(imageData, 0, 0);
}