diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..058954976e --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "githubPullRequests.ignoredPullRequestBranches": [ + "master" + ] +} \ No newline at end of file diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index 58c322ae7b..0e5ca5cee6 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -2421,6 +2421,121 @@ namespace dlib } // ------------------------------------------------------------------------------------ + + void embeddings( + resizable_tensor& dest, + const tensor& src, + const tensor& embs + ) + { + DLIB_CASSERT( + src.nr() > 0 && + embs.num_samples() > 0 && + embs.k() > 0 && + embs.nr() == 1 && + embs.nc() == 1, + "\nsrc.num_samples(): " << src.num_samples() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\nembs.num_samples(): " << embs.num_samples() << + "\nembs.k(): " << embs.k() << + "\nembs.nr(): " << embs.nr() << + "\nembs.nc(): " << embs.nc() + ); + + long ns = dest.num_samples(), nk = dest.k(), nr = dest.nr(), nc = dest.nc(); + const float* src_data = src.host(); + float* dest_data = dest.host(); + const float* embs_data = embs.host(); + for (long s = 0; s < ns; ++s) + { + for (long k = 0; k < nk; ++k) + { + for (long r = 0; r < nr; ++r) + { + const unsigned long token_idx = static_cast(src_data[tensor_index(src, s, k, r, 0)]); + if (token_idx < embs.num_samples()) + { + for (long c = 0; c < nc; ++c) + dest_data[tensor_index(dest, s, k, r, c)] = embs_data[tensor_index(embs, token_idx, c, 0, 0)]; + } + else + { + for (long c = 0; c < nc; ++c) + dest_data[tensor_index(dest, s, k, r, c)] = 0; + } + } + } + } + } + + void embeddings_gradient( + const tensor& prev, + const tensor& gradient_input, + tensor& grads, + const tensor& freqs, + float learning_rate, + bool scale + ) + { + DLIB_CASSERT( + prev.nr() > 0 && + gradient_input.num_samples() == prev.num_samples() && + gradient_input.k() == prev.k() && + gradient_input.nr() == prev.nr() && + gradient_input.nc() == grads.k() && + grads.num_samples() > 0 && + grads.k() > 0 && + grads.nr() == 1 && + grads.nc() == 1, + "\ngradient_input.num_samples(): " << gradient_input.num_samples() << + "\ngradient_input.k(): " << gradient_input.k() << + "\ngradient_input.nr(): " << gradient_input.nr() << + "\ngradient_input.nc(): " << gradient_input.nc() << + "\nprev.num_samples(): " << prev.num_samples() << + "\nprev.k(): " << prev.k() << + "\nprev.nr(): " << prev.nr() << + "\nprev.nc(): " << prev.nc() << + "\ngrads.num_samples(): " << grads.num_samples() << + "\ngrads.k(): " << grads.k() << + "\ngrads.nr(): " << grads.nr() << + "\ngrads.nc(): " << grads.nc() + ); + + const float* prev_data = prev.host(); + const float* gradient_input_data = gradient_input.host(); + const float* freqs_data = freqs.host(); + float* grads_data = grads.host(); + long ns = gradient_input.num_samples(), nk = gradient_input.k(); + long nr = gradient_input.nr(), nc = gradient_input.nc(); + + std::vector embedding_mutexes(grads.num_samples()); + parallel_for(0, ns * nk, [&](long i) + { + long s = i / nk; + long k = i % nk; + + for (long r = 0; r < nr; ++r) + { + const unsigned long token_idx = static_cast(prev_data[tensor_index(prev, s, k, r, 0)]); + if (token_idx < grads.num_samples()) + { + const float freg_token = freqs_data[token_idx]; + float freq_scale = 1.0f; + + if (scale && freg_token != 0.0f) freq_scale = std::min(0.15f, std::max(1.0f / freg_token, 1.0f)); + auto_mutex locker(embedding_mutexes[token_idx]); + for (long c = 0; c < nc; ++c) + { + const float gradient = gradient_input_data[tensor_index(gradient_input, s, k, r, c)]; + grads_data[tensor_index(grads, token_idx, c, 0, 0)] -= (gradient * learning_rate * freq_scale); + } + } + } + }); + } + // ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------ diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index e0896193c3..f26795445d 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -517,6 +517,23 @@ namespace dlib const tensor& gradient_input ); + // ----------------------------------------------------------------------------------- + + void embeddings( + resizable_tensor& dest, + const tensor& src, + const tensor& embs + ); + + void embeddings_gradient( + const tensor& prev, + const tensor& gradient_input, + tensor& grads, + const tensor& freqs, + float learning_rate, + bool scale + ); + // ----------------------------------------------------------------------------------- class pooling diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 171bb55ff1..c650e89bee 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -2088,6 +2088,126 @@ namespace dlib row_stride, col_stride, add_to); } + // ---------------------------------------------------------------------------------------- + + __global__ void _cuda_embeddings(size_t dsize, size_t dk, size_t dr, size_t dc, + float* d, const float* s, const float* e, size_t es + ) + { + for (auto i : grid_stride_range(0, dsize)) + { + const auto n = i / (dk * dr * dc); + const auto s_idx = i % (dk * dr * dc); + const auto k = (s_idx / (dr * dc)) % dk; + const auto r = (s_idx / dc) % dr; + const auto c = s_idx % dc; + + const unsigned long t_idx = static_cast(s[(n * dk + k) * dr + r]); + + if (t_idx < es) + d[i] = e[t_idx * dc + c]; + else + d[i] = 0.0f; + } + } + + void embeddings( + resizable_tensor& dest, + const tensor& src, + const tensor& embs + ) + { + DLIB_CASSERT( + src.nr() > 0 && + embs.num_samples() > 0 && + embs.k() > 0 && + embs.nr() == 1 && + embs.nc() == 1, + "\nsrc.num_samples(): " << src.num_samples() << + "\nsrc.k(): " << src.k() << + "\nsrc.nr(): " << src.nr() << + "\nsrc.nc(): " << src.nc() << + "\nembs.num_samples(): " << embs.num_samples() << + "\nembs.k(): " << embs.k() << + "\nembs.nr(): " << embs.nr() << + "\nembs.nc(): " << embs.nc() + ); + + const long dk = dest.k(); + const long dr = dest.nr(); + const long dc = dest.nc(); + + launch_kernel(_cuda_embeddings, dest.size(), dk, dr, dc, + dest.device(), src.device(), embs.device(), embs.num_samples()); + } + + __global__ void _cuda_embeddings_gradient(size_t ssize, size_t sk, size_t sr, size_t sc, + const float* o, const float* gi, float* g, const float* f, float lr, bool sl, size_t es + ) + { + for (auto i : grid_stride_range(0, ssize)) + { + const auto n = i / (sk * sr * sc); + const auto s_idx = i % (sk * sr * sc); + const auto k = (s_idx / (sr * sc)) % sk; + const auto r = (s_idx / sc) % sr; + const auto c = s_idx % sc; + + const unsigned long t_idx = static_cast(o[(n * sk + k) * sr + r]); + if (t_idx < es) + { + const float f_t = f[t_idx]; + float f_s = 1.0f; + + if (sl && f_t != 0.0f) f_s = fminf(0.15f, fmaxf(1.0f / f_t, 1.0f)); + if (f_t > 1) atomicAdd(&g[t_idx * sc + c], -gi[i] * lr * f_s); + else g[t_idx * sc + c] -= gi[i] * lr * f_s; + } + } + } + + void embeddings_gradient( + const tensor& prev, + const tensor& gradient_input, + tensor& grads, + const tensor& freqs, + float learning_rate, + bool scale + ) + { + DLIB_CASSERT( + prev.nr() > 0 && + gradient_input.num_samples() == prev.num_samples() && + gradient_input.k() == prev.k() && + gradient_input.nr() == prev.nr() && + gradient_input.nc() == grads.k() && + grads.num_samples() > 0 && + grads.k() > 0 && + grads.nr() == 1 && + grads.nc() == 1, + "\ngradient_input.num_samples(): " << gradient_input.num_samples() << + "\ngradient_input.k(): " << gradient_input.k() << + "\ngradient_input.nr(): " << gradient_input.nr() << + "\ngradient_input.nc(): " << gradient_input.nc() << + "\nprev.num_samples(): " << prev.num_samples() << + "\nprev.k(): " << prev.k() << + "\nprev.nr(): " << prev.nr() << + "\nprev.nc(): " << prev.nc() << + "\ngrads.num_samples(): " << grads.num_samples() << + "\ngrads.k(): " << grads.k() << + "\ngrads.nr(): " << grads.nr() << + "\ngrads.nc(): " << grads.nc() + ); + + const long sk = gradient_input.k(); + const long sr = gradient_input.nr(); + const long sc = gradient_input.nc(); + + launch_kernel(_cuda_embeddings_gradient, gradient_input.size(), sk, sr, sc, + prev.device(), gradient_input.device(), grads.device(), freqs.device(), + learning_rate, scale, grads.num_samples()); + } + // ---------------------------------------------------------------------------------------- __global__ void _cuda_layer_normalize( diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index 7aa2e74a84..dab3627b1b 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -561,6 +561,23 @@ namespace dlib const tensor& gradient_input ); + // ----------------------------------------------------------------------------------- + + void embeddings( + resizable_tensor& dest, + const tensor& src, + const tensor& embs + ); + + void embeddings_gradient( + const tensor& prev, + const tensor& gradient_input, + tensor& grads, + const tensor& freqs, + float learning_rate, + bool scale + ); + // ---------------------------------------------------------------------------------------- void copy_tensor( diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index 516e150b3c..90a09a2884 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -1296,6 +1296,37 @@ namespace dlib { namespace tt #endif } +// ---------------------------------------------------------------------------------------- + + void embeddings( + resizable_tensor& dest, + const tensor& src, + const tensor& embs + ) + { +#ifdef DLIB_USE_CUDA + cuda::embeddings(dest, src, embs); +#else + cpu::embeddings(dest, src, embs); +#endif + } + + void embeddings_gradient( + const tensor& prev, + const tensor& gradient_input, + tensor& grads, + const tensor& freqs, + float learning_rate, + bool scale + ) + { +#ifdef DLIB_USE_CUDA + cuda::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale); +#else + cpu::embeddings_gradient(prev, gradient_input, grads, freqs, learning_rate, scale); +#endif + } + // ---------------------------------------------------------------------------------------- }} diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index b44b863560..8ea593a429 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -2050,6 +2050,78 @@ namespace dlib { namespace tt from the channel dimension of gradient_input to the spatial dimensions of grad. !*/ +// ---------------------------------------------------------------------------------------- + + void embeddings( + resizable_tensor& dest, + const tensor& src, + const tensor& embs + ); + /*! + requires + - src.nr() > 0 + - embs.num_samples() > 0 + - embs.k() > 0 + - embs.nr() == 1 + - embs.nc() == 1 + - dest.num_samples() == src.num_samples() + - dest.k() == src.k() + - dest.nr() == src.nr() + - dest.nc() == embs.k() + ensures + - Projects tokens from the input tensor `src` into embeddings stored in `embs`. + - The resulting embeddings are stored in the `dest` tensor. + - For all valid s (0 <= s < dest.num_samples()), + k (0 <= k < dest.k()), + r (0 <= r < dest.nr()), + c (0 <= c < dest.nc()): + - Let token_idx = static_cast(src(s,k,r,0)) + - If token_idx < embs.num_samples(): + - #dest(s,k,r,c) = embs(token_idx, c, 0, 0) + - Else: + - #dest(s,k,r,c) = 0 + - The function iterates over all elements of src and populates dest accordingly. + - If a token index in src is out of range (>= embs.num_samples()), + the corresponding embedding in dest is filled with 0's. + */ + + void embeddings_gradient( + const tensor& prev, + const tensor& gradient_input, + tensor& grads, + const tensor& freqs, + float learning_rate, + bool scale + ); + /*! + requires + - prev.nr() > 0 + - gradient_input.num_samples() == prev.num_samples() + - gradient_input.k() == prev.k() + - gradient_input.nr() == prev.nr() + - gradient_input.nc() == grads.k() + - grads.num_samples() > 0 + - grads.k() > 0 + - grads.nr() == 1 + - grads.nc() == 1 + - freqs.num_samples() == grads.num_samples() + - freqs.k() == 1 + - freqs.nr() == 1 + - freqs.nc() == 1 + ensures + - Updates the `grads` tensor based on the gradients in `gradient_input`. + - For each sample s, channel k, and row r in prev: + - Retrieves the token index from prev[s,k,r,0] + - If the token index is valid (< grads.num_samples()): + - If scale is true: + - Computes a frequency scale factor based on freqs[token_idx] + - The scale factor is min(0.15, max(1.0 / freqs[token_idx], 1.0)) + - For each column c in gradient_input: + - Updates grads[token_idx, c] -= gradient_input[s,k,r,c] * learning_rate * freq_scale + - The updates to grads are performed atomically to handle concurrent updates to the same embedding. + - The function is thread-safe and processes samples in parallel. + */ + // ---------------------------------------------------------------------------------------- class multi_device_tensor_averager diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index d0a0ccaa5a..f34e7a8390 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -4640,7 +4640,7 @@ namespace dlib class transpose_ { public: transpose_() {} - template void setup(const SUBNET& /* sub */) {} + template void setup(const SUBNET& /*sub*/) {} template void forward(const SUBNET& sub, resizable_tensor& output) { auto& prev = sub.get_output(); @@ -4672,21 +4672,21 @@ namespace dlib const tensor& get_layer_params() const { return params; } tensor& get_layer_params() { return params; } - friend void serialize(const transpose_& /* item */, std::ostream& out) { + friend void serialize(const transpose_& /*item*/, std::ostream& out) { serialize("transpose_", out); } - friend void deserialize(transpose_& /* item */, std::istream& in) { + friend void deserialize(transpose_& /*item*/, std::istream& in) { std::string version; deserialize(version, in); if (version != "transpose_") throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::transpose_."); } - friend std::ostream& operator<<(std::ostream& out, const transpose_& /* item */) { + friend std::ostream& operator<<(std::ostream& out, const transpose_& /*item*/) { out << "transpose"; return out; } - friend void to_xml(const transpose_& /* item */, std::ostream& out) { + friend void to_xml(const transpose_& /*item*/, std::ostream& out) { out << "\n"; } @@ -4698,6 +4698,268 @@ namespace dlib // ---------------------------------------------------------------------------------------- + class positional_encodings_ { + public: + positional_encodings_(unsigned long sequence_dim_ = 1, unsigned long embedding_dim_ = 1) : + sequence_dim(sequence_dim_), embedding_dim(embedding_dim_) + { + } + positional_encodings_(const positional_encodings_& item) : + pe(item.pe), sequence_dim(item.sequence_dim), embedding_dim(item.embedding_dim) + { + } + positional_encodings_& operator= (const positional_encodings_& item) { + if (this == &item) return *this; + pe = item.pe; + sequence_dim = item.sequence_dim; + embedding_dim = item.embedding_dim; + return *this; + } + + template + void setup(const SUBNET& sub) + { + auto& prev = sub.get_output(); + + sequence_dim = prev.nr(); + embedding_dim = prev.nc(); + const unsigned long ns = prev.num_samples(); + const unsigned long nk = prev.k(); + const float n = 10000.0f; + + pe.set_size(ns, nk, sequence_dim, embedding_dim); + for (unsigned long s = 0; s < ns; ++s) + { + for (unsigned long k = 0; k < nk; ++k) + { + for (unsigned long r = 0; r < sequence_dim; ++r) + { + for (unsigned long c = 0; c < embedding_dim; ++c) + { + float theta = static_cast(r) / std::pow(n, static_cast(c) / embedding_dim); + if (c % 2 == 0) pe.host()[tensor_index(pe, s, k, r, c)] = std::sin(theta); + else pe.host()[tensor_index(pe, s, k, r, c)] = std::cos(theta); + } + } + } + } + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + const auto& prev_output = sub.get_output(); + if (!have_same_dimensions(pe, prev_output)) setup(sub); + + output.set_size(prev_output.num_samples(), prev_output.k(), sequence_dim, embedding_dim); + tt::add(output, prev_output, pe); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + auto& prev_grad = sub.get_gradient_input(); + tt::add(prev_grad, prev_grad, gradient_input); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + const tensor& get_positional_encodings() const { return pe; } + tensor& get_positional_encodings() { return pe; } + + friend void serialize(const positional_encodings_& /*item*/, std::ostream& out) + { + serialize("positional_encodings_", out); + } + friend void deserialize(positional_encodings_& /*item*/, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "positional_encodings_") + throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::positional_encodings_."); + } + + friend std::ostream& operator<<(std::ostream& out, const positional_encodings_& /*item*/) + { + out << "positional_encodings"; + return out; + } + friend void to_xml(const positional_encodings_& /*item*/, std::ostream& out) + { + out << "\n"; + } + + private: + resizable_tensor params; // unused + resizable_tensor pe; + unsigned long sequence_dim, embedding_dim; + }; + + template + using positional_encodings = add_layer; + +// ---------------------------------------------------------------------------------------- + + template< + unsigned long num_embeddings_, + unsigned long embedding_dim_ + > + class embeddings_ + { + static_assert(num_embeddings_ > 0, "The size of the embedding dictionary must be > 0"); + static_assert(embedding_dim_ > 0, "The size of each embedding vector must be > 0"); + + public: + embeddings_() : num_embeddings(num_embeddings_), + embedding_dim(embedding_dim_), + learning_rate_multiplier(1.0f), + scale_by_freq(true) + { + } + + double get_learning_rate_multiplier() const { return learning_rate_multiplier; } + void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } + + void set_scale_by_freq(bool val) { scale_by_freq = val; } + bool get_scale_by_freq() const { return scale_by_freq; } + + unsigned long get_num_embeddings() const { return num_embeddings; } + void set_num_embeddings(unsigned long num) + { + DLIB_CASSERT(num > 0); + if (num != num_embeddings) + { + DLIB_CASSERT(get_embeddings().size() == 0, + "It is not possible to change the size of the embedding dictionary if the parameter has already been assigned."); + } + } + + unsigned long get_embedding_dim() const { return embedding_dim; } + void set_embedding_dim(unsigned long dim) + { + DLIB_CASSERT(dim > 0); + if (dim != embedding_dim) + { + DLIB_CASSERT(get_embeddings().size() == 0, + "It is not possible to change the size of the embedding dictionary if the parameter has already been assigned."); + } + } + + template + void setup(const SUBNET& /*sub*/) + { + embs.set_size(num_embeddings, embedding_dim); + tt::tensor_rand rnd(std::rand()); + rnd.fill_gaussian(embs); + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + const auto& prev = sub.get_output(); + output.set_size(prev.num_samples(), prev.k(), prev.nr(), embedding_dim); + + tt::embeddings(output, prev, embs); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + // Because this class is expected to be directly after an layer, + // it's not necessary to propagate the gradient. + // Additionally, this layer is treated as constant during backpropagation, + // so it technically doesn't contribute to the gradient computation. + if (learning_rate_multiplier != 0) + { + auto& prev_src = sub.get_output(); + + calc_token_freqs(prev_src, gradient_input); + tt::embeddings_gradient(prev_src, gradient_input, embs, freqs, learning_rate_multiplier, scale_by_freq); + } + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + const tensor& get_embeddings() const { return embs; } + tensor& get_embeddings() { return embs; } + + friend void serialize(const embeddings_& item, std::ostream& out) + { + serialize("embeddings_", out); + serialize(item.embs, out); + serialize(item.num_embeddings, out); + serialize(item.embedding_dim, out); + serialize(item.learning_rate_multiplier, out); + serialize(item.scale_by_freq, out); + } + friend void deserialize(embeddings_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "embeddings_") + throw serialization_error("Unexpected version found while deserializing dlib::embeddings_."); + deserialize(item.embs, in); + deserialize(item.num_embeddings, in); + deserialize(item.embedding_dim, in); + deserialize(item.learning_rate_multiplier, in); + deserialize(item.scale_by_freq, in); + } + + friend std::ostream& operator<<(std::ostream& out, const embeddings_& item) + { + out << "embeddings (num_embeddings=" << item.num_embeddings + << ", embedding_dim=" << item.embedding_dim + << ") learning_rate_mult=" << item.learning_rate_multiplier; + return out; + } + friend void to_xml(const embeddings_& item, std::ostream& out) + { + out << "\n"; + out << mat(item.embs); + out << "\n"; + } + + private: + void calc_token_freqs(const tensor& prev, const tensor& input) { + if (freqs.size() == 0) freqs.set_size(num_embeddings, 1, 1, 1); + freqs = 0; + + const float* prev_data = prev.host(); + float* freqs_data = freqs.host(); + for (long s = 0; s < input.num_samples(); ++s) + { + for (long k = 0; k < input.k(); ++k) + { + for (long r = 0; r < input.nr(); ++r) + { + const unsigned long token_idx = static_cast(prev_data[tensor_index(prev, s, k, r, 0)]); + if (token_idx < num_embeddings) freqs_data[tensor_index(freqs, token_idx, 0, 0, 0)]++; + } + } + } + } + + resizable_tensor params; // unused + resizable_tensor embs, freqs; + unsigned long num_embeddings, embedding_dim; + double learning_rate_multiplier; + bool scale_by_freq; + }; + + template < + unsigned long nb_embeddings, + unsigned long embedding_length, + typename SUBNET + > + using embeddings = add_layer, SUBNET>; + +// ---------------------------------------------------------------------------------------- + struct neg_infinity_tag {}; struct zero_tag {}; @@ -4826,6 +5088,4 @@ namespace dlib } -#endif // DLIB_DNn_LAYERS_H_ - - +#endif // DLIB_DNn_LAYERS_H_ \ No newline at end of file diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index db61423a9a..0d951e7804 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -3711,6 +3711,226 @@ namespace dlib template using transpose = add_layer; +// ---------------------------------------------------------------------------------------- + + class positional_encodings_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + It defines a positional encoding layer that adds position information to + the input tensor. This is particularly useful in transformer architectures + where the order of the sequence matters. + + The dimensions of the tensors output by this layer are the same as the input + tensor dimensions. + + This implementation is based on the positional encoding described in: + Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., + Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In Advances + in neural information processing systems (pp. 5998-6008). + + The encoding uses sine and cosine functions of different frequencies: + PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) + PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model)) + where pos is the position and i is the dimension. + !*/ + + public: + + positional_encodings_( + unsigned long sequence_dim_ = 1, + unsigned long embedding_dim_ = 1 + ); + /*! + ensures + - #sequence_dim == sequence_dim_ + - #embedding_dim == embedding_dim_ + !*/ + + positional_encodings_ ( + const positional_encodings_& item + ); + /*! + ensures + - EXAMPLE_COMPUTATIONAL_LAYER_ objects are copy constructable + !*/ + + positional_encodings_& operator=( + const positional_encodings_& item + ); + /*! + ensures + - EXAMPLE_COMPUTATIONAL_LAYER_ objects are assignable + !*/ + + template + void setup ( + const SUBNET& sub + ); + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of this file. + ensures + - performs any necessary setup for the layer, including the calculation + of positional encodings based on the dimensions of the input. + !*/ + + template + void forward( + const SUBNET& sub, + resizable_tensor& output + ); + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of this file. + - setup() has been called. + ensures + - Adds the positional encodings to the output of the subnetwork and + stores the results into #output. + !*/ + + template + void backward( + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ); + /*! + requires + - SUBNET implements the SUBNET interface defined at the top of this file. + - setup() has been called. + - #params_grad is unused in this layer as there are no learnable parameters. + ensures + - Computes the gradient of the layer with respect to the input, which + is simply the input gradient itself as positional encodings are constant. + !*/ + + const tensor& get_layer_params( + ) const; + /*! + ensures + - returns the parameters that define the behavior of forward(). + Note: This layer has no learnable parameters, so this returns an empty tensor. + !*/ + + tensor& get_layer_params( + ); + /*! + ensures + - returns the parameters that define the behavior of forward(). + Note: This layer has no learnable parameters, so this returns an empty tensor. + !*/ + + const tensor& get_positional_encodings( + ) const; + /*! + ensures + - returns the computed positional encodings. + !*/ + + tensor& get_positional_encodings( + ); + /*! + ensures + - returns the computed positional encodings. + !*/ + + friend void serialize(const positional_encodings_& item, std::ostream& out); + friend void deserialize(positional_encodings_& item, std::istream& in); + /*! + provides serialization support + !*/ + + friend std::ostream& operator<<(std::ostream& out, const positional_encodings_& item); + /*! + print a string describing this layer. + !*/ + + friend void to_xml(const positional_encodings_& item, std::ostream& out); + /*! + This function is optional, but required if you want to print your networks with + net_to_xml(). It prints a layer as XML. + !*/ + }; + + template + using positional_encodings = add_layer; + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long num_embeddings_, + unsigned long embedding_dim_ + > + class embeddings_ + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an embedding layer in a neural network. It maps discrete + tokens to continuous vector representations. This is a fundamental technique in + natural language processing and other domains dealing with categorical data. + + The layer takes as input a tensor of integer indices and outputs a tensor of + the same shape (except for the last dimension) where each index is replaced by + its corresponding embedding vector. + + For more information on embeddings, see: + Mikolov, T., Sutskever, I., Chen, K., Corrado, G. S., & Dean, J. (2013). + Distributed representations of words and phrases and their compositionality. + In Advances in neural information processing systems (pp. 3111-3119). + + TEMPLATE PARAMETERS + - num_embeddings_: The size of the embedding dictionary, i.e., the number of + discrete tokens that can be embedded. + - embedding_dim_: The dimensionality of each embedding vector. + + CONVENTION + - get_embeddings() returns the tensor of embedding vectors. + - get_num_embeddings() == num_embeddings_ + - get_embedding_dim() == embedding_dim_ + - get_learning_rate_multiplier() returns the learning rate multiplier for this layer. + - get_scale_by_freq() returns whether to scale gradients by token frequency. + */ + public: + embeddings_() = default; + + unsigned long get_num_embeddings() const; + unsigned long get_embedding_dim() const; + double get_learning_rate_multiplier() const; + bool get_scale_by_freq() const; + + void set_num_embeddings(unsigned long num); + void set_embedding_dim(unsigned long dim); + void set_learning_rate_multiplier(double val); + void set_scale_by_freq(bool val); + + template void setup(const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + + const tensor& get_layer_params() const; + tensor& get_layer_params(); + const tensor& get_embeddings() const; + tensor& get_embeddings(); + + friend void serialize(const embeddings_& item, std::ostream& out); + friend void deserialize(embeddings_& item, std::istream& in); + friend std::ostream& operator<<(std::ostream& out, const embeddings_& item); + friend void to_xml(const embeddings_& item, std::ostream& out); + + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template < + unsigned long num_embeddings, + unsigned long embedding_dim, + typename SUBNET + > + using embeddings = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- struct neg_infinity_tag {}; @@ -3865,7 +4085,7 @@ namespace dlib using tril_mask = add_layer, SUBNET>; template - using tril_diag = add_layer, SUBNET>; + using tril_diag = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- diff --git a/dlib/dnn/visitors.h b/dlib/dnn/visitors.h index d383967c36..726f3b200e 100644 --- a/dlib/dnn/visitors.h +++ b/dlib/dnn/visitors.h @@ -1029,6 +1029,24 @@ namespace dlib update(i); } + template + void operator()(size_t i, const add_layer, U, E>& l) + { + start_node(i, "embeddings"); + out << " | {num_embeddings|{" << l.layer_details().get_num_embeddings() << "}}"; + out << " | {embedding_dim|{" << l.layer_details().get_embedding_dim() << "}}"; + end_node(); + update(i); + } + + template + void operator()(size_t i, const add_layer&) + { + start_node(i, "positional_encodings"); + end_node(); + update(i); + } + template void operator()(size_t i, const add_layer, U, E>&) { @@ -1043,7 +1061,7 @@ namespace dlib out << "}}"; end_node(); update(i); - } + } template void operator()(size_t i, const add_layer&) diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 9ca69ce286..6d3c6c94b4 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -779,6 +779,89 @@ namespace #endif } +// ---------------------------------------------------------------------------------------- + +void test_positional_encodings() +{ + print_spinner(); + using net_type = tag1>>>; + net_type net; + + const unsigned long sequence_dim = 4; + const unsigned long embedding_dim = 6; + const unsigned long n_samples = 1, n_channels = 1; + matrix input_data(sequence_dim, embedding_dim); + input_data = 0.0f; + + resizable_tensor input_tensor(n_samples, n_channels, sequence_dim, embedding_dim); + std::vector> x(n_samples); + x[0] = input_data; + net.to_tensor(&x[0], &x[0] + n_samples, input_tensor); + net.forward(input_tensor); + + matrix expected_output(sequence_dim, embedding_dim); + const float n = 10000.0f; + for (long r = 0; r < sequence_dim; ++r) { + for (long c = 0; c < embedding_dim; ++c) { + float theta = static_cast(r) / std::pow(n, static_cast(c) / embedding_dim); + expected_output(r, c) = (c % 2 == 0) ? std::sin(theta) : std::cos(theta); + } + } + + auto& net_output = layer(net).get_output(); + DLIB_TEST(max(abs(mat(net_output) - expected_output)) < 1e-5); +} + +// ---------------------------------------------------------------------------------------- + +void test_embeddings() +{ + print_spinner(); + const size_t num_sequences = 100, sequence_length = 7, num_classes = 3, num_tokens = 50, embedding_length = 5; + using net_type = loss_multiclass_log>>>>>>>>; + net_type net; + dnn_trainer trainer(net, sgd(0, 0.9)); + trainer.set_learning_rate(1e-1); + trainer.set_min_learning_rate(1e-4); + trainer.set_mini_batch_size(16); + trainer.set_max_num_epochs(500); + + dlib::rand rnd(std::rand()); + auto generate_sequences = [&](size_t num_sequences, size_t sequence_length, size_t num_tokens) { + std::vector> sequences; + for (size_t i = 0; i < num_sequences; ++i) + { + matrix seq(sequence_length, 1); + for (size_t j = 0; j < sequence_length; ++j) + seq(j, 0) = rnd.get_random_32bit_number() % num_tokens; + sequences.push_back(seq); + } + return sequences; + }; + + auto generate_labels = [&](size_t num_sequences, size_t num_classes) { + std::vector labels; + for (size_t i = 0; i < num_sequences; ++i) + labels.push_back(rnd.get_random_32bit_number() % num_classes); + return labels; + }; + + auto sequences = generate_sequences(num_sequences, sequence_length, num_tokens); + auto labels = generate_labels(num_sequences, num_classes); + + trainer.train(sequences, labels); + std::vector predicted_labels = net(sequences); + size_t num_correct = 0; + for (size_t i = 0; i < labels.size(); ++i) + if (predicted_labels[i] == labels[i]) ++num_correct; + + double acc = static_cast(num_correct) / labels.size(); + DLIB_TEST(acc > 0.9); +} + // ---------------------------------------------------------------------------------------- void test_basic_tensor_ops() @@ -2322,6 +2405,18 @@ namespace transpose_ l; auto res = test_layer(l); DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + positional_encodings_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + embeddings_<7, 12> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); } } @@ -4574,6 +4669,8 @@ namespace test_layer_normalize(); test_rms_normalize(); test_transpose(); + test_positional_encodings(); + test_embeddings(); test_tril(); test_basic_tensor_ops(); test_layers();