You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What is the current behavior?
The current validation dataloader in pretraining_utils.py takes the sampler generated from the X_train, which causes errors if X_train and X_valid in eval_set don't have the exact size.
If the current behavior is a bug, please provide the steps to reproduce.
Here's a script to reproduce the issue, where weights is assigned as a ndarray.
from pytorch_tabnet.pretraining import TabNetPretrainer
import numpy as np
import torch
# Set the random seed for reproducibility
np.random.seed(42)
# Generate random features
num_train_samples = 100000
num_valid_samples = 50000
num_features = 10
X_train = np.random.rand(num_train_samples, num_features)
X_valid = np.random.rand(num_valid_samples, num_features)
# Generate random binary labels
y_train = np.random.randint(2, size=num_train_samples)
y_valid = np.random.randint(2, size=num_valid_samples)
num_positive_samples = np.sum(y_train)
num_negative_samples = len(y_train)-num_positive_samples
class_weights=np.zeros(len(y_train))
class_weights[y_train==0] = 1/num_negative_samples
class_weights[y_train==1] = 1/num_positive_samples
# TabNetPretrainer
unsupervised_model = TabNetPretrainer(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
mask_type='entmax', # "sparsemax",,
device_name='cpu'
)
unsupervised_model.fit(
X_train=X_train,
eval_set=[X_valid],
pretraining_ratio=0.5,
weights=class_weights
)
Expected behavior
Now the above script returns an error IndexError: index 94028 is out of bounds for axis 0 with size 50000, which suggests that the weights is also applied to the X_valid which is not necessary.
Describe the bug
What is the current behavior?
The current validation dataloader in
pretraining_utils.py
takes the sampler generated from theX_train
, which causes errors ifX_train
andX_valid
ineval_set
don't have the exact size.If the current behavior is a bug, please provide the steps to reproduce.
Here's a script to reproduce the issue, where
weights
is assigned as a ndarray.Expected behavior
Now the above script returns an error
IndexError: index 94028 is out of bounds for axis 0 with size 50000
, which suggests that the weights is also applied to the X_valid which is not necessary.Screenshots
Other relevant information:
poetry version:
python version: 3.8
Operating System: linux, macos
Additional tools:
Additional context
The text was updated successfully, but these errors were encountered: