Skip to content

Commit

Permalink
[WIP] changes is non zero logic to more general use of reference indi…
Browse files Browse the repository at this point in the history
…ces in TiledIndexSpaces
  • Loading branch information
erdalmutlu committed Nov 12, 2024
1 parent c586dd7 commit d8accf3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 43 deletions.
56 changes: 30 additions & 26 deletions src/tamm/block_sparse_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,50 +79,54 @@ class BlockSparseTensor: public Tensor<T> { // move to another hpp
/// @return
NonZeroCheck construct_is_non_zero_check(const TiledIndexSpaceVec& tis_vec,
BlockSparseInfo sparse_info) const {
auto is_within_tis = [](size_t block_offset, size_t tis_lo, size_t tis_hi) -> bool {
return (block_offset >= tis_lo && block_offset < tis_hi);
};

auto is_in_allowed_blocks = [tis_vec, allowed_tis_vecs = sparse_info.allowed_tis_vecs,
is_within_tis](const IndexVector& blockid) -> bool {
std::vector<size_t> blockid_offsets;
auto is_in_allowed_blocks = [tis_vec, allowed_tis_vecs = sparse_info.allowed_tis_vecs](
const IndexVector& blockid) -> bool {
std::vector<size_t> ref_indices;
for(size_t i = 0; i < blockid.size(); i++) {
blockid_offsets.push_back(tis_vec[i].tile_offset(blockid[i]));
ref_indices.push_back(tis_vec[i].ref_indices()[blockid[i]]);
}

for(size_t i = 0; i < allowed_tis_vecs.size(); i++) {
auto curr_tis_vec = allowed_tis_vecs[i];
for(size_t j = 0; j < blockid_offsets.size(); j++) {
if(!is_within_tis(blockid_offsets[j], curr_tis_vec[j].tile_offsets().front(),
curr_tis_vec[j].tile_offsets().back())) {
return false;
auto curr_tis_vec = allowed_tis_vecs[i];
bool is_disallowed = false;
for(size_t j = 0; j < ref_indices.size(); j++) {
auto allowed_ref_indices = curr_tis_vec[j].ref_indices();

if(std::find(allowed_ref_indices.begin(), allowed_ref_indices.end(), ref_indices[j]) ==
allowed_ref_indices.end()) {
is_disallowed = true;
break;
}
}
if(!is_disallowed) { return true; }
}

return true;
return false;
};

auto is_in_disallowed_blocks = [tis_vec, disallowed_tis_vecs = sparse_info.disallowed_tis_vecs,
is_within_tis](const IndexVector& blockid) -> bool {
std::vector<size_t> blockid_offsets;
auto is_in_disallowed_blocks = [tis_vec, disallowed_tis_vecs = sparse_info.disallowed_tis_vecs](
const IndexVector& blockid) -> bool {
std::vector<size_t> ref_indices;
for(size_t i = 0; i < blockid.size(); i++) {
blockid_offsets.push_back(tis_vec[i].tile_offset(blockid[i]));
ref_indices.push_back(tis_vec[i].ref_indices()[blockid[i]]);
}

for(size_t i = 0; i < disallowed_tis_vecs.size(); i++) {
std::cerr << __FUNCTION__ << " " << __LINE__ << std::endl;

auto curr_tis_vec = disallowed_tis_vecs[i];
for(size_t j = 0; j < blockid_offsets.size(); j++) {
if(!is_within_tis(blockid_offsets[j], curr_tis_vec[j].tile_offsets().front(),
curr_tis_vec[j].tile_offsets().back())) {
return false;
auto curr_tis_vec = disallowed_tis_vecs[i];
bool is_disallowed = false;
for(size_t j = 0; j < ref_indices.size(); j++) {
auto allowed_ref_indices = curr_tis_vec[j].ref_indices();

if(std::find(allowed_ref_indices.begin(), allowed_ref_indices.end(), ref_indices[j]) ==
allowed_ref_indices.end()) {
is_disallowed = true;
break;
}
}
if(!is_disallowed) { return true; }
}

return true;
return false;
};

auto non_zero_check =
Expand Down
28 changes: 11 additions & 17 deletions tests/tamm/Test_Tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ TEST_CASE("Block Sparse Tensor Construction") {

TiledIndexSpace MO{MO_IS, 5};

auto [i, j, k, l] = MO("occ").labels<4>();
auto [a, b, c, d] = MO("virt").labels<4>();

BlockSparseInfo sparse_info{
{MO, MO, MO, MO}, // Tensor dims
{"ijab", "iajb", "ijka", "ijkl", "iabc", "abcd"}, // Allowed blocks
Expand All @@ -138,33 +141,24 @@ TEST_CASE("Block Sparse Tensor Construction") {
{'a', "virt"},
{'b', "virt"},
{'c', "virt"},
{'d', "virt"}}
// , // Char to named sub-space string
// {"abij", "aibj"} // Disallowed blocks
{'d', "virt"}}, // Char to named sub-space string
{"abij",
"aibj"} // Disallowed blocks - note that allowed blocks will precedence over disallowed blocks
};

failed = false;
try {
std::cerr << __FUNCTION__ << " " << __LINE__ << std::endl;

BlockSparseTensor<T> tensor{{MO, MO, MO, MO}, sparse_info};
// std::cerr << __FUNCTION__ << " " << __LINE__ << std::endl;

// for(const auto& blockid: tensor.loop_nest()) {
// std::cout << "blockid: [ ";
// for(size_t i = 0; i < blockid.size(); i++) { std::cout << blockid[i] << " "; }
// std::cout << "] -> " << std::endl;

// if(tensor.base_ptr()->is_non_zero(blockid)) { std::cout << "is non zero" << std::endl; }
// else { std::cout << "is zero" << std::endl; }
// }
// std::cerr << __FUNCTION__ << " " << __LINE__ << std::endl;

tensor.allocate(ec);
std::cerr << __FUNCTION__ << " " << __LINE__ << std::endl;

Scheduler{*ec}(tensor() = 42).execute();
check_value(tensor, (T) 42);

Scheduler{*ec}(tensor(i, j, a, b) = 1.0)(tensor(i, a, j, b) = 2.0)(tensor(i, j, k, a) = 3.0)(
tensor(i, j, k, l) = 4.0)(tensor(i, a, b, c) = 5.0)(tensor(a, b, c, d) = 6.0)
.execute();

print_tensor_all(tensor);

tensor.deallocate();
Expand Down

0 comments on commit d8accf3

Please sign in to comment.