-
Notifications
You must be signed in to change notification settings - Fork 2
/
dqnAgent.h
59 lines (46 loc) · 1.26 KB
/
dqnAgent.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/*
* http://github.com/dusty-nv/jetson-reinforcement
*/
#ifndef __DEEP_Q_LEARNING_AGENT_H_
#define __DEEP_Q_LEARNING_AGENT_H_
#include "rlAgent.h"
/**
* Deep Q-Learner Agent
*/
class dqnAgent : public rlAgent
{
public:
/**
* Create a new DQN agent training instance,
* the dimensions of a 2D image are expected.
*/
static dqnAgent* Create( uint32_t width, uint32_t height, uint32_t channels, uint32_t numActions,
const char* optimizer="RMSprop", float learning_rate=0.001,
uint32_t replay_mem=10000, uint32_t batch_size=64, float gamma=0.9,
float epsilon_start=0.9, float epsilon_end=0.05, float epsilon_decay=200,
bool use_lstm=true, int lstm_size=256, bool allow_random=true, bool debug_mode=false);
/**
* Destructor
*/
virtual ~dqnAgent();
/**
* From the input state, predict the next action (inference)
* This function isn't used during training, for that see NextReward()
*/
virtual bool NextAction( Tensor* state, int* action );
/**
* Next action with reward (training)
*/
virtual bool NextReward( float reward, bool end_episode );
/**
* GetType
*/
virtual TypeID GetType() const { return TYPE_DQN; }
/**
* TypeID
*/
const TypeID TYPE_DQN = TYPE_RL | (1 << 2);
protected:
dqnAgent();
};
#endif