SmartEngine  1.6.0
LossTrainer.h
1 // Copyright (C) Entropy Software LLC - All Rights Reserved
2 
3 #pragma once
4 
5 #include "Context.h"
6 #include "Graph.h"
7 #include "Object.h"
8 #include "Resource.h"
9 
10 namespace SmartEngine
11 {
12 
13 #pragma pack(push, 4)
14 struct LossTrainerCInfo : public ResourceCInfo
18 {
22  IContext* context = nullptr;
23 
27  IGraph* graph = nullptr;
28 };
29 
30 enum class LossTrainingMethod : byte
31 {
35  WholeDataset,
36 
40  Stochastic,
41 };
42 
44 {
49 
57  const int64* sessionIndices = nullptr;
58 
62  int sessionCount = 0;
63 
67  int batchSize = 32;
68 
73  int sequenceLength = 1;
74 };
75 #pragma pack(pop)
76 
80 class SMARTENGINE_EXPORT ILossTrainer : public virtual IObject, public IResource
81 {
82 public:
83  SMARTENGINE_DECLARE_CLASS(ILossTrainer)
84 
85 
86  virtual void SetTrainingMethod(const LossTrainingMethodInfo& info) = 0;
91 
95  virtual float GetLoss() = 0;
96 
100  virtual uint32 GetGenerationCount() const = 0;
101 };
102 
104 extern "C"
105 {
106  SMARTENGINE_EXPORT void LossTrainer_SetTrainingMethod(ObjPtr object, const LossTrainingMethodInfo& info);
107  SMARTENGINE_EXPORT float LossTrainer_GetLoss(ObjPtr object);
108  SMARTENGINE_EXPORT uint32 LossTrainer_GetGenerationCount(ObjPtr object);
109 }
111 
112 } // namespace SmartEngine
SmartEngine::LossTrainingMethodInfo
Definition: LossTrainer.h:44
SmartEngine::LossTrainingMethod
LossTrainingMethod
Definition: LossTrainer.h:31
SmartEngine::LossTrainerCInfo::graph
IGraph * graph
The graph to be trained.
Definition: LossTrainer.h:27
SmartEngine::LossTrainingMethodInfo::sequenceLength
int sequenceLength
How many sequences should be trained every step. Only applies to stepped trainable layers (such as LS...
Definition: LossTrainer.h:73
SmartEngine::LossTrainingMethod::WholeDataset
@ WholeDataset
Train against all data in the buffer in groups of size BatchSize.
SmartEngine::LossTrainingMethodInfo::sessionIndices
const int64 * sessionIndices
Session indices divide the graph input into logical chunks. These are row indicies in the input buffe...
Definition: LossTrainer.h:57
SmartEngine::LossTrainingMethodInfo::method
LossTrainingMethod method
The method to use while training
Definition: LossTrainer.h:48
SmartEngine::LossTrainingMethodInfo::batchSize
int batchSize
The number of rows to train at a time.
Definition: LossTrainer.h:67
SmartEngine
Definition: A2CTrainer.h:10
SmartEngine::LossTrainingMethodInfo::sessionCount
int sessionCount
The session index array count
Definition: LossTrainer.h:62
SmartEngine::IObject
Base class for SmartEngine AI objects. It is not common to deal with this class directly.
Definition: Object.h:19
SmartEngine::ILossTrainer
Base class for NeuralNetwork loss trainers
Definition: LossTrainer.h:81
SmartEngine::IGraph
A graph is a collection of buffers and nodes that together form a neural network. The graph is create...
Definition: Graph.h:61
SmartEngine::ILossTrainer::GetGenerationCount
virtual uint32 GetGenerationCount() const =0
Returns the number of generations we have been training for.
SmartEngine::IResource
Base class for objects that can be loaded from and saved to disk.
Definition: Resource.h:77
SmartEngine::IContext
Every node in the AI graph must belong to the same context.
Definition: Context.h:40
SmartEngine::LossTrainerCInfo::context
IContext * context
Context this trainer belongs to. Only trainers and graph nodes of the same context can be connected.
Definition: LossTrainer.h:22