1 #ifndef EDGE_WITH_WEIGHT_H_
2 #define EDGE_WITH_WEIGHT_H_
28 Matrix& GetWeight() {
return weights_;}
29 Matrix& GetGradWeight() {
return grad_weights_;}
30 Matrix& GetBias() {
return bias_;}
31 Matrix& GetGradBias() {
return grad_bias_;}
33 float GetDecayedEpsilon(
float base_epsilon)
const;
34 float GetMomentum()
const;
36 virtual void InsertPolyak();
37 virtual void BackupCurrent();
38 virtual void LoadCurrentOnGPU();
39 virtual void LoadPolyakOnGPU();
43 void IncrementNumGradsReceived();
44 int GetNumGradsReceived();
46 Matrix weights_, grad_weights_, bias_, grad_bias_;
51 vector<Matrix> polyak_weights_, polyak_bias_;
52 Matrix weights_backup_, bias_backup_;
53 const config::Edge::Initialization initialization_;
54 const int polyak_queue_size_;
56 bool polyak_queue_full_;
59 const float init_wt_, init_bias_;
61 const bool has_no_bias_;
62 int num_grads_received_, num_shares_;
63 const float scale_gradients_;
65 const string pretrained_model_, pretrained_edge_name_;
virtual void SaveParameters(hid_t file)
Write the weights and biases in an hdf5 file.
Definition: edge_with_weight.cc:30
virtual void Initialize()
Initialize the weights and biases.
Definition: edge_with_weight.cc:109
virtual void DisplayWeights()
Displays the weights.
Definition: edge_with_weight.cc:66
virtual void SetTiedTo(Edge *e)
Sets the edge to be tied to another edge.
Definition: edge_with_weight.cc:230
virtual float GetRMSWeight()
Returns the root mean square weight value.
Definition: edge_with_weight.cc:148
virtual bool HasNoParameters() const
Returns whether the edge has any parameters.
Definition: edge_with_weight.cc:161
This class is intended to be used as a base class for implementing edges.
Definition: edge.h:13
virtual void DisplayWeightStats()
Displays the statistics of the weights.
Definition: edge_with_weight.cc:74
Base class for all optimizers.
Definition: optimizer.h:8
virtual void ReduceLearningRate(float factor)
Reduce the learning rate by factor.
Definition: edge_with_weight.cc:92
virtual void UpdateWeights()
Update the weights.
Definition: edge_with_weight.cc:97
A GPU matrix class.
Definition: matrix.h:11
virtual int GetNumModules() const
Returns the number of modules.
Definition: edge_with_weight.cc:165
virtual void LoadParameters(hid_t file)
Load the weights and biases from an hdf5 file.
Definition: edge_with_weight.cc:60
Base class for all edges which have weights.
Definition: edge_with_weight.h:9