diff --git a/docs/source/user_guide/configuration_options.rst b/docs/source/user_guide/configuration_options.rst index 3de8831a3..b0eab33da 100644 --- a/docs/source/user_guide/configuration_options.rst +++ b/docs/source/user_guide/configuration_options.rst @@ -62,6 +62,27 @@ Log level - Error - string +Add payload tracing using `RECEPTOR_PAYLOAD_TRACE_LEVEL=int` envorment variable and using log level debug. + +.. list-table:: RECEPTOR_PAYLOAD_TRACE_LEVEL options + :header-rows: 1 + :widths: auto + + * - Tracing level + - Description + * - 0 + - No payload tracing log + * - 1 + - Log connection type + * - 2 + - Log connection type and work unit id + * - 3 + - Log connection type, work unit id and payload + +**Warning: Payload Tracing May Expose Sensitive Data** + +Please be aware that using payload tracing can potentially reveal sensitive information. This includes, but is not limited to, personal data, authentication tokens, and system configurations. Ensure that you only use tracing tools in a secure environment and avoid sharing trace output with unauthorized users. Always follow your organization's data protection policies when handling sensitive information. Proceed with caution! + .. code-block:: yaml log-level: diff --git a/pkg/controlsvc/controlsvc.go b/pkg/controlsvc/controlsvc.go index f483cacde..4ae9f18d0 100644 --- a/pkg/controlsvc/controlsvc.go +++ b/pkg/controlsvc/controlsvc.go @@ -4,6 +4,7 @@ package controlsvc import ( + "bufio" "context" "crypto/tls" "encoding/json" @@ -14,6 +15,7 @@ import ( "os" "reflect" "runtime" + "strconv" "strings" "sync" "time" @@ -122,8 +124,38 @@ func (s *SockControl) ReadFromConn(message string, out io.Writer, io Copier) err if err := s.WriteMessage(message); err != nil { return err } - if _, err := io.Copy(out, s.conn); err != nil { - return err + payloadDebug, _ := strconv.Atoi(os.Getenv("RECEPTOR_PAYLOAD_TRACE_LEVEL")) + + if payloadDebug != 0 { + var connectionType string + var payload string + if s.conn.LocalAddr().Network() == "unix" { + connectionType = "unix socket" + } else { + connectionType = "network connection" + } + reader := bufio.NewReader(s.conn) + + for { + response, err := reader.ReadString('\n') + if err != nil { + if err.Error() != "EOF" { + MainInstance.nc.GetLogger().Error("Error reading from conn: %v \n", err) + } + + break + } + payload += response + } + + MainInstance.nc.GetLogger().DebugPayload(payloadDebug, payload, "", connectionType) + if _, err := out.Write([]byte(payload)); err != nil { + return err + } + } else { + if _, err := io.Copy(out, s.conn); err != nil { + return err + } } return nil diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 3f5d52d7b..ffaab5bff 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -148,6 +148,33 @@ func (rl *ReceptorLogger) Debug(format string, v ...interface{}) { rl.Log(DebugLevel, format, v...) } +// Debug payload data. +func (rl *ReceptorLogger) DebugPayload(payloadDebug int, payload string, workUnitID string, connectionType string) { + var payloadMessage string + var workunitIDMessage string + var connectionTypeMessage string + switch payloadDebug { + case 3: + payloadMessage = fmt.Sprintf(" with a payload of: %s", payload) + + fallthrough + case 2: + if workUnitID != "" { + workunitIDMessage = fmt.Sprintf(" with work unit %s", workUnitID) + } else { + workunitIDMessage = ", work unit not created yet" + } + + fallthrough + case 1: + if connectionType != "" { + connectionTypeMessage = fmt.Sprintf("Reading from %s", connectionType) + } + default: + } + rl.Debug(fmt.Sprintf("PACKET TRACING ENABLED: %s%s%s", connectionTypeMessage, workunitIDMessage, payloadMessage)) //nolint:govet +} + // SanitizedDebug contains extra information helpful to developers. func (rl *ReceptorLogger) SanitizedDebug(format string, v ...interface{}) { rl.SanitizedLog(DebugLevel, format, v...) diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index f728068ee..a19452095 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -1,7 +1,9 @@ package logger_test import ( + "bytes" "fmt" + "os" "testing" "github.com/ansible/receptor/pkg/logger" @@ -68,3 +70,53 @@ func TestLogLevelToNameWithError(t *testing.T) { t.Error("should have error") } } + +func TestDebugPayload(t *testing.T) { + logFilePath := "/tmp/test-output" + logger.SetGlobalLogLevel(4) + receptorLogger := logger.NewReceptorLogger("testDebugPayload") + logFile, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o600) + if err != nil { + t.Error("error creating test-output file") + } + + payload := "Testing debugPayload" + workUnitID := "1234" + connectionType := "unix socket" + + debugPayloadTestCases := []struct { + name string + debugPayload int + payload string + workUnitID string + connectionType string + expectedLog string + }{ + {name: "debugPayload no log", debugPayload: 0, payload: "", workUnitID: "", connectionType: "", expectedLog: ""}, + {name: "debugPayload log level 1", debugPayload: 1, payload: "", workUnitID: "", connectionType: connectionType, expectedLog: fmt.Sprintf("PACKET TRACING ENABLED: Reading from %v", connectionType)}, + {name: "debugPayload log level 2 with workUnitID", debugPayload: 2, payload: "", workUnitID: workUnitID, connectionType: connectionType, expectedLog: fmt.Sprintf("PACKET TRACING ENABLED: Reading from %v with work unit %v", connectionType, workUnitID)}, + {name: "debugPayload log level 2 without workUnitID", debugPayload: 2, payload: "", workUnitID: "", connectionType: connectionType, expectedLog: fmt.Sprintf("PACKET TRACING ENABLED: Reading from %v", connectionType)}, + {name: "debugPayload log level 3 with workUnitID", debugPayload: 3, payload: payload, workUnitID: workUnitID, connectionType: connectionType, expectedLog: fmt.Sprintf("PACKET TRACING ENABLED: Reading from %v with work unit %v with a payload of: %v", connectionType, workUnitID, payload)}, + {name: "debugPayload log level 3 without workUnitID", debugPayload: 3, payload: payload, workUnitID: "", connectionType: connectionType, expectedLog: fmt.Sprintf("PACKET TRACING ENABLED: Reading from %v, work unit not created yet with a payload of: %v", connectionType, payload)}, + {name: "debugPayload log level 3 without workUnitID and payload is new line", debugPayload: 3, payload: "\n", workUnitID: "", connectionType: connectionType, expectedLog: fmt.Sprintf("PACKET TRACING ENABLED: Reading from %v, work unit not created yet with a payload of: %v", connectionType, "\n")}, + {name: "debugPayload log level 3 without workUnitID or payload", debugPayload: 3, payload: "", workUnitID: "", connectionType: connectionType, expectedLog: fmt.Sprintf("PACKET TRACING ENABLED: Reading from %v, work unit not created yet with a payload of: %v", connectionType, "")}, + } + + for _, testCase := range debugPayloadTestCases { + t.Run(testCase.name, func(t *testing.T) { + receptorLogger.SetOutput(logFile) + receptorLogger.DebugPayload(testCase.debugPayload, testCase.payload, testCase.workUnitID, testCase.connectionType) + + testOutput, err := os.ReadFile(logFilePath) + if err != nil { + t.Error("error reading test-output file") + } + if !bytes.Contains(testOutput, []byte(testCase.expectedLog)) { + t.Errorf("failed to log correctly, expected: %v got %v", testCase.expectedLog, string(testOutput)) + } + if err := os.Truncate(logFilePath, 0); err != nil { + t.Errorf("failed to truncate: %v", err) + } + }) + } +} diff --git a/pkg/workceptor/command.go b/pkg/workceptor/command.go index 2ff793ce5..f0fe9dfbb 100644 --- a/pkg/workceptor/command.go +++ b/pkg/workceptor/command.go @@ -4,13 +4,16 @@ package workceptor import ( + "bufio" "context" "flag" "fmt" + "io" "os" "os/exec" "os/signal" "path" + "strconv" "strings" "sync" "syscall" @@ -112,7 +115,38 @@ func commandRunner(command string, params string, unitdir string) error { if err != nil { return err } - cmd.Stdin = stdin + payloadDebug, _ := strconv.Atoi(os.Getenv("RECEPTOR_PAYLOAD_TRACE_LEVEL")) + + if payloadDebug != 0 { + splitUnitDir := strings.Split(unitdir, "/") + workUnitID := splitUnitDir[len(splitUnitDir)-1] + stdinStream, err := cmd.StdinPipe() + if err != nil { + return err + } + var payload string + reader := bufio.NewReader(stdin) + if err != nil { + return err + } + + for { + response, err := reader.ReadString('\n') + if err != nil { + if err.Error() != "EOF" { + MainInstance.nc.GetLogger().Error("Error reading work unit %v stdin: %v\n", workUnitID, err) + } + + break + } + payload += response + } + + MainInstance.nc.GetLogger().DebugPayload(payloadDebug, payload, workUnitID, "") + io.WriteString(stdinStream, payload) + } else { + cmd.Stdin = stdin + } stdout, err := os.OpenFile(path.Join(unitdir, "stdout"), os.O_CREATE+os.O_WRONLY+os.O_SYNC, 0o600) if err != nil { return err diff --git a/pkg/workceptor/stdio_utils.go b/pkg/workceptor/stdio_utils.go index 99e50a0b1..7c4f438b8 100644 --- a/pkg/workceptor/stdio_utils.go +++ b/pkg/workceptor/stdio_utils.go @@ -8,6 +8,8 @@ import ( "io" "os" "path" + "strconv" + "strings" "sync" ) @@ -111,6 +113,7 @@ func (sw *STDoutWriter) SetWriter(writer FileWriteCloser) { // STDinReader reads from a stdin file and provides a Done function. type STDinReader struct { reader FileReadCloser + workUnit string lasterr error doneChan chan struct{} doneOnce sync.Once @@ -120,6 +123,8 @@ var errFileSizeZero = errors.New("file is empty") // NewStdinReader allocates a new stdinReader, which reads from a stdin file and provides a Done function. func NewStdinReader(fs FileSystemer, unitdir string) (*STDinReader, error) { + splitUnitDir := strings.Split(unitdir, "/") + workUnitID := splitUnitDir[len(splitUnitDir)-1] stdinpath := path.Join(unitdir, "stdin") stat, err := fs.Stat(stdinpath) if err != nil { @@ -135,6 +140,7 @@ func NewStdinReader(fs FileSystemer, unitdir string) (*STDinReader, error) { return &STDinReader{ reader: reader, + workUnit: workUnitID, lasterr: nil, doneChan: make(chan struct{}), doneOnce: sync.Once{}, @@ -143,6 +149,23 @@ func NewStdinReader(fs FileSystemer, unitdir string) (*STDinReader, error) { // Read reads data from the stdout file, implementing io.Reader. func (sr *STDinReader) Read(p []byte) (n int, err error) { + payloadDebug, _ := strconv.Atoi(os.Getenv("RECEPTOR_PAYLOAD_TRACE_LEVEL")) + + if payloadDebug != 0 { + isNotEmpty := func() bool { + for _, v := range p { + if v != 0 { + return true + } + } + + return false + }() + if isNotEmpty { + payload := string(p) + MainInstance.nc.GetLogger().DebugPayload(payloadDebug, payload, sr.workUnit, "kube api") + } + } n, err = sr.reader.Read(p) if err != nil { sr.lasterr = err