forked from PAIR-code/lit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.ts
230 lines (207 loc) · 6.23 KB
/
utils.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
/**
* @license
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Shared helper functions used across the app.
*/
import * as d3 from 'd3'; // Used for array helpers.
import {FacetMap, LitName, LitType, ModelsMap, Spec} from './types';
/**
* Random integer in range [min, max), where min and max are integers
* (behavior on floats is undefined).
*/
export function randInt(min: number, max: number) {
return Math.floor(min + Math.random() * (max - min));
}
/**
* Determines whether or not two sets are equal.
*/
export function setEquals<T>(setA: Set<T>, setB: Set<T>) {
if (setA.size !== setB.size) return false;
for (const a of setA) {
if (!setB.has(a)) return false;
}
return true;
}
/**
* Determines whether two arrays contain the same (unique) items.
*/
export function arrayContainsSame<T>(arrayA: T[], arrayB: T[]) {
return setEquals(new Set<T>(arrayA), new Set<T>(arrayB));
}
/**
* Check if a spec field (LitType) is an instance of one or more type names.
* This is analogous to using isinstance(litType, typesToFind) in Python,
* and relies on exporting the Python class hierarchy in the __mro__ field.
*/
export function isLitSubtype(litType: LitType, typesToFind: LitName|LitName[]) {
// TODO(lit-dev): figure out why this is occasionally called on an invalid
// spec. Likely due to skew between keys and specs in specific modules when
// dataset is changed, but worth diagnosing to make sure this doesn't mask a
// larger issue.
if (litType == null) return false;
if (typeof typesToFind === 'string') {
typesToFind = [typesToFind];
}
for (const typeName of typesToFind) {
if (litType.__mro__.includes(typeName)) {
return true;
}
}
return false;
}
/**
* Find all keys from the spec which match any of the specified types.
*/
export function findSpecKeys(
spec: Spec, typesToFind: LitName|LitName[]): string[] {
if (typeof typesToFind === 'string') {
typesToFind = [typesToFind];
}
return Object.keys(spec).filter(
key => isLitSubtype(spec[key], typesToFind as LitName[]));
}
/**
* Flattens a nested array by a single level.
*/
export function flatten<T>(arr: T[][]): T[] {
return d3.merge(arr);
}
/**
* Permutes an array.
*/
export function permute<T>(arr: T[], perm: number[]): T[] {
const sorted: T[] = [];
for (let i = 0; i < arr.length; i++) {
sorted.push(arr[perm[i]]);
}
return sorted;
}
/**
* Handler for a keystroke that checks if the key pressed was enter,
* and if so, calls the callback.
* @param e Original event
* @param callback User defined callback method.
*/
export function handleEnterKey(e: KeyboardEvent, callback: () => void) {
if (e.key === 'Enter') {
callback();
}
}
/**
* Converts the margin value to the threshold for binary classification.
*/
export function getThresholdFromMargin(margin: number) {
if (margin == null) {
return .5;
}
return margin === 0 ? .5 : 1 / (1 + Math.exp(-margin));
}
/**
* Shortens the id of an input data to be displayed in the UI.
*/
export function shortenId(id: string|null) {
if (id == null) {
return;
}
return id.substring(0, 6);
}
/**
* Return true for finite numbers.
* Also coerces numbers in string form (e.g., "2")
*/
// tslint:disable-next-line:no-any
export function isNumber(num: any) {
if (typeof num === 'number') {
return num - num === 0;
}
if (typeof num === 'string' && num.trim() !== '') {
return Number.isFinite(+num);
}
return false;
}
/**
* Return an array of provided size with sequential numbers starting at 0.
*/
export function range(size: number) {
return [...Array.from<number>({length: size}).keys()];
}
/**
* Sum of the items in an array.
*/
export function sumArray(array: number[]) {
return array.reduce((a, b) => a + b, 0);
}
/**
* Cumulative sum for an array.
*/
export function cumSumArray(array: number[]) {
const newArray: number[] = [];
array.reduce((a, b, i) => newArray[i] = a + b, 0);
return newArray;
}
/**
* Python-style array comparison.
* Compare on first element, then second, and so on until a mismatch is found.
* If one array is a prefix of another, the longer one is treated as larger.
* Example:
* [1] < [1,2] < [1,3] < [1,3,0] < [2]
*/
export function compareArrays(a: d3.Primitive[], b: d3.Primitive[]): number {
// If either is empty, the longer one wins.
if (a.length === 0 || b.length === 0) {
return d3.ascending(a.length, b.length);
}
// If both non-empty, compare the first element.
const firstComparison = d3.ascending(a[0], b[0]);
if (firstComparison !== 0) {
return firstComparison;
}
// If first element matches, recurse.
return compareArrays(a.slice(1), b.slice(1));
}
/**
* Checks if any of the model output specs contain any of the provided types.
* Can be provided a single type string or a list of them.
*/
export function doesOutputSpecContain(
models: ModelsMap, typesToCheck: LitName|LitName[]): boolean {
const modelNames = Object.keys(models);
for (let modelNum = 0; modelNum < modelNames.length; modelNum++) {
const outputSpec = models[modelNames[modelNum]].spec.output;
if (findSpecKeys(outputSpec, typesToCheck).length) {
return true;
}
}
return false;
}
/**
* Helper function to make an object into a human readable key.
* Sorts object keys, so order of object does not matter.
*/
export function objToDictKey(dict: FacetMap) {
return Object.keys(dict).sort().map(key => `${key}:${dict[key]}`).join('/');
}
/**
* Rounds a number up to the provided number of decimal places.
*/
export function roundToDecimalPlaces(num: number, places: number) {
if (places < 0) {
return num;
}
const numForPlaces = Math.pow(10, places);
return Math.round((num + Number.EPSILON) * numForPlaces) / numForPlaces;
}