Skip to content

Commit

Permalink
weight saving / loading moved over to emer to extent possible. api cl…
Browse files Browse the repository at this point in the history
…eanup should be done now.
  • Loading branch information
rcoreilly committed Aug 13, 2024
1 parent 5037e6e commit ac258bd
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 30 deletions.
27 changes: 11 additions & 16 deletions emer/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ type Layer interface {
// not found.
UnitVarIndex(varNm string) (int, error)

// UnitVal1D returns value of given variable index on given unit,
// UnitValue1D returns value of given variable index on given unit,
// using 1-dimensional index, and a data parallel index di,
// for networks capable of processing multiple input patterns
// in parallel. Returns NaN on invalid index.
// This is the core unit var access method used by other methods,
// so it is the only one that needs to be updated for derived layer types.
UnitVal1D(varIndex int, idx, di int) float32
UnitValue1D(varIndex int, idx, di int) float32

// VarRange returns the min / max values for given variable
VarRange(varNm string) (min, max float32, err error)
Expand Down Expand Up @@ -123,18 +123,8 @@ type Layer interface {

// WriteWeightsJSON writes the weights from this layer from the
// receiver-side perspective in a JSON text format.
// We build in the indentation logic to make it much faster and
// more efficient.
WriteWeightsJSON(w io.Writer, depth int)

// ReadWeightsJSON reads the weights from this layer from the
// receiver-side perspective in a JSON text format.
// This is for a set of weights that were saved
// *for one layer only* and is not used for the
// network-level ReadWeightsJSON, which reads into a separate
// structure -- see SetWeights method.
ReadWeightsJSON(r io.Reader) error

// SetWeights sets the weights for this layer from weights.Layer
// decoded values
SetWeights(lw *weights.Layer) error
Expand Down Expand Up @@ -212,6 +202,11 @@ type LayerBase struct {

// provides a history of parameters applied to the layer
ParamsHistory params.HistoryImpl `table:"-"`

// optional metadata that is saved in network weights files,
// e.g., can indicate number of epochs that were trained,
// or any other information about this network that would be useful to save.
MetaData map[string]string
}

// InitLayer initializes the layer, setting the EmerLayer interface
Expand Down Expand Up @@ -371,7 +366,7 @@ func (ly *LayerBase) UnitValues(vals *[]float32, varNm string, di int) error {
return err
}
for lni := range nn {
(*vals)[lni] = ly.EmerLayer.UnitVal1D(vidx, lni, di)
(*vals)[lni] = ly.EmerLayer.UnitValue1D(vidx, lni, di)
}
return nil
}
Expand Down Expand Up @@ -400,7 +395,7 @@ func (ly *LayerBase) UnitValuesTensor(tsr tensor.Tensor, varNm string, di int) e
return err
}
for lni := 0; lni < nn; lni++ {
v := ly.EmerLayer.UnitVal1D(vidx, lni, di)
v := ly.EmerLayer.UnitValue1D(vidx, lni, di)
if math32.IsNaN(v) {
tsr.SetFloat1D(lni, math.NaN())
} else {
Expand Down Expand Up @@ -446,7 +441,7 @@ func (ly *LayerBase) UnitValuesSampleTensor(tsr tensor.Tensor, varNm string, di
return err
}
for i, ui := range ly.SampleIndexes {
v := ly.EmerLayer.UnitVal1D(vidx, ui, di)
v := ly.EmerLayer.UnitValue1D(vidx, ui, di)
if math32.IsNaN(v) {
tsr.SetFloat1D(i, math.NaN())
} else {
Expand All @@ -467,7 +462,7 @@ func (ly *LayerBase) UnitValue(varNm string, idx []int, di int) float32 {
return math32.NaN()
}
fidx := ly.Shape.Offset(idx)
return ly.EmerLayer.UnitVal1D(vidx, fidx, di)
return ly.EmerLayer.UnitValue1D(vidx, fidx, di)
}

// CenterPoolIndexes returns the indexes for n x n center pools of given 4D layer.
Expand Down
9 changes: 0 additions & 9 deletions emer/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,8 @@ type Path interface {

// WriteWeightsJSON writes the weights from this pathway
// from the receiver-side perspective in a JSON text format.
// We build in the indentation logic to make it much faster and
// more efficient.
WriteWeightsJSON(w io.Writer, depth int)

// ReadWeightsJSON reads the weights from this pathway
// from the receiver-side perspective in a JSON text format.
// This is for a set of weights that were saved *for one path only*
// and is not used for the network-level ReadWeightsJSON,
// which reads into a separate structure -- see SetWeights method.
ReadWeightsJSON(r io.Reader) error

// SetWeights sets the weights for this pathway from weights.Path
// decoded values
SetWeights(pw *weights.Path) error
Expand Down
2 changes: 1 addition & 1 deletion emer/typegen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ac258bd

Please sign in to comment.