Skip to content

Commit

Permalink
Patch 0.8.2 (#1079)
Browse files Browse the repository at this point in the history
* Fix kappa (#1047)

* fix kappa
* add more tests and rename regression variable
* add cross_entropy test for binary class model

(cherry picked from commit 83df531)

* Typo in type #1067 (#1069)
(cherry picked from commit 99352d0)

* Bump to 0.8.2

* Add CI to release branches

Co-authored-by: Aakash Kumar Nain <[email protected]>
Co-authored-by: failure-to-thrive <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2020
1 parent a68be5c commit fcaa672
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 71 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ on:
push:
branches:
- master
- r*
pull_request:
branches:
- master
- r*

env:
BAZEL_VERSION: 1.1.0
Expand Down
116 changes: 60 additions & 56 deletions tensorflow_addons/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tensorflow_addons.utils import types


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.keras.utils.register_keras_serializable(package="Addons")
class GroupNormalization(tf.keras.layers.Layer):
"""Group normalization layer.
Expand Down Expand Up @@ -71,19 +71,21 @@ class GroupNormalization(tf.keras.layers.Layer):
"""

@typechecked
def __init__(self,
groups: int = 2,
axis: int = -1,
epsilon: int = 1e-3,
center: bool = True,
scale: bool = True,
beta_initializer: types.Initializer = 'zeros',
gamma_initializer: types.Initializer = 'ones',
beta_regularizer: types.Regularizer = None,
gamma_regularizer: types.Regularizer = None,
beta_constraint: types.Constraint = None,
gamma_constraint: types.Constraint = None,
**kwargs):
def __init__(
self,
groups: int = 2,
axis: int = -1,
epsilon: float = 1e-3,
center: bool = True,
scale: bool = True,
beta_initializer: types.Initializer = "zeros",
gamma_initializer: types.Initializer = "ones",
beta_regularizer: types.Regularizer = None,
gamma_regularizer: types.Regularizer = None,
beta_constraint: types.Constraint = None,
gamma_constraint: types.Constraint = None,
**kwargs
):
super().__init__(**kwargs)
self.supports_masking = True
self.groups = groups
Expand Down Expand Up @@ -117,39 +119,32 @@ def call(self, inputs):
tensor_input_shape = tf.shape(inputs)

reshaped_inputs, group_shape = self._reshape_into_groups(
inputs, input_shape, tensor_input_shape)
inputs, input_shape, tensor_input_shape
)

normalized_inputs = self._apply_normalization(reshaped_inputs,
input_shape)
normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)

outputs = tf.reshape(normalized_inputs, tensor_input_shape)

return outputs

def get_config(self):
config = {
'groups':
self.groups,
'axis':
self.axis,
'epsilon':
self.epsilon,
'center':
self.center,
'scale':
self.scale,
'beta_initializer':
tf.keras.initializers.serialize(self.beta_initializer),
'gamma_initializer':
tf.keras.initializers.serialize(self.gamma_initializer),
'beta_regularizer':
tf.keras.regularizers.serialize(self.beta_regularizer),
'gamma_regularizer':
tf.keras.regularizers.serialize(self.gamma_regularizer),
'beta_constraint':
tf.keras.constraints.serialize(self.beta_constraint),
'gamma_constraint':
tf.keras.constraints.serialize(self.gamma_constraint)
"groups": self.groups,
"axis": self.axis,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
"gamma_initializer": tf.keras.initializers.serialize(
self.gamma_initializer
),
"beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": tf.keras.regularizers.serialize(
self.gamma_regularizer
),
"beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
"gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
}
base_config = super().get_config()
return {**base_config, **config}
Expand All @@ -174,7 +169,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
group_reduction_axes.pop(axis)

mean, variance = tf.nn.moments(
reshaped_inputs, group_reduction_axes, keepdims=True)
reshaped_inputs, group_reduction_axes, keepdims=True
)

gamma, beta = self._get_reshaped_weights(input_shape)
normalized_inputs = tf.nn.batch_normalization(
Expand All @@ -183,7 +179,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
variance=variance,
scale=gamma,
offset=beta,
variance_epsilon=self.epsilon)
variance_epsilon=self.epsilon,
)
return normalized_inputs

def _get_reshaped_weights(self, input_shape):
Expand All @@ -200,10 +197,11 @@ def _get_reshaped_weights(self, input_shape):
def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
raise ValueError(
"Axis " + str(self.axis) + " of "
"input tensor should have a defined dimension "
"but the layer received an input with shape " + str(input_shape) + "."
)

def _set_number_of_groups_for_instance_norm(self, input_shape):
dim = input_shape[self.axis]
Expand All @@ -216,26 +214,30 @@ def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
'Number of groups (' + str(self.groups) + ') cannot be '
'more than the number of channels (' + str(dim) + ').')
"Number of groups (" + str(self.groups) + ") cannot be "
"more than the number of channels (" + str(dim) + ")."
)

if dim % self.groups != 0:
raise ValueError(
'Number of groups (' + str(self.groups) + ') must be a '
'multiple of the number of channels (' + str(dim) + ').')
"Number of groups (" + str(self.groups) + ") must be a "
"multiple of the number of channels (" + str(dim) + ")."
)

def _check_axis(self):

if self.axis == 0:
raise ValueError(
"You are trying to normalize your batch axis. Do you want to "
"use tf.layer.batch_normalization instead")
"use tf.layer.batch_normalization instead"
)

def _create_input_spec(self, input_shape):

dim = input_shape[self.axis]
self.input_spec = tf.keras.layers.InputSpec(
ndim=len(input_shape), axes={self.axis: dim})
ndim=len(input_shape), axes={self.axis: dim}
)

def _add_gamma_weight(self, input_shape):

Expand All @@ -245,10 +247,11 @@ def _add_gamma_weight(self, input_shape):
if self.scale:
self.gamma = self.add_weight(
shape=shape,
name='gamma',
name="gamma",
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
constraint=self.gamma_constraint,
)
else:
self.gamma = None

Expand All @@ -260,10 +263,11 @@ def _add_beta_weight(self, input_shape):
if self.center:
self.beta = self.add_weight(
shape=shape,
name='beta',
name="beta",
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
constraint=self.beta_constraint,
)
else:
self.beta = None

Expand All @@ -274,7 +278,7 @@ def _create_broadcast_shape(self, input_shape):
return broadcast_shape


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.keras.utils.register_keras_serializable(package="Addons")
class InstanceNormalization(GroupNormalization):
"""Instance normalization layer.
Expand Down
62 changes: 51 additions & 11 deletions tensorflow_addons/metrics/cohens_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,26 @@ def __init__(self,
num_classes: FloatTensorLike,
name: str = 'cohen_kappa',
weightage: Optional[str] = None,
sparse_labels: bool = False,
regression: bool = False,
dtype: AcceptableDTypes = None,
**kwargs):
"""Creates a `CohenKappa` instance.
Args:
num_classes: Number of unique classes in your dataset.
name: (Optional) String name of the metric instance.
weightage: (Optional) Weighting to be considered for calculating
weightage: (optional) Weighting to be considered for calculating
kappa statistics. A valid value is one of
[None, 'linear', 'quadratic']. Defaults to `None`.
dtype: (Optional) Data type of the metric result.
Defaults to `None`.
[None, 'linear', 'quadratic']. Defaults to `None`
sparse_lables: (bool) Valid only for multi-class scenario.
If True, ground truth labels are expected tp be integers
and not one-hot encoded
regression: (bool) If set, that means the problem is being treated
as a regression problem where you are regressing the predictions.
**Note:** If you are regressing for the values, the the output layer
should contain a single unit.
name: (optional) String name of the metric instance
dtype: (optional) Data type of the metric result. Defaults to `None`
Raises:
ValueError: If the value passed for `weightage` is invalid
Expand All @@ -89,8 +97,18 @@ def __init__(self,
if weightage not in (None, 'linear', 'quadratic'):
raise ValueError("Unknown kappa weighting type.")

if num_classes == 2:
self._update = self._update_binary_class_model
elif num_classes > 2:
self._update = self._update_multi_class_model
else:
raise ValueError("""Number of classes must be
greater than or euqal to two""")

self.weightage = weightage
self.num_classes = num_classes
self.regression = regression
self.sparse_labels = sparse_labels
self.conf_mtx = self.add_weight(
'conf_mtx',
shape=(self.num_classes, self.num_classes),
Expand All @@ -114,22 +132,42 @@ def update_state(self, y_true, y_pred, sample_weight=None):
Returns:
Update op.
"""
return self._update(y_true, y_pred, sample_weight)

def _update_binary_class_model(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, dtype=tf.int64)
y_pred = tf.cast(y_pred, dtype=tf.int64)
y_pred = tf.cast(y_pred, dtype=tf.float32)
y_pred = tf.cast(y_pred > 0.5, dtype=tf.int64)
return self._update_confusion_matrix(y_true, y_pred, sample_weight)

def _update_multi_class_model(self, y_true, y_pred, sample_weight=None):
if not self.sparse_labels:
y_true = tf.cast(tf.argmax(y_true, axis=-1), dtype=tf.int64)
else:
y_true = tf.cast(y_true, dtype=tf.int64)

if tf.rank(y_pred) > 1:
if not self.regression:
y_pred = tf.cast(tf.argmax(y_pred, axis=-1), dtype=tf.int64)
else:
y_pred = tf.math.round(tf.math.abs(y_pred))
y_pred = tf.cast(y_pred, dtype=tf.int64)
else:
y_pred = tf.cast(y_pred, dtype=tf.int64)

return self._update_confusion_matrix(y_true, y_pred, sample_weight)

if y_true.shape != y_pred.shape:
raise ValueError(
"Number of samples in `y_true` and `y_pred` are different")
def _update_confusion_matrix(self, y_true, y_pred, sample_weight):
y_true = tf.squeeze(y_true)
y_pred = tf.squeeze(y_pred)

# compute the new values of the confusion matrix
new_conf_mtx = tf.math.confusion_matrix(
labels=y_true,
predictions=y_pred,
num_classes=self.num_classes,
weights=sample_weight,
dtype=tf.float32)

# update the values in the original confusion matrix
return self.conf_mtx.assign_add(new_conf_mtx)

def result(self):
Expand Down Expand Up @@ -179,6 +217,8 @@ def get_config(self):
config = {
"num_classes": self.num_classes,
"weightage": self.weightage,
"sparse_labels": self.sparse_labels,
"regression": self.regression
}
base_config = super().get_config()
return {**base_config, **config}
Expand Down
Loading

0 comments on commit fcaa672

Please sign in to comment.