diff --git a/README.md b/README.md index 81f343520..a54ed443b 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,12 @@ Please see the main repository for full Tensorflow documentation. This readme w - Variables can be placed on GPU - `matmul` (using [CLBlast](https://github.com/CNugteren/CLBlast)) - some gradients -- `reduction_sum` (not tested) +- `reduce_sum`, `reduce_prod`, `reduce_max`, `reduce_mean` at least partly working, [test_reductions.py](tensorflow/stream_executor/cl/test/test_reductions.py) - training works :-))) ### To do -- reduction operations +- `reduce_min` - convolutions ## Installation @@ -92,6 +92,7 @@ python ~/git/tensorflow-cl/tensorflow/stream_executor/cl/test/test_gradients.py - Oct 28: - training working :-) [test_gradients.py](tensorflow/stream_executor/cl/test/test_gradients.py) + - `reduce_sum`, `reduce_prod`, `reduce_max`, `reduce_mean` added, in beta [test_reductions.py](tensorflow/stream_executor/cl/test/test_reductions.py) - Oct 25: - fixed BLAS wrapper, working now, on GPU, test script: [test_blas.py](tensorflow/stream_executor/cl/test/test_blas.py) - int32 constant works on gpu now, [test_ints.py](tensorflow/stream_executor/cl/test/test_ints.py) diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc index 6d3feeb66..4f169498e 100644 --- a/tensorflow/core/kernels/reduction_ops_max.cc +++ b/tensorflow/core/kernels/reduction_ops_max.cc @@ -24,7 +24,7 @@ namespace tensorflow { TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +// #if GOOGLE_CUDA #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -34,7 +34,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .HostMemory("reduction_indices"), \ ReductionOp>); REGISTER_GPU_KERNELS(float); -REGISTER_GPU_KERNELS(double); +// REGISTER_GPU_KERNELS(double); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -50,6 +50,6 @@ REGISTER_KERNEL_BUILDER( #undef REGISTER_GPU_KERNELS -#endif +// #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc index 6c75dff3e..447dd291d 100644 --- a/tensorflow/core/kernels/reduction_ops_mean.cc +++ b/tensorflow/core/kernels/reduction_ops_mean.cc @@ -24,7 +24,7 @@ namespace tensorflow { TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +// #if GOOGLE_CUDA #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -33,11 +33,11 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .TypeConstraint("T") \ .HostMemory("reduction_indices"), \ ReductionOp>); -REGISTER_GPU_KERNELS(Eigen::half); +// REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); -REGISTER_GPU_KERNELS(double); +// REGISTER_GPU_KERNELS(double); #undef REGISTER_GPU_KERNELS -#endif +// #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc index c567aca0b..c1483b58d 100644 --- a/tensorflow/core/kernels/reduction_ops_min.cc +++ b/tensorflow/core/kernels/reduction_ops_min.cc @@ -24,7 +24,7 @@ namespace tensorflow { TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +// #if GOOGLE_CUDA #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -34,7 +34,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .HostMemory("reduction_indices"), \ ReductionOp>); REGISTER_GPU_KERNELS(float); -REGISTER_GPU_KERNELS(double); +// REGISTER_GPU_KERNELS(double); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel @@ -50,6 +50,6 @@ REGISTER_KERNEL_BUILDER( #undef REGISTER_GPU_KERNELS -#endif +// #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc index e824fe426..1ce03668d 100644 --- a/tensorflow/core/kernels/reduction_ops_prod.cc +++ b/tensorflow/core/kernels/reduction_ops_prod.cc @@ -24,7 +24,7 @@ namespace tensorflow { TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +// #if GOOGLE_CUDA #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -33,12 +33,12 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .TypeConstraint("T") \ .HostMemory("reduction_indices"), \ ReductionOp>); -REGISTER_GPU_KERNELS(Eigen::half); +// REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(int32); REGISTER_GPU_KERNELS(float); -REGISTER_GPU_KERNELS(double); +// REGISTER_GPU_KERNELS(double); #undef REGISTER_GPU_KERNELS -#endif +// #endif } // namespace tensorflow diff --git a/tensorflow/stream_executor/cl/test/test_reduction.py b/tensorflow/stream_executor/cl/test/test_reduction.py deleted file mode 100644 index 224d696cb..000000000 --- a/tensorflow/stream_executor/cl/test/test_reduction.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import print_function - -import tensorflow as tf -import numpy as np - -learning_rate = 0.1 - -# lets learn or -# we'll use one-hot, with 2 binary inputs, so 4 input neurons in total -# output is one binary value, so 2 output neurons (since one-hot) -data = [ - {'input': [False, False], 'output': False}, - {'input': [False, True], 'output': True}, - {'input': [True, False], 'output': True}, - {'input': [True, True], 'output': True} -] -batch_size = len(data) -X = np.zeros((batch_size, 4), dtype=np.float32) -y = np.zeros((batch_size, 2), dtype=np.float32) -for n, ex in enumerate(data): - input = ex['input'] - output = ex['output'] - if input[0]: - X[n][1] = 1 - else: - X[n][0] = 1 - if input[1]: - X[n][3] = 1 - else: - X[n][2] = 1 - if output: - y[n][1] = 1 - else: - y[n][0] = 1 -print('X', X) -print('y', y) - -with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: - tf_x = tf.placeholder(tf.float32, [None, 4], 'x') - tf_W = tf.Variable(tf.zeros([4, 2]), 'W') - tf_bias = tf.Variable(tf.zeros(2,), 'bias') - tf_out = tf.matmul(tf_x, tf_W, name="out") + tf_bias - - np.random.seed(123) - - W_init = np.random.uniform(size=(4, 2)).astype(np.float32) - sess.run(tf.assign(tf_W, W_init)) - # print(sess.run(tf_W)) - bias_init = np.random.uniform(size=(2,)).astype(np.float32) - sess.run(tf.assign(tf_bias, bias_init)) - - tf_y = tf.placeholder(tf.float32, [None, 2], 'y') - tf_loss = tf.square(tf_y - tf_out) - tf_red = tf.reduce_sum(tf_out) - print('red', sess.run(tf_red, {tf_x: X, tf_y: y})) - - optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) - train_op = optimizer.minimize(tf_loss) - - for epoch in range(4): - loss, out, _ = sess.run((tf_loss, tf_out, train_op), {tf_x: X, tf_y: y}) - if epoch % 1 == 0: - print('epoch', epoch) - print('loss', loss) - print(np.argmax(out, 1)) diff --git a/tensorflow/stream_executor/cl/test/test_reductions.py b/tensorflow/stream_executor/cl/test/test_reductions.py new file mode 100644 index 000000000..52770e761 --- /dev/null +++ b/tensorflow/stream_executor/cl/test/test_reductions.py @@ -0,0 +1,37 @@ +from __future__ import print_function +import tensorflow as tf +import numpy as np + + +def test(tf_func, py_func): + print('func', tf_func) + with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess: + with tf.device('/gpu:0'): + tf_a = tf.placeholder(tf.float32, [None, None], 'a') + # tf_b = tf.placeholder(tf.float32, [None, None], 'b') + tf_c = tf.__dict__[tf_func](tf_a, name="c") + + np.random.seed(123) + shape = (1, 10) + a = np.random.choice(50, shape) / 25 + # b = np.random.choice(50, shape) / 25 + + ar, cr = sess.run((tf_a, tf_c), {tf_a: a}) + print('ar', ar) + # print('br', br) + print('cr', cr) + c_py = eval(py_func) + diff = np.abs(c_py - cr).max() + print('diff', diff) + assert diff < 1e-4, 'failed for %s' % tf_func + + +funcs = { + 'reduce_sum': 'np.sum(a)', + 'reduce_max': 'np.max(a)', + # 'reduce_min': 'np.min(a)', + 'reduce_prod': 'np.prod(a)', + 'reduce_mean': 'np.mean(a)' +} +for tf_func, py_func in funcs.items(): + test(tf_func, py_func)