diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 09de98cf79..e7c38a3d46 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -16,6 +16,9 @@ jobs:
name: Checks
runs-on: ubuntu-22.04
+ env:
+ PMM_ENCRYPTION_KEY_PATH: pmm-encryption.key
+
steps:
- name: Check out code
uses: actions/checkout@v4
diff --git a/go.mod b/go.mod
index f2e9f2d32a..f25d8d8e8f 100644
--- a/go.mod
+++ b/go.mod
@@ -38,6 +38,7 @@ require (
github.com/go-sql-driver/mysql v1.7.1
github.com/gogo/status v1.1.1
github.com/golang-migrate/migrate/v4 v4.17.0
+ github.com/google/tink/go v1.7.0
github.com/google/uuid v1.6.0
github.com/grafana/grafana-api-golang-client v0.27.0
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
@@ -101,7 +102,6 @@ require (
github.com/google/btree v1.0.0 // indirect
github.com/hashicorp/go-hclog v1.6.2 // indirect
github.com/hashicorp/go-msgpack/v2 v2.1.1 // indirect
- github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/miekg/dns v1.1.26 // indirect
diff --git a/go.sum b/go.sum
index eec4683a48..c236c1d13f 100644
--- a/go.sum
+++ b/go.sum
@@ -239,6 +239,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/tink/go v1.7.0 h1:6Eox8zONGebBFcCBqkVmt60LaWZa6xg1cl/DwAh/J1w=
+github.com/google/tink/go v1.7.0/go.mod h1:GAUOd+QE3pgj9q8VKIGTCP33c/B7eb4NhxLcgTJZStM=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
diff --git a/managed/Makefile b/managed/Makefile
index fa0ab06fa8..dcf35ec550 100644
--- a/managed/Makefile
+++ b/managed/Makefile
@@ -37,7 +37,6 @@ clean: ## Remove generated files
release: ## Build pmm-managed release binaries
env CGO_ENABLED=0 go build -v $(PMM_LD_FLAGS) -o $(PMM_RELEASE_PATH)/ ./cmd/...
- $(PMM_RELEASE_PATH)/pmm-managed --version
release-starlark:
env CGO_ENABLED=0 go build -v $(PMM_LD_FLAGS) -o $(PMM_RELEASE_PATH)/ ./cmd/pmm-managed-starlark/...
diff --git a/managed/models/agent_helpers.go b/managed/models/agent_helpers.go
index 041fbe4499..eb677fc791 100644
--- a/managed/models/agent_helpers.go
+++ b/managed/models/agent_helpers.go
@@ -229,7 +229,8 @@ func FindAgents(q *reform.Querier, filters AgentFilters) ([]*Agent, error) {
agents := make([]*Agent, len(structs))
for i, s := range structs {
- agents[i] = s.(*Agent) //nolint:forcetypeassert
+ decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert
+ agents[i] = &decryptedAgent
}
return agents, nil
@@ -249,8 +250,9 @@ func FindAgentByID(q *reform.Querier, id string) (*Agent, error) {
}
return nil, errors.WithStack(err)
}
+ decryptedAgent := DecryptAgent(*agent)
- return agent, nil
+ return &decryptedAgent, nil
}
// FindAgentsByIDs finds Agents by IDs.
@@ -272,7 +274,8 @@ func FindAgentsByIDs(q *reform.Querier, ids []string) ([]*Agent, error) {
res := make([]*Agent, len(structs))
for i, s := range structs {
- res[i] = s.(*Agent) //nolint:forcetypeassert
+ decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert
+ res[i] = &decryptedAgent
}
return res, nil
}
@@ -323,7 +326,8 @@ func FindDBConfigForService(q *reform.Querier, serviceID string) (*DBConfig, err
res := make([]*Agent, len(structs))
for i, s := range structs {
- res[i] = s.(*Agent) //nolint:forcetypeassert
+ decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert
+ res[i] = &decryptedAgent
}
if len(res) == 0 {
@@ -350,8 +354,8 @@ func FindPMMAgentsRunningOnNode(q *reform.Querier, nodeID string) ([]*Agent, err
res := make([]*Agent, 0, len(structs))
for _, str := range structs {
- row := str.(*Agent) //nolint:forcetypeassert
- res = append(res, row)
+ decryptedAgent := DecryptAgent(*str.(*Agent)) //nolint:forcetypeassert
+ res = append(res, &decryptedAgent)
}
return res, nil
@@ -395,8 +399,8 @@ func FindPMMAgentsForService(q *reform.Querier, serviceID string) ([]*Agent, err
}
res := make([]*Agent, 0, len(pmmAgentRecords))
for _, str := range pmmAgentRecords {
- row := str.(*Agent) //nolint:forcetypeassert
- res = append(res, row)
+ decryptedAgent := DecryptAgent(*str.(*Agent)) //nolint:forcetypeassert
+ res = append(res, &decryptedAgent)
}
return res, nil
@@ -477,7 +481,8 @@ func FindAgentsForScrapeConfig(q *reform.Querier, pmmAgentID *string, pushMetric
res := make([]*Agent, len(allAgents))
for i, s := range allAgents {
- res[i] = s.(*Agent) //nolint:forcetypeassert
+ decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert
+ res[i] = &decryptedAgent
}
return res, nil
}
@@ -641,11 +646,14 @@ func CreateNodeExporter(q *reform.Querier,
if err := row.SetCustomLabels(customLabels); err != nil {
return nil, err
}
- if err := q.Insert(row); err != nil {
+
+ encryptedAgent := EncryptAgent(*row)
+ if err := q.Insert(&encryptedAgent); err != nil {
return nil, errors.WithStack(err)
}
+ agent := DecryptAgent(encryptedAgent)
- return row, nil
+ return &agent, nil
}
// CreateExternalExporterParams params for add external exporter.
@@ -725,11 +733,14 @@ func CreateExternalExporter(q *reform.Querier, params *CreateExternalExporterPar
if err := row.SetCustomLabels(params.CustomLabels); err != nil {
return nil, err
}
- if err := q.Insert(row); err != nil {
+
+ encryptedAgent := EncryptAgent(*row)
+ if err := q.Insert(&encryptedAgent); err != nil {
return nil, errors.WithStack(err)
}
+ agent := DecryptAgent(encryptedAgent)
- return row, nil
+ return &agent, nil
}
// CreateAgentParams params for add common exporter.
@@ -912,15 +923,17 @@ func CreateAgent(q *reform.Querier, agentType AgentType, params *CreateAgentPara
DisabledCollectors: params.DisableCollectors,
LogLevel: pointer.ToStringOrNil(params.LogLevel),
}
-
if err := row.SetCustomLabels(params.CustomLabels); err != nil {
return nil, err
}
- if err := q.Insert(row); err != nil {
+
+ encryptedAgent := EncryptAgent(*row)
+ if err := q.Insert(&encryptedAgent); err != nil {
return nil, errors.WithStack(err)
}
+ agent := DecryptAgent(encryptedAgent)
- return row, nil
+ return &agent, nil
}
// ChangeCommonAgentParams contains parameters that can be changed for all Agents.
diff --git a/managed/models/database.go b/managed/models/database.go
index 8f45a0364c..d2c2066c7f 100644
--- a/managed/models/database.go
+++ b/managed/models/database.go
@@ -27,6 +27,7 @@ import (
"net"
"net/url"
"os"
+ "slices"
"strconv"
"strings"
@@ -36,6 +37,8 @@ import (
"google.golang.org/grpc/status"
"gopkg.in/reform.v1"
"gopkg.in/reform.v1/dialects/postgresql"
+
+ "github.com/percona/pmm/managed/utils/encryption"
)
const (
@@ -1146,12 +1149,87 @@ func SetupDB(ctx context.Context, sqlDB *sql.DB, params SetupDBParams) (*reform.
return nil, errCV
}
- if err := migrateDB(db, params); err != nil {
+ agentColumnsToEncrypt := []encryption.Column{
+ {Name: "username"},
+ {Name: "password"},
+ {Name: "aws_access_key"},
+ {Name: "aws_secret_key"},
+ {Name: "mongo_db_tls_options", CustomHandler: EncryptMongoDBOptionsHandler},
+ {Name: "azure_options", CustomHandler: EncryptAzureOptionsHandler},
+ {Name: "mysql_options", CustomHandler: EncryptMySQLOptionsHandler},
+ {Name: "postgresql_options", CustomHandler: EncryptPostgreSQLOptionsHandler},
+ {Name: "agent_password"},
+ }
+
+ itemsToEncrypt := []encryption.Table{
+ {
+ Name: "agents",
+ Identifiers: []string{"agent_id"},
+ Columns: agentColumnsToEncrypt,
+ },
+ }
+
+ if err := migrateDB(db, params, itemsToEncrypt); err != nil {
return nil, err
}
+
return db, nil
}
+// EncryptDB encrypts a set of columns in a specific database and table.
+func EncryptDB(tx *reform.TX, params SetupDBParams, itemsToEncrypt []encryption.Table) error {
+ if len(itemsToEncrypt) == 0 {
+ return nil
+ }
+
+ settings, err := GetSettings(tx)
+ if err != nil {
+ return err
+ }
+ alreadyEncrypted := make(map[string]bool)
+ for _, v := range settings.EncryptedItems {
+ alreadyEncrypted[v] = true
+ }
+
+ notEncrypted := []encryption.Table{}
+ newlyEncrypted := []string{}
+ for _, table := range itemsToEncrypt {
+ columns := []encryption.Column{}
+ for _, column := range table.Columns {
+ dbTableColumn := fmt.Sprintf("%s.%s.%s", params.Name, table.Name, column.Name)
+ if alreadyEncrypted[dbTableColumn] {
+ continue
+ }
+
+ columns = append(columns, column)
+ newlyEncrypted = append(newlyEncrypted, dbTableColumn)
+ }
+ if len(columns) == 0 {
+ continue
+ }
+
+ table.Columns = columns
+ notEncrypted = append(notEncrypted, table)
+ }
+
+ if len(notEncrypted) == 0 {
+ return nil
+ }
+
+ err = encryption.EncryptItems(tx, notEncrypted)
+ if err != nil {
+ return err
+ }
+ _, err = UpdateSettings(tx, &ChangeSettingsParams{
+ EncryptedItems: slices.Concat(settings.EncryptedItems, newlyEncrypted),
+ })
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
// checkVersion checks minimal required PostgreSQL server version.
func checkVersion(ctx context.Context, db reform.DBTXContext) error {
PGVersion, err := GetPostgreSQLVersion(ctx, db)
@@ -1211,7 +1289,7 @@ func initWithRoot(params SetupDBParams) error {
}
// migrateDB runs PostgreSQL database migrations.
-func migrateDB(db *reform.DB, params SetupDBParams) error {
+func migrateDB(db *reform.DB, params SetupDBParams, itemsToEncrypt []encryption.Table) error {
var currentVersion int
errDB := db.QueryRow("SELECT id FROM schema_migrations ORDER BY id DESC LIMIT 1").Scan(¤tVersion)
// undefined_table (see https://www.postgresql.org/docs/current/errcodes-appendix.html)
@@ -1247,6 +1325,11 @@ func migrateDB(db *reform.DB, params SetupDBParams) error {
}
}
+ err := EncryptDB(tx, params, itemsToEncrypt)
+ if err != nil {
+ return err
+ }
+
if params.SetupFixtures == SkipFixtures {
return nil
}
@@ -1260,14 +1343,16 @@ func migrateDB(db *reform.DB, params SetupDBParams) error {
return err
}
- if err = setupFixture1(tx.Querier, params); err != nil {
+ err = setupPMMServerAgents(tx.Querier, params)
+ if err != nil {
return err
}
+
return nil
})
}
-func setupFixture1(q *reform.Querier, params SetupDBParams) error {
+func setupPMMServerAgents(q *reform.Querier, params SetupDBParams) error {
// create PMM Server Node and associated Agents
node, err := createNodeWithID(q, PMMServerNodeID, GenericNodeType, &CreateNodeParams{
NodeName: "pmm-server",
diff --git a/managed/models/database_test.go b/managed/models/database_test.go
index 2b487629d5..efafd5e2f7 100644
--- a/managed/models/database_test.go
+++ b/managed/models/database_test.go
@@ -21,7 +21,6 @@ import (
"database/sql"
"fmt"
"testing"
- "time"
"github.com/AlekSi/pointer"
"github.com/lib/pq"
@@ -327,60 +326,6 @@ func TestDatabaseChecks(t *testing.T) {
}
func TestDatabaseMigrations(t *testing.T) {
- t.Run("Update metrics resolutions", func(t *testing.T) {
- sqlDB := testdb.Open(t, models.SkipFixtures, pointer.ToInt(9))
- defer sqlDB.Close() //nolint:errcheck
- settings, err := models.GetSettings(sqlDB)
- require.NoError(t, err)
- metricsResolutions := models.MetricsResolutions{
- HR: 5 * time.Second,
- MR: 5 * time.Second,
- LR: 60 * time.Second,
- }
- settings.MetricsResolutions = metricsResolutions
- err = models.SaveSettings(sqlDB, settings)
- require.NoError(t, err)
-
- settings, err = models.GetSettings(sqlDB)
- require.NoError(t, err)
- require.Equal(t, metricsResolutions, settings.MetricsResolutions)
-
- testdb.SetupDB(t, sqlDB, models.SkipFixtures, pointer.ToInt(10))
- settings, err = models.GetSettings(sqlDB)
- require.NoError(t, err)
- require.Equal(t, models.MetricsResolutions{
- HR: 5 * time.Second,
- MR: 10 * time.Second,
- LR: 60 * time.Second,
- }, settings.MetricsResolutions)
- })
- t.Run("Shouldn' update metrics resolutions if it's already changed", func(t *testing.T) {
- sqlDB := testdb.Open(t, models.SkipFixtures, pointer.ToInt(9))
- defer sqlDB.Close() //nolint:errcheck
- settings, err := models.GetSettings(sqlDB)
- require.NoError(t, err)
- metricsResolutions := models.MetricsResolutions{
- HR: 1 * time.Second,
- MR: 5 * time.Second,
- LR: 60 * time.Second,
- }
- settings.MetricsResolutions = metricsResolutions
- err = models.SaveSettings(sqlDB, settings)
- require.NoError(t, err)
-
- settings, err = models.GetSettings(sqlDB)
- require.NoError(t, err)
- require.Equal(t, metricsResolutions, settings.MetricsResolutions)
-
- testdb.SetupDB(t, sqlDB, models.SkipFixtures, pointer.ToInt(10))
- settings, err = models.GetSettings(sqlDB)
- require.NoError(t, err)
- require.Equal(t, models.MetricsResolutions{
- HR: 1 * time.Second,
- MR: 5 * time.Second,
- LR: 60 * time.Second,
- }, settings.MetricsResolutions)
- })
t.Run("stats_collections field migration: string to string array", func(t *testing.T) {
sqlDB := testdb.Open(t, models.SkipFixtures, pointer.ToInt(57))
defer sqlDB.Close() //nolint:errcheck
diff --git a/managed/models/encryption_helpers.go b/managed/models/encryption_helpers.go
new file mode 100644
index 0000000000..f3ae2e01a8
--- /dev/null
+++ b/managed/models/encryption_helpers.go
@@ -0,0 +1,296 @@
+// Copyright (C) 2023 Percona LLC
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package models
+
+import (
+ "database/sql"
+ "encoding/json"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/percona/pmm/managed/utils/encryption"
+)
+
+// EncryptAgent encrypt agent.
+func EncryptAgent(agent Agent) Agent {
+ return agentEncryption(agent, encryption.Encrypt)
+}
+
+// DecryptAgent decrypt agent.
+func DecryptAgent(agent Agent) Agent {
+ return agentEncryption(agent, encryption.Decrypt)
+}
+
+func agentEncryption(agent Agent, handler func(string) (string, error)) Agent {
+ if agent.Username != nil {
+ username, err := handler(*agent.Username)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.Username = &username
+ }
+
+ if agent.Password != nil {
+ password, err := handler(*agent.Password)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.Password = &password
+ }
+
+ if agent.AgentPassword != nil {
+ agentPassword, err := handler(*agent.AgentPassword)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.AgentPassword = &agentPassword
+ }
+
+ if agent.AWSAccessKey != nil {
+ awsAccessKey, err := handler(*agent.AWSAccessKey)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.AWSAccessKey = &awsAccessKey
+ }
+
+ if agent.AWSSecretKey != nil {
+ awsSecretKey, err := handler(*agent.AWSSecretKey)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.AWSSecretKey = &awsSecretKey
+ }
+
+ var err error
+ if agent.MySQLOptions != nil {
+ agent.MySQLOptions.TLSCert, err = handler(agent.MySQLOptions.TLSCert)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.MySQLOptions.TLSKey, err = handler(agent.MySQLOptions.TLSKey)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ }
+
+ if agent.PostgreSQLOptions != nil {
+ agent.PostgreSQLOptions.SSLCert, err = handler(agent.PostgreSQLOptions.SSLCert)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.PostgreSQLOptions.SSLKey, err = handler(agent.PostgreSQLOptions.SSLKey)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ }
+
+ if agent.MongoDBOptions != nil {
+ agent.MongoDBOptions.TLSCertificateKey, err = handler(agent.MongoDBOptions.TLSCertificateKey)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.MongoDBOptions.TLSCertificateKeyFilePassword, err = handler(agent.MongoDBOptions.TLSCertificateKeyFilePassword)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ }
+
+ if agent.AzureOptions != nil {
+ agent.AzureOptions.ClientID, err = handler(agent.AzureOptions.ClientID)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.AzureOptions.ClientSecret, err = handler(agent.AzureOptions.ClientSecret)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.AzureOptions.SubscriptionID, err = handler(agent.AzureOptions.SubscriptionID)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ agent.AzureOptions.TenantID, err = handler(agent.AzureOptions.TenantID)
+ if err != nil {
+ logrus.Warning(err)
+ }
+ }
+
+ return agent
+}
+
+// EncryptMySQLOptionsHandler returns encrypted MySQL Options.
+func EncryptMySQLOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return mySQLOptionsHandler(val, e.Encrypt)
+}
+
+// DecryptMySQLOptionsHandler returns decrypted MySQL Options.
+func DecryptMySQLOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return mySQLOptionsHandler(val, e.Decrypt)
+}
+
+func mySQLOptionsHandler(val any, handler func(string) (string, error)) (any, error) {
+ o := MySQLOptions{}
+ value := val.(*sql.NullString) //nolint:forcetypeassert
+ if !value.Valid {
+ return sql.NullString{}, nil
+ }
+
+ err := json.Unmarshal([]byte(value.String), &o)
+ if err != nil {
+ return nil, err
+ }
+
+ o.TLSCert, err = handler(o.TLSCert)
+ if err != nil {
+ return nil, err
+ }
+ o.TLSKey, err = handler(o.TLSKey)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := json.Marshal(o)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+// EncryptPostgreSQLOptionsHandler returns encrypted PostgreSQL Options.
+func EncryptPostgreSQLOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return postgreSQLOptionsHandler(val, e.Encrypt)
+}
+
+// DecryptPostgreSQLOptionsHandler returns decrypted PostgreSQL Options.
+func DecryptPostgreSQLOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return postgreSQLOptionsHandler(val, e.Decrypt)
+}
+
+func postgreSQLOptionsHandler(val any, handler func(string) (string, error)) (any, error) {
+ o := PostgreSQLOptions{}
+ value := val.(*sql.NullString) //nolint:forcetypeassert
+ if !value.Valid {
+ return sql.NullString{}, nil
+ }
+
+ err := json.Unmarshal([]byte(value.String), &o)
+ if err != nil {
+ return nil, err
+ }
+
+ o.SSLCert, err = handler(o.SSLCert)
+ if err != nil {
+ return nil, err
+ }
+ o.SSLKey, err = handler(o.SSLKey)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := json.Marshal(o)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+// EncryptMongoDBOptionsHandler returns encrypted MongoDB Options.
+func EncryptMongoDBOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return mongoDBOptionsHandler(val, e.Encrypt)
+}
+
+// DecryptMongoDBOptionsHandler returns decrypted MongoDB Options.
+func DecryptMongoDBOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return mongoDBOptionsHandler(val, e.Decrypt)
+}
+
+func mongoDBOptionsHandler(val any, handler func(string) (string, error)) (any, error) {
+ o := MongoDBOptions{}
+ value := val.(*sql.NullString) //nolint:forcetypeassert
+ if !value.Valid {
+ return sql.NullString{}, nil
+ }
+
+ err := json.Unmarshal([]byte(value.String), &o)
+ if err != nil {
+ return nil, err
+ }
+
+ o.TLSCertificateKey, err = handler(o.TLSCertificateKey)
+ if err != nil {
+ return nil, err
+ }
+ o.TLSCertificateKeyFilePassword, err = handler(o.TLSCertificateKeyFilePassword)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := json.Marshal(o)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+// EncryptAzureOptionsHandler returns encrypted Azure Options.
+func EncryptAzureOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return azureOptionsHandler(val, e.Encrypt)
+}
+
+// DecryptAzureOptionsHandler returns decrypted Azure Options.
+func DecryptAzureOptionsHandler(e *encryption.Encryption, val any) (any, error) {
+ return azureOptionsHandler(val, e.Decrypt)
+}
+
+func azureOptionsHandler(val any, handler func(string) (string, error)) (any, error) {
+ o := AzureOptions{}
+ value := val.(*sql.NullString) //nolint:forcetypeassert
+ if !value.Valid {
+ return sql.NullString{}, nil
+ }
+
+ err := json.Unmarshal([]byte(value.String), &o)
+ if err != nil {
+ return nil, err
+ }
+
+ o.ClientID, err = handler(o.ClientID)
+ if err != nil {
+ return nil, err
+ }
+ o.ClientSecret, err = handler(o.ClientSecret)
+ if err != nil {
+ return nil, err
+ }
+ o.SubscriptionID, err = handler(o.SubscriptionID)
+ if err != nil {
+ return nil, err
+ }
+ o.TenantID, err = handler(o.TenantID)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := json.Marshal(o)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
diff --git a/managed/models/settings.go b/managed/models/settings.go
index cf86398765..1b12259c23 100644
--- a/managed/models/settings.go
+++ b/managed/models/settings.go
@@ -109,6 +109,9 @@ type Settings struct {
// Enabled is true if access control is enabled.
Enabled *bool `json:"enabled"`
} `json:"access_control"`
+
+ // Contains all encrypted tables in format 'db.table.column'.
+ EncryptedItems []string `json:"encrypted_items"`
}
// IsAlertingEnabled returns true if alerting is enabled.
diff --git a/managed/models/settings_helpers.go b/managed/models/settings_helpers.go
index 351202dd04..fb5125f336 100644
--- a/managed/models/settings_helpers.go
+++ b/managed/models/settings_helpers.go
@@ -92,6 +92,9 @@ type ChangeSettingsParams struct {
// DefaultRoleID sets a default role to be assigned to new users.
DefaultRoleID *int
+
+ // List of items in format 'db.table.column' to be encrypted.
+ EncryptedItems []string
}
// SetPMMServerID should be run on start up to generate unique PMM Server ID.
@@ -223,6 +226,9 @@ func UpdateSettings(q reform.DBTX, params *ChangeSettingsParams) (*Settings, err
settings.DefaultRoleID = *params.DefaultRoleID
}
+ if len(params.EncryptedItems) != 0 {
+ settings.EncryptedItems = params.EncryptedItems
+ }
err = SaveSettings(q, settings)
if err != nil {
return nil, err
diff --git a/managed/models/settings_helpers_test.go b/managed/models/settings_helpers_test.go
index bd92db48cd..14112c30cd 100644
--- a/managed/models/settings_helpers_test.go
+++ b/managed/models/settings_helpers_test.go
@@ -37,6 +37,7 @@ func TestSettings(t *testing.T) {
t.Run("Defaults", func(t *testing.T) {
actual, err := models.GetSettings(sqlDB)
require.NoError(t, err)
+ require.NotEmpty(t, actual.EncryptedItems)
expected := &models.Settings{
MetricsResolutions: models.MetricsResolutions{
HR: 5 * time.Second,
@@ -52,7 +53,8 @@ func TestSettings(t *testing.T) {
FrequentInterval: 4 * time.Hour,
},
},
- DefaultRoleID: 1,
+ DefaultRoleID: 1,
+ EncryptedItems: actual.EncryptedItems,
}
assert.Equal(t, expected, actual)
})
diff --git a/managed/services/agents/agents.go b/managed/services/agents/agents.go
index 71ffbefadb..451b34f09b 100644
--- a/managed/services/agents/agents.go
+++ b/managed/services/agents/agents.go
@@ -109,6 +109,25 @@ func redactWords(agent *models.Agent) []string {
words = append(words, s)
}
}
+ if agent.MySQLOptions != nil {
+ if s := agent.MySQLOptions.TLSKey; s != "" {
+ words = append(words, s)
+ }
+ }
+ if agent.PostgreSQLOptions != nil {
+ if s := agent.PostgreSQLOptions.SSLKey; s != "" {
+ words = append(words, s)
+ }
+ }
+ if agent.MongoDBOptions != nil {
+ if s := agent.MongoDBOptions.TLSCertificateKey; s != "" {
+ words = append(words, s)
+ }
+ if s := agent.MongoDBOptions.TLSCertificateKeyFilePassword; s != "" {
+ words = append(words, s)
+ }
+ }
+
return words
}
diff --git a/managed/services/agents/connection_checker.go b/managed/services/agents/connection_checker.go
index 447dbb9ee8..93e5c92a73 100644
--- a/managed/services/agents/connection_checker.go
+++ b/managed/services/agents/connection_checker.go
@@ -86,9 +86,9 @@ func (c *ConnectionChecker) CheckConnectionToService(ctx context.Context, q *ref
return err
}
- var sanitizedDSN string
+ sanitizedDSN := request.Dsn
for _, word := range redactWords(agent) {
- sanitizedDSN = strings.ReplaceAll(request.Dsn, word, "****")
+ sanitizedDSN = strings.ReplaceAll(sanitizedDSN, word, "****")
}
l.Infof("CheckConnectionRequest: type: %s, DSN: %s timeout: %s.", request.Type, sanitizedDSN, request.Timeout)
diff --git a/managed/services/agents/mysql_test.go b/managed/services/agents/mysql_test.go
index 5a8a3b7b41..6cee95abeb 100644
--- a/managed/services/agents/mysql_test.go
+++ b/managed/services/agents/mysql_test.go
@@ -199,7 +199,7 @@ func TestMySQLdExporterConfigTablestatsGroupDisabled(t *testing.T) {
"DATA_SOURCE_NAME=username:s3cur3 p@$$w0r4.@tcp(1.2.3.4:3306)/?timeout=1s&tls=custom",
"HTTP_AUTH=pmm:agent-id",
},
- RedactWords: []string{"s3cur3 p@$$w0r4."},
+ RedactWords: []string{"s3cur3 p@$$w0r4.", "content-of-tls-key"},
TextFiles: map[string]string{
"tlsCa": "content-of-tls-ca",
"tlsCert": "content-of-tls-cert",
diff --git a/managed/services/agents/service_info_broker.go b/managed/services/agents/service_info_broker.go
index e8aa3e88f7..6802de66f8 100644
--- a/managed/services/agents/service_info_broker.go
+++ b/managed/services/agents/service_info_broker.go
@@ -154,9 +154,9 @@ func (c *ServiceInfoBroker) GetInfoFromService(ctx context.Context, q *reform.Qu
return err
}
- var sanitizedDSN string
+ sanitizedDSN := request.Dsn
for _, word := range redactWords(agent) {
- sanitizedDSN = strings.ReplaceAll(request.Dsn, word, "****")
+ sanitizedDSN = strings.ReplaceAll(sanitizedDSN, word, "****")
}
l.Infof("ServiceInfoRequest: type: %s, DSN: %s timeout: %s.", request.Type, sanitizedDSN, request.Timeout)
@@ -182,9 +182,11 @@ func (c *ServiceInfoBroker) GetInfoFromService(ctx context.Context, q *reform.Qu
case models.MySQLServiceType:
agent.TableCount = &sInfo.TableCount
l.Debugf("Updating table count: %d.", sInfo.TableCount)
- if err = q.Update(agent); err != nil {
+ encryptedAgent := models.EncryptAgent(*agent)
+ if err = q.Update(&encryptedAgent); err != nil {
return errors.Wrap(err, "failed to update table count")
}
+
return updateServiceVersion(ctx, q, resp, service)
case models.PostgreSQLServiceType:
if agent.PostgreSQLOptions == nil {
@@ -206,9 +208,11 @@ func (c *ServiceInfoBroker) GetInfoFromService(ctx context.Context, q *reform.Qu
agent.PostgreSQLOptions.DatabaseCount = int32(databaseCount - excludedDatabaseCount)
l.Debugf("Updating PostgreSQL options, database count: %d.", agent.PostgreSQLOptions.DatabaseCount)
- if err = q.Update(agent); err != nil {
+ encryptedAgent := models.EncryptAgent(*agent)
+ if err = q.Update(&encryptedAgent); err != nil {
return errors.Wrap(err, "failed to update database count")
}
+
return updateServiceVersion(ctx, q, resp, service)
case models.MongoDBServiceType,
models.ProxySQLServiceType:
diff --git a/managed/services/management/agent.go b/managed/services/management/agent.go
index fc13caeb90..c2e74b6960 100644
--- a/managed/services/management/agent.go
+++ b/managed/services/management/agent.go
@@ -130,9 +130,9 @@ func (s *ManagementService) agentToAPI(agent *models.Agent) (*managementv1.Unive
Disabled: agent.Disabled,
DisabledCollectors: agent.DisabledCollectors,
IsConnected: s.r.IsConnected(agent.AgentID),
- IsAgentPasswordSet: agent.AgentPassword != nil,
- IsAwsSecretKeySet: agent.AWSSecretKey != nil,
- IsPasswordSet: agent.Password != nil,
+ IsAgentPasswordSet: pointer.GetString(agent.AgentPassword) != "",
+ IsAwsSecretKeySet: pointer.GetString(agent.AWSSecretKey) != "",
+ IsPasswordSet: pointer.GetString(agent.Password) != "",
ListenPort: uint32(pointer.GetUint16(agent.ListenPort)),
LogLevel: pointer.GetString(agent.LogLevel),
MaxQueryLength: agent.MaxQueryLength,
diff --git a/managed/services/management/mongodb.go b/managed/services/management/mongodb.go
index 6ce4d94072..fe7782c194 100644
--- a/managed/services/management/mongodb.go
+++ b/managed/services/management/mongodb.go
@@ -57,8 +57,6 @@ func (s *ManagementService) addMongoDB(ctx context.Context, req *managementv1.Ad
}
mongodb.Service = invService.(*inventoryv1.MongoDBService) //nolint:forcetypeassert
- mongoDBOptions := models.MongoDBOptionsFromRequest(req)
-
req.MetricsMode, err = supportedMetricsMode(tx.Querier, req.MetricsMode, req.PmmAgentId)
if err != nil {
return err
@@ -72,7 +70,7 @@ func (s *ManagementService) addMongoDB(ctx context.Context, req *managementv1.Ad
AgentPassword: req.AgentPassword,
TLS: req.Tls,
TLSSkipVerify: req.TlsSkipVerify,
- MongoDBOptions: mongoDBOptions,
+ MongoDBOptions: models.MongoDBOptionsFromRequest(req),
PushMetrics: isPushMode(req.MetricsMode),
ExposeExporter: req.ExposeExporter,
DisableCollectors: req.DisableCollectors,
@@ -106,7 +104,7 @@ func (s *ManagementService) addMongoDB(ctx context.Context, req *managementv1.Ad
Password: req.Password,
TLS: req.Tls,
TLSSkipVerify: req.TlsSkipVerify,
- MongoDBOptions: mongoDBOptions,
+ MongoDBOptions: models.MongoDBOptionsFromRequest(req),
MaxQueryLength: req.MaxQueryLength,
LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_FATAL),
// TODO QueryExamplesDisabled https://jira.percona.com/browse/PMM-7860
diff --git a/managed/services/management/postgresql.go b/managed/services/management/postgresql.go
index 94ccff5845..e1f1a473da 100644
--- a/managed/services/management/postgresql.go
+++ b/managed/services/management/postgresql.go
@@ -64,7 +64,6 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1
return err
}
- options := models.PostgreSQLOptionsFromRequest(req)
row, err := models.CreateAgent(tx.Querier, models.PostgresExporterType, &models.CreateAgentParams{
PMMAgentID: req.PmmAgentId,
ServiceID: service.ServiceID,
@@ -76,7 +75,7 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1
PushMetrics: isPushMode(req.MetricsMode),
ExposeExporter: req.ExposeExporter,
DisableCollectors: req.DisableCollectors,
- PostgreSQLOptions: options,
+ PostgreSQLOptions: models.PostgreSQLOptionsFromRequest(req),
LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_ERROR),
})
if err != nil {
@@ -117,7 +116,7 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1
CommentsParsingDisabled: req.DisableCommentsParsing,
TLS: req.Tls,
TLSSkipVerify: req.TlsSkipVerify,
- PostgreSQLOptions: options,
+ PostgreSQLOptions: models.PostgreSQLOptionsFromRequest(req),
LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_FATAL),
})
if err != nil {
@@ -142,7 +141,7 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1
CommentsParsingDisabled: req.DisableCommentsParsing,
TLS: req.Tls,
TLSSkipVerify: req.TlsSkipVerify,
- PostgreSQLOptions: options,
+ PostgreSQLOptions: models.PostgreSQLOptionsFromRequest(req),
LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_FATAL),
})
if err != nil {
diff --git a/managed/utils/encryption/encryption.go b/managed/utils/encryption/encryption.go
new file mode 100644
index 0000000000..a7cba048df
--- /dev/null
+++ b/managed/utils/encryption/encryption.go
@@ -0,0 +1,205 @@
+// Copyright (C) 2023 Percona LLC
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+// Package encryption contains functions to encrypt/decrypt items or DB.
+package encryption
+
+import (
+ "encoding/base64"
+ "os"
+ "slices"
+
+ "github.com/pkg/errors"
+ "github.com/sirupsen/logrus"
+ "gopkg.in/reform.v1"
+)
+
+// DefaultEncryptionKeyPath contains default PMM encryption key path.
+const DefaultEncryptionKeyPath = "/srv/pmm-encryption.key"
+
+var (
+ // ErrEncryptionNotInitialized is error in case of encryption is not initialized.
+ ErrEncryptionNotInitialized = errors.New("encryption is not initialized")
+ // DefaultEncryption is the default implementation of encryption.
+ DefaultEncryption = New(DefaultEncryptionKeyPath)
+)
+
+// New creates an encryption; if key on path doesn't exist, it will be generated.
+func New(keyPath string) *Encryption {
+ e := &Encryption{}
+ customKeyPath := os.Getenv("PMM_ENCRYPTION_KEY_PATH")
+ if customKeyPath != "" {
+ e.Path = customKeyPath
+ } else {
+ e.Path = keyPath
+ }
+
+ bytes, err := os.ReadFile(e.Path)
+ switch {
+ case os.IsNotExist(err):
+ err = e.generateKey()
+ if err != nil {
+ logrus.Panicf("Encryption: %v", err)
+ }
+ case err != nil:
+ logrus.Panicf("Encryption: %v", err)
+ default:
+ e.Key = string(bytes)
+ }
+
+ primitive, err := e.getPrimitive()
+ if err != nil {
+ logrus.Panicf("Encryption: %v", err)
+ }
+ e.Primitive = primitive
+
+ return e
+}
+
+// Encrypt is a wrapper around DefaultEncryption.Encrypt.
+func Encrypt(secret string) (string, error) {
+ return DefaultEncryption.Encrypt(secret)
+}
+
+// Encrypt returns input string encrypted.
+func (e *Encryption) Encrypt(secret string) (string, error) {
+ if e == nil || e.Primitive == nil {
+ return secret, ErrEncryptionNotInitialized
+ }
+ if secret == "" {
+ return secret, nil
+ }
+ cipherText, err := e.Primitive.Encrypt([]byte(secret), []byte(""))
+ if err != nil {
+ return secret, err
+ }
+
+ return base64.StdEncoding.EncodeToString(cipherText), nil
+}
+
+// EncryptItems is a wrapper around DefaultEncryption.EncryptItems.
+func EncryptItems(tx *reform.TX, tables []Table) error {
+ return DefaultEncryption.EncryptItems(tx, tables)
+}
+
+// EncryptItems will encrypt all columns provided in DB connection.
+func (e *Encryption) EncryptItems(tx *reform.TX, tables []Table) error {
+ if len(tables) == 0 {
+ return nil
+ }
+
+ for _, table := range tables {
+ res, err := table.read(tx)
+ if err != nil {
+ return err
+ }
+
+ for k, v := range res.SetValues {
+ for i, val := range v {
+ var encrypted any
+ var err error
+ switch table.Columns[i].CustomHandler {
+ case nil:
+ encrypted, err = encryptColumnStringHandler(e, val)
+ default:
+ encrypted, err = table.Columns[i].CustomHandler(e, val)
+ }
+
+ if err != nil {
+ return err
+ }
+ res.SetValues[k][i] = encrypted
+ }
+ data := slices.Concat([]any{}, v)
+ data = slices.Concat(data, res.WhereValues[k])
+ _, err := tx.Exec(res.Query, data...)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// Decrypt is wrapper around DefaultEncryption.Decrypt.
+func Decrypt(cipherText string) (string, error) {
+ return DefaultEncryption.Decrypt(cipherText)
+}
+
+// Decrypt returns input string decrypted.
+func (e *Encryption) Decrypt(cipherText string) (string, error) {
+ if e == nil || e.Primitive == nil {
+ return cipherText, ErrEncryptionNotInitialized
+ }
+ if cipherText == "" {
+ return cipherText, nil
+ }
+ decoded, err := base64.StdEncoding.DecodeString(cipherText)
+ if err != nil {
+ return cipherText, err
+ }
+ secret, err := e.Primitive.Decrypt(decoded, []byte(""))
+ if err != nil {
+ return cipherText, err
+ }
+
+ return string(secret), nil
+}
+
+// DecryptItems is wrapper around DefaultEncryption.DecryptItems.
+func DecryptItems(tx *reform.TX, tables []Table) error {
+ return DefaultEncryption.DecryptItems(tx, tables)
+}
+
+// DecryptItems will decrypt all columns provided in DB connection.
+func (e *Encryption) DecryptItems(tx *reform.TX, tables []Table) error {
+ if len(tables) == 0 {
+ return nil
+ }
+
+ for _, table := range tables {
+ res, err := table.read(tx)
+ if err != nil {
+ return err
+ }
+
+ for k, v := range res.SetValues {
+ for i, val := range v {
+ var decrypted any
+ var err error
+ switch table.Columns[i].CustomHandler {
+ case nil:
+ decrypted, err = decryptColumnStringHandler(e, val)
+ default:
+ decrypted, err = table.Columns[i].CustomHandler(e, val)
+ }
+
+ if err != nil {
+ return err
+ }
+ res.SetValues[k][i] = decrypted
+ }
+ data := slices.Concat([]any{}, v)
+ data = slices.Concat(data, res.WhereValues[k])
+ _, err := tx.Exec(res.Query, data...)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/managed/utils/encryption/helpers.go b/managed/utils/encryption/helpers.go
new file mode 100644
index 0000000000..c8c11ab4e3
--- /dev/null
+++ b/managed/utils/encryption/helpers.go
@@ -0,0 +1,177 @@
+// Copyright (C) 2023 Percona LLC
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package encryption
+
+import (
+ "bytes"
+ "database/sql"
+ "encoding/base64"
+ "fmt"
+ "os"
+ "slices"
+ "strings"
+
+ "github.com/google/tink/go/aead"
+ "github.com/google/tink/go/insecurecleartextkeyset"
+ "github.com/google/tink/go/keyset"
+ "github.com/google/tink/go/tink"
+ "gopkg.in/reform.v1"
+)
+
+func prepareRowPointers(rows *sql.Rows) ([]any, error) {
+ columnTypes, err := rows.ColumnTypes()
+ if err != nil {
+ return nil, err
+ }
+ columns := make(map[string]string)
+ for _, columnType := range columnTypes {
+ columns[columnType.Name()] = columnType.DatabaseTypeName()
+ }
+
+ row := []any{}
+ for _, t := range columns {
+ switch t {
+ case "VARCHAR", "JSONB":
+ row = append(row, &sql.NullString{})
+ default:
+ return nil, fmt.Errorf("unsupported identificator type %s", t)
+ }
+ }
+
+ return row, nil
+}
+
+func encryptColumnStringHandler(e *Encryption, val any) (any, error) {
+ value := val.(*sql.NullString) //nolint:forcetypeassert
+ if !value.Valid {
+ return sql.NullString{}, nil
+ }
+
+ encrypted, err := e.Encrypt(value.String)
+ if err != nil {
+ return nil, err
+ }
+
+ return encrypted, nil
+}
+
+func decryptColumnStringHandler(e *Encryption, val any) (any, error) {
+ value := val.(*sql.NullString) //nolint:forcetypeassert
+ if !value.Valid {
+ return nil, nil //nolint:nilnil
+ }
+
+ decrypted, err := e.Decrypt(value.String)
+ if err != nil {
+ return nil, err
+ }
+
+ return decrypted, nil
+}
+
+func (e *Encryption) getPrimitive() (tink.AEAD, error) { //nolint:ireturn
+ serializedKeyset, err := base64.StdEncoding.DecodeString(e.Key)
+ if err != nil {
+ return nil, err
+ }
+
+ binaryReader := keyset.NewBinaryReader(bytes.NewBuffer(serializedKeyset))
+ parsedHandle, err := insecurecleartextkeyset.Read(binaryReader)
+ if err != nil {
+ return nil, err
+ }
+
+ return aead.New(parsedHandle)
+}
+
+func (e *Encryption) generateKey() error {
+ handle, err := keyset.NewHandle(aead.AES256GCMKeyTemplate())
+ if err != nil {
+ return err
+ }
+
+ buff := &bytes.Buffer{}
+ err = insecurecleartextkeyset.Write(handle, keyset.NewBinaryWriter(buff))
+ if err != nil {
+ return err
+ }
+ e.Key = base64.StdEncoding.EncodeToString(buff.Bytes())
+
+ return e.saveKeyToFile()
+}
+
+func (e *Encryption) saveKeyToFile() error {
+ return os.WriteFile(e.Path, []byte(e.Key), 0o644) //nolint:gosec
+}
+
+func (table Table) columnsList() []string {
+ res := []string{}
+ for _, c := range table.Columns {
+ res = append(res, c.Name)
+ }
+
+ return res
+}
+
+func (table Table) read(tx *reform.TX) (*QueryValues, error) {
+ what := slices.Concat(table.Identifiers, table.columnsList())
+ query := fmt.Sprintf("SELECT %s FROM %s", strings.Join(what, ", "), table.Name)
+ rows, err := tx.Query(query)
+ if err != nil {
+ return nil, err
+ }
+
+ q := &QueryValues{}
+ for rows.Next() {
+ row, err := prepareRowPointers(rows)
+ if err != nil {
+ return nil, err
+ }
+ err = rows.Scan(row...)
+ if err != nil {
+ return nil, err
+ }
+
+ i := 1
+ set := []string{}
+ setValues := []any{}
+ for k, v := range row[len(table.Identifiers):] {
+ set = append(set, fmt.Sprintf("%s = $%d", table.Columns[k].Name, i))
+ setValues = append(setValues, v)
+ i++
+ }
+ setSQL := fmt.Sprintf("SET %s", strings.Join(set, ", "))
+ q.SetValues = append(q.SetValues, setValues)
+
+ where := []string{}
+ whereValues := []any{}
+ for k, id := range table.Identifiers {
+ where = append(where, fmt.Sprintf("%s = $%d", id, i))
+ whereValues = append(whereValues, row[k])
+ i++
+ }
+ whereSQL := "WHERE " + strings.Join(where, " AND ")
+ q.WhereValues = append(q.WhereValues, whereValues)
+
+ q.Query = fmt.Sprintf("UPDATE %s %s %s", table.Name, setSQL, whereSQL)
+ }
+ err = rows.Close() //nolint:sqlclosecheck
+ if err != nil {
+ return nil, err
+ }
+
+ return q, nil
+}
diff --git a/managed/utils/encryption/models.go b/managed/utils/encryption/models.go
new file mode 100644
index 0000000000..257b49b1de
--- /dev/null
+++ b/managed/utils/encryption/models.go
@@ -0,0 +1,45 @@
+// Copyright (C) 2023 Percona LLC
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package encryption
+
+import "github.com/google/tink/go/tink"
+
+// Encryption contains fields required for encryption.
+type Encryption struct {
+ Path string
+ Key string
+ Primitive tink.AEAD
+}
+
+// Table represents table name, it's identifiers and columns to be encrypted/decrypted.
+type Table struct {
+ Name string
+ Identifiers []string
+ Columns []Column
+}
+
+// Column represents column name and column's custom handler (if needed).
+type Column struct {
+ Name string
+ CustomHandler func(e *Encryption, val any) (any, error)
+}
+
+// QueryValues represents query to update row after encrypt/decrypt.
+type QueryValues struct {
+ Query string
+ SetValues [][]any
+ WhereValues [][]any
+}
diff --git a/managed/utils/testdb/db.go b/managed/utils/testdb/db.go
index d5707e367a..2f6ee84cc7 100644
--- a/managed/utils/testdb/db.go
+++ b/managed/utils/testdb/db.go
@@ -70,7 +70,8 @@ func Open(tb testing.TB, setupFixtures models.SetupFixturesMode, migrationVersio
// Please use Open method to recreate DB for each test if you don't need to control migrations.
func SetupDB(tb testing.TB, db *sql.DB, setupFixtures models.SetupFixturesMode, migrationVersion *int) {
tb.Helper()
- _, err := models.SetupDB(context.TODO(), db, models.SetupDBParams{
+ ctx := context.TODO()
+ params := models.SetupDBParams{
// Uncomment to see all setup queries:
// Logf: tb.Logf,
Address: models.DefaultPostgreSQLAddr,
@@ -79,7 +80,9 @@ func SetupDB(tb testing.TB, db *sql.DB, setupFixtures models.SetupFixturesMode,
Password: password,
SetupFixtures: setupFixtures,
MigrationVersion: migrationVersion,
- })
+ }
+
+ _, err := models.SetupDB(ctx, db, params)
require.NoError(tb, err)
}