Skip to content

Commit

Permalink
major cleanup of env interface: only essential methods, remove counte…
Browse files Browse the repository at this point in the history
…rs, which are now primarily on the looper; updated names for looper too: Ctr -> Counter.
  • Loading branch information
rcoreilly committed Aug 12, 2024
1 parent b11aea1 commit 2e85a3e
Show file tree
Hide file tree
Showing 20 changed files with 136 additions and 520 deletions.
6 changes: 3 additions & 3 deletions econfig/econfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ type TestConfig struct {
NData int `default:"16"`

// if true, save final weights after each run
SaveWts bool
SaveWeights bool

// if true, save train epoch log to file, as .epc.tsv typically
EpochLog bool `default:"true"`
Expand Down Expand Up @@ -145,15 +145,15 @@ func TestArgs(t *testing.T) {
cfg := &TestConfig{}
SetFromDefaults(cfg)
// note: cannot use "-Includes=testcfg.toml",
args := []string{"-save-wts", "-nogui", "-no-epoch-log", "--NoRunLog", "--runs=5", "--run", "1", "--TAG", "nice", "--PatParams.Sparseness=0.1", "--Network", "{'.PFCLayer:Layer.Inhib.Gi' = '2.4', '#VSPatchPath:Path.Learn.LRate' = '0.01'}", "-Enum=TestValue2", "-Slice=[3.2, 2.4, 1.9]", "leftover1", "leftover2"}
args := []string{"-save-weights", "-nogui", "-no-epoch-log", "--NoRunLog", "--runs=5", "--run", "1", "--TAG", "nice", "--PatParams.Sparseness=0.1", "--Network", "{'.PFCLayer:Layer.Inhib.Gi' = '2.4', '#VSPatchPath:Path.Learn.LRate' = '0.01'}", "-Enum=TestValue2", "-Slice=[3.2, 2.4, 1.9]", "leftover1", "leftover2"}
allArgs := make(map[string]reflect.Value)
FieldArgNames(cfg, allArgs)
leftovers, err := ParseArgs(cfg, args, allArgs, true)
if err != nil {
t.Errorf(err.Error())
}
fmt.Println(leftovers)
if cfg.Runs != 5 || cfg.Run != 1 || cfg.Tag != "nice" || cfg.PatParams.Sparseness != 0.1 || cfg.SaveWts != true || cfg.GUI != false || cfg.EpochLog != false || cfg.RunLog != false {
if cfg.Runs != 5 || cfg.Run != 1 || cfg.Tag != "nice" || cfg.PatParams.Sparseness != 0.1 || cfg.SaveWeights != true || cfg.GUI != false || cfg.EpochLog != false || cfg.RunLog != false {
t.Errorf("args not set properly: %#v", cfg)
}
if cfg.Enum != TestValue2 {
Expand Down
18 changes: 10 additions & 8 deletions env/ctr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

package env

// Ctr is a counter that counts increments at a given time scale.
import "github.com/emer/emergent/v2/etime"

// Counter is a counter that counts increments at a given time scale.
// It keeps track of when it has been incremented or not, and
// retains the previous value.
type Ctr struct {
type Counter struct {

// current counter value
Cur int
Expand All @@ -22,25 +24,25 @@ type Ctr struct {
Max int

// the unit of time scale represented by this counter (just FYI)
Scale TimeScales `display:"-"`
Scale etime.Times `display:"-"`
}

// Init initializes counter -- Cur = 0, Prv = -1
func (ct *Ctr) Init() {
func (ct *Counter) Init() {
ct.Prv = -1
ct.Cur = 0
ct.Chg = false
}

// Same resets Chg = false -- good idea to call this on all counters at start of Step
// or can put in an else statement, but that is more error-prone.
func (ct *Ctr) Same() {
func (ct *Counter) Same() {
ct.Chg = false
}

// Incr increments the counter by 1. If Max > 0 then if Incr >= Max
// the counter is reset to 0 and true is returned. Otherwise false.
func (ct *Ctr) Incr() bool {
func (ct *Counter) Incr() bool {
ct.Chg = true
ct.Prv = ct.Cur
ct.Cur++
Expand All @@ -54,7 +56,7 @@ func (ct *Ctr) Incr() bool {
// Set sets the Cur value if different from Cur, while preserving previous value
// and setting Chg appropriately. Returns true if changed.
// does NOT check Cur vs. Max.
func (ct *Ctr) Set(cur int) bool {
func (ct *Counter) Set(cur int) bool {
if ct.Cur == cur {
ct.Chg = false
return false
Expand All @@ -66,6 +68,6 @@ func (ct *Ctr) Set(cur int) bool {
}

// Query returns the current, previous and changed values for this counter
func (ct *Ctr) Query() (cur, prv int, chg bool) {
func (ct *Counter) Query() (cur, prv int, chg bool) {
return ct.Cur, ct.Prv, ct.Chg
}
37 changes: 19 additions & 18 deletions env/ctrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,56 +8,57 @@ import (
"fmt"

"github.com/emer/emergent/v2/estats"
"github.com/emer/emergent/v2/etime"
)

// Ctrs contains an ordered slice of timescales,
// Counters contains an ordered slice of timescales,
// and a lookup map of counters by timescale
// used to manage counters in the Env.
type Ctrs struct {
type Counters struct {

// ordered list of the counter timescales, from outer-most (highest) to inner-most (lowest)
Order []TimeScales
Order []etime.Times

// map of the counters by timescale
Ctrs map[TimeScales]*Ctr
Counters map[etime.Times]*Counter
}

// SetTimes initializes Ctrs for given mode
// SetTimes initializes Counters for given mode
// and list of times ordered from highest to lowest
func (cs *Ctrs) SetTimes(mode string, times ...TimeScales) {
func (cs *Counters) SetTimes(mode string, times ...etime.Times) {
cs.Order = times
cs.Ctrs = make(map[TimeScales]*Ctr, len(times))
cs.Counters = make(map[etime.Times]*Counter, len(times))
for _, tm := range times {
cs.Ctrs[tm] = &Ctr{Scale: tm}
cs.Counters[tm] = &Counter{Scale: tm}
}
}

// ByTime returns counter by timescale key -- nil if not found
func (cs *Ctrs) ByScope(tm TimeScales) *Ctr {
return cs.Ctrs[tm]
func (cs *Counters) ByScope(tm etime.Times) *Counter {
return cs.Counters[tm]
}

// ByTimeTry returns counter by timescale key -- returns nil, error if not found
func (cs *Ctrs) ByTimeTry(tm TimeScales) (*Ctr, error) {
ct, ok := cs.Ctrs[tm]
func (cs *Counters) ByTimeTry(tm etime.Times) (*Counter, error) {
ct, ok := cs.Counters[tm]
if ok {
return ct, nil
}
err := fmt.Errorf("env.Ctrs: scope not found: %s", tm.String())
err := fmt.Errorf("env.Counters: scope not found: %s", tm.String())
return nil, err
}

// Init does Init on all the counters
func (cs *Ctrs) Init() {
for _, ct := range cs.Ctrs {
func (cs *Counters) Init() {
for _, ct := range cs.Counters {
ct.Init()
}
}

// CtrsToStats sets the current counter values to estats Int values
// CountersToStats sets the current counter values to estats Int values
// by their time names only (no eval Mode).
func (cs *Ctrs) CtrsToStats(mode string, stats *estats.Stats) {
for _, ct := range cs.Ctrs {
func (cs *Counters) CountersToStats(mode string, stats *estats.Stats) {
for _, ct := range cs.Counters {
tm := ct.Scale.String()
stats.SetInt(mode+":"+tm, ct.Cur)
}
Expand Down
101 changes: 12 additions & 89 deletions env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

package env

import "cogentcore.org/core/tensor"
import (
"cogentcore.org/core/base/labels"
"cogentcore.org/core/tensor"
)

//go:generate core generate -add-types

// Env defines an interface for environments, which determine the nature and
// sequence of States that can be used as inputs to a model, and the Env
// also can accept Action responses from the model that affect state evolution.
//
// The Env encapsulates all of the counter management logic to advance
// the temporal state of the environment, using TimeScales standard
// intervals.
// The Env manages [Counter] values to advance the temporal state of the
// environment, using [etime.Times] standard intervals.
//
// State is comprised of one or more Elements, each of which consists of an
// tensor.Tensor chunk of values that can be obtained by the model.
Expand All @@ -27,25 +29,7 @@ import "cogentcore.org/core/tensor"
// multiple parameters etc that can be modified to control env behavior --
// all of this is paradigm-specific and outside the scope of this basic interface.
type Env interface {
// Name returns a name for this environment, which can be useful
// for selecting from a list of options etc.
Name() string

// Desc returns an (optional) brief description of this particular
// environment
Desc() string

// Validate checks if the various specific parameters for this
// Env have been properly set -- if not, error message(s) will
// be returned. If everything is OK, nil is returned, in which
// case calls to Counters(), States(), and Actions() should all
// return valid data. It is essential that a model *always* check
// this as a first step, because the Env will not generally check
// for errors on any subsequent calls (for greater efficiency
// and simplicity) and this call can also establish certain general
// initialization settings that are not run-specific and thus make
// sense to do once at this point, not every time during Init().
Validate() error
labels.Labeler

// Init initializes the environment for a given run of the model.
// The environment may not care about the run number, but may implement
Expand All @@ -61,93 +45,32 @@ type Env interface {
Init(run int)

// Step generates the next step of environment state.
// This is the main API for how the model interacts with the environment --
// the env should update all other levels of state internally over
// This is the main API for how the model interacts with the environment.
// The env should update all other levels of state internally over
// repeated calls to the Step method.
// If there are no further inputs available, it returns false (most envs
// typically only return true and just continue running as long as needed).
//
// The Env thus always reflects the *current* state of things, and this
// call increments that current state, such that subsequent calls to
// State(), Counter() etc will return this current state.
// State() will return this current state.
//
// This implies that the state just after Init and prior to first Step
// call should be an *initialized* state that then allows the first Step
// call to establish the proper *first* state. Typically this means that
// one or more counters will be set to -1 during Init and then get incremented
// to 0 on the first Step call.
Step() bool

// Counter(scale TimeScales) returns current counter state for given time scale,
// the immediate previous counter state, and whether that time scale changed
// during the last Step() function call (this may be true even if cur == prv, if
// the Max = 1). Use the Ctr struct for each counter, which manages all of this.
// See external Counter* methods for Python-safe single-return-value versions.
Counter(scale TimeScales) (cur, prv int, changed bool)

// State returns the given element's worth of tensor data from the environment
// based on the current state of the env, as a function of having called Step().
// If no output is available on that element, then nil is returned.
// The returned tensor must be treated as read-only as it likely points to original
// source data -- please make a copy before modifying (e.g., Clone() methdod)
// source data -- please make a copy before modifying (e.g., Clone() methdod).
State(element string) tensor.Tensor

// Action sends tensor data about e.g., responses from model back to act
// on the environment and influence its subsequent evolution.
// The nature and timing of this input is paradigm dependent.
Action(element string, input tensor.Tensor)
}

// EnvDesc is an interface that defines methods that describe an Env.
// These are optional for basic Env, but in cases where an Env
// should be fully self-describing, these methods can be implemented.
type EnvDesc interface {
// Counters returns []TimeScales list of counters supported by this env.
// These should be consistent within a paradigm and most models
// will just expect particular sets of counters, but this can be
// useful for sanity checking that a suitable env has been selected.
// See SchemaFromScales function that takes this list of time
// scales and returns an table.Schema for Table columns to record
// these counters in a log.
Counters() []TimeScales

// States returns a list of Elements of tensor outputs that this env
// generates, specifying the unique Name and Shape of the data.
// This information can be derived directly from an table.Schema
// and used for configuring model input / output pathways to fit
// with those provided by the environment. Depending on the
// env paradigm, all elements may not be always available at every
// point in time e.g., an env might alternate between Action and Reward
// elements. This may return nil if Env has not been properly
// configured.
States() Elements

// Actions returns a list of elements of tensor inputs that this env
// accepts, specifying the unique Name and Shape of the data.
// Specific paradigms of envs can establish the timing and function
// of these inputs, and how they then affect subsequent outputs
// e.g., if the model is required to make a particular choice
// response and then it can receive a reward or not contingent
// on that choice.
Actions() Elements
}

// CounterCur returns current counter state for given time scale
// this Counter for Python because it cannot process multiple return values
func CounterCur(en Env, scale TimeScales) int {
cur, _, _ := en.Counter(scale)
return cur
}

// CounterPrv returns previous counter state for given time scale
// this Counter for Python because it cannot process multiple return values
func CounterPrv(en Env, scale TimeScales) int {
_, prv, _ := en.Counter(scale)
return prv
}

// CounterChg returns whether counter changed during last Step()
// this Counter for Python because it cannot process multiple return values
func CounterChg(en Env, scale TimeScales) bool {
_, _, chg := en.Counter(scale)
return chg
}
4 changes: 2 additions & 2 deletions env/envs.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ func (es *Envs) Init() {
}
}

// Add adds Env(s), using its Name as the key
// Add adds Env(s), using its Label as the key
func (es *Envs) Add(evs ...Env) {
es.Init()
for _, ev := range evs {
(*es)[ev.Name()] = ev
(*es)[ev.Label()] = ev
}
}

Expand Down
Loading

0 comments on commit 2e85a3e

Please sign in to comment.