-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use TabNet with standard PyTorch model API for data too large to fit in memory #378
Comments
Hello @jlehrer1, The M_loss part is used to add extra sparsity as you can see here : tabnet/pytorch_tabnet/abstract_model.py Line 499 in 40107a8
So you can either ignore M_loss and this will be the same as using lambda_sparse=0 or you can add this to your loss in your custom pipeline to add sparsity constraint (as in the original paper). |
related to #143 |
@Optimox Gotcha, thank you. Is the correct base model |
It would be really nice to use the base model with raw |
I'm not sure I understand what feature you are requesting : every class in this file inherits from Feel free to reuse it and insert it in your own pipeline. |
Perfect, thanks. I wasn't sure if there was any extra logic outside of |
Feature request
Although the sklearn API is quite nice for ease-of-use, it would also be great to use the TabNet model with the standard PyTorch API.
What is the expected behavior?
Call call net = TabNet(100, 10), and use net(sample) and optimizer.backward() to train the model via SGD.
What is motivation or use case for adding/changing the behavior?
There are many cases where writing a manual train loop is preferred, especially when I want to hot-swap this model into an already existing pipeline, or the dataset is too large to fit in memory and can only be accessed sample-wise. This is my entire reason for using TabNet over XGBoost, where creating a dataset distributed in memory is not trivial in certain cases.
How should this be implemented in your opinion?
I see that
pytorch_tabnet.tab_network.TabNet
already exists. What I'm unsure about is the output of the forward pass. It seems to contain both the outputs of a forward pass, as well asM_loss
defined in the encoder. Should I be using this loss, or a standardCrossEntropy
loss for classification?Are you willing to work on this yourself?
Yes! This should be a simple thing to do, I just need to know if
M_loss
can be ignored in the output of a forward pass ofpytorch_tabnet.tab_network.TabNet
.The text was updated successfully, but these errors were encountered: