diff --git a/README.md b/README.md index 51e3486..322dd06 100644 --- a/README.md +++ b/README.md @@ -194,8 +194,8 @@ See `test/main.go` for an example. - [x] Bind request - [ ] Compare request - [x] Extended requests -- [ ] Modify request -- [ ] ModifyDN request +- [x] Modify request +- [x] ModifyDN request - [x] Search request (concurrent) - [x] StartTLS request - [x] Unbind request diff --git a/ber_test.go b/ber_test.go index 0cad6ee..400e542 100644 --- a/ber_test.go +++ b/ber_test.go @@ -38,37 +38,37 @@ func slicesEqual[T comparable](a []T, b []T) bool { func TestBerTypes(t *testing.T) { if ldapserver.BerType(0b00000000).Class() != ldapserver.BerClassUniversal { - t.Error("invalid BER type reported") + t.Fatal("invalid BER type reported") } if ldapserver.BerType(0b01000000).Class() != ldapserver.BerClassApplication { - t.Error("invalid BER type reported") + t.Fatal("invalid BER type reported") } if ldapserver.BerType(0b10000000).Class() != ldapserver.BerClassContextSpecific { - t.Error("invalid BER type reported") + t.Fatal("invalid BER type reported") } if ldapserver.BerType(0b11000000).Class() != ldapserver.BerClassPrivate { - t.Error("invalid BER type reported") + t.Fatal("invalid BER type reported") } if ldapserver.BerType(0b00100000).IsPrimitive() { - t.Error("invalid primitive flag reported") + t.Fatal("invalid primitive flag reported") } if !ldapserver.BerType(0b00000000).IsPrimitive() { - t.Error("invalid primitive flag reported") + t.Fatal("invalid primitive flag reported") } if ldapserver.BerType(0b00000000).IsConstructed() { - t.Error("invalid constructed flag reported") + t.Fatal("invalid constructed flag reported") } if !ldapserver.BerType(0b00100000).IsConstructed() { - t.Error("invalid constructed flag reported") + t.Fatal("invalid constructed flag reported") } if ldapserver.BerType(0b11111111).TagNumber() != 0b00011111 { - t.Error("invalid tag number reported") + t.Fatal("invalid tag number reported") } if ldapserver.BerType(0b00000000).TagNumber() != 0b00000000 { - t.Error("invalid tag number reported") + t.Fatal("invalid tag number reported") } if ldapserver.BerType(0b10101010).TagNumber() != 0b00001010 { - t.Error("invalid tag number reported") + t.Fatal("invalid tag number reported") } } func TestBerSizes(t *testing.T) { @@ -93,10 +93,10 @@ func TestBerSizes(t *testing.T) { } { size, err := ldapserver.BerReadSize(bytes.NewReader(st.repr)) if size != st.size { - t.Error("invalid size read") + t.Fatal("invalid size read") } if !errors.Is(err, st.err) { - t.Error("Expected error", st.err, ", got error", err) + t.Fatal("Expected error", st.err, ", got error", err) } } } @@ -116,26 +116,26 @@ func TestBerReadElement(t *testing.T) { } { elmt, err := ldapserver.BerReadElement(bytes.NewReader(et.repr)) if elmt.Type != et.res.Type { - t.Error("invalid type read") + t.Fatal("invalid type read") } if !bytes.Equal(elmt.Data, et.res.Data) { - t.Error("invalid data read") + t.Fatal("invalid data read") } if err != et.err { - t.Error("Expected error", et.err, ", got error", err) + t.Fatal("Expected error", et.err, ", got error", err) } } } func TestBerBoolean(t *testing.T) { if getBooleanSimple([]byte{0x00}, false) { - t.Error("invalid boolean read") + t.Fatal("invalid boolean read") } if !getBooleanSimple([]byte{0x01}, true) { - t.Error("invalid boolean read") + t.Fatal("invalid boolean read") } if !getBooleanSimple([]byte{0xff}, true) { - t.Error("invalid boolean read") + t.Fatal("invalid boolean read") } } @@ -143,34 +143,34 @@ func TestBerInteger(t *testing.T) { BerGetInteger := func(data []byte) int64 { res, err := ldapserver.BerGetInteger(data) if err != nil { - t.Error("Error reading integer:", err.Error()) + t.Fatal("Error reading integer:", err.Error()) } return res } if BerGetInteger([]byte{0x00}) != 0 { - t.Error("invalid integer read") + t.Fatal("invalid integer read") } if BerGetInteger([]byte{0x32}) != 50 { - t.Error("invalid integer read") + t.Fatal("invalid integer read") } if BerGetInteger([]byte{0x00, 0xc3, 0x50}) != 50000 { - t.Error("invalid integer read") + t.Fatal("invalid integer read") } if BerGetInteger([]byte{0xcf, 0xc7}) != -12345 { - t.Error("invalid integer read") + t.Fatal("invalid integer read") } _, err := ldapserver.BerGetInteger([]byte{0x12, 0x34, 0x56, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00}) if !errors.Is(err, ldapserver.ErrIntegerTooLarge) { - t.Error("Expected error", ldapserver.ErrIntegerTooLarge, ", got error", err) + t.Fatal("Expected error", ldapserver.ErrIntegerTooLarge, ", got error", err) } } func TestBerOctetString(t *testing.T) { if ldapserver.BerGetOctetString([]byte{}) != "" { - t.Error("invalid octet string read") + t.Fatal("invalid octet string read") } if ldapserver.BerGetOctetString([]byte("This is a test!")) != "This is a test!" { - t.Error("invalid octet string read") + t.Fatal("invalid octet string read") } } @@ -178,19 +178,19 @@ func TestBerSequence(t *testing.T) { seq, err := ldapserver.BerGetSequence( []byte{0x04, 0x06, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x01, 0x01, 0xff, 0x02, 0x01, 0x05}) if err != nil { - t.Error(err) + t.Fatal(err) } if len(seq) != 3 { - t.Error("wrong length of sequence", len(seq)) + t.Fatal("wrong length of sequence", len(seq)) } if seq[0].Type != ldapserver.BerTypeOctetString && ldapserver.BerGetOctetString(seq[0].Data) != "Hello!" { - t.Error("wrong first item of sequence", seq[0]) + t.Fatal("wrong first item of sequence", seq[0]) } if seq[1].Type != ldapserver.BerTypeBoolean && getBooleanSimple(seq[1].Data, true) != true { - t.Error("wrong second item of sequence", seq[1]) + t.Fatal("wrong second item of sequence", seq[1]) } if seq[2].Type != ldapserver.BerTypeInteger && getIntegerSimple(seq[2].Data, 5) != 5 { - t.Error("wrong third item of sequence", seq[2]) + t.Fatal("wrong third item of sequence", seq[2]) } } @@ -202,26 +202,26 @@ func TestParseDeleteRequest(t *testing.T) { 0x2e, 0x38, 0x30, 0x35, 0x01, 0x01, 0xff} m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(deleteRequest)) if err != nil { - t.Error("Failed to parse LDAP message:", err) + t.Fatal("Failed to parse LDAP message:", err) } if m.MessageID != 5 { - t.Error("invalid message ID") + t.Fatal("invalid message ID") } if m.ProtocolOp.Type != ldapserver.TypeDeleteRequestOp { - t.Error("invalid protocol op type") + t.Fatal("invalid protocol op type") } // m.ProtocolOp.Data should be "dc=example,dc=com" if len(m.Controls) != 1 { - t.Error("invalid number of controls") + t.Fatal("invalid number of controls") } if m.Controls[0].OID != "1.2.840.113556.1.4.805" { - t.Error("invalid control OID") + t.Fatal("invalid control OID") } if m.Controls[0].Criticality != true { - t.Error("invalid criticality") + t.Fatal("invalid criticality") } if m.Controls[0].ControlValue != "" { - t.Error("invalid control value") + t.Fatal("invalid control value") } } @@ -229,32 +229,32 @@ func TestParseEmptySuccessResult(t *testing.T) { emptySuccess := []byte{0x30, 0x0c, 0x02, 0x01, 0x03, 0x69, 0x07, 0x0a, 0x01, 0x00, 0x04, 0x00, 0x04, 0x00} m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(emptySuccess)) if err != nil { - t.Error("Failed to read LDAPMessage:", err) + t.Fatal("Failed to read LDAPMessage:", err) } if m.MessageID != 3 { - t.Error("invalid message ID") + t.Fatal("invalid message ID") } if m.ProtocolOp.Type != ldapserver.TypeAddResponseOp { - t.Error("invalid protocol op type") + t.Fatal("invalid protocol op type") } if len(m.Controls) != 0 { - t.Error("invalid number of controls") + t.Fatal("invalid number of controls") } r, err := ldapserver.GetResult(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to parse LDAPResult:", err) + t.Fatal("Failed to parse LDAPResult:", err) } if r.ResultCode != ldapserver.ResultSuccess { - t.Error("invalid result code") + t.Fatal("invalid result code") } if r.MatchedDN != "" { - t.Error("invalid matchedDN") + t.Fatal("invalid matchedDN") } if r.DiagnosticMessage != "" { - t.Error("invalid diagnostic message") + t.Fatal("invalid diagnostic message") } if len(r.Referral) != 0 { - t.Error("invalid referral") + t.Fatal("invalid referral") } } @@ -286,32 +286,32 @@ func TestParseNoSuchObjectResult(t *testing.T) { } m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(noSuchObject)) if err != nil { - t.Error("Failed to read LDAPMessage:", err) + t.Fatal("Failed to read LDAPMessage:", err) } if m.MessageID != 3 { - t.Error("invalid message ID") + t.Fatal("invalid message ID") } if len(m.Controls) != 0 { - t.Error("wrong number of controls") + t.Fatal("wrong number of controls") } if m.ProtocolOp.Type != ldapserver.TypeAddResponseOp { - t.Error("invalid protocol op type") + t.Fatal("invalid protocol op type") } r, err := ldapserver.GetResult(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to read LDAPResult:", err) + t.Fatal("Failed to read LDAPResult:", err) } if r.ResultCode != ldapserver.LDAPResultNoSuchObject { - t.Error("wrong result code") + t.Fatal("wrong result code") } if r.MatchedDN != "ou=People, dc=example, dc=com" { - t.Error("wrong matched DN:", r.MatchedDN) + t.Fatal("wrong matched DN:", r.MatchedDN) } if r.DiagnosticMessage != "Entry uid=missing1, ou=missing2, ou=People, dc=example, dc=com cannot be created because its parent does not exist." { - t.Error("wrong diagnostic message:", r.DiagnosticMessage) + t.Fatal("wrong diagnostic message:", r.DiagnosticMessage) } if len(r.Referral) != 0 { - t.Error("wrong referral") + t.Fatal("wrong referral") } } @@ -350,38 +350,62 @@ func TestParseReferralResult(t *testing.T) { } m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(referral)) if err != nil { - t.Error("Failed to read LDAPMessage:", err) + t.Fatal("Failed to read LDAPMessage:", err) } if m.MessageID != 3 { - t.Error("wrong message ID") + t.Fatal("wrong message ID") } if len(m.Controls) != 0 { - t.Error("wrong number of controls") + t.Fatal("wrong number of controls") } if m.ProtocolOp.Type != ldapserver.TypeAddResponseOp { - t.Error("wrong protocol op type") + t.Fatal("wrong protocol op type") } r, err := ldapserver.GetResult(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to get LDAPResult:", err) + t.Fatal("Failed to get LDAPResult:", err) } if r.ResultCode != ldapserver.LDAPResultReferral { - t.Error("wrong result code") + t.Fatal("wrong result code") } if r.MatchedDN != "" { - t.Error("wrong matched DN") + t.Fatal("wrong matched DN") } if r.DiagnosticMessage != "This server is read-only. Try a different one." { - t.Error("wrong diagnostic message:", r.DiagnosticMessage) + t.Fatal("wrong diagnostic message:", r.DiagnosticMessage) } if len(r.Referral) != 2 { - t.Error("wrong referral length", len(r.Referral)) + t.Fatal("wrong referral length", len(r.Referral)) } if r.Referral[0] != "ldap://alternate1.example.com:389/uid=jdoe,ou=Remote,dc=example,dc=com" { - t.Error("wrong first referral", r.Referral[0]) + t.Fatal("wrong first referral", r.Referral[0]) } if r.Referral[1] != "ldap://alternate2.example.com:389/uid=jdoe,ou=Remote,dc=example,dc=com" { - t.Error("wrong first referral", r.Referral[1]) + t.Fatal("wrong first referral", r.Referral[1]) + } +} + +func TestParseAbandonRequest(t *testing.T) { + abandonRequest := []byte{0x30, 0x06, 0x02, 0x01, 0x06, 0x50, 0x01, 0x05} + m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(abandonRequest)) + if err != nil { + t.Fatal("Failed to read LDAPMessage:", err) + } + if m.MessageID != 6 { + t.Fatal("wrong message ID") + } + if m.ProtocolOp.Type != ldapserver.TypeAbandonRequestOp { + t.Fatal("wrong protocol op type") + } + if len(m.Controls) != 0 { + t.Fatal("wrong number of controls") + } + messageID, err := ldapserver.BerGetInteger(m.ProtocolOp.Data) + if err != nil { + t.Fatal("Failed to read integer:", err) + } + if messageID != 5 { + t.Fatal("wrong abandon ID") } } @@ -407,35 +431,35 @@ func TestParseAddRequest(t *testing.T) { } m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(addrequest)) if err != nil { - t.Error("Failed to parse LDAPMessage:", err) + t.Fatal("Failed to parse LDAPMessage:", err) } if m.MessageID != 2 { - t.Error("wrong message ID") + t.Fatal("wrong message ID") } if m.ProtocolOp.Type != ldapserver.TypeAddRequestOp { - t.Error("wrong protocol op type") + t.Fatal("wrong protocol op type") } r_add, err := ldapserver.GetAddRequest(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to parse LDAPAddRequest", err) + t.Fatal("Failed to parse LDAPAddRequest", err) } if r_add.Entry != "dc=example,dc=com" { - t.Error("wrong entry", r_add.Entry) + t.Fatal("wrong entry", r_add.Entry) } if len(r_add.Attributes) != 2 { - t.Error("wrong number of attributes") + t.Fatal("wrong number of attributes") } if r_add.Attributes[0].Description != "objectClass" { - t.Error("wrong attribute description") + t.Fatal("wrong attribute description") } if !slicesEqual(r_add.Attributes[0].Values, []string{"top", "domain"}) { - t.Error("wrong attribute values") + t.Fatal("wrong attribute values") } if r_add.Attributes[1].Description != "dc" { - t.Error("wrong attribute description") + t.Fatal("wrong attribute description") } if !slicesEqual(r_add.Attributes[1].Values, []string{"example"}) { - t.Error("wrong attribute values") + t.Fatal("wrong attribute values") } } @@ -450,32 +474,32 @@ func TestParseAnonymousSimpleBindRequest(t *testing.T) { } m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(bindrequest)) if err != nil { - t.Error("Failed to parse LDAPMessage:", err) + t.Fatal("Failed to parse LDAPMessage:", err) } if m.MessageID != 1 { - t.Error("wrong message id") + t.Fatal("wrong message id") } if m.ProtocolOp.Type != ldapserver.TypeBindRequestOp { - t.Error("wrong protocol op type") + t.Fatal("wrong protocol op type") } if len(m.Controls) != 0 { - t.Error("wrong number of controls") + t.Fatal("wrong number of controls") } req, err := ldapserver.GetBindRequest(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to parse bind request:", err) + t.Fatal("Failed to parse bind request:", err) } if req.Version != 3 { - t.Error("wrong protocol version") + t.Fatal("wrong protocol version") } if req.Name != "" { - t.Error("wrong bind DN") + t.Fatal("wrong bind DN") } if req.AuthType != ldapserver.AuthenticationTypeSimple { - t.Error("wrong auth type") + t.Fatal("wrong auth type") } if req.Credentials.(string) != "" { - t.Error("wrong password", req.Credentials) + t.Fatal("wrong password", req.Credentials) } } @@ -494,32 +518,32 @@ func TestParseAuthenticatedSimpleBindRequest(t *testing.T) { } m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(bindrequest)) if err != nil { - t.Error("Failed to parse LDAPMessage:", err) + t.Fatal("Failed to parse LDAPMessage:", err) } if m.MessageID != 1 { - t.Error("wrong message id") + t.Fatal("wrong message id") } if m.ProtocolOp.Type != ldapserver.TypeBindRequestOp { - t.Error("wrong protocol op type") + t.Fatal("wrong protocol op type") } if len(m.Controls) != 0 { - t.Error("wrong number of controls") + t.Fatal("wrong number of controls") } req, err := ldapserver.GetBindRequest(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to parse bind request:", err) + t.Fatal("Failed to parse bind request:", err) } if req.Version != 3 { - t.Error("wrong protocol version") + t.Fatal("wrong protocol version") } if req.Name != "uid=jdoe,ou=People,dc=example,dc=com" { - t.Error("wrong bind DN") + t.Fatal("wrong bind DN") } if req.AuthType != ldapserver.AuthenticationTypeSimple { - t.Error("wrong auth type") + t.Fatal("wrong auth type") } if req.Credentials.(string) != "secret123" { - t.Error("wrong password", req.Credentials) + t.Fatal("wrong password", req.Credentials) } } func TestParseSASLCRAMMD5InitialBindRequest(t *testing.T) { @@ -534,36 +558,36 @@ func TestParseSASLCRAMMD5InitialBindRequest(t *testing.T) { } m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(bindrequest)) if err != nil { - t.Error("Failed to parse LDAPMessage:", err) + t.Fatal("Failed to parse LDAPMessage:", err) } if m.MessageID != 1 { - t.Error("wrong message id") + t.Fatal("wrong message id") } if m.ProtocolOp.Type != ldapserver.TypeBindRequestOp { - t.Error("wrong protocol op type") + t.Fatal("wrong protocol op type") } if len(m.Controls) != 0 { - t.Error("wrong number of controls") + t.Fatal("wrong number of controls") } req, err := ldapserver.GetBindRequest(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to parse bind request:", err) + t.Fatal("Failed to parse bind request:", err) } if req.Version != 3 { - t.Error("wrong protocol version") + t.Fatal("wrong protocol version") } if req.Name != "" { - t.Error("wrong bind DN") + t.Fatal("wrong bind DN") } if req.AuthType != ldapserver.AuthenticationTypeSASL { - t.Error("wrong auth type") + t.Fatal("wrong auth type") } cr := req.Credentials.(*ldapserver.SASLCredentials) if cr.Mechanism != "CRAM-MD5" { - t.Error("wrong mechanism") + t.Fatal("wrong mechanism") } if cr.Credentials != "" { - t.Error("wrong credentials", req.Credentials) + t.Fatal("wrong credentials", req.Credentials) } } func TestParseSASLCRAMMD5BindRequest(t *testing.T) { @@ -583,35 +607,263 @@ func TestParseSASLCRAMMD5BindRequest(t *testing.T) { } m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(bindrequest)) if err != nil { - t.Error("Failed to parse LDAPMessage:", err) + t.Fatal("Failed to parse LDAPMessage:", err) } if m.MessageID != 2 { - t.Error("wrong message id") + t.Fatal("wrong message id") } if m.ProtocolOp.Type != ldapserver.TypeBindRequestOp { - t.Error("wrong protocol op type") + t.Fatal("wrong protocol op type") } if len(m.Controls) != 0 { - t.Error("wrong number of controls") + t.Fatal("wrong number of controls") } req, err := ldapserver.GetBindRequest(m.ProtocolOp.Data) if err != nil { - t.Error("Failed to parse bind request:", err) + t.Fatal("Failed to parse bind request:", err) } if req.Version != 3 { - t.Error("wrong protocol version") + t.Fatal("wrong protocol version") } if req.Name != "" { - t.Error("wrong bind DN") + t.Fatal("wrong bind DN") } if req.AuthType != ldapserver.AuthenticationTypeSASL { - t.Error("wrong auth type") + t.Fatal("wrong auth type") } cr := req.Credentials.(*ldapserver.SASLCredentials) if cr.Mechanism != "CRAM-MD5" { - t.Error("wrong mechanism") + t.Fatal("wrong mechanism") } if cr.Credentials != "u:jdoe d52116c87c31d9cc747600f9486d2a1d" { - t.Error("wrong credentials", req.Credentials) + t.Fatal("wrong credentials", req.Credentials) + } +} + +func TestParseModifyRequest(t *testing.T) { + modifyrequest := []byte{ + 0x30, 0x81, 0x80, + 0x02, 0x01, 0x02, + 0x66, 0x7b, + 0x04, 0x24, 0x75, 0x69, 0x64, 0x3d, 0x6a, 0x64, 0x6f, 0x65, + 0x2c, 0x6f, 0x75, 0x3d, 0x50, 0x65, 0x6f, 0x70, + 0x6c, 0x65, 0x2c, 0x64, 0x63, 0x3d, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2c, 0x64, 0x63, + 0x3d, 0x63, 0x6f, 0x6d, + 0x30, 0x53, + 0x30, 0x18, + 0x0a, 0x01, 0x01, + 0x30, 0x13, + 0x04, 0x09, 0x67, 0x69, 0x76, 0x65, 0x6e, 0x4e, 0x61, 0x6d, + 0x65, + 0x31, 0x06, + 0x04, 0x04, 0x4a, 0x6f, 0x68, 0x6e, + 0x30, 0x1c, + 0x0a, 0x01, 0x00, + 0x30, 0x17, + 0x04, 0x09, 0x67, 0x69, 0x76, 0x65, 0x6e, 0x4e, 0x61, 0x6d, + 0x65, + 0x31, 0x0a, + 0x04, 0x08, 0x4a, 0x6f, 0x6e, 0x61, 0x74, 0x68, 0x61, 0x6e, + 0x30, 0x19, + 0x0a, 0x01, 0x02, + 0x30, 0x14, + 0x04, 0x02, 0x63, 0x6e, + 0x31, 0x0e, + 0x04, 0x0c, 0x4a, 0x6f, 0x6e, 0x61, 0x74, 0x68, 0x61, 0x6e, + 0x20, 0x44, 0x6f, 0x65, + } + m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(modifyrequest)) + if err != nil { + t.Fatal("Failed to parse LDAPMessage:", err) + } + if m.MessageID != 2 { + t.Fatal("wrong message id") + } + if m.ProtocolOp.Type != ldapserver.TypeModifyRequestOp { + t.Fatal("wrong protocol op type") + } + if len(m.Controls) != 0 { + t.Fatal("wrong number of controls") + } + req, err := ldapserver.GetModifyRequest(m.ProtocolOp.Data) + if err != nil { + t.Fatal("Failed to parse modify request:", err) + } + if req.Object != "uid=jdoe,ou=People,dc=example,dc=com" { + t.Fatal("wrong object") + } + if len(req.Changes) != 3 { + t.Fatal("wrong number of modifications") + } + if req.Changes[0].Operation != ldapserver.ModifyDelete { + t.Fatal("wrong operation") + } + if req.Changes[0].Modification.Description != "givenName" { + t.Fatal("wrong attribute") + } + if !slicesEqual(req.Changes[0].Modification.Values, []string{"John"}) { + t.Fatal("wrong values") + } + if req.Changes[1].Operation != ldapserver.ModifyAdd { + t.Fatal("wrong operation") + } + if req.Changes[1].Modification.Description != "givenName" { + t.Fatal("wrong attribute") + } + if !slicesEqual(req.Changes[1].Modification.Values, []string{"Jonathan"}) { + t.Fatal("wrong values") + } + if req.Changes[2].Operation != ldapserver.ModifyReplace { + t.Fatal("wrong operation") + } + if req.Changes[2].Modification.Description != "cn" { + t.Fatal("wrong attribute") + } + if !slicesEqual(req.Changes[2].Modification.Values, []string{"Jonathan Doe"}) { + t.Fatal("wrong values") + } +} + +func TestParseModifyDNRenameRequest(t *testing.T) { + modifyDNRequest := []byte{ + 0x30, 0x3c, + 0x02, 0x01, 0x02, + 0x6c, 0x37, + 0x04, 0x24, 0x75, 0x69, 0x64, 0x3d, 0x6a, 0x64, 0x6f, 0x65, + 0x2c, 0x6f, 0x75, 0x3d, 0x50, 0x65, 0x6f, 0x70, + 0x6c, 0x65, 0x2c, 0x64, 0x63, 0x3d, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2c, 0x64, 0x63, + 0x3d, 0x63, 0x6f, 0x6d, + 0x04, 0x0c, 0x75, 0x69, 0x64, 0x3d, 0x6a, 0x6f, 0x68, 0x6e, + 0x2e, 0x64, 0x6f, 0x65, + 0x01, 0x01, 0xff, + } + m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(modifyDNRequest)) + if err != nil { + t.Fatal("Failed to parse LDAPMessage:", err) + } + if m.MessageID != 2 { + t.Fatal("wrong message id") + } + if m.ProtocolOp.Type != ldapserver.TypeModifyDNRequestOp { + t.Fatal("wrong protocol op type") + } + if len(m.Controls) != 0 { + t.Fatal("wrong number of controls") + } + req, err := ldapserver.GetModifyDNRequest(m.ProtocolOp.Data) + if err != nil { + t.Fatal("Failed to parse modify DN request:", err) + } + if req.Object != "uid=jdoe,ou=People,dc=example,dc=com" { + t.Fatal("wrong object") + } + if req.NewRDN != "uid=john.doe" { + t.Fatal("wrong new RDN") + } + if req.DeleteOldRDN != true { + t.Fatal("wrong delete old RDN") + } + if req.NewSuperior != "" { + t.Fatal("wrong new superior") + } +} + +func TestParseModifyDNMoveRequest(t *testing.T) { + moveRequest := []byte{ + 0x30, 0x5c, + 0x02, 0x01, 0x03, + 0x6c, 0x57, + 0x04, 0x28, 0x75, 0x69, 0x64, 0x3d, 0x6a, 0x6f, 0x68, 0x6e, + 0x2e, 0x64, 0x6f, 0x65, 0x2c, 0x6f, 0x75, 0x3d, + 0x50, 0x65, 0x6f, 0x70, 0x6c, 0x65, 0x2c, 0x64, + 0x63, 0x3d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x2c, 0x64, 0x63, 0x3d, 0x63, 0x6f, 0x6d, + 0x04, 0x0c, 0x75, 0x69, 0x64, 0x3d, 0x6a, 0x6f, 0x68, 0x6e, + 0x2e, 0x64, 0x6f, 0x65, + 0x01, 0x01, 0x00, + 0x80, 0x1a, 0x6f, 0x75, 0x3d, 0x55, 0x73, 0x65, 0x72, 0x73, + 0x2c, 0x64, 0x63, 0x3d, 0x65, 0x78, 0x61, 0x6d, + 0x70, 0x6c, 0x65, 0x2c, 0x64, 0x63, 0x3d, 0x63, + 0x6f, 0x6d, + } + m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(moveRequest)) + if err != nil { + t.Fatal("Failed to parse LDAPMessage:", err) + } + if m.MessageID != 3 { + t.Fatal("wrong message id") + } + if m.ProtocolOp.Type != ldapserver.TypeModifyDNRequestOp { + t.Fatal("wrong protocol op type") + } + if len(m.Controls) != 0 { + t.Fatal("wrong number of controls") + } + req, err := ldapserver.GetModifyDNRequest(m.ProtocolOp.Data) + if err != nil { + t.Fatal("Failed to parse modify DN request:", err) + } + if req.Object != "uid=john.doe,ou=People,dc=example,dc=com" { + t.Fatal("wrong object") + } + if req.NewRDN != "uid=john.doe" { + t.Fatal("wrong new RDN") + } + if req.DeleteOldRDN != false { + t.Fatal("wrong delete old RDN") + } + if req.NewSuperior != "ou=Users,dc=example,dc=com" { + t.Fatal("wrong new superior") + } +} + +func TestParseModifyDNRenameAndMoveRequest(t *testing.T) { + renameAndMoveRequest := []byte{ + 0x30, 0x58, + 0x02, 0x01, 0x02, + 0x6c, 0x53, + 0x04, 0x24, 0x75, 0x69, 0x64, 0x3d, 0x6a, 0x64, 0x6f, 0x65, + 0x2c, 0x6f, 0x75, 0x3d, 0x50, 0x65, 0x6f, 0x70, + 0x6c, 0x65, 0x2c, 0x64, 0x63, 0x3d, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2c, 0x64, 0x63, + 0x3d, 0x63, 0x6f, 0x6d, + 0x04, 0x0c, 0x75, 0x69, 0x64, 0x3d, 0x6a, 0x6f, 0x68, 0x6e, + 0x2e, 0x64, 0x6f, 0x65, + 0x01, 0x01, 0xff, + 0x80, 0x1a, 0x6f, 0x75, 0x3d, 0x55, 0x73, 0x65, 0x72, 0x73, + 0x2c, 0x64, 0x63, 0x3d, 0x65, 0x78, 0x61, 0x6d, + 0x70, 0x6c, 0x65, 0x2c, 0x64, 0x63, 0x3d, 0x63, + 0x6f, 0x6d, + } + m, err := ldapserver.ReadLDAPMessage(bytes.NewReader(renameAndMoveRequest)) + if err != nil { + t.Fatal("Failed to parse LDAPMessage:", err) + } + if m.MessageID != 2 { + t.Fatal("wrong message id") + } + if m.ProtocolOp.Type != ldapserver.TypeModifyDNRequestOp { + t.Fatal("wrong protocol op type") + } + if len(m.Controls) != 0 { + t.Fatal("wrong number of controls") + } + req, err := ldapserver.GetModifyDNRequest(m.ProtocolOp.Data) + if err != nil { + t.Fatal("Failed to parse modify DN request:", err) + } + if req.Object != "uid=jdoe,ou=People,dc=example,dc=com" { + t.Fatal("wrong object") + } + if req.NewRDN != "uid=john.doe" { + t.Fatal("wrong new RDN") + } + if req.DeleteOldRDN != true { + t.Fatal("wrong delete old RDN") + } + if req.NewSuperior != "ou=Users,dc=example,dc=com" { + t.Fatal("wrong new superior") } } diff --git a/handler.go b/handler.go index 228d021..4a283ad 100644 --- a/handler.go +++ b/handler.go @@ -16,6 +16,10 @@ type Handler interface { Bind(*Conn, *Message, *BindRequest) // Perform an Extended request Extended(*Conn, *Message, *ExtendedRequest) + // Perform a Modify request + Modify(*Conn, *Message, *ModifyRequest) + // Perform a ModifyDN request + ModifyDN(*Conn, *Message, *ModifyDNRequest) // Perform a Search request Search(*Conn, *Message, *SearchRequest) } @@ -38,6 +42,14 @@ func (*BaseHandler) Bind(conn *Conn, msg *Message, req *BindRequest) { conn.SendResult(msg.MessageID, nil, TypeBindResponseOp, UnsupportedOperation) } +func (*BaseHandler) Modify(conn *Conn, msg *Message, req *ModifyRequest) { + conn.SendResult(msg.MessageID, nil, TypeModifyResponseOp, UnsupportedOperation) +} + +func (*BaseHandler) ModifyDN(conn *Conn, msg *Message, req *ModifyDNRequest) { + conn.SendResult(msg.MessageID, nil, TypeModifyDNResponseOp, UnsupportedOperation) +} + func (*BaseHandler) Search(conn *Conn, msg *Message, req *SearchRequest) { conn.SendResult(msg.MessageID, nil, TypeSearchResultDoneOp, UnsupportedOperation) } diff --git a/modify.go b/modify.go new file mode 100644 index 0000000..2bf9027 --- /dev/null +++ b/modify.go @@ -0,0 +1,80 @@ +package ldapserver + +// ModifyRequest ::= [APPLICATION 6] SEQUENCE { +// object LDAPDN, +// changes SEQUENCE OF change SEQUENCE { +// operation ENUMERATED { +// add (0), +// delete (1), +// replace (2) }, +// modification Attribute } +type ModifyRequest struct { + Object string + Changes []ModifyChange +} + +type ModifyChange struct { + Operation ModifyOperation + Modification Attribute +} + +type ModifyOperation uint8 + +// Defined operations +const ( + ModifyAdd ModifyOperation = 0 + ModifyDelete ModifyOperation = 1 + ModifyReplace ModifyOperation = 2 + // extensible, more possible +) + +// Return a ModifyRequest from BER-encoded data +func GetModifyRequest(data []byte) (*ModifyRequest, error) { + seq, err := BerGetSequence(data) + if err != nil { + return nil, err + } + if len(seq) != 2 { + return nil, ErrWrongSequenceLength.WithInfo("ModifyRequest sequence length", len(seq)) + } + if seq[0].Type != BerTypeOctetString { + return nil, ErrWrongElementType.WithInfo("ModifyRequest object type", seq[0].Type) + } + object := BerGetOctetString(seq[0].Data) + if seq[1].Type != BerTypeSequence { + return nil, ErrWrongElementType.WithInfo("ModifyRequest changes type", seq[1].Type) + } + ch_seq, err := BerGetSequence(seq[1].Data) + if err != nil { + return nil, err + } + var changes []ModifyChange + for _, c := range ch_seq { + if c.Type != BerTypeSequence { + return nil, ErrWrongElementType.WithInfo("ModifyRequest change type", c.Type) + } + c_seq, err := BerGetSequence(c.Data) + if err != nil { + return nil, err + } + if len(c_seq) != 2 { + return nil, ErrWrongSequenceLength.WithInfo("ModifyRequest change sequence length", len(c_seq)) + } + if c_seq[0].Type != BerTypeEnumerated { + return nil, ErrWrongElementType.WithInfo("ModifyRequest change operation type", c_seq[0].Type) + } + op, err := BerGetEnumerated(c_seq[0].Data) + if err != nil { + return nil, err + } + if c_seq[1].Type != BerTypeSequence { + return nil, ErrWrongElementType.WithInfo("ModifyRequest change modification type", c_seq[1].Type) + } + attr, err := GetAttribute(c_seq[1].Data) + if err != nil { + return nil, err + } + changes = append(changes, ModifyChange{Operation: ModifyOperation(op), Modification: attr}) + } + return &ModifyRequest{Object: object, Changes: changes}, nil +} diff --git a/modifyDN.go b/modifyDN.go new file mode 100644 index 0000000..5e5e836 --- /dev/null +++ b/modifyDN.go @@ -0,0 +1,47 @@ +package ldapserver + +// ModifyDNRequest ::= [APPLICATION 12] SEQUENCE { +// entry LDAPDN, +// newrdn RelativeLDAPDN, +// deleteoldrdn BOOLEAN, +// newSuperior [0] LDAPDN OPTIONAL } +type ModifyDNRequest struct { + Object string + NewRDN string + DeleteOldRDN bool + NewSuperior string +} + +// Return a ModifyDNRequest from BER-encoded data +func GetModifyDNRequest(data []byte) (*ModifyDNRequest, error) { + seq, err := BerGetSequence(data) + if err != nil { + return nil, err + } + if len(seq) != 3 && len(seq) != 4 { + return nil, ErrWrongSequenceLength.WithInfo("ModifyDNRequest sequence length", len(seq)) + } + if seq[0].Type != BerTypeOctetString { + return nil, ErrWrongElementType.WithInfo("ModifyDNRequest entry type", seq[0].Type) + } + entry := BerGetOctetString(seq[0].Data) + if seq[1].Type != BerTypeOctetString { + return nil, ErrWrongElementType.WithInfo("ModifyDNRequest new RDN type", seq[1].Type) + } + newRDN := BerGetOctetString(seq[1].Data) + if seq[2].Type != BerTypeBoolean { + return nil, ErrWrongElementType.WithInfo("ModifyDNRequest delete old RDN type", seq[2].Type) + } + deleteOldRDN, err := BerGetBoolean(seq[2].Data) + if err != nil { + return nil, err + } + newSuperior := "" + if len(seq) == 4 { + if seq[3].Type != BerContextSpecificType(0, false) { + return nil, ErrWrongElementType.WithInfo("ModifyDNRequest new superior type", seq[3].Type) + } + newSuperior = BerGetOctetString(seq[3].Data) + } + return &ModifyDNRequest{entry, newRDN, deleteOldRDN, newSuperior}, nil +} diff --git a/server.go b/server.go index 55e4cee..8272ca8 100644 --- a/server.go +++ b/server.go @@ -190,10 +190,22 @@ func (s *LDAPServer) handleMessage(conn *Conn, msg *Message) { s.Handler.Extended(conn, msg, req) case TypeModifyDNRequestOp: log.Println("ModifyDN request") - conn.SendResult(msg.MessageID, nil, TypeModifyDNResponseOp, UnsupportedOperation) + req, err := GetModifyDNRequest(msg.ProtocolOp.Data) + if err != nil { + log.Println("Error parsing ModifyDN request:", err) + conn.SendResult(msg.MessageID, nil, TypeModifyDNResponseOp, ProtocolError) + return + } + s.Handler.ModifyDN(conn, msg, req) case TypeModifyRequestOp: log.Println("Modify request") - conn.SendResult(msg.MessageID, nil, TypeModifyResponseOp, UnsupportedOperation) + req, err := GetModifyRequest(msg.ProtocolOp.Data) + if err != nil { + log.Println("Error parsing Modify request:", err) + conn.SendResult(msg.MessageID, nil, TypeModifyResponseOp, ProtocolError) + return + } + s.Handler.Modify(conn, msg, req) case TypeSearchRequestOp: log.Println("Search request") req, err := GetSearchRequest(msg.ProtocolOp.Data) diff --git a/test/main.go b/test/main.go index df3e1d7..56be66f 100644 --- a/test/main.go +++ b/test/main.go @@ -143,3 +143,27 @@ func (t *TestHandler) Search(conn *ldapserver.Conn, msg *ldapserver.Message, req } conn.SendResult(msg.MessageID, nil, ldapserver.TypeSearchResultDoneOp, res) } + +func (t *TestHandler) Modify(conn *ldapserver.Conn, msg *ldapserver.Message, req *ldapserver.ModifyRequest) { + log.Println("Modify DN:", req.Object) + for _, change := range req.Changes { + log.Println(" Operation:", change.Operation) + log.Println(" Modification attribute:", change.Modification.Description) + log.Println(" Values:", change.Modification.Values) + } + res := &ldapserver.Result{ + ResultCode: ldapserver.ResultSuccess, + } + conn.SendResult(msg.MessageID, nil, ldapserver.TypeModifyResponseOp, res) +} + +func (t *TestHandler) ModifyDN(conn *ldapserver.Conn, msg *ldapserver.Message, req *ldapserver.ModifyDNRequest) { + log.Println("Modify DN:", req.Object) + log.Println(" New RDN:", req.NewRDN) + log.Println(" Delete old RDN:", req.DeleteOldRDN) + log.Println(" New superior:", req.NewSuperior) + res := &ldapserver.Result{ + ResultCode: ldapserver.ResultSuccess, + } + conn.SendResult(msg.MessageID, nil, ldapserver.TypeModifyDNResponseOp, res) +}