SmartEngine  1.6.0
RLTrainer.h
1 // Copyright (C) Entropy Software LLC - All Rights Reserved
2 
3 #pragma once
4 
5 #include "Agent.h"
6 #include "GradientDescentTrainer.h"
7 #include "Resource.h"
8 
9 namespace SmartEngine
10 {
11 
12 class IContext;
13 class IGraphNode;
14 
15 #pragma pack(push, 4)
16 struct RLTrainerCInfo : public ResourceCInfo
20 {
24  IContext* context = nullptr;
25 
30 
34  const char* agentName = "";
35 
41  const char** policyNodeNames = nullptr;
42 
47 
51  float gamma = 0.99f; // Reward discount
52 
57 
61  int sequenceLength = 1;
62 };
63 #pragma pack(pop)
64 
68 class SMARTENGINE_EXPORT IRLTrainer : public virtual IObject, public IAgentFactory, public IResource
69 {
70 public:
71  SMARTENGINE_DECLARE_CLASS(IRLTrainer)
72 
73 
74  virtual int GetGenerationCount() const = 0;
77 
82  virtual float GetLoss() = 0;
83 
87  virtual void Reset() = 0;
88 
92  virtual void Step() = 0;
93 };
94 
96 extern "C"
97 {
98  SMARTENGINE_EXPORT int RLTrainer_GetGenerationCount(ObjPtr object);
99  SMARTENGINE_EXPORT float RLTrainer_GetLoss(ObjPtr object);
100  SMARTENGINE_EXPORT void RLTrainer_Reset(ObjPtr object);
101  SMARTENGINE_EXPORT void RLTrainer_Step(ObjPtr object);
102  SMARTENGINE_EXPORT ObjPtr RLTrainer_CreateAgent(ObjPtr object);
103 }
105 
106 } // namespace SmartEngine
SmartEngine::RLTrainerCInfo::sequenceLength
int sequenceLength
LSTM sequence lengths. Can be ignored if there is no LSTM in the graphs.
Definition: RLTrainer.h:61
SmartEngine::IAgentFactory
RL trainers implement this to create agents.
Definition: Agent.h:129
SmartEngine::RLTrainerCInfo::policyNodeNameCount
int policyNodeNameCount
The number of elements in the policy node name array
Definition: RLTrainer.h:46
SmartEngine::RLTrainerCInfo::policyNodeNames
const char ** policyNodeNames
The names of the output nodes of the actor (the network used to manipulate the environment)....
Definition: RLTrainer.h:41
SmartEngine::RLTrainerCInfo::trainingInfo
GradientDescentTrainingInfo trainingInfo
Gradient descent training parameters
Definition: RLTrainer.h:56
SmartEngine::RLTrainerCInfo::gamma
float gamma
Reward decay over time
Definition: RLTrainer.h:51
SmartEngine::RLTrainerCInfo::dataStore
IAgentDataStore * dataStore
The data store used to save experience state
Definition: RLTrainer.h:29
SmartEngine::IRLTrainer
Base class for all reinforcement learning trainers.
Definition: RLTrainer.h:69
SmartEngine::GradientDescentTrainingInfo
GradientDescentTrainer training info
Definition: GradientDescentTrainer.h:74
SmartEngine
Definition: A2CTrainer.h:10
SmartEngine::IObject
Base class for SmartEngine AI objects. It is not common to deal with this class directly.
Definition: Object.h:19
SmartEngine::RLTrainerCInfo::context
IContext * context
The context to perform graph operations within.
Definition: RLTrainer.h:24
SmartEngine::IRLTrainer::GetLoss
virtual float GetLoss()=0
This value will mean different things to different trainers. See each trainer's description for the v...
SmartEngine::IResource
Base class for objects that can be loaded from and saved to disk.
Definition: Resource.h:77
SmartEngine::IAgentDataStore
The agent data store keeps experience data for the purpose of training. Some RL trainers don't store ...
Definition: Agent.h:57
SmartEngine::IContext
Every node in the AI graph must belong to the same context.
Definition: Context.h:40
SmartEngine::RLTrainerCInfo::agentName
const char * agentName
Should be a unique name across the data store
Definition: RLTrainer.h:34
SmartEngine::IRLTrainer::Step
virtual void Step()=0
Step training. May not actual result in any training if there is not enough data available yet.
SmartEngine::IRLTrainer::Reset
virtual void Reset()=0
Resets the trainer to a fresh state, initializing any internal weights to random values.