Skip to content

Commit

Permalink
Revert "🐛 Load provider flags from environment variables (#4847)" (#4857
Browse files Browse the repository at this point in the history
)

This reverts commit 0b4c641.
  • Loading branch information
chris-rock authored Nov 14, 2024
1 parent 405a999 commit e2cc11e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 53 deletions.
31 changes: 22 additions & 9 deletions cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package providers

import (
"encoding/json"
"go.mondoo.com/cnquery/v11/utils/piped"
"go.mondoo.com/ranger-rpc/status"
"os"
"strings"

Expand All @@ -19,8 +21,6 @@ import (
"go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin"
"go.mondoo.com/cnquery/v11/providers-sdk/v1/recording"
"go.mondoo.com/cnquery/v11/types"
"go.mondoo.com/cnquery/v11/utils/piped"
"go.mondoo.com/ranger-rpc/status"
)

type Command struct {
Expand Down Expand Up @@ -318,22 +318,35 @@ func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) {
}
}

func getFlagValue(flag plugin.Flag) *llx.Primitive {
func getFlagValue(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive {
switch flag.Type {
case plugin.FlagType_Bool:
return llx.BoolPrimitive(viper.GetBool(flag.Long))
v, err := cmd.Flags().GetBool(flag.Long)
if err == nil {
return llx.BoolPrimitive(v)
}
log.Warn().Err(err).Msg("failed to get flag " + flag.Long)
case plugin.FlagType_Int:
return llx.IntPrimitive(viper.GetInt64(flag.Long))
if v, err := cmd.Flags().GetInt(flag.Long); err == nil {
return llx.IntPrimitive(int64(v))
}
case plugin.FlagType_String:
return llx.StringPrimitive(viper.GetString(flag.Long))
if v, err := cmd.Flags().GetString(flag.Long); err == nil {
return llx.StringPrimitive(v)
}
case plugin.FlagType_List:
return llx.ArrayPrimitiveT(viper.GetStringSlice(flag.Long), llx.StringPrimitive, types.String)
if v, err := cmd.Flags().GetStringSlice(flag.Long); err == nil {
return llx.ArrayPrimitiveT(v, llx.StringPrimitive, types.String)
}
case plugin.FlagType_KeyValue:
return llx.MapPrimitiveT(viper.GetStringMapString(flag.Long), llx.StringPrimitive, types.String)
if v, err := cmd.Flags().GetStringToString(flag.Long); err == nil {
return llx.MapPrimitiveT(v, llx.StringPrimitive, types.String)
}
default:
log.Warn().Msg("unknown flag type for " + flag.Long)
return nil
}
return nil
}

func setConnector(provider *plugin.Provider, connector *plugin.Connector, run func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes), cmd *cobra.Command) {
Expand Down Expand Up @@ -408,7 +421,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
continue
}

if v := getFlagValue(flag); v != nil {
if v := getFlagValue(flag, cmd); v != nil {
flagVals[flag.Long] = v
}
}
Expand Down
47 changes: 3 additions & 44 deletions test/providers/os_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
package providers

import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/test"
"log"
"os"
"os/exec"
"path/filepath"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/test"
)

var once sync.Once
Expand Down Expand Up @@ -186,43 +185,3 @@ func TestOsProviderSharedTests(t *testing.T) {
}
}
}

func TestProvidersEnvVarsLoading(t *testing.T) {
t.Run("command WITHOUT path should not find any package", func(t *testing.T) {
r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
assert.NotNil(t, r.Stderr())

var c mqlPackages
err = r.Json(&c)
assert.NoError(t, err)

// No packages
assert.Empty(t, c)
})
t.Run("command WITH path should find packages", func(t *testing.T) {
os.Setenv("MONDOO_PATH", "./testdata/fs")
defer os.Unsetenv("MONDOO_PATH")
// Note we are not passing the flag "--path ./testdata/fs"
r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
assert.NotNil(t, r.Stderr())

var c mqlPackages
err = r.Json(&c)
assert.NoError(t, err)

// Should have packages
if assert.NotEmpty(t, c) {
x := c[0]
assert.NotNil(t, x.Packages)
assert.True(t, len(x.Packages) > 0)
}
})
}

0 comments on commit e2cc11e

Please sign in to comment.