Skip to content

Commit

Permalink
Update the readme and fix bugs in custom-dataset example (#1214)
Browse files Browse the repository at this point in the history
amend

amend
  • Loading branch information
lancerts authored Jan 13, 2024
1 parent 5921fc1 commit 97adea1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cpp/custom-dataset/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This folder contains an example of loading a custom image dataset with OpenCV and training a model to label images, using the PyTorch C++ frontend.

The dataset used here is [Caltech 101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) dataset.
The dataset used here is [Caltech 101](https://data.caltech.edu/records/mzrjq-6wc02) dataset.

The entire training code is contained in custom-data.cpp.

Expand Down
12 changes: 8 additions & 4 deletions cpp/custom-dataset/custom-dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <iostream>
#include <string>
#include <vector>
#include <random>

struct Options {
int image_size = 224;
Expand Down Expand Up @@ -55,7 +56,7 @@ class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {
auto tdata = torch::cat({R, G, B})
.view({3, options.image_size, options.image_size})
.to(torch::kFloat);
auto tlabel = torch::from_blob(&data[index].second, {1}, torch::kLong);
auto tlabel = torch::tensor(data[index].second, torch::kLong);
return {tdata, tlabel};
}

Expand All @@ -65,6 +66,8 @@ class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {
};

std::pair<Data, Data> readInfo() {
std::random_device randomDevice;
std::mt19937 mersenneTwisterGenerator(randomDevice());
Data train, test;

std::ifstream stream(options.infoFilePath);
Expand All @@ -87,8 +90,8 @@ std::pair<Data, Data> readInfo() {
break;
}

std::random_shuffle(train.begin(), train.end());
std::random_shuffle(test.begin(), test.end());
std::shuffle(train.begin(), train.end(), mersenneTwisterGenerator);
std::shuffle(test.begin(), test.end(), mersenneTwisterGenerator);
return std::make_pair(train, test);
}

Expand Down Expand Up @@ -119,7 +122,8 @@ struct NetworkImpl : torch::nn::SequentialImpl {
push_back(Linear(4096, 4096));
push_back(Functional(torch::relu));
push_back(Linear(4096, 102));
push_back(Functional(torch::log_softmax, 1, torch::nullopt));
push_back(Functional(
[](torch::Tensor input) { return torch::log_softmax(input, 1); }));
}
};
TORCH_MODULE(Network);
Expand Down

0 comments on commit 97adea1

Please sign in to comment.