-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
08d4100
commit d956d3d
Showing
5 changed files
with
576 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,354 @@ | ||
#pragma once | ||
|
||
#include "tamm/tensor.hpp" | ||
|
||
namespace tamm { | ||
|
||
// template<typename T> | ||
// class LabeledTensor; | ||
|
||
/// @brief Creates a local copy of the distributed tensor | ||
/// @tparam T Data type for the tensor being made local | ||
template<typename T> | ||
class LocalTensor: public Tensor<T> { // move to another hpp | ||
public: | ||
LocalTensor() = default; | ||
LocalTensor(LocalTensor&&) = default; | ||
LocalTensor(const LocalTensor&) = default; | ||
LocalTensor& operator=(LocalTensor&&) = default; | ||
LocalTensor& operator=(const LocalTensor&) = default; | ||
~LocalTensor() = default; | ||
|
||
// LocalTensor(Tensor<T> dist_tensor): dist_tensor_(dist_tensor) { construct_local_tensor(); } | ||
|
||
LocalTensor(std::initializer_list<TiledIndexSpace> tiss): | ||
Tensor<T>(construct_local_tis_vec(TiledIndexSpaceVec(tiss))) {} | ||
|
||
LocalTensor(std::vector<TiledIndexSpace> tiss): Tensor<T>(construct_local_tis_vec(tiss)) {} | ||
|
||
LocalTensor(std::initializer_list<TiledIndexLabel> tis_labels): | ||
Tensor<T>(construct_local_tis_vec(IndexLabelVec(tis_labels))) {} | ||
|
||
LocalTensor(std::initializer_list<size_t> dim_sizes): | ||
Tensor<T>(construct_tis_vec(std::vector<size_t>(dim_sizes))) {} | ||
|
||
LocalTensor(std::vector<size_t> dim_sizes): Tensor<T>(construct_tis_vec(dim_sizes)) {} | ||
|
||
/// @brief | ||
/// @tparam ...Args | ||
/// @param ...rest | ||
/// @return | ||
template<class... Args> | ||
LabeledTensor<T> operator()(Args&&... rest) const { | ||
return LabeledTensor<T>{*this, std::forward<Args>(rest)...}; | ||
} | ||
|
||
// void write_back_to_dist() { fill_distributed_tensor(); } | ||
|
||
/// @brief | ||
/// @param val | ||
void init(T val) { | ||
EXPECTS_STR(this->is_allocated(), "LocalTensor has to be allocated"); | ||
|
||
auto ec = this->execution_context(); | ||
Scheduler{*ec}((*this)() = val).execute(); | ||
} | ||
|
||
/// @brief | ||
/// @param indices | ||
/// @param val | ||
void set(std::vector<size_t> indices, T val) { | ||
EXPECTS_STR(this->is_allocated(), "LocalTensor has to be allocated"); | ||
EXPECTS_STR(indices.size() == this->num_modes(), | ||
"Number of indices must match the number of dimensions"); | ||
size_t linearIndex = compute_linear_index(indices); | ||
|
||
this->access_local_buf()[linearIndex] = val; | ||
} | ||
|
||
/// @brief | ||
/// @param indices | ||
/// @return | ||
T get(const std::vector<size_t>& indices) const { | ||
EXPECTS_STR(indices.size() == this->num_modes(), | ||
"Number of indices must match the number of dimensions"); | ||
size_t linearIndex = compute_linear_index(indices); | ||
|
||
return this->access_local_buf()[linearIndex]; | ||
} | ||
|
||
/// @brief | ||
/// @tparam ...Args | ||
/// @param ...args | ||
/// @return | ||
template<typename... Args> | ||
T get(Args... args) { | ||
std::vector<size_t> indices; | ||
unpack(indices, args...); | ||
EXPECTS_STR(indices.size() == this->num_modes(), | ||
"Number of indices must match the number of dimensions"); | ||
size_t linearIndex = compute_linear_index(indices); | ||
|
||
return this->access_local_buf()[linearIndex]; | ||
} | ||
|
||
/// @brief | ||
/// @param new_sizes | ||
template<typename... Args> | ||
void resize(Args... args) { | ||
std::vector<size_t> new_sizes; | ||
unpack(new_sizes, args...); | ||
EXPECTS_STR(new_sizes.size() == (*this).num_modes(), | ||
"Number of new sizes must match the number of dimensions"); | ||
resize(std::vector<size_t>{new_sizes}); | ||
} | ||
|
||
/// @brief | ||
/// @param new_sizes | ||
void resize(const std::vector<size_t>& new_sizes) { | ||
EXPECTS_STR((*this).is_allocated(), "LocalTensor has to be allocated!"); | ||
auto num_dims = (*this).num_modes(); | ||
EXPECTS_STR(num_dims == new_sizes.size(), | ||
"Number of new sizes must match the number of dimensions."); | ||
|
||
for(size_t i = 0; i < new_sizes.size(); i++) { | ||
EXPECTS_STR(new_sizes[i] != 0, "New size should be larger than 0."); | ||
} | ||
|
||
LocalTensor<T> resizedTensor; | ||
|
||
auto dimensions = (*this).dim_sizes(); | ||
|
||
if(dimensions == new_sizes) return; | ||
|
||
if(isWithinOldDimensions(new_sizes)) { | ||
std::vector<size_t> offsets(new_sizes.size(), 0); | ||
resizedTensor = (*this).block(offsets, new_sizes); | ||
} | ||
else { | ||
resizedTensor = LocalTensor<T>{new_sizes}; | ||
resizedTensor.allocate((*this).execution_context()); | ||
(*this).copy_to_bigger(resizedTensor); | ||
} | ||
|
||
auto old_tensor = (*this); | ||
(*this) = resizedTensor; | ||
old_tensor.deallocate(); | ||
} | ||
|
||
// /// @brief | ||
// /// @param sbuf | ||
// /// @param block_dims | ||
// /// @param block_offset | ||
// /// @param copy_to_local | ||
// void patch_copy_local(std::vector<T>& sbuf, const std::vector<size_t>& block_dims, | ||
// const std::vector<size_t>& block_offset, bool copy_to_local) { | ||
// auto num_dims = local_tensor_.num_modes(); | ||
// // Compute the total number of elements to copy | ||
// size_t total_elements = 1; | ||
// for(size_t dim: block_dims) { total_elements *= dim; } | ||
|
||
// // Initialize indices to the starting offset | ||
// std::vector<size_t> indices(block_offset); | ||
|
||
// for(size_t c = 0; c < total_elements; ++c) { | ||
// // Access the tensor element at the current indices | ||
// if(copy_to_local) (*this)(indices) = sbuf[c]; | ||
// else sbuf[c] = (*this)(indices); | ||
|
||
// // Increment indices | ||
// for(int dim = num_dims - 1; dim >= 0; --dim) { | ||
// if(++indices[dim] < block_offset[dim] + block_dims[dim]) { break; } | ||
// indices[dim] = block_offset[dim]; | ||
// } | ||
// } | ||
// } | ||
|
||
/// @brief | ||
/// @param bigger_tensor | ||
void copy_to_bigger(LocalTensor& bigger_tensor) const { | ||
auto smallerDims = (*this).dim_sizes(); | ||
|
||
// Helper lambda to iterate over all indices of a tensor | ||
auto iterateIndices = [](const std::vector<size_t>& dims) { | ||
std::vector<size_t> indices(dims.size(), 0); | ||
bool done = false; | ||
return [=]() mutable { | ||
if(done) return std::optional<std::vector<size_t>>{}; | ||
auto current = indices; | ||
for(int i = indices.size() - 1; i >= 0; --i) { | ||
if(++indices[i] < dims[i]) break; | ||
if(i == 0) { | ||
done = true; | ||
break; | ||
} | ||
indices[i] = 0; | ||
} | ||
return std::optional<std::vector<size_t>>{current}; | ||
}; | ||
}; | ||
|
||
auto smallerIt = iterateIndices(smallerDims); | ||
while(auto indices = smallerIt()) { | ||
auto bigIndices = *indices; | ||
bigger_tensor.set(bigIndices, (*this).get(*indices)); | ||
} | ||
} | ||
|
||
/// @brief | ||
/// @param start_offsets | ||
/// @param span_sizes | ||
/// @return | ||
LocalTensor<T> block(const std::vector<size_t>& start_offsets, | ||
const std::vector<size_t>& span_sizes) const { | ||
EXPECTS_STR((*this).is_allocated(), "LocalTensor has to be allocated!"); | ||
auto num_dims = (*this).num_modes(); | ||
EXPECTS_STR(num_dims == start_offsets.size(), | ||
"Number of start offsets should match the number of dimensions."); | ||
EXPECTS_STR(num_dims == span_sizes.size(), | ||
"Number of span sizes should match the number of dimensions."); | ||
|
||
// this has to be allocated | ||
// offsets should be within limits | ||
// offset + span size should be within limit | ||
|
||
// Create a local tensor for the block | ||
LocalTensor<T> blockTensor{span_sizes}; | ||
blockTensor.allocate(this->execution_context()); | ||
|
||
// Iterate over all dimensions to copy the block | ||
std::vector<size_t> indices(num_dims, 0); | ||
std::vector<size_t> source_indices = start_offsets; | ||
|
||
bool done = false; | ||
while(!done) { | ||
// Copy the element | ||
blockTensor.set(indices, (*this).get(source_indices)); | ||
|
||
// Update indices | ||
done = true; | ||
for(size_t i = 0; i < num_dims; ++i) { | ||
if(++indices[i] < span_sizes[i]) { | ||
++source_indices[i]; | ||
done = false; | ||
break; | ||
} | ||
else { | ||
indices[i] = 0; | ||
source_indices[i] = start_offsets[i]; | ||
} | ||
} | ||
} | ||
|
||
return blockTensor; | ||
} | ||
|
||
/// @brief | ||
/// @param x_offset | ||
/// @param y_offset | ||
/// @param x_span | ||
/// @param y_span | ||
/// @return | ||
LocalTensor<T> block(size_t x_offset, size_t y_offset, size_t x_span, size_t y_span) const { | ||
auto num_dims = (*this).num_modes(); | ||
EXPECTS_STR(num_dims == 2, "This block method only works for 2-D tensors!"); | ||
|
||
return block({x_offset, y_offset}, {x_span, y_span}); | ||
} | ||
|
||
/// @brief | ||
/// @return | ||
std::vector<size_t> dim_sizes() const { | ||
std::vector<size_t> dimensions; | ||
|
||
for(const auto& tis: (*this).tiled_index_spaces()) { | ||
dimensions.push_back(tis.max_num_indices()); | ||
} | ||
|
||
return dimensions; | ||
} | ||
|
||
private: | ||
/// @brief | ||
/// @param tiss | ||
/// @return | ||
TiledIndexSpaceVec construct_local_tis_vec(std::vector<TiledIndexSpace> tiss) { | ||
std::vector<size_t> dim_sizes; | ||
|
||
for(const auto& tis: tiss) { dim_sizes.push_back(tis.max_num_indices()); } | ||
|
||
return construct_tis_vec(dim_sizes); | ||
} | ||
|
||
/// @brief | ||
/// @param tis_labels | ||
/// @return | ||
TiledIndexSpaceVec construct_local_tis_vec(std::vector<TiledIndexLabel> tis_labels) { | ||
std::vector<size_t> dim_sizes; | ||
|
||
for(const auto& tis_label: tis_labels) { | ||
dim_sizes.push_back(tis_label.tiled_index_space().max_num_indices()); | ||
} | ||
|
||
return construct_tis_vec(dim_sizes); | ||
} | ||
|
||
/// @brief | ||
/// @param dim_sizes | ||
/// @return | ||
TiledIndexSpaceVec construct_tis_vec(std::vector<size_t> dim_sizes) { | ||
TiledIndexSpaceVec local_tis_vec; | ||
for(const auto& dim_size: dim_sizes) { | ||
local_tis_vec.push_back(TiledIndexSpace{IndexSpace{range(dim_size)}, dim_size}); | ||
} | ||
|
||
return local_tis_vec; | ||
} | ||
|
||
/// @brief Method for constructing the linearized index for a given location on the local tensor | ||
/// @param indices The index for the corresponding location wanted to be accessed | ||
/// @return The linear position to the local memory manager | ||
size_t compute_linear_index(const std::vector<size_t>& indices) const { | ||
auto num_modes = this->num_modes(); | ||
std::vector<size_t> dims = (*this).dim_sizes(); | ||
size_t index = 0; | ||
size_t stride = 1; | ||
|
||
for(size_t i = 0; i < num_modes; ++i) { | ||
index += indices[num_modes - 1 - i] * stride; | ||
stride *= dims[num_modes - 1 - i]; | ||
} | ||
|
||
return index; | ||
} | ||
|
||
/// @brief | ||
/// @param indices | ||
/// @return | ||
bool isWithinOldDimensions(const std::vector<size_t>& indices) const { | ||
std::vector<size_t> dimensions = (*this).dim_sizes(); | ||
|
||
for(size_t i = 0; i < indices.size(); ++i) { | ||
if(indices[i] > dimensions[i]) { return false; } | ||
} | ||
return true; | ||
} | ||
|
||
/// @brief Helper method that will unpack the variadic template for operator() | ||
/// @param indices A reference to the vector of indices | ||
/// @param index The last index that is provided to the operator() | ||
void unpack(std::vector<size_t>& indices, size_t index) { indices.push_back(index); } | ||
|
||
/// @brief Helper method that will unpack the variadic template for operator() | ||
/// @tparam ...Args The variadic template from the arguments to the operator() | ||
/// @param indices A reference to the vector of indices | ||
/// @param next Unpacked index for the operator() | ||
/// @param ...rest The rest of the variadic template that will be unpacked in the recursive calls | ||
template<typename... Args> | ||
void unpack(std::vector<size_t>& indices, size_t next, Args... rest) { | ||
indices.push_back(next); | ||
unpack(indices, rest...); | ||
} | ||
}; | ||
|
||
} // namespace tamm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.