Skip to content

Commit

Permalink
[utils] add a set routine
Browse files Browse the repository at this point in the history
  • Loading branch information
ajaypanyala committed Jul 31, 2024
1 parent 809a55f commit d8f0fac
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/tamm/tamm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,8 +1030,20 @@ Tensor<TensorType> sqrt(Tensor<TensorType> tensor) {
return sqrt(tensor(), false);
}

// Normally, a setop is used instead of this routine.
template<typename TensorType>
void random_ip(LabeledTensor<TensorType> ltensor, bool is_lt = true) {
void set_val_ip(LabeledTensor<TensorType> ltensor, TensorType alpha) {
std::function<TensorType(TensorType)> func = [&](TensorType a) { return alpha; };
apply_ewise_ip(ltensor, func);
}

template<typename TensorType>
void set_val_ip(Tensor<TensorType> tensor, TensorType alpha) {
set_val_ip(tensor(), alpha);
}

template<typename TensorType>
void random_ip(LabeledTensor<TensorType> ltensor) {
std::mt19937 generator(get_ec(ltensor).pg().rank().value());
std::uniform_real_distribution<double> tensor_rand_dist(0.0, 1.0);

Expand All @@ -1051,7 +1063,7 @@ void random_ip(LabeledTensor<TensorType> ltensor, bool is_lt = true) {

template<typename TensorType>
void random_ip(Tensor<TensorType> tensor) {
random_ip(tensor(), false);
random_ip(tensor());
}

template<typename TensorType>
Expand Down

0 comments on commit d8f0fac

Please sign in to comment.