Skip to content

Commit

Permalink
customizable prediction (#11)
Browse files Browse the repository at this point in the history
* feat: add parameter 'predictor/max_candidates: {int}' and 'predictor/max_iteration: {int}' to determine maximum prediction candidates count, and maximum prediction times
* feat: add parameter 'predictor/db': {string} for customizing predict db filename

---------

Co-authored-by: Qijia Liu <[email protected]>
  • Loading branch information
fxliang and eagleoflqj authored Oct 29, 2023
1 parent d80be7d commit 1979c00
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 9 deletions.
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@
librime plugin. predict next word.

## Usage
* Put `predict.db` in rime user directory.
* In `*.schema.yaml`, add `predictor` to the list of `engine/processors` before `key_binder`.
* Put the db file (by default `predict.db`) in rime user directory.
* In `*.schema.yaml`, add `predictor` to the list of `engine/processors` before `key_binder`,
or patch the schema with: `engine/processors/@before 0: predictor`
* Config items for your predictor:
```yaml
predictor:
# predict db file in user directory/shared directory
# default to 'predict.db'
db: predict.db
# max prediction candidates every time
# default to 0, which means showing all candidates
# you may set it the same with page_size so that period doesn't trigger next page
max_candidates: 5
# max continuous prediction times
# default to 0, which means no limitation
max_iterations: 1
```
* Deploy and enjoy.
3 changes: 3 additions & 0 deletions src/predict_db.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class PredictDb : public MappedFile {
public:
PredictDb(const string& file_name)
: MappedFile(file_name),
file_name_(file_name),
key_trie_(new Darts::DoubleArray),
value_trie_(new StringTable) {}

Expand All @@ -44,11 +45,13 @@ class PredictDb : public MappedFile {
bool Build(const predict::RawData& data);
predict::Candidates* Lookup(const string& query);
string GetEntryText(const ::rime::table::Entry& entry);
const string file_name() { return file_name_; }

private:
int WriteCandidates(const vector<predict::RawEntry>& candidates,
const table::Entry* entry);

const string file_name_;
predict::Metadata* metadata_ = nullptr;
the<Darts::DoubleArray> key_trie_;
the<StringTable> value_trie_;
Expand Down
57 changes: 51 additions & 6 deletions src/predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
#include <rime/segmentation.h>
#include <rime/service.h>
#include <rime/translation.h>
#include <rime/schema.h>

namespace rime {

Predictor::Predictor(const Ticket& ticket, PredictDb* db)
: Processor(ticket), db_(db) {
Predictor::Predictor(const Ticket& ticket,
PredictDb* db,
int max_candidates,
int max_iterations)
: Processor(ticket),
db_(db),
iteration_counter_(0),
max_iterations_(max_iterations),
max_candidates_(max_candidates) {
// update prediction on context change.
auto* context = engine_->context();
select_connection_ = context->select_notifier().connect(
Expand Down Expand Up @@ -60,8 +68,21 @@ void Predictor::OnContextUpdate(Context* ctx) {
auto last_commit = ctx->commit_history().back();
if (last_commit.type == "punct" || last_commit.type == "raw" ||
last_commit.type == "thru") {
iteration_counter_ = 0;
return;
}
if (last_commit.type == "prediction") {
iteration_counter_++;
if (max_iterations_ > 0 && iteration_counter_ >= max_iterations_) {
iteration_counter_ = 0;
auto* ctx = engine_->context();
if (!ctx->composition().empty() &&
ctx->composition().back().HasTag("prediction")) {
ctx->Clear();
}
return;
}
}
Predict(ctx, last_commit.text);
}

Expand All @@ -75,9 +96,13 @@ void Predictor::Predict(Context* ctx, const string& context_query) {
ctx->composition().back().tags.erase("raw");

auto translation = New<FifoTranslation>();
int i = 0;
for (auto* it = candidates->begin(); it != candidates->end(); ++it) {
translation->Append(
New<SimpleCandidate>("prediction", end, end, db_->GetEntryText(*it)));
i++;
if (max_candidates_ > 0 && i >= max_candidates_)
break;
}
auto menu = New<Menu>();
menu->AddTranslation(translation);
Expand All @@ -90,16 +115,36 @@ PredictorComponent::PredictorComponent() {}
PredictorComponent::~PredictorComponent() {}

Predictor* PredictorComponent::Create(const Ticket& ticket) {
if (!db_) {
int max_iterations = 0, max_candidates = 0;
string db_file = "predict.db";
// load config items from schema
auto* schema = ticket.schema;
if (schema) {
auto* config = schema->config();
string customized_db;
if (!config->GetInt("predictor/max_iterations", &max_iterations)) {
LOG(INFO) << "predictor/max_iterations is not set in schema";
}
if (!config->GetInt("predictor/max_candidates", &max_candidates)) {
LOG(INFO) << "predictor/max_candidates is not set in schema";
}
if (config->GetString("predictor/db", &customized_db) &&
!customized_db.empty()) {
db_file = customized_db;
}
}
if (!db_ || db_->file_name() != db_file) {
the<ResourceResolver> res(
Service::instance().CreateResourceResolver({"predict_db", "", ""}));
auto db =
std::make_unique<PredictDb>(res->ResolvePath("predict.db").string());
string path = res->ResolvePath(db_file).string();
auto db = std::make_unique<PredictDb>(path);
if (db && db->Load()) {
db_ = std::move(db);
} else {
LOG(ERROR) << "failed to load db file: " << path;
}
}
return new Predictor(ticket, db_.get());
return new Predictor(ticket, db_.get(), max_candidates, max_iterations);
}

} // namespace rime
8 changes: 7 additions & 1 deletion src/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ class PredictDb;

class Predictor : public Processor {
public:
Predictor(const Ticket& ticket, PredictDb* db);
Predictor(const Ticket& ticket,
PredictDb* db,
int max_candidates,
int max_iterations);
~Predictor();

ProcessResult ProcessKeyEvent(const KeyEvent& key_event) override;
Expand All @@ -23,6 +26,9 @@ class Predictor : public Processor {
private:
enum Action { kUnspecified, kSelect, kDelete };
Action last_action_ = kUnspecified;
int max_iterations_; // prediction times limit
int max_candidates_; // prediction candidate count limit
int iteration_counter_; // times has been predicted

PredictDb* db_;
connection select_connection_;
Expand Down

0 comments on commit 1979c00

Please sign in to comment.