Skip to content

Commit

Permalink
Get cost value for each data segment in the trimmed set
Browse files Browse the repository at this point in the history
  • Loading branch information
doccstat committed Mar 30, 2024
1 parent af09cc1 commit fa40d37
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 55 deletions.
125 changes: 70 additions & 55 deletions src/fastcpd_class.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,73 @@ void Fastcpd::update_theta_sum(ucolvec pruned_left) {
theta_sum = theta_sum.cols(pruned_left);
}

List Fastcpd::get_cval_for_r_t_set(
const ucolvec r_t_set,
const unsigned int i,
mat start,
const int t,
double lambda
) {
double cval = 0;
DEBUG_RCOUT(i);
int tau = r_t_set(i - 1);
if (family == "lasso") {
// Mean of `err_sd` only works if error sd is unchanged.
lambda = mean(err_sd) * sqrt(2 * std::log(p) / (t - tau));
}
mat data_segment = data.rows(tau, t - 1);
DEBUG_RCOUT(data_segment);
if (t > vanilla_percentage * n) {
// fastcpd
update_cost_parameters(t, tau, i, k.get(), lambda, line_search);
colvec theta = theta_sum.col(i - 1) / (t - tau);
DEBUG_RCOUT(theta);
if (!contain(FASTCPD_FAMILIES, family)) {
Function cost_non_null = cost.get();
SEXP cost_result = cost_non_null(data_segment, theta);
cval = as<double>(cost_result);
} else if (
(family != "lasso" && t - tau >= p) ||
(family == "lasso" && t - tau >= 3)
) {
cval = cost_function_wrapper(
data_segment, wrap(theta), lambda, false, R_NilValue
).value;
} else {
// t - tau < p or for lasso t - tau < 3
}
} else {
// vanilla PELT
CostResult cost_result;
if (!contain(FASTCPD_FAMILIES, family)) {
cost_result = get_optimized_cost(data_segment);
} else {
if (warm_start && t - tau >= 10 * p) {
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false,
wrap(segment_theta_hat[segment_indices(t - 1) - 1])
// Or use `wrap(start.col(tau))` for warm start.
);
start.col(tau) = colvec(cost_result.par);
} else {
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false, R_NilValue
);
}
}
cval = cost_result.value;

// If `vanilla_percentage` is not 1, then we need to keep track of
// thetas for later `fastcpd` steps.
if (vanilla_percentage < 1 && t <= vanilla_percentage * n) {
update_theta_hat(i - 1, cost_result.par);
update_theta_sum(i - 1, cost_result.par);
}
}

return List::create(Named("cval") = cval, Named("start") = start.col(tau));
}

List Fastcpd::process_cp_set(const colvec raw_cp_set, const double lambda) {
colvec cp_set = trim_cp_set(raw_cp_set);

Expand Down Expand Up @@ -469,61 +536,9 @@ List Fastcpd::run() {

// For tau in R_t \ {t-1}.
for (unsigned int i = 1; i < r_t_count; i++) {
DEBUG_RCOUT(i);
int tau = r_t_set(i - 1);
if (family == "lasso") {
// Mean of `err_sd` only works if error sd is unchanged.
lambda = mean(err_sd) * sqrt(2 * std::log(p) / (t - tau));
}
mat data_segment = data.rows(tau, t - 1);
DEBUG_RCOUT(data_segment);
if (t > vanilla_percentage * n) {
// fastcpd
update_cost_parameters(t, tau, i, k.get(), lambda, line_search);
colvec theta = theta_sum.col(i - 1) / (t - tau);
DEBUG_RCOUT(theta);
if (!contain(FASTCPD_FAMILIES, family)) {
Function cost_non_null = cost.get();
SEXP cost_result = cost_non_null(data_segment, theta);
cval(i - 1) = as<double>(cost_result);
} else if (
(family != "lasso" && t - tau >= p) ||
(family == "lasso" && t - tau >= 3)
) {
cval(i - 1) = cost_function_wrapper(
data_segment, wrap(theta), lambda, false, R_NilValue
).value;
} else {
// t - tau < p or for lasso t - tau < 3
}
} else {
// vanilla PELT
CostResult cost_result;
if (!contain(FASTCPD_FAMILIES, family)) {
cost_result = get_optimized_cost(data_segment);
} else {
if (warm_start && t - tau >= 10 * p) {
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false,
wrap(segment_theta_hat[segment_indices(t - 1) - 1])
// Or use `wrap(start.col(tau))` for warm start.
);
start.col(tau) = colvec(cost_result.par);
} else {
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false, R_NilValue
);
}
}
cval(i - 1) = cost_result.value;

// If `vanilla_percentage` is not 1, then we need to keep track of
// thetas for later `fastcpd` steps.
if (vanilla_percentage < 1 && t <= vanilla_percentage * n) {
update_theta_hat(i - 1, cost_result.par);
update_theta_sum(i - 1, cost_result.par);
}
}
List r_t_set_result = get_cval_for_r_t_set(r_t_set, i, start, t, lambda);
cval(i - 1) = as<double>(r_t_set_result["cval"]);
start.col(r_t_set(i - 1)) = as<colvec>(r_t_set_result["start"]);
}

DEBUG_RCOUT(cval);
Expand Down
8 changes: 8 additions & 0 deletions src/fastcpd_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ class Fastcpd {

double get_cost_adjustment_value(const unsigned nrows);

List get_cval_for_r_t_set(
const ucolvec r_t_set,
const unsigned int i,
mat start,
const int t,
double lambda
);

// Update \code{theta_hat}, \code{theta_sum}, and \code{hessian}.
//
// @param data_segment A data frame containing a segment of the data.
Expand Down

0 comments on commit fa40d37

Please sign in to comment.