Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embeddings_ layer and supporting utility functions #3021

Merged
merged 16 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"githubPullRequests.ignoredPullRequestBranches": [
"master"
]
}
115 changes: 115 additions & 0 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,121 @@ namespace dlib
}

// ------------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
)
{
DLIB_CASSERT(
src.nr() > 0 &&
embs.num_samples() > 0 &&
embs.k() > 0 &&
embs.nr() == 1 &&
embs.nc() == 1,
"\nsrc.num_samples(): " << src.num_samples() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\nembs.num_samples(): " << embs.num_samples() <<
"\nembs.k(): " << embs.k() <<
"\nembs.nr(): " << embs.nr() <<
"\nembs.nc(): " << embs.nc()
);

long ns = dest.num_samples(), nk = dest.k(), nr = dest.nr(), nc = dest.nc();
const float* src_data = src.host();
float* dest_data = dest.host();
const float* embs_data = embs.host();
for (long s = 0; s < ns; ++s)
{
for (long k = 0; k < nk; ++k)
{
for (long r = 0; r < nr; ++r)
{
const unsigned long token_idx = static_cast<unsigned long>(src_data[tensor_index(src, s, k, r, 0)]);
if (token_idx < embs.num_samples())
{
for (long c = 0; c < nc; ++c)
dest_data[tensor_index(dest, s, k, r, c)] = embs_data[tensor_index(embs, token_idx, c, 0, 0)];
}
else
{
for (long c = 0; c < nc; ++c)
dest_data[tensor_index(dest, s, k, r, c)] = 0;
}
}
}
}
}

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
)
{
DLIB_CASSERT(
prev.nr() > 0 &&
gradient_input.num_samples() == prev.num_samples() &&
gradient_input.k() == prev.k() &&
gradient_input.nr() == prev.nr() &&
gradient_input.nc() == grads.k() &&
grads.num_samples() > 0 &&
grads.k() > 0 &&
grads.nr() == 1 &&
grads.nc() == 1,
"\ngradient_input.num_samples(): " << gradient_input.num_samples() <<
"\ngradient_input.k(): " << gradient_input.k() <<
"\ngradient_input.nr(): " << gradient_input.nr() <<
"\ngradient_input.nc(): " << gradient_input.nc() <<
"\nprev.num_samples(): " << prev.num_samples() <<
"\nprev.k(): " << prev.k() <<
"\nprev.nr(): " << prev.nr() <<
"\nprev.nc(): " << prev.nc() <<
"\ngrads.num_samples(): " << grads.num_samples() <<
"\ngrads.k(): " << grads.k() <<
"\ngrads.nr(): " << grads.nr() <<
"\ngrads.nc(): " << grads.nc()
);

const float* prev_data = prev.host();
const float* gradient_input_data = gradient_input.host();
const float* freqs_data = freqs.host();
float* grads_data = grads.host();
long ns = gradient_input.num_samples(), nk = gradient_input.k();
long nr = gradient_input.nr(), nc = gradient_input.nc();

std::vector<dlib::mutex> embedding_mutexes(grads.num_samples());
parallel_for(0, ns * nk, [&](long i)
{
long s = i / nk;
long k = i % nk;

for (long r = 0; r < nr; ++r)
{
const unsigned long token_idx = static_cast<unsigned long>(prev_data[tensor_index(prev, s, k, r, 0)]);
if (token_idx < grads.num_samples())
{
const float freg_token = freqs_data[token_idx];
float freq_scale = 1.0f;

if (scale && freg_token != 0.0f) freq_scale = std::min(0.15f, std::max(1.0f / freg_token, 1.0f));
auto_mutex locker(embedding_mutexes[token_idx]);
for (long c = 0; c < nc; ++c)
{
const float gradient = gradient_input_data[tensor_index(gradient_input, s, k, r, c)];
grads_data[tensor_index(grads, token_idx, c, 0, 0)] -= (gradient * learning_rate * freq_scale);
}
}
}
});
}

// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------

Expand Down
17 changes: 17 additions & 0 deletions dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,23 @@ namespace dlib
const tensor& gradient_input
);

// -----------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
);

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
);

// -----------------------------------------------------------------------------------

class pooling
Expand Down
120 changes: 120 additions & 0 deletions dlib/cuda/cuda_dlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,126 @@ namespace dlib
row_stride, col_stride, add_to);
}

// ----------------------------------------------------------------------------------------

__global__ void _cuda_embeddings(size_t dsize, size_t dk, size_t dr, size_t dc,
float* d, const float* s, const float* e, size_t es
)
{
for (auto i : grid_stride_range(0, dsize))
{
const auto n = i / (dk * dr * dc);
const auto s_idx = i % (dk * dr * dc);
const auto k = (s_idx / (dr * dc)) % dk;
const auto r = (s_idx / dc) % dr;
const auto c = s_idx % dc;

const unsigned long t_idx = static_cast<unsigned long>(s[(n * dk + k) * dr + r]);

if (t_idx < es)
d[i] = e[t_idx * dc + c];
else
d[i] = 0.0f;
}
}

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
)
{
DLIB_CASSERT(
src.nr() > 0 &&
embs.num_samples() > 0 &&
embs.k() > 0 &&
embs.nr() == 1 &&
embs.nc() == 1,
"\nsrc.num_samples(): " << src.num_samples() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\nembs.num_samples(): " << embs.num_samples() <<
"\nembs.k(): " << embs.k() <<
"\nembs.nr(): " << embs.nr() <<
"\nembs.nc(): " << embs.nc()
);

const long dk = dest.k();
const long dr = dest.nr();
const long dc = dest.nc();

launch_kernel(_cuda_embeddings, dest.size(), dk, dr, dc,
dest.device(), src.device(), embs.device(), embs.num_samples());
}

__global__ void _cuda_embeddings_gradient(size_t ssize, size_t sk, size_t sr, size_t sc,
const float* o, const float* gi, float* g, const float* f, float lr, bool sl, size_t es
)
{
for (auto i : grid_stride_range(0, ssize))
{
const auto n = i / (sk * sr * sc);
const auto s_idx = i % (sk * sr * sc);
const auto k = (s_idx / (sr * sc)) % sk;
const auto r = (s_idx / sc) % sr;
const auto c = s_idx % sc;

const unsigned long t_idx = static_cast<unsigned long>(o[(n * sk + k) * sr + r]);
if (t_idx < es)
{
const float f_t = f[t_idx];
float f_s = 1.0f;

if (sl && f_t != 0.0f) f_s = fminf(0.15f, fmaxf(1.0f / f_t, 1.0f));
if (f_t > 1) atomicAdd(&g[t_idx * sc + c], -gi[i] * lr * f_s);
else g[t_idx * sc + c] -= gi[i] * lr * f_s;
}
}
}

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
)
{
DLIB_CASSERT(
prev.nr() > 0 &&
gradient_input.num_samples() == prev.num_samples() &&
gradient_input.k() == prev.k() &&
gradient_input.nr() == prev.nr() &&
gradient_input.nc() == grads.k() &&
grads.num_samples() > 0 &&
grads.k() > 0 &&
grads.nr() == 1 &&
grads.nc() == 1,
"\ngradient_input.num_samples(): " << gradient_input.num_samples() <<
"\ngradient_input.k(): " << gradient_input.k() <<
"\ngradient_input.nr(): " << gradient_input.nr() <<
"\ngradient_input.nc(): " << gradient_input.nc() <<
"\nprev.num_samples(): " << prev.num_samples() <<
"\nprev.k(): " << prev.k() <<
"\nprev.nr(): " << prev.nr() <<
"\nprev.nc(): " << prev.nc() <<
"\ngrads.num_samples(): " << grads.num_samples() <<
"\ngrads.k(): " << grads.k() <<
"\ngrads.nr(): " << grads.nr() <<
"\ngrads.nc(): " << grads.nc()
);

const long sk = gradient_input.k();
const long sr = gradient_input.nr();
const long sc = gradient_input.nc();

launch_kernel(_cuda_embeddings_gradient, gradient_input.size(), sk, sr, sc,
prev.device(), gradient_input.device(), grads.device(), freqs.device(),
learning_rate, scale, grads.num_samples());
}

// ----------------------------------------------------------------------------------------

__global__ void _cuda_layer_normalize(
Expand Down
17 changes: 17 additions & 0 deletions dlib/cuda/cuda_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,23 @@ namespace dlib
const tensor& gradient_input
);

// -----------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
);

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
);

// ----------------------------------------------------------------------------------------

void copy_tensor(
Expand Down
31 changes: 31 additions & 0 deletions dlib/cuda/tensor_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,37 @@ namespace dlib { namespace tt
#endif
}

// ----------------------------------------------------------------------------------------

void embeddings(
resizable_tensor& dest,
const tensor& src,
const tensor& embs
)
{
#ifdef DLIB_USE_CUDA
cuda::embeddings(dest, src, embs);
#else
cpu::embeddings(dest, src, embs);
#endif
}

void embeddings_gradient(
const tensor& prev,
const tensor& gradient_input,
tensor& grads,
const tensor& freqs,
float learning_rate,
bool scale
)
{
#ifdef DLIB_USE_CUDA
cuda::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale);
#else
cpu::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale);
#endif
}

// ----------------------------------------------------------------------------------------

}}
Expand Down
Loading
Loading