Skip to content

Commit

Permalink
refactor(x/participationrewards/keeper): combine GetProtocolData+Unma…
Browse files Browse the repository at this point in the history
…rshalProtocolData (#1681)

This change combines GetProtocolData and types.UnmarshalProtocolData
into a generic function that unifies the functionality and the pattern.

Fixes #1631
  • Loading branch information
odeke-em authored Aug 5, 2024
1 parent de7a37c commit 7419f3a
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 120 deletions.
84 changes: 14 additions & 70 deletions x/participationrewards/keeper/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,12 @@ func OsmosisPoolUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, quer
}

poolID := sdk.BigEndianToUint64(query.Request[1:])
data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisPool, fmt.Sprintf("%d", poolID))
if !ok {
return fmt.Errorf("unable to find protocol data for osmosispools/%d", poolID)
}
ipool, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisPool, data.Data)
key := fmt.Sprintf("%d", poolID)
data, pool, err := GetAndUnmarshalProtocolData[*types.OsmosisPoolProtocolData](ctx, k, key, types.ProtocolDataTypeOsmosisPool)
if err != nil {
return err
}
pool, ok := ipool.(*types.OsmosisPoolProtocolData)
if !ok {
return fmt.Errorf("unable to unmarshal protocol data for osmosispools/%d", poolID)
}

pool.PoolData, err = json.Marshal(pd)
if err != nil {
return err
Expand Down Expand Up @@ -185,18 +179,11 @@ func OsmosisClPoolUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, qu
return err
}

data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisCLPool, fmt.Sprintf("%d", poolID))
if !ok {
return fmt.Errorf("unable to find protocol data for osmosisclpools/%d", poolID)
}
ipool, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisCLPool, data.Data)
data, pool, err := GetAndUnmarshalProtocolData[*types.OsmosisClPoolProtocolData](ctx, k, fmt.Sprintf("%d", poolID), types.ProtocolDataTypeOsmosisCLPool)
if err != nil {
return err
}
pool, ok := ipool.(*types.OsmosisClPoolProtocolData)
if !ok {
return fmt.Errorf("unable to unmarshal protocol data for osmosisclpools/%d", poolID)
}

pool.PoolData, err = json.Marshal(pd)
if err != nil {
return err
Expand All @@ -222,18 +209,11 @@ func UmeeReservesUpdateCallback(ctx sdk.Context, k *Keeper, response []byte, que
}

denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixReserveAmount)
data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeReserves, denom)
if !ok {
return fmt.Errorf("unable to find protocol data for umeereserves/%s", denom)
}
ireserves, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeReserves, data.Data)
data, reserves, err := GetAndUnmarshalProtocolData[*types.UmeeReservesProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeReserves)
if err != nil {
return err
}
reserves, ok := ireserves.(*types.UmeeReservesProtocolData)
if !ok {
return fmt.Errorf("unable to unmarshal protocol data for umeereserves/%s", denom)
}

reserves.Data, err = json.Marshal(reserveAmount)
if err != nil {
return err
Expand All @@ -259,18 +239,11 @@ func UmeeTotalBorrowsUpdateCallback(ctx sdk.Context, k *Keeper, response []byte,
}

denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixAdjustedTotalBorrow)
data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeTotalBorrows, denom)
if !ok {
return fmt.Errorf("unable to find protocol data for umee-types total borrows/%s", denom)
}
iborrows, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeTotalBorrows, data.Data)
data, borrows, err := GetAndUnmarshalProtocolData[*types.UmeeTotalBorrowsProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeTotalBorrows)
if err != nil {
return err
}
borrows, ok := iborrows.(*types.UmeeTotalBorrowsProtocolData)
if !ok {
return fmt.Errorf("unable to unmarshal protocol data for umee-types total borrows/%s", denom)
}

borrows.Data, err = json.Marshal(totalBorrows)
if err != nil {
return err
Expand All @@ -296,18 +269,11 @@ func UmeeInterestScalarUpdateCallback(ctx sdk.Context, k *Keeper, response []byt
}

denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixInterestScalar)
data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeInterestScalar, denom)
if !ok {
return fmt.Errorf("unable to find protocol data for interestscalar/%s", denom)
}
iinterest, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeInterestScalar, data.Data)
data, interest, err := GetAndUnmarshalProtocolData[*types.UmeeInterestScalarProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeInterestScalar)
if err != nil {
return err
}
interest, ok := iinterest.(*types.UmeeInterestScalarProtocolData)
if !ok {
return fmt.Errorf("unable to unmarshal protocol data for interestscalar/%s", denom)
}

interest.Data, err = json.Marshal(interestScalar)
if err != nil {
return err
Expand All @@ -333,18 +299,10 @@ func UmeeUTokenSupplyUpdateCallback(ctx sdk.Context, k *Keeper, response []byte,
}

denom := umeetypes.DenomFromKey(query.Request, umeetypes.KeyPrefixUtokenSupply)
data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeUTokenSupply, denom)
if !ok {
return fmt.Errorf("unable to find protocol data for umee-types utoken supply/%s", denom)
}
isupply, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeUTokenSupply, data.Data)
data, supply, err := GetAndUnmarshalProtocolData[*types.UmeeUTokenSupplyProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeUTokenSupply)
if err != nil {
return err
}
supply, ok := isupply.(*types.UmeeUTokenSupplyProtocolData)
if !ok {
return fmt.Errorf("unable to unmarshal protocol data for umee-types utoken supply/%s", denom)
}
supply.Data, err = json.Marshal(supplyAmount)
if err != nil {
return err
Expand Down Expand Up @@ -377,18 +335,10 @@ func UmeeLeverageModuleBalanceUpdateCallback(ctx sdk.Context, k *Keeper, respons
}
balanceAmount := balanceCoin.Amount

data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeUmeeLeverageModuleBalance, denom)
if !ok {
return fmt.Errorf("unable to find protocol data for umee-types leverage module/%s", denom)
}
ibalance, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeLeverageModuleBalance, data.Data)
data, balance, err := GetAndUnmarshalProtocolData[*types.UmeeLeverageModuleBalanceProtocolData](ctx, k, denom, types.ProtocolDataTypeUmeeLeverageModuleBalance)
if err != nil {
return err
}
balance, ok := ibalance.(*types.UmeeLeverageModuleBalanceProtocolData)
if !ok {
return fmt.Errorf("unable to unmarshal protocol data for umee-types leverage module/%s", denom)
}
balance.Data, err = json.Marshal(balanceAmount)
if err != nil {
return err
Expand All @@ -405,14 +355,8 @@ func UmeeLeverageModuleBalanceUpdateCallback(ctx sdk.Context, k *Keeper, respons

// SetEpochBlockCallback records the block height of the registered zone at the epoch boundary.
func SetEpochBlockCallback(ctx sdk.Context, k *Keeper, args []byte, query icqtypes.Query) error {
data, ok := k.GetProtocolData(ctx, types.ProtocolDataTypeConnection, query.ChainId)
if !ok {
return fmt.Errorf("unable to find protocol data for connection/%s", query.ChainId)
}
k.Logger(ctx).Debug("epoch callback called")
iConnectionData, err := types.UnmarshalProtocolData(types.ProtocolDataTypeConnection, data.Data)
connectionData, _ := iConnectionData.(*types.ConnectionProtocolData)

data, connectionData, err := GetAndUnmarshalProtocolData[*types.ConnectionProtocolData](ctx, k, query.ChainId, types.ProtocolDataTypeConnection)
if err != nil {
return err
}
Expand Down
52 changes: 8 additions & 44 deletions x/participationrewards/keeper/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,9 @@ func (suite *KeeperTestSuite) TestOsmosisPoolUpdateCallback() {

suite.NoError(err)

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisPool, "944")
suite.True(found)

data, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisPool, pd.GetData())
_, pooldata, err := keeper.GetAndUnmarshalProtocolData[*types.OsmosisPoolProtocolData](ctx, prk, "944", types.ProtocolDataTypeOsmosisPool)
suite.NoError(err)

pooldata, ok := data.(*types.OsmosisPoolProtocolData)
suite.True(ok)

pool, err := pooldata.GetPool()
suite.NoError(err)

Expand Down Expand Up @@ -137,15 +131,9 @@ func (suite *KeeperTestSuite) TestOsmosisClPoolUpdateCallback() {

suite.NoError(err)

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisCLPool, "1089")
suite.True(found)

data, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisCLPool, pd.GetData())
_, pooldata, err := keeper.GetAndUnmarshalProtocolData[*types.OsmosisClPoolProtocolData](ctx, prk, "1089", types.ProtocolDataTypeOsmosisCLPool)
suite.NoError(err)

pooldata, ok := data.(*types.OsmosisClPoolProtocolData)
suite.True(ok)

pool, err := pooldata.GetPool()
suite.NoError(err)

Expand Down Expand Up @@ -198,12 +186,8 @@ func (suite *KeeperTestSuite) executeOsmosisPoolUpdateCallback() {
},
}

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisPool, "1")
suite.True(found)

ioppd, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisPool, pd.Data)
_, oppd, err := keeper.GetAndUnmarshalProtocolData[*types.OsmosisPoolProtocolData](ctx, prk, "1", types.ProtocolDataTypeOsmosisPool)
suite.NoError(err)
oppd := ioppd.(*types.OsmosisPoolProtocolData)
suite.Equal(want, oppd)
}

Expand Down Expand Up @@ -321,12 +305,8 @@ func (suite *KeeperTestSuite) executeUmeeReservesUpdateCallback() {
},
}

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeReserves, umeeBaseDenom)
suite.True(found)

value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeReserves, pd.Data)
_, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeReservesProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeReserves)
suite.NoError(err)
result := value.(*types.UmeeReservesProtocolData)
suite.Equal(want, result)
}

Expand Down Expand Up @@ -365,12 +345,8 @@ func (suite *KeeperTestSuite) executeUmeeLeverageModuleBalanceUpdateCallback() {
},
}

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeLeverageModuleBalance, umeeBaseDenom)
suite.True(found)

value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeLeverageModuleBalance, pd.Data)
_, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeLeverageModuleBalanceProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeLeverageModuleBalance)
suite.NoError(err)
result := value.(*types.UmeeLeverageModuleBalanceProtocolData)
suite.Equal(want, result)
}

Expand Down Expand Up @@ -407,12 +383,8 @@ func (suite *KeeperTestSuite) executeUmeeUTokenSupplyUpdateCallback() {
},
}

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeUTokenSupply, leveragetypes.UTokenPrefix+umeeBaseDenom)
suite.True(found)

value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeUTokenSupply, pd.Data)
_, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeUTokenSupplyProtocolData](ctx, prk, leveragetypes.UTokenPrefix+umeeBaseDenom, types.ProtocolDataTypeUmeeUTokenSupply)
suite.NoError(err)
result := value.(*types.UmeeUTokenSupplyProtocolData)
suite.Equal(want, result)
}

Expand Down Expand Up @@ -449,12 +421,8 @@ func (suite *KeeperTestSuite) executeUmeeTotalBorrowsUpdateCallback() {
},
}

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeTotalBorrows, umeeBaseDenom)
suite.True(found)

value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeTotalBorrows, pd.Data)
_, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeTotalBorrowsProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeTotalBorrows)
suite.NoError(err)
result := value.(*types.UmeeTotalBorrowsProtocolData)
suite.Equal(want, result)
}

Expand Down Expand Up @@ -491,11 +459,7 @@ func (suite *KeeperTestSuite) executeUmeeInterestScalarUpdateCallback() {
},
}

pd, found := prk.GetProtocolData(ctx, types.ProtocolDataTypeUmeeInterestScalar, umeeBaseDenom)
suite.True(found)

value, err := types.UnmarshalProtocolData(types.ProtocolDataTypeUmeeInterestScalar, pd.Data)
_, result, err := keeper.GetAndUnmarshalProtocolData[*types.UmeeInterestScalarProtocolData](ctx, prk, umeeBaseDenom, types.ProtocolDataTypeUmeeInterestScalar)
suite.NoError(err)
result := value.(*types.UmeeInterestScalarProtocolData)
suite.Equal(want, result)
}
8 changes: 2 additions & 6 deletions x/participationrewards/keeper/distribution.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,12 @@ func DepthFirstSearch(graph AssetGraph, visited map[string]struct{}, asset strin
func (k *Keeper) CalcTokenValues(ctx sdk.Context) (TokenValues, error) {
k.Logger(ctx).Info("calcTokenValues")

data, found := k.GetProtocolData(ctx, types.ProtocolDataTypeOsmosisParams, "osmosisparams")
if !found {
return TokenValues{}, errors.New("could not find osmosisparams protocol data")
}
osmoParams, err := types.UnmarshalProtocolData(types.ProtocolDataTypeOsmosisParams, data.Data)
_, osmoParams, err := GetAndUnmarshalProtocolData[*types.OsmosisParamsProtocolData](ctx, k, "osmosisparams", types.ProtocolDataTypeOsmosisParams)
if err != nil {
return TokenValues{}, err
}

baseDenom := osmoParams.(*types.OsmosisParamsProtocolData).BaseDenom
baseDenom := osmoParams.BaseDenom

tvs := make(TokenValues)
graph := make(AssetGraphSlice)
Expand Down
18 changes: 18 additions & 0 deletions x/participationrewards/keeper/protocol_data.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package keeper

import (
"fmt"

"github.com/cosmos/cosmos-sdk/store/prefix"
sdk "github.com/cosmos/cosmos-sdk/types"

Expand Down Expand Up @@ -38,6 +40,22 @@ func (k Keeper) SetProtocolData(ctx sdk.Context, key []byte, data *types.Protoco
store.Set(types.GetProtocolDataKey(types.ProtocolDataType(pdType), key), bz)
}

func GetAndUnmarshalProtocolData[T any](ctx sdk.Context, k *Keeper, key string, pdType types.ProtocolDataType) (dt types.ProtocolData, tt T, err error) {
data, ok := k.GetProtocolData(ctx, pdType, key)
if !ok {
return dt, tt, fmt.Errorf("unable to find protocol data for %q", key)
}
pd, err := types.UnmarshalProtocolData(pdType, data.Data)
if err != nil {
return dt, tt, err
}
asType, ok := pd.(T)
if !ok {
return dt, tt, fmt.Errorf("could not retrieve type of %T, actual type: %T", (*T)(nil), pd)
}
return data, asType, nil
}

// DeleteProtocolData deletes protocol data info.
func (k *Keeper) DeleteProtocolData(ctx sdk.Context, key []byte) {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefixProtocolData)
Expand Down

0 comments on commit 7419f3a

Please sign in to comment.