Skip to content

Commit

Permalink
Implement Hadamard transform for dim = 20 * 2^k and 28 * 2^k
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Nov 30, 2023
1 parent 2c81f4e commit de409a6
Show file tree
Hide file tree
Showing 9 changed files with 641 additions and 155 deletions.
14 changes: 13 additions & 1 deletion benchmarks/benchmark_fast_hadamard_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from flash_attn.utils.benchmark import benchmark_forward, pytorch_profiler
from fast_hadamard_transform import hadamard_transform
from fast_hadamard_transform.fast_hadamard_transform_interface import hadamard_transform_20N
from fast_hadamard_transform.fast_hadamard_transform_interface import hadamard_transform_28N


batch_size = 16
seqlen = 2048
dim = 8192
dim = 16384
dtype = torch.float16
device = "cuda"

Expand All @@ -16,3 +18,13 @@
pytorch_profiler(hadamard_transform, x)
benchmark_forward(torch.clone, x, desc="torch.clone")
pytorch_profiler(torch.clone, x)

dim = 20 * 512
x = torch.randn(batch_size, seqlen, dim, dtype=dtype, device=device)
benchmark_forward(hadamard_transform_20N, x, 1.0, desc="Hadamard transform 20N")
pytorch_profiler(hadamard_transform_20N, x, 1.0)

dim = 28 * 512
x = torch.randn(batch_size, seqlen, dim, dtype=dtype, device=device)
benchmark_forward(hadamard_transform_28N, x, 1.0, desc="Hadamard transform 28N")
pytorch_profiler(hadamard_transform_28N, x, 1.0)
107 changes: 107 additions & 0 deletions csrc/code_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import math
import re
from pathlib import Path

import numpy as np

# From http://neilsloane.com/hadamard/

had_20_will = """
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
"""


had_28_will = """
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
"""

header = """
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// This file is auto-generated. See "code_gen.py"\n
#pragma once
"""

template = """
__device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) {
float out[{N}];
{code}
#pragma unroll
for (int i = 0; i < {N}; i++) { x[i] = out[i]; }
}
"""


def string_to_array(string):
# Convert strings of + and - to bool arrays
string = string.strip().replace('+', '1').replace('-', '-1').split()
return np.stack([np.fromstring(" ".join(string[i]), dtype=np.int, sep=' ') for i in range(len(string))])


def array_code_gen(arr):
N = arr.shape[0]
assert arr.shape[0] == arr.shape[1]
out = []
for i in range(N):
out.append(f"out[{i}] = " + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) + ";")
return template.replace("{N}", str(N)).replace("{code}", '\n '.join(out))



def main():
output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h"
output_dir.write_text(header + array_code_gen(string_to_array(had_20_will)) + array_code_gen(string_to_array(had_28_will)))

if __name__ == '__main__':
main()
99 changes: 97 additions & 2 deletions csrc/fast_hadamard_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@
template<typename input_t>
void fast_hadamard_transform_cuda(HadamardParamsBase &params, cudaStream_t stream);

template<typename input_t>
void fast_hadamard_transform_20N_cuda(HadamardParamsBase &params, cudaStream_t stream);

template<typename input_t>
void fast_hadamard_transform_28N_cuda(HadamardParamsBase &params, cudaStream_t stream);

void set_hadamard_params(HadamardParamsBase &params,
// sizes
const size_t batch,
const size_t dim,
const size_t multiple,
// device pointers
const at::Tensor x,
const at::Tensor out,
Expand All @@ -43,7 +50,7 @@ void set_hadamard_params(HadamardParamsBase &params,

params.batch = batch;
params.dim = dim;
params.log_dim = int(ceil(std::log2(dim)));
params.log_N = int(ceil(std::log2(dim / multiple)));

// Set the pointers and strides.
params.x_ptr = x.data_ptr();
Expand Down Expand Up @@ -84,7 +91,7 @@ fast_hadamard_transform(at::Tensor &x, float scale) {
at::Tensor out = torch::empty_like(x);

HadamardParamsBase params;
set_hadamard_params(params, batch_size, dim, x, out, scale);
set_hadamard_params(params, batch_size, dim, 1, x, out, scale);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand All @@ -99,6 +106,94 @@ fast_hadamard_transform(at::Tensor &x, float scale) {
return out.reshape(shapes_og);
}

at::Tensor
fast_hadamard_transform_20N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

TORCH_CHECK(x.is_cuda());

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
x = x.reshape({-1, dim_og});
if (x.stride(-1) != 1) { x = x.contiguous(); }
const auto sizes = x.sizes();
const int batch_size = sizes[0];

CHECK_SHAPE(x, batch_size, dim_og);
TORCH_CHECK(x.stride(1) == 1);

if (dim_og % (4 * 20) != 0) {
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 20) - dim_og % (4 * 20)}));
}
const int dim = x.size(1);

TORCH_CHECK(dim % (4 * 20) == 0, "fast_hadamard_transform_20N only supports hidden dimension divisible by 80 for now");
TORCH_CHECK(dim <= 20 * 1024, "fast_hadamard_transform_20N only supports hidden dimension at most 20480 for now");

at::Tensor out = torch::empty_like(x);

HadamardParamsBase params;
set_hadamard_params(params, batch_size, dim, 20, x, out, scale);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
fast_hadamard_transform_20N_cuda<input_t>(params, stream);
});
if (dim_og % (4 * 20) != 0) {
out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
}
return out.reshape(shapes_og);
}

at::Tensor
fast_hadamard_transform_28N(at::Tensor &x, float scale) {
auto input_type = x.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

TORCH_CHECK(x.is_cuda());

const auto shapes_og = x.sizes();
const int dim_og = x.size(-1);
x = x.reshape({-1, dim_og});
if (x.stride(-1) != 1) { x = x.contiguous(); }
const auto sizes = x.sizes();
const int batch_size = sizes[0];

CHECK_SHAPE(x, batch_size, dim_og);
TORCH_CHECK(x.stride(1) == 1);

if (dim_og % (4 * 28) != 0) {
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 28) - dim_og % (4 * 28)}));
}
const int dim = x.size(1);

TORCH_CHECK(dim % (4 * 28) == 0, "fast_hadamard_transform_28N only supports hidden dimension divisible by 112 for now");
TORCH_CHECK(dim <= 28 * 1024, "fast_hadamard_transform_28N only supports hidden dimension at most 28672 for now");

at::Tensor out = torch::empty_like(x);

HadamardParamsBase params;
set_hadamard_params(params, batch_size, dim, 28, x, out, scale);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
fast_hadamard_transform_28N_cuda<input_t>(params, stream);
});
if (dim_og % (8 * 28) != 0) {
out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
}
return out.reshape(shapes_og);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fast_hadamard_transform", &fast_hadamard_transform, "Fast Hadamard transform");
m.def("fast_hadamard_transform_20N", &fast_hadamard_transform_20N, "Fast Hadamard transform with dimension = 20 * power of 2");
m.def("fast_hadamard_transform_28N", &fast_hadamard_transform_28N, "Fast Hadamard transform with dimension = 28 * power of 2");
}
2 changes: 1 addition & 1 deletion csrc/fast_hadamard_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
struct HadamardParamsBase {
using index_t = int64_t;

int batch, dim, log_dim;
int batch, dim, log_N;

index_t x_batch_stride;
index_t out_batch_stride;
Expand Down
Loading

0 comments on commit de409a6

Please sign in to comment.