5 #include "CuriosityModule.h"
90 virtual float GetPolicyLoss() = 0;
113 SMARTENGINE_EXPORT ObjPtr A2CTrainer_CreateInstance(
const A2CTrainerCInfo& cinfo);
114 SMARTENGINE_EXPORT
float A2CTrainer_GetPolicyLoss(ObjPtr
object);
115 SMARTENGINE_EXPORT
float A2CTrainer_GetValueLoss(ObjPtr
object);
116 SMARTENGINE_EXPORT
float A2CTrainer_GetEntropyLoss(ObjPtr
object);
IGraph * graph
The graph we are training. This should contain the policy network and value network.
Definition: A2CTrainer.h:22
The A2C Trainer is a reinforcement learning trainer that is composed of two parts: an actor sub graph...
Definition: A2CTrainer.h:85
float valueCoefficient
How much weight the value contributes to the loss
Definition: A2CTrainer.h:38
Data used to construct an IA2CTrainer instance
Definition: A2CTrainer.h:17
const char * valueNodeName
The name of the output of the critic node. This node should be a linear layer with one output neuron ...
Definition: A2CTrainer.h:33
int minBatchSize
How many data samples we should try to train at a time.
Definition: A2CTrainer.h:55
Base class for all reinforcement learning trainers.
Definition: RLTrainer.h:69
float entropyCoefficient
How much weight the entropy contributes to the loss. Entropy is a measure of how random our output is...
Definition: A2CTrainer.h:45
Smart pointer to an IObject. Automatic ref counting.
Definition: ObjectPtr.h:16
Definition: A2CTrainer.h:10
int lookAheadSteps
How many actual experiences we should look at before using an estimate for total rewards this episode...
Definition: A2CTrainer.h:50
A curiosity module is a way of rewarding an agent for behavior not yet seen. Rewards are given based ...
Definition: CuriosityModule.h:108
SMARTENGINE_EXPORT ObjectPtr< IA2CTrainer > CreateA2CTrainer(const A2CTrainerCInfo &cinfo)
Creates an instance of IA2CTrainer
A graph is a collection of buffers and nodes that together form a neural network. The graph is create...
Definition: Graph.h:61
ICuriosityModule * curiosityModule
Optional curiosity module for additional exploration rewards
Definition: A2CTrainer.h:27
virtual float GetValueLoss()=0
Returns the loss in value sub-graph
virtual float GetEntropyLoss()=0
Returns the entropy loss - a measure of how random the network is.
Data used to construct an IRLTrainer instance
Definition: RLTrainer.h:20