Skip to content

Commit

Permalink
add reduce_sum, reduce_prod, reduce_max, reduce_mean, test: […
Browse files Browse the repository at this point in the history
…test_reductions.py](tensorflow/stream_executor/cl/test/test_reductions.py)
  • Loading branch information
hughperkins committed Oct 28, 2016
1 parent cf87006 commit 3d6b2ac
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 81 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ Please see the main repository for full Tensorflow documentation. This readme w
- Variables can be placed on GPU
- `matmul` (using [CLBlast](https://github.com/CNugteren/CLBlast))
- some gradients
- `reduction_sum` (not tested)
- `reduce_sum`, `reduce_prod`, `reduce_max`, `reduce_mean` at least partly working, [test_reductions.py](tensorflow/stream_executor/cl/test/test_reductions.py)
- training works :-)))

### To do

- reduction operations
- `reduce_min`
- convolutions

## Installation
Expand Down Expand Up @@ -92,6 +92,7 @@ python ~/git/tensorflow-cl/tensorflow/stream_executor/cl/test/test_gradients.py

- Oct 28:
- training working :-) [test_gradients.py](tensorflow/stream_executor/cl/test/test_gradients.py)
- `reduce_sum`, `reduce_prod`, `reduce_max`, `reduce_mean` added, in beta [test_reductions.py](tensorflow/stream_executor/cl/test/test_reductions.py)
- Oct 25:
- fixed BLAS wrapper, working now, on GPU, test script: [test_blas.py](tensorflow/stream_executor/cl/test/test_blas.py)
- int32 constant works on gpu now, [test_ints.py](tensorflow/stream_executor/cl/test/test_ints.py)
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/kernels/reduction_ops_max.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace tensorflow {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS

#if GOOGLE_CUDA
// #if GOOGLE_CUDA

#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Expand All @@ -34,7 +34,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::MaxReducer<type>>);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
// REGISTER_GPU_KERNELS(double);

// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
Expand All @@ -50,6 +50,6 @@ REGISTER_KERNEL_BUILDER(

#undef REGISTER_GPU_KERNELS

#endif
// #endif

} // namespace tensorflow
8 changes: 4 additions & 4 deletions tensorflow/core/kernels/reduction_ops_mean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace tensorflow {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS

#if GOOGLE_CUDA
// #if GOOGLE_CUDA

#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Expand All @@ -33,11 +33,11 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<type>("T") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::MeanReducer<type>>);
REGISTER_GPU_KERNELS(Eigen::half);
// REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
// REGISTER_GPU_KERNELS(double);
#undef REGISTER_GPU_KERNELS

#endif
// #endif

} // namespace tensorflow
6 changes: 3 additions & 3 deletions tensorflow/core/kernels/reduction_ops_min.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace tensorflow {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS

#if GOOGLE_CUDA
// #if GOOGLE_CUDA

#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Expand All @@ -34,7 +34,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::MinReducer<type>>);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
// REGISTER_GPU_KERNELS(double);

// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
Expand All @@ -50,6 +50,6 @@ REGISTER_KERNEL_BUILDER(

#undef REGISTER_GPU_KERNELS

#endif
// #endif

} // namespace tensorflow
8 changes: 4 additions & 4 deletions tensorflow/core/kernels/reduction_ops_prod.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace tensorflow {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS

#if GOOGLE_CUDA
// #if GOOGLE_CUDA

#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Expand All @@ -33,12 +33,12 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<type>("T") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>);
REGISTER_GPU_KERNELS(Eigen::half);
// REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(int32);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
// REGISTER_GPU_KERNELS(double);
#undef REGISTER_GPU_KERNELS

#endif
// #endif

} // namespace tensorflow
65 changes: 0 additions & 65 deletions tensorflow/stream_executor/cl/test/test_reduction.py

This file was deleted.

37 changes: 37 additions & 0 deletions tensorflow/stream_executor/cl/test/test_reductions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import print_function
import tensorflow as tf
import numpy as np


def test(tf_func, py_func):
print('func', tf_func)
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
with tf.device('/gpu:0'):
tf_a = tf.placeholder(tf.float32, [None, None], 'a')
# tf_b = tf.placeholder(tf.float32, [None, None], 'b')
tf_c = tf.__dict__[tf_func](tf_a, name="c")

np.random.seed(123)
shape = (1, 10)
a = np.random.choice(50, shape) / 25
# b = np.random.choice(50, shape) / 25

ar, cr = sess.run((tf_a, tf_c), {tf_a: a})
print('ar', ar)
# print('br', br)
print('cr', cr)
c_py = eval(py_func)
diff = np.abs(c_py - cr).max()
print('diff', diff)
assert diff < 1e-4, 'failed for %s' % tf_func


funcs = {
'reduce_sum': 'np.sum(a)',
'reduce_max': 'np.max(a)',
# 'reduce_min': 'np.min(a)',
'reduce_prod': 'np.prod(a)',
'reduce_mean': 'np.mean(a)'
}
for tf_func, py_func in funcs.items():
test(tf_func, py_func)

0 comments on commit 3d6b2ac

Please sign in to comment.