SmartEngine
1.6.0
Native
Library
API
LossTrainer.h
1
// Copyright (C) Entropy Software LLC - All Rights Reserved
2
3
#pragma once
4
5
#include "Context.h"
6
#include "Graph.h"
7
#include "Object.h"
8
#include "Resource.h"
9
10
namespace
SmartEngine
11
{
12
13
#pragma pack(push, 4)
14
struct
LossTrainerCInfo :
public
ResourceCInfo
18
{
22
IContext
*
context
=
nullptr
;
23
27
IGraph
*
graph
=
nullptr
;
28
};
29
30
enum class
LossTrainingMethod
: byte
31
{
35
WholeDataset,
36
40
Stochastic
,
41
};
42
43
struct
LossTrainingMethodInfo
44
{
48
LossTrainingMethod
method
=
LossTrainingMethod::Stochastic
;
49
57
const
int64*
sessionIndices
=
nullptr
;
58
62
int
sessionCount
= 0;
63
67
int
batchSize
= 32;
68
73
int
sequenceLength
= 1;
74
};
75
#pragma pack(pop)
76
80
class
SMARTENGINE_EXPORT
ILossTrainer
:
public
virtual
IObject
,
public
IResource
81
{
82
public
:
83
SMARTENGINE_DECLARE_CLASS(
ILossTrainer
)
84
85
86
virtual
void
SetTrainingMethod(
const
LossTrainingMethodInfo
& info) = 0;
91
95
virtual
float
GetLoss() = 0;
96
100
virtual
uint32
GetGenerationCount
()
const
= 0;
101
};
102
104
extern
"C"
105
{
106
SMARTENGINE_EXPORT
void
LossTrainer_SetTrainingMethod(ObjPtr
object
,
const
LossTrainingMethodInfo
& info);
107
SMARTENGINE_EXPORT
float
LossTrainer_GetLoss(ObjPtr
object
);
108
SMARTENGINE_EXPORT uint32 LossTrainer_GetGenerationCount(ObjPtr
object
);
109
}
111
112
}
// namespace SmartEngine
SmartEngine::LossTrainingMethodInfo
Definition:
LossTrainer.h:44
SmartEngine::LossTrainingMethod
LossTrainingMethod
Definition:
LossTrainer.h:31
SmartEngine::LossTrainerCInfo::graph
IGraph * graph
The graph to be trained.
Definition:
LossTrainer.h:27
SmartEngine::LossTrainingMethodInfo::sequenceLength
int sequenceLength
How many sequences should be trained every step. Only applies to stepped trainable layers (such as LS...
Definition:
LossTrainer.h:73
SmartEngine::LossTrainingMethod::WholeDataset
@ WholeDataset
Train against all data in the buffer in groups of size BatchSize.
SmartEngine::LossTrainingMethodInfo::sessionIndices
const int64 * sessionIndices
Session indices divide the graph input into logical chunks. These are row indicies in the input buffe...
Definition:
LossTrainer.h:57
SmartEngine::LossTrainingMethodInfo::method
LossTrainingMethod method
The method to use while training
Definition:
LossTrainer.h:48
SmartEngine::LossTrainingMethodInfo::batchSize
int batchSize
The number of rows to train at a time.
Definition:
LossTrainer.h:67
SmartEngine
Definition:
A2CTrainer.h:10
SmartEngine::LossTrainingMethodInfo::sessionCount
int sessionCount
The session index array count
Definition:
LossTrainer.h:62
SmartEngine::IObject
Base class for SmartEngine AI objects. It is not common to deal with this class directly.
Definition:
Object.h:19
SmartEngine::ILossTrainer
Base class for NeuralNetwork loss trainers
Definition:
LossTrainer.h:81
SmartEngine::IGraph
A graph is a collection of buffers and nodes that together form a neural network. The graph is create...
Definition:
Graph.h:61
SmartEngine::ILossTrainer::GetGenerationCount
virtual uint32 GetGenerationCount() const =0
Returns the number of generations we have been training for.
SmartEngine::IResource
Base class for objects that can be loaded from and saved to disk.
Definition:
Resource.h:77
SmartEngine::IContext
Every node in the AI graph must belong to the same context.
Definition:
Context.h:40
SmartEngine::LossTrainerCInfo::context
IContext * context
Context this trainer belongs to. Only trainers and graph nodes of the same context can be connected.
Definition:
LossTrainer.h:22
Generated by
1.8.20