diff options
Diffstat (limited to '')
-rw-r--r-- | README.md | 16 | ||||
-rw-r--r-- | control.go | 62 | ||||
-rw-r--r-- | examples/server.go | 6 | ||||
-rw-r--r-- | filter.go | 39 | ||||
-rw-r--r-- | filter_test.go | 24 | ||||
-rw-r--r-- | ldap.go | 44 | ||||
-rw-r--r-- | modify.go | 6 | ||||
-rw-r--r-- | server.go | 544 | ||||
-rw-r--r-- | server_bind.go | 73 | ||||
-rw-r--r-- | server_modify.go | 231 | ||||
-rw-r--r-- | server_modify_test.go | 191 | ||||
-rw-r--r-- | server_search.go | 216 | ||||
-rw-r--r-- | server_search_test.go | 403 | ||||
-rw-r--r-- | server_test.go | 376 | ||||
-rw-r--r-- | tests/add.ldif | 6 | ||||
-rw-r--r-- | tests/add2.ldif | 6 | ||||
-rw-r--r-- | tests/cert_DONOTUSE.pem (renamed from examples/cert_DONOTUSE.pem) | 0 | ||||
-rw-r--r-- | tests/key_DONOTUSE.pem (renamed from examples/key_DONOTUSE.pem) | 0 | ||||
-rw-r--r-- | tests/modify.ldif | 16 | ||||
-rw-r--r-- | tests/modify2.ldif | 10 |
20 files changed, 1515 insertions, 754 deletions
@@ -54,7 +54,7 @@ searchResults, err := l.Search(search) The server library is modeled after net/http - you designate handlers for the LDAP operations you want to support (Bind/Search/etc.), then start the server with ListenAndServe(). You can specify different handlers for different baseDNs - they must implement the interfaces of the operations you want to support: ```go type Binder interface { - Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) + Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) } type Searcher interface { Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) @@ -76,7 +76,7 @@ func main() { } type ldapHandler struct { } -func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { +func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAPResultCode, error) { if bindDN == "" && bindSimplePw == "" { return ldap.LDAPResultSuccess, nil } @@ -89,25 +89,17 @@ func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, e ### LDAP server examples: * examples/server.go: **Basic LDAP authentication (bind and search only)** * examples/proxy.go: **Simple LDAP proxy server.** -* server_test: **The tests have examples of all server functions.** - -*Warning: Do not use the example SSL certificates in production!* +* server_test.go: **The _test.go files have examples of all server functions.** ### Known limitations: * Golang's TLS implementation does not support SSLv2. Some old OSs require SSLv2, and are not able to connect to an LDAP server created with this library's ListenAndServeTLS() function. If you *must* support legacy (read: *insecure*) SSLv2 clients, run your LDAP server behind HAProxy. ### Not implemented: -All of [RFC4510](http://tools.ietf.org/html/rfc4510) is implemented **except**: -* 4.1.11. Controls +From the server perspective, all of [RFC4510](http://tools.ietf.org/html/rfc4510) is implemented **except**: * 4.5.1.3. SearchRequest.derefAliases * 4.5.1.5. SearchRequest.timeLimit * 4.5.1.6. SearchRequest.typesOnly -* 4.6. Modify Operation -* 4.7. Add Operation -* 4.8. Delete Operation -* 4.9. Modify DN Operation -* 4.10. Compare Operation * 4.14. StartTLS Operation *Server library by: [nmcclain](https://github.com/nmcclain)* @@ -6,7 +6,6 @@ package ldap import ( "fmt" - "github.com/nmcclain/asn1-ber" ) @@ -99,40 +98,41 @@ func FindControl(controls []Control, controlType string) Control { func DecodeControl(packet *ber.Packet) Control { ControlType := packet.Children[0].Value.(string) - Criticality := false - packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" - value := packet.Children[1] - if len(packet.Children) == 3 { - value = packet.Children[2] - packet.Children[1].Description = "Criticality" - Criticality = packet.Children[1].Value.(bool) - } + c := new(ControlString) + c.ControlType = ControlType + c.Criticality = false + + if len(packet.Children) > 1 { + value := packet.Children[1] + if len(packet.Children) == 3 { + value = packet.Children[2] + packet.Children[1].Description = "Criticality" + c.Criticality = packet.Children[1].Value.(bool) + } - value.Description = "Control Value" - switch ControlType { - case ControlTypePaging: - value.Description += " (Paging)" - c := new(ControlPaging) - if value.Value != nil { - valueChildren := ber.DecodePacket(value.Data.Bytes()) - value.Data.Truncate(0) - value.Value = nil - value.AppendChild(valueChildren) + value.Description = "Control Value" + switch ControlType { + case ControlTypePaging: + value.Description += " (Paging)" + c := new(ControlPaging) + if value.Value != nil { + valueChildren := ber.DecodePacket(value.Data.Bytes()) + value.Data.Truncate(0) + value.Value = nil + value.AppendChild(valueChildren) + } + value = value.Children[0] + value.Description = "Search Control Value" + value.Children[0].Description = "Paging Size" + value.Children[1].Description = "Cookie" + c.PagingSize = uint32(value.Children[0].Value.(uint64)) + c.Cookie = value.Children[1].Data.Bytes() + value.Children[1].Value = c.Cookie + return c } - value = value.Children[0] - value.Description = "Search Control Value" - value.Children[0].Description = "Paging Size" - value.Children[1].Description = "Cookie" - c.PagingSize = uint32(value.Children[0].Value.(uint64)) - c.Cookie = value.Children[1].Data.Bytes() - value.Children[1].Value = c.Cookie - return c + c.ControlValue = value.Value.(string) } - c := new(ControlString) - c.ControlType = ControlType - c.Criticality = Criticality - c.ControlValue = value.Value.(string) return c } diff --git a/examples/server.go b/examples/server.go index dca74ed..3341991 100644 --- a/examples/server.go +++ b/examples/server.go @@ -24,7 +24,9 @@ func main() { s.SearchFunc("", handler) // start the server - if err := s.ListenAndServe("localhost:3389"); err != nil { + listen := "localhost:3389" + log.Printf("Starting example LDAP server on %s", listen) + if err := s.ListenAndServe(listen); err != nil { log.Fatal("LDAP Server Failed: %s", err.Error()) } } @@ -33,7 +35,7 @@ type ldapHandler struct { } ///////////// Allow anonymous binds only -func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { +func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAPResultCode, error) { if bindDN == "" && bindSimplePw == "" { return ldap.LDAPResultSuccess, nil } @@ -246,9 +246,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { } } -func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, uint64) { - //log.Printf("%# v", pretty.Formatter(entry)) - +func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, LDAPResultCode) { switch FilterMap[f.Tag] { default: //log.Fatalf("Unknown LDAP filter code: %d", f.Tag) @@ -308,30 +306,30 @@ func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, uint64) { } else if !ok { return true, LDAPResultSuccess } - case "FilterSubstrings": + case "FilterSubstrings": // TODO return false, LDAPResultOperationsError - case "FilterGreaterOrEqual": + case "FilterGreaterOrEqual": // TODO return false, LDAPResultOperationsError - case "FilterLessOrEqual": + case "FilterLessOrEqual": // TODO return false, LDAPResultOperationsError - case "FilterApproxMatch": + case "FilterApproxMatch": // TODO return false, LDAPResultOperationsError - case "FilterExtensibleMatch": + case "FilterExtensibleMatch": // TODO return false, LDAPResultOperationsError } return false, LDAPResultSuccess } -func GetFilterType(filter string) (string, error) { // TODO <- test this +func GetFilterObjectClass(filter string) (string, error) { f, err := CompileFilter(filter) if err != nil { return "", err } - return parseFilterType(f) + return parseFilterObjectClass(f) } -func parseFilterType(f *ber.Packet) (string, error) { - searchType := "" +func parseFilterObjectClass(f *ber.Packet) (string, error) { + objectClass := "" switch FilterMap[f.Tag] { case "Equality Match": if len(f.Children) != 2 { @@ -339,42 +337,41 @@ func parseFilterType(f *ber.Packet) (string, error) { } attribute := strings.ToLower(f.Children[0].Value.(string)) value := f.Children[1].Value.(string) - if attribute == "objectclass" { - searchType = strings.ToLower(value) + objectClass = strings.ToLower(value) } case "And": for _, child := range f.Children { - subType, err := parseFilterType(child) + subType, err := parseFilterObjectClass(child) if err != nil { return "", err } if len(subType) > 0 { - searchType = subType + objectClass = subType } } case "Or": for _, child := range f.Children { - subType, err := parseFilterType(child) + subType, err := parseFilterObjectClass(child) if err != nil { return "", err } if len(subType) > 0 { - searchType = subType + objectClass = subType } } case "Not": if len(f.Children) != 1 { return "", errors.New("Not filter must have only one child") } - subType, err := parseFilterType(f.Children[0]) + subType, err := parseFilterObjectClass(f.Children[0]) if err != nil { return "", err } if len(subType) > 0 { - searchType = subType + objectClass = subType } } - return strings.ToLower(searchType), nil + return strings.ToLower(objectClass), nil } diff --git a/filter_test.go b/filter_test.go index fb54905..2e62f25 100644 --- a/filter_test.go +++ b/filter_test.go @@ -111,3 +111,27 @@ func BenchmarkFilterDecompile(b *testing.B) { DecompileFilter(filters[i%maxIdx]) } } + +func TestGetFilterObjectClass(t *testing.T) { + c, err := GetFilterObjectClass("(objectClass=*)") + if err != nil { + t.Errorf("GetFilterObjectClass failed") + } + if c != "" { + t.Errorf("GetFilterObjectClass failed") + } + c, err = GetFilterObjectClass("(objectClass=posixAccount)") + if err != nil { + t.Errorf("GetFilterObjectClass failed") + } + if c != "posixaccount" { + t.Errorf("GetFilterObjectClass failed") + } + c, err = GetFilterObjectClass("(&(cn=awesome)(objectClass=posixGroup))") + if err != nil { + t.Errorf("GetFilterObjectClass failed") + } + if c != "posixgroup" { + t.Errorf("GetFilterObjectClass failed") + } +} @@ -107,7 +107,7 @@ const ( ErrorDebugging = 203 ) -var LDAPResultCodeMap = map[uint8]string{ +var LDAPResultCodeMap = map[LDAPResultCode]string{ LDAPResultSuccess: "Success", LDAPResultOperationsError: "Operations Error", LDAPResultProtocolError: "Protocol Error", @@ -155,6 +155,38 @@ const ( LDAPBindAuthSASL = 3 ) +type LDAPResultCode uint8 + +type Attribute struct { + attrType string + attrVals []string +} +type AddRequest struct { + dn string + attributes []Attribute +} +type DeleteRequest struct { + dn string +} +type ModifyDNRequest struct { + dn string + newrdn string + deleteoldrdn bool + newSuperior string +} +type AttributeValueAssertion struct { + attributeDesc string + assertionValue string +} +type CompareRequest struct { + dn string + ava []AttributeValueAssertion +} +type ExtendedRequest struct { + requestName string + requestValue string +} + // Adds descriptions to an LDAP Response packet for debugging func addLDAPDescriptions(packet *ber.Packet) (err error) { defer func() { @@ -259,7 +291,7 @@ func addRequestDescriptions(packet *ber.Packet) { func addDefaultLDAPResponseDescriptions(packet *ber.Packet) { resultCode := packet.Children[1].Children[0].Value.(uint64) - packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[uint8(resultCode)] + ")" + packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[LDAPResultCode(resultCode)] + ")" packet.Children[1].Children[1].Description = "Matched DN" packet.Children[1].Children[2].Description = "Error Message" if len(packet.Children[1].Children) > 3 { @@ -285,22 +317,22 @@ func DebugBinaryFile(fileName string) error { type Error struct { Err error - ResultCode uint8 + ResultCode LDAPResultCode } func (e *Error) Error() string { return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error()) } -func NewError(resultCode uint8, err error) error { +func NewError(resultCode LDAPResultCode, err error) error { return &Error{ResultCode: resultCode, Err: err} } -func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) { +func getLDAPResultCode(packet *ber.Packet) (code LDAPResultCode, description string) { if len(packet.Children) >= 2 { response := packet.Children[1] if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 { - return uint8(response.Children[0].Value.(uint64)), response.Children[2].Value.(string) + return LDAPResultCode(response.Children[0].Value.(uint64)), response.Children[2].Value.(string) } } @@ -42,6 +42,12 @@ const ( ReplaceAttribute = 2 ) +var LDAPModifyAttributeMap = map[uint64]string{ + AddAttribute: "Add", + DeleteAttribute: "Delete", + ReplaceAttribute: "Replace", +} + type PartialAttribute struct { attrType string attrVals []string @@ -2,8 +2,6 @@ package ldap import ( "crypto/tls" - "errors" - "fmt" "github.com/nmcclain/asn1-ber" "io" "log" @@ -13,23 +11,55 @@ import ( ) type Binder interface { - Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) + Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) } type Searcher interface { - Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) + Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) +} +type Adder interface { + Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) +} +type Modifier interface { + Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) +} +type Deleter interface { + Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) +} +type ModifyDNr interface { + ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) +} +type Comparer interface { + Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) +} +type Abandoner interface { + Abandon(boundDN string, conn net.Conn) error +} +type Extender interface { + Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) +} +type Unbinder interface { + Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) } type Closer interface { - Close(conn net.Conn) error + Close(boundDN string, conn net.Conn) error } -///////////////////////// +// type Server struct { - bindFns map[string]Binder - searchFns map[string]Searcher - closeFns map[string]Closer - quit chan bool + BindFns map[string]Binder + SearchFns map[string]Searcher + AddFns map[string]Adder + ModifyFns map[string]Modifier + DeleteFns map[string]Deleter + ModifyDNFns map[string]ModifyDNr + CompareFns map[string]Comparer + AbandonFns map[string]Abandoner + ExtendedFns map[string]Extender + UnbindFns map[string]Unbinder + CloseFns map[string]Closer + Quit chan bool EnforceLDAP bool - stats *Stats + Stats *Stats } type Stats struct { @@ -44,35 +74,75 @@ type ServerSearchResult struct { Entries []*Entry Referrals []string Controls []Control - ResultCode uint64 + ResultCode LDAPResultCode } -///////////////////////// +// func NewServer() *Server { s := new(Server) - s.quit = make(chan bool) + s.Quit = make(chan bool) d := defaultHandler{} - s.bindFns = make(map[string]Binder) - s.searchFns = make(map[string]Searcher) - s.closeFns = make(map[string]Closer) - s.bindFns[""] = d - s.searchFns[""] = d - s.closeFns[""] = d - s.stats = nil + s.BindFns = make(map[string]Binder) + s.SearchFns = make(map[string]Searcher) + s.AddFns = make(map[string]Adder) + s.ModifyFns = make(map[string]Modifier) + s.DeleteFns = make(map[string]Deleter) + s.ModifyDNFns = make(map[string]ModifyDNr) + s.CompareFns = make(map[string]Comparer) + s.AbandonFns = make(map[string]Abandoner) + s.ExtendedFns = make(map[string]Extender) + s.UnbindFns = make(map[string]Unbinder) + s.CloseFns = make(map[string]Closer) + s.BindFunc("", d) + s.SearchFunc("", d) + s.AddFunc("", d) + s.ModifyFunc("", d) + s.DeleteFunc("", d) + s.ModifyDNFunc("", d) + s.CompareFunc("", d) + s.AbandonFunc("", d) + s.ExtendedFunc("", d) + s.UnbindFunc("", d) + s.CloseFunc("", d) + s.Stats = nil return s } -func (server *Server) BindFunc(baseDN string, bindFn Binder) { - server.bindFns[baseDN] = bindFn +func (server *Server) BindFunc(baseDN string, f Binder) { + server.BindFns[baseDN] = f +} +func (server *Server) SearchFunc(baseDN string, f Searcher) { + server.SearchFns[baseDN] = f +} +func (server *Server) AddFunc(baseDN string, f Adder) { + server.AddFns[baseDN] = f +} +func (server *Server) ModifyFunc(baseDN string, f Modifier) { + server.ModifyFns[baseDN] = f } -func (server *Server) SearchFunc(baseDN string, searchFn Searcher) { - server.searchFns[baseDN] = searchFn +func (server *Server) DeleteFunc(baseDN string, f Deleter) { + server.DeleteFns[baseDN] = f } -func (server *Server) CloseFunc(baseDN string, closeFn Closer) { - server.closeFns[baseDN] = closeFn +func (server *Server) ModifyDNFunc(baseDN string, f ModifyDNr) { + server.ModifyDNFns[baseDN] = f +} +func (server *Server) CompareFunc(baseDN string, f Comparer) { + server.CompareFns[baseDN] = f +} +func (server *Server) AbandonFunc(baseDN string, f Abandoner) { + server.AbandonFns[baseDN] = f +} +func (server *Server) ExtendedFunc(baseDN string, f Extender) { + server.ExtendedFns[baseDN] = f +} +func (server *Server) UnbindFunc(baseDN string, f Unbinder) { + server.UnbindFns[baseDN] = f +} +func (server *Server) CloseFunc(baseDN string, f Closer) { + server.CloseFns[baseDN] = f } func (server *Server) QuitChannel(quit chan bool) { - server.quit = quit + server.Quit = quit } func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error { @@ -95,18 +165,18 @@ func (server *Server) ListenAndServeTLS(listenString string, certFile string, ke func (server *Server) SetStats(enable bool) { if enable { - server.stats = &Stats{} + server.Stats = &Stats{} } else { - server.stats = nil + server.Stats = nil } } func (server *Server) GetStats() Stats { defer func() { - server.stats.statsMutex.Unlock() + server.Stats.statsMutex.Unlock() }() - server.stats.statsMutex.Lock() - return *server.stats + server.Stats.statsMutex.Lock() + return *server.Stats } func (server *Server) ListenAndServe(listenString string) error { @@ -140,9 +210,9 @@ listener: for { select { case c := <-newConn: - server.stats.countConns(1) + server.Stats.countConns(1) go server.handleConnection(c) - case <-server.quit: + case <-server.Quit: ln.Close() break listener } @@ -150,8 +220,7 @@ listener: return nil } -///////////////////////// - +// func (server *Server) handleConnection(conn net.Conn) { boundDN := "" // "" == anonymous @@ -172,40 +241,46 @@ handler: break } // check the message ID and ClassType - messageID := packet.Children[0].Value.(uint64) + messageID, ok := packet.Children[0].Value.(uint64) + if !ok { + log.Print("malformed messageID") + break + } req := packet.Children[1] if req.ClassType != ber.ClassApplication { log.Print("req.ClassType != ber.ClassApplication") break } // handle controls if present + controls := []Control{} if len(packet.Children) > 2 { - controls := packet.Children[2] - ber.PrintPacket(controls) - log.Print("TODO Parse Controls") - /* - Controls ::= SEQUENCE OF control Control - - Control ::= SEQUENCE { - controlType LDAPOID, - criticality BOOLEAN DEFAULT FALSE, // unavailableCriticalExtension - controlValue OCTET STRING OPTIONAL } - */ + for _, child := range packet.Children[2].Children { + controls = append(controls, DecodeControl(child)) + } } + //log.Printf("DEBUG: handling operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + //ber.PrintPacket(packet) // DEBUG + // dispatch the LDAP operation switch req.Tag { // ldap op code default: - //log.Printf("Bound as %s", boundDN) - //ber.PrintPacket(packet) + responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultOperationsError, "Unsupported operation: add") + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + } log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) break handler case ApplicationBindRequest: - server.stats.countBinds(1) - ldapResultCode := server.handleBindRequest(req, server.bindFns, conn) + server.Stats.countBinds(1) + ldapResultCode := HandleBindRequest(req, server.BindFns, conn) if ldapResultCode == LDAPResultSuccess { - boundDN = req.Children[1].Value.(string) + boundDN, ok = req.Children[1].Value.(string) + if !ok { + log.Printf("Malformed Bind DN") + break handler + } } responsePacket := encodeBindResponse(messageID, ldapResultCode) if err = sendPacket(conn, responsePacket); err != nil { @@ -213,12 +288,13 @@ handler: break handler } case ApplicationSearchRequest: - server.stats.countSearches(1) - if err := server.handleSearchRequest(req, messageID, boundDN, server.searchFns, conn); err != nil { + server.Stats.countSearches(1) + if err := HandleSearchRequest(req, &controls, messageID, boundDN, server, conn); err != nil { log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks? e := err.(*Error) - if err = sendPacket(conn, encodeSearchDone(messageID, uint64(e.ResultCode))); err != nil { + if err = sendPacket(conn, encodeSearchDone(messageID, e.ResultCode)); err != nil { log.Printf("sendPacket error %s", err.Error()) + break handler } break handler } else { @@ -228,181 +304,65 @@ handler: } } case ApplicationUnbindRequest: - server.stats.countUnbinds(1) - break handler // simply disconnect - this IS implemented + server.Stats.countUnbinds(1) + break handler // simply disconnect case ApplicationExtendedRequest: - responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, "Unsupported extended request") + ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn) + responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode]) if err = sendPacket(conn, responsePacket); err != nil { log.Printf("sendPacket error %s", err.Error()) + break handler } - break handler case ApplicationAbandonRequest: - log.Printf("Abandoning request!") + HandleAbandonRequest(req, boundDN, server.AbandonFns, conn) break handler - // Unimplemented LDAP operations: - case ApplicationModifyRequest: - log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) - break handler case ApplicationAddRequest: - log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) - break handler - case ApplicationDelRequest: - log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) - break handler - case ApplicationModifyDNRequest: - log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) - break handler - case ApplicationCompareRequest: - log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) - break handler - } - } - - for _, c := range server.closeFns { - c.Close(conn) - } - - conn.Close() -} - -///////////////////////// -func (server *Server) handleSearchRequest(req *ber.Packet, messageID uint64, boundDN string, searchFns map[string]Searcher, conn net.Conn) (resultErr error) { - defer func() { - if r := recover(); r != nil { - resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r)) - } - }() - - searchReq, err := parseSearchRequest(boundDN, req) - if err != nil { - return NewError(LDAPResultOperationsError, err) - } - - filterPacket, err := CompileFilter(searchReq.Filter) - if err != nil { - return NewError(LDAPResultOperationsError, err) - } - - fnNames := []string{} - for k := range searchFns { - fnNames = append(fnNames, k) - } - searchFn := routeFunc(searchReq.BaseDN, fnNames) - searchResp, err := searchFns[searchFn].Search(boundDN, searchReq, conn) - if err != nil { - return NewError(uint8(searchResp.ResultCode), err) - } - - if server.EnforceLDAP { - if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find} - // TODO: Server DerefAliases not implemented: RFC4511 4.5.1.3. SearchRequest.derefAliases - } - if len(searchReq.Controls) > 0 { - return NewError(LDAPResultOperationsError, errors.New("Server controls not implemented")) // TODO - } - if searchReq.TimeLimit > 0 { - return NewError(LDAPResultOperationsError, errors.New("Server TimeLimit not implemented")) // TODO - } - } - - for i, entry := range searchResp.Entries { - if server.EnforceLDAP { - // size limit - if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit { - break + ldapResultCode := HandleAddRequest(req, boundDN, server.AddFns, conn) + responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode]) + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler } - - // filter - keep, resultCode := ServerApplyFilter(filterPacket, entry) - if resultCode != LDAPResultSuccess { - return NewError(uint8(resultCode), errors.New("ServerApplyFilter error")) + case ApplicationModifyRequest: + ldapResultCode := HandleModifyRequest(req, boundDN, server.ModifyFns, conn) + responsePacket := encodeLDAPResponse(messageID, ApplicationModifyResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode]) + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler } - if !keep { - continue + case ApplicationDelRequest: + ldapResultCode := HandleDeleteRequest(req, boundDN, server.DeleteFns, conn) + responsePacket := encodeLDAPResponse(messageID, ApplicationDelResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode]) + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler } - - // constrained search scope - switch searchReq.Scope { - case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates. - case ScopeBaseObject: // The scope is constrained to the entry named by baseObject. - if entry.DN != searchReq.BaseDN { - continue - } - case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject. - parts := strings.Split(entry.DN, ",") - if len(parts) < 2 && entry.DN != searchReq.BaseDN { - continue - } - if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN { - continue - } + case ApplicationModifyDNRequest: + ldapResultCode := HandleModifyDNRequest(req, boundDN, server.ModifyDNFns, conn) + responsePacket := encodeLDAPResponse(messageID, ApplicationModifyDNResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode]) + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler } - - // attributes - if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) { - entry, err = filterAttributes(entry, searchReq.Attributes) - if err != nil { - return NewError(LDAPResultOperationsError, err) - } + case ApplicationCompareRequest: + ldapResultCode := HandleCompareRequest(req, boundDN, server.CompareFns, conn) + responsePacket := encodeLDAPResponse(messageID, ApplicationCompareResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode]) + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler } } - - // respond - responsePacket := encodeSearchResponse(messageID, searchReq, entry) - if err = sendPacket(conn, responsePacket); err != nil { - return NewError(LDAPResultOperationsError, err) - } } - return nil -} -///////////////////////// -func (server *Server) handleBindRequest(req *ber.Packet, bindFns map[string]Binder, conn net.Conn) (resultCode uint64) { - defer func() { - if r := recover(); r != nil { - resultCode = LDAPResultOperationsError - } - }() - - // we only support ldapv3 - ldapVersion := req.Children[0].Value.(uint64) - if ldapVersion != 3 { - log.Printf("Unsupported LDAP version: %d", ldapVersion) - return LDAPResultInappropriateAuthentication + for _, c := range server.CloseFns { + c.Close(boundDN, conn) } - // auth types - bindDN := req.Children[1].Value.(string) - bindAuth := req.Children[2] - switch bindAuth.Tag { - default: - log.Print("Unknown LDAP authentication method") - return LDAPResultInappropriateAuthentication - case LDAPBindAuthSimple: - if len(req.Children) == 3 { - fnNames := []string{} - for k := range bindFns { - fnNames = append(fnNames, k) - } - bindFn := routeFunc(bindDN, fnNames) - resultCode, err := bindFns[bindFn].Bind(bindDN, bindAuth.Data.String(), conn) - if err != nil { - log.Printf("BindFn Error %s", err.Error()) - } - return resultCode - } else { - log.Print("Simple bind request has wrong # children. len(req.Children) != 3") - return LDAPResultInappropriateAuthentication - } - case LDAPBindAuthSASL: - log.Print("SASL authentication is not supported") - return LDAPResultInappropriateAuthentication - } - return LDAPResultOperationsError + conn.Close() } -///////////////////////// +// func sendPacket(conn net.Conn, packet *ber.Packet) error { _, err := conn.Write(packet.Bytes()) if err != nil { @@ -412,38 +372,7 @@ func sendPacket(conn net.Conn, packet *ber.Packet) error { return nil } -///////////////////////// -func parseSearchRequest(boundDN string, req *ber.Packet) (SearchRequest, error) { - if len(req.Children) != 8 { - return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request")) - } - - // Parse the request - baseObject := req.Children[0].Value.(string) - scope := int(req.Children[1].Value.(uint64)) - derefAliases := int(req.Children[2].Value.(uint64)) - sizeLimit := int(req.Children[3].Value.(uint64)) - timeLimit := int(req.Children[4].Value.(uint64)) - typesOnly := false - if req.Children[5].Value != nil { - typesOnly = req.Children[5].Value.(bool) - } - filter, err := DecompileFilter(req.Children[6]) - if err != nil { - return SearchRequest{}, err - } - attributes := []string{} - for _, attr := range req.Children[7].Children { - attributes = append(attributes, attr.Value.(string)) - } - searchReq := SearchRequest{baseObject, scope, - derefAliases, sizeLimit, timeLimit, - typesOnly, filter, attributes, nil} - - return searchReq, nil -} - -///////////////////////// +// func routeFunc(dn string, funcNames []string) string { bestPick := "" for _, fn := range funcNames { @@ -460,109 +389,58 @@ func routeFunc(dn string, funcNames []string) string { return bestPick } -///////////////////////// -func filterAttributes(entry *Entry, attributes []string) (*Entry, error) { - // only return requested attributes - newAttributes := []*EntryAttribute{} - - for _, attr := range entry.Attributes { - for _, requested := range attributes { - if strings.ToLower(attr.Name) == strings.ToLower(requested) { - newAttributes = append(newAttributes, attr) - } - } - } - entry.Attributes = newAttributes - - return entry, nil +// +func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType]) + reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: ")) + reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: ")) + responsePacket.AppendChild(reponse) + return responsePacket } -///////////////////////// +// type defaultHandler struct { } -func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { - return LDAPResultInappropriateAuthentication, nil +func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultInvalidCredentials, nil } -func (h defaultHandler) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { +func (h defaultHandler) Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) { return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil } -func (h defaultHandler) Close(conn net.Conn) error { - conn.Close() - return nil +func (h defaultHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultInsufficientAccessRights, nil } - -///////////////////////// -func encodeBindResponse(messageID uint64, ldapResultCode uint64) *ber.Packet { - responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") - responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) - - bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") - bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) - bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) - bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) - - responsePacket.AppendChild(bindReponse) - - // ber.PrintPacket(responsePacket) - return responsePacket +func (h defaultHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultInsufficientAccessRights, nil } -func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet { - responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") - responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) - - searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry") - searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name")) - - attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:") - for _, attribute := range res.Attributes { - attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values)) - } - - searchEntry.AppendChild(attrs) - responsePacket.AppendChild(searchEntry) - - return responsePacket +func (h defaultHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultInsufficientAccessRights, nil } - -func encodeSearchAttribute(name string, values []string) *ber.Packet { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") - packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name")) - - valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values") - for _, value := range values { - valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value")) - } - - packet.AppendChild(valuesPacket) - - return packet +func (h defaultHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultInsufficientAccessRights, nil } - -func encodeSearchDone(messageID uint64, ldapResultCode uint64) *ber.Packet { - responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") - responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) - donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done") - donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) - donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) - donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) - responsePacket.AppendChild(donePacket) - - return responsePacket +func (h defaultHandler) Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultInsufficientAccessRights, nil } - -func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode uint64, message string) *ber.Packet { - responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") - responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) - reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType]) - reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) - reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) - reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: ")) - responsePacket.AppendChild(reponse) - return responsePacket +func (h defaultHandler) Abandon(boundDN string, conn net.Conn) error { + return nil +} +func (h defaultHandler) Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultProtocolError, nil +} +func (h defaultHandler) Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultSuccess, nil +} +func (h defaultHandler) Close(boundDN string, conn net.Conn) error { + conn.Close() + return nil } -///////////////////////// +// func (stats *Stats) countConns(delta int) { if stats != nil { stats.statsMutex.Lock() @@ -592,4 +470,4 @@ func (stats *Stats) countSearches(delta int) { } } -///////////////////////// +// diff --git a/server_bind.go b/server_bind.go new file mode 100644 index 0000000..5a80bf5 --- /dev/null +++ b/server_bind.go @@ -0,0 +1,73 @@ +package ldap + +import ( + "github.com/nmcclain/asn1-ber" + "log" + "net" +) + +func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) { + defer func() { + if r := recover(); r != nil { + resultCode = LDAPResultOperationsError + } + }() + + // we only support ldapv3 + ldapVersion, ok := req.Children[0].Value.(uint64) + if !ok { + return LDAPResultProtocolError + } + if ldapVersion != 3 { + log.Printf("Unsupported LDAP version: %d", ldapVersion) + return LDAPResultInappropriateAuthentication + } + + // auth types + bindDN, ok := req.Children[1].Value.(string) + if !ok { + return LDAPResultProtocolError + } + bindAuth := req.Children[2] + switch bindAuth.Tag { + default: + log.Print("Unknown LDAP authentication method") + return LDAPResultInappropriateAuthentication + case LDAPBindAuthSimple: + if len(req.Children) == 3 { + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(bindDN, fnNames) + resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn) + if err != nil { + log.Printf("BindFn Error %s", err.Error()) + return LDAPResultOperationsError + } + return resultCode + } else { + log.Print("Simple bind request has wrong # children. len(req.Children) != 3") + return LDAPResultInappropriateAuthentication + } + case LDAPBindAuthSASL: + log.Print("SASL authentication is not supported") + return LDAPResultInappropriateAuthentication + } + return LDAPResultOperationsError +} + +func encodeBindResponse(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + + bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") + bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: ")) + bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) + + responsePacket.AppendChild(bindReponse) + + // ber.PrintPacket(responsePacket) + return responsePacket +} diff --git a/server_modify.go b/server_modify.go new file mode 100644 index 0000000..0dca219 --- /dev/null +++ b/server_modify.go @@ -0,0 +1,231 @@ +package ldap + +import ( + "github.com/nmcclain/asn1-ber" + "log" + "net" +) + +func HandleAddRequest(req *ber.Packet, boundDN string, fns map[string]Adder, conn net.Conn) (resultCode LDAPResultCode) { + if len(req.Children) != 2 { + return LDAPResultProtocolError + } + var ok bool + addReq := AddRequest{} + addReq.dn, ok = req.Children[0].Value.(string) + if !ok { + return LDAPResultProtocolError + } + addReq.attributes = []Attribute{} + for _, attr := range req.Children[1].Children { + if len(attr.Children) != 2 { + return LDAPResultProtocolError + } + + a := Attribute{} + a.attrType, ok = attr.Children[0].Value.(string) + if !ok { + return LDAPResultProtocolError + } + a.attrVals = []string{} + for _, val := range attr.Children[1].Children { + v, ok := val.Value.(string) + if !ok { + return LDAPResultProtocolError + } + a.attrVals = append(a.attrVals, v) + } + addReq.attributes = append(addReq.attributes, a) + } + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(boundDN, fnNames) + resultCode, err := fns[fn].Add(boundDN, addReq, conn) + if err != nil { + log.Printf("AddFn Error %s", err.Error()) + return LDAPResultOperationsError + } + return resultCode +} + +func HandleDeleteRequest(req *ber.Packet, boundDN string, fns map[string]Deleter, conn net.Conn) (resultCode LDAPResultCode) { + deleteDN := ber.DecodeString(req.Data.Bytes()) + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(boundDN, fnNames) + resultCode, err := fns[fn].Delete(boundDN, deleteDN, conn) + if err != nil { + log.Printf("DeleteFn Error %s", err.Error()) + return LDAPResultOperationsError + } + return resultCode +} + +func HandleModifyRequest(req *ber.Packet, boundDN string, fns map[string]Modifier, conn net.Conn) (resultCode LDAPResultCode) { + if len(req.Children) != 2 { + return LDAPResultProtocolError + } + var ok bool + modReq := ModifyRequest{} + modReq.dn, ok = req.Children[0].Value.(string) + if !ok { + return LDAPResultProtocolError + } + for _, change := range req.Children[1].Children { + if len(change.Children) != 2 { + return LDAPResultProtocolError + } + attr := PartialAttribute{} + attrs := change.Children[1].Children + if len(attrs) != 2 { + return LDAPResultProtocolError + } + attr.attrType, ok = attrs[0].Value.(string) + if !ok { + return LDAPResultProtocolError + } + for _, val := range attrs[1].Children { + v, ok := val.Value.(string) + if !ok { + return LDAPResultProtocolError + } + attr.attrVals = append(attr.attrVals, v) + } + op, ok := change.Children[0].Value.(uint64) + if !ok { + return LDAPResultProtocolError + } + switch op { + default: + log.Printf("Unrecognized Modify attribute %d", op) + return LDAPResultProtocolError + case AddAttribute: + modReq.Add(attr.attrType, attr.attrVals) + case DeleteAttribute: + modReq.Delete(attr.attrType, attr.attrVals) + case ReplaceAttribute: + modReq.Replace(attr.attrType, attr.attrVals) + } + } + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(boundDN, fnNames) + resultCode, err := fns[fn].Modify(boundDN, modReq, conn) + if err != nil { + log.Printf("ModifyFn Error %s", err.Error()) + return LDAPResultOperationsError + } + return resultCode +} + +func HandleCompareRequest(req *ber.Packet, boundDN string, fns map[string]Comparer, conn net.Conn) (resultCode LDAPResultCode) { + if len(req.Children) != 2 { + return LDAPResultProtocolError + } + var ok bool + compReq := CompareRequest{} + compReq.dn, ok = req.Children[0].Value.(string) + if !ok { + return LDAPResultProtocolError + } + ava := req.Children[1] + if len(ava.Children) != 2 { + return LDAPResultProtocolError + } + attr, ok := ava.Children[0].Value.(string) + if !ok { + return LDAPResultProtocolError + } + val, ok := ava.Children[1].Value.(string) + if !ok { + return LDAPResultProtocolError + } + compReq.ava = []AttributeValueAssertion{AttributeValueAssertion{attr, val}} + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(boundDN, fnNames) + resultCode, err := fns[fn].Compare(boundDN, compReq, conn) + if err != nil { + log.Printf("CompareFn Error %s", err.Error()) + return LDAPResultOperationsError + } + return resultCode +} + +func HandleExtendedRequest(req *ber.Packet, boundDN string, fns map[string]Extender, conn net.Conn) (resultCode LDAPResultCode) { + if len(req.Children) != 1 && len(req.Children) != 2 { + return LDAPResultProtocolError + } + name := ber.DecodeString(req.Children[0].Data.Bytes()) + var val string + if len(req.Children) == 2 { + val = ber.DecodeString(req.Children[1].Data.Bytes()) + } + extReq := ExtendedRequest{name, val} + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(boundDN, fnNames) + resultCode, err := fns[fn].Extended(boundDN, extReq, conn) + if err != nil { + log.Printf("ExtendedFn Error %s", err.Error()) + return LDAPResultOperationsError + } + return resultCode +} + +func HandleAbandonRequest(req *ber.Packet, boundDN string, fns map[string]Abandoner, conn net.Conn) error { + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(boundDN, fnNames) + err := fns[fn].Abandon(boundDN, conn) + return err +} + +func HandleModifyDNRequest(req *ber.Packet, boundDN string, fns map[string]ModifyDNr, conn net.Conn) (resultCode LDAPResultCode) { + if len(req.Children) != 3 && len(req.Children) != 4 { + return LDAPResultProtocolError + } + var ok bool + mdnReq := ModifyDNRequest{} + mdnReq.dn, ok = req.Children[0].Value.(string) + if !ok { + return LDAPResultProtocolError + } + mdnReq.newrdn, ok = req.Children[1].Value.(string) + if !ok { + return LDAPResultProtocolError + } + mdnReq.deleteoldrdn, ok = req.Children[2].Value.(bool) + if !ok { + return LDAPResultProtocolError + } + if len(req.Children) == 4 { + mdnReq.newSuperior, ok = req.Children[3].Value.(string) + if !ok { + return LDAPResultProtocolError + } + } + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(boundDN, fnNames) + resultCode, err := fns[fn].ModifyDN(boundDN, mdnReq, conn) + if err != nil { + log.Printf("ModifyDN Error %s", err.Error()) + return LDAPResultOperationsError + } + return resultCode +} diff --git a/server_modify_test.go b/server_modify_test.go new file mode 100644 index 0000000..d45b810 --- /dev/null +++ b/server_modify_test.go @@ -0,0 +1,191 @@ +package ldap + +import ( + "net" + "os/exec" + "strings" + "testing" + "time" +) + +// +func TestAdd(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", modifyTestHandler{}) + s.AddFunc("", modifyTestHandler{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + go func() { + cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "modify complete") { + t.Errorf("ldapadd failed: %v", string(out)) + } + cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "ldap_add: Insufficient access") { + t.Errorf("ldapadd should have failed: %v", string(out)) + } + if strings.Contains(string(out), "modify complete") { + t.Errorf("ldapadd should have failed: %v", string(out)) + } + done <- true + }() + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapadd command timed out") + } + quit <- true +} + +// +func TestDelete(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", modifyTestHandler{}) + s.DeleteFunc("", modifyTestHandler{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + go func() { + cmd := exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Delete Me,dc=example,dc=com") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "Delete Result: Success (0)") || !strings.Contains(string(out), "Additional info: Success") { + t.Errorf("ldapdelete failed: %v", string(out)) + } + cmd = exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Bob,dc=example,dc=com") + out, _ = cmd.CombinedOutput() + if strings.Contains(string(out), "Success") || !strings.Contains(string(out), "ldap_delete: Insufficient access") { + t.Errorf("ldapdelete should have failed: %v", string(out)) + } + done <- true + }() + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapdelete command timed out") + } + quit <- true +} + +func TestModify(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", modifyTestHandler{}) + s.ModifyFunc("", modifyTestHandler{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + go func() { + cmd := exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify.ldif") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "modify complete") { + t.Errorf("ldapmodify failed: %v", string(out)) + } + cmd = exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify2.ldif") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "ldap_modify: Insufficient access") || strings.Contains(string(out), "modify complete") { + t.Errorf("ldapmodify should have failed: %v", string(out)) + } + done <- true + }() + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapadd command timed out") + } + quit <- true +} + +/* +func TestModifyDN(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", modifyTestHandler{}) + s.AddFunc("", modifyTestHandler{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + go func() { + cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif") + //ldapmodrdn -H ldap://localhost:3389 -x "uid=babs,dc=example,dc=com" "uid=babsy,dc=example,dc=com" + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "modify complete") { + t.Errorf("ldapadd failed: %v", string(out)) + } + cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "ldap_add: Insufficient access") { + t.Errorf("ldapadd should have failed: %v", string(out)) + } + if strings.Contains(string(out), "modify complete") { + t.Errorf("ldapadd should have failed: %v", string(out)) + } + done <- true + }() + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapadd command timed out") + } + quit <- true +} +*/ + +// +type modifyTestHandler struct { +} + +func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { + if bindDN == "" && bindSimplePw == "" { + return LDAPResultSuccess, nil + } + return LDAPResultInvalidCredentials, nil +} +func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) { + // only succeed on expected contents of add.ldif: + if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" && + req.attributes[2].attrType == "sn" && len(req.attributes[2].attrVals) == 1 && + req.attributes[2].attrVals[0] == "Jensen" { + return LDAPResultSuccess, nil + } + return LDAPResultInsufficientAccessRights, nil +} +func (h modifyTestHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) { + // only succeed on expected deleteDN + if deleteDN == "cn=Delete Me,dc=example,dc=com" { + return LDAPResultSuccess, nil + } + return LDAPResultInsufficientAccessRights, nil +} +func (h modifyTestHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) { + // only succeed on expected contents of modify.ldif: + if req.dn == "cn=testy,dc=example,dc=com" && len(req.addAttributes) == 1 && + len(req.deleteAttributes) == 3 && len(req.replaceAttributes) == 2 && + req.deleteAttributes[2].attrType == "details" && len(req.deleteAttributes[2].attrVals) == 0 { + return LDAPResultSuccess, nil + } + return LDAPResultInsufficientAccessRights, nil +} +func (h modifyTestHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) { + return LDAPResultInsufficientAccessRights, nil +} diff --git a/server_search.go b/server_search.go new file mode 100644 index 0000000..a7d78ac --- /dev/null +++ b/server_search.go @@ -0,0 +1,216 @@ +package ldap + +import ( + "errors" + "fmt" + "github.com/nmcclain/asn1-ber" + "net" + "strings" +) + +func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) { + defer func() { + if r := recover(); r != nil { + resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r)) + } + }() + + searchReq, err := parseSearchRequest(boundDN, req, controls) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + filterPacket, err := CompileFilter(searchReq.Filter) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + fnNames := []string{} + for k := range server.SearchFns { + fnNames = append(fnNames, k) + } + fn := routeFunc(searchReq.BaseDN, fnNames) + searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn) + if err != nil { + return NewError(searchResp.ResultCode, err) + } + + if server.EnforceLDAP { + if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find} + // Server DerefAliases not supported: RFC4511 4.5.1.3 + return NewError(LDAPResultOperationsError, errors.New("Server DerefAliases not supported")) + } + if searchReq.TimeLimit > 0 { + // TODO: Server TimeLimit not implemented + } + } + + for i, entry := range searchResp.Entries { + if server.EnforceLDAP { + // size limit + if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit { + break + } + + // filter + keep, resultCode := ServerApplyFilter(filterPacket, entry) + if resultCode != LDAPResultSuccess { + return NewError(resultCode, errors.New("ServerApplyFilter error")) + } + if !keep { + continue + } + + // constrained search scope + switch searchReq.Scope { + case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates. + case ScopeBaseObject: // The scope is constrained to the entry named by baseObject. + if entry.DN != searchReq.BaseDN { + continue + } + case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject. + parts := strings.Split(entry.DN, ",") + if len(parts) < 2 && entry.DN != searchReq.BaseDN { + continue + } + if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN { + continue + } + } + + // attributes + if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) { + entry, err = filterAttributes(entry, searchReq.Attributes) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + } + + // respond + responsePacket := encodeSearchResponse(messageID, searchReq, entry) + if err = sendPacket(conn, responsePacket); err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + return nil +} + +///////////////////////// +func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (SearchRequest, error) { + if len(req.Children) != 8 { + return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request")) + } + + // Parse the request + baseObject, ok := req.Children[0].Value.(string) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + s, ok := req.Children[1].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + scope := int(s) + d, ok := req.Children[2].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + derefAliases := int(d) + s, ok = req.Children[3].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + sizeLimit := int(s) + t, ok := req.Children[4].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + timeLimit := int(t) + typesOnly := false + if req.Children[5].Value != nil { + typesOnly, ok = req.Children[5].Value.(bool) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + } + filter, err := DecompileFilter(req.Children[6]) + if err != nil { + return SearchRequest{}, err + } + attributes := []string{} + for _, attr := range req.Children[7].Children { + a, ok := attr.Value.(string) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + attributes = append(attributes, a) + } + searchReq := SearchRequest{baseObject, scope, + derefAliases, sizeLimit, timeLimit, + typesOnly, filter, attributes, *controls} + + return searchReq, nil +} + +///////////////////////// +func filterAttributes(entry *Entry, attributes []string) (*Entry, error) { + // only return requested attributes + newAttributes := []*EntryAttribute{} + + for _, attr := range entry.Attributes { + for _, requested := range attributes { + if strings.ToLower(attr.Name) == strings.ToLower(requested) { + newAttributes = append(newAttributes, attr) + } + } + } + entry.Attributes = newAttributes + + return entry, nil +} + +///////////////////////// +func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + + searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry") + searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name")) + + attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:") + for _, attribute := range res.Attributes { + attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values)) + } + + searchEntry.AppendChild(attrs) + responsePacket.AppendChild(searchEntry) + + return responsePacket +} + +func encodeSearchAttribute(name string, values []string) *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name")) + + valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values") + for _, value := range values { + valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value")) + } + + packet.AppendChild(valuesPacket) + + return packet +} + +func encodeSearchDone(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done") + donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) + responsePacket.AppendChild(donePacket) + + return responsePacket +} diff --git a/server_search_test.go b/server_search_test.go new file mode 100644 index 0000000..c3f42b0 --- /dev/null +++ b/server_search_test.go @@ -0,0 +1,403 @@ +package ldap + +import ( + "os/exec" + "strings" + "testing" + "time" +) + +// +func TestSearchSimpleOK(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + serverBaseDN := "o=testers,c=test" + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "uidNumber: 5000") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "numResponses: 4") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +func TestSearchSizelimit(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "9") // effectively no limit for this test + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "numEntries: 3") { + t.Errorf("ldapsearch sizelimit failed - not enough entries: %v", string(out)) + } + + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "2") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch failed: %v", string(out)) + } + if !strings.Contains(string(out), "numEntries: 2") { + t.Errorf("ldapsearch sizelimit failed - too many entries: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestBindSearchMulti(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.BindFunc("", bindSimple{}) + s.BindFunc("c=testz", bindSimple2{}) + s.SearchFunc("", searchSimple{}) + s.SearchFunc("c=testz", searchSimple2{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test", + "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("error routing default bind/search functions: %v", string(out)) + } + if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { + t.Errorf("search default routing failed: %v", string(out)) + } + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=testz", + "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("error routing custom bind/search functions: %v", string(out)) + } + if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") { + t.Errorf("search custom routing failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + + quit <- true +} + +///////////////////////// +func TestSearchPanic(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.SearchFunc("", searchPanic{}) + s.BindFunc("", bindAnonOK{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 1 Operations error") { + t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +type compileSearchFilterTest struct { + name string + filterStr string + numResponses string +} + +var searchFilterTestFilters = []compileSearchFilterTest{ + compileSearchFilterTest{name: "equalityOk", filterStr: "(uid=ned)", numResponses: "2"}, + compileSearchFilterTest{name: "equalityNo", filterStr: "(uid=foo)", numResponses: "1"}, + compileSearchFilterTest{name: "equalityOk", filterStr: "(objectclass=posixaccount)", numResponses: "4"}, + compileSearchFilterTest{name: "presentEmptyOk", filterStr: "", numResponses: "4"}, + compileSearchFilterTest{name: "presentOk", filterStr: "(objectclass=*)", numResponses: "4"}, + compileSearchFilterTest{name: "presentOk", filterStr: "(description=*)", numResponses: "3"}, + compileSearchFilterTest{name: "presentNo", filterStr: "(foo=*)", numResponses: "1"}, + compileSearchFilterTest{name: "andOk", filterStr: "(&(uid=ned)(objectclass=posixaccount))", numResponses: "2"}, + compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(objectclass=posixgroup))", numResponses: "1"}, + compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(uid=trent))", numResponses: "1"}, + compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(uid=trent))", numResponses: "3"}, + compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(objectclass=posixaccount))", numResponses: "4"}, + compileSearchFilterTest{name: "orNo", filterStr: "(|(uid=foo)(objectclass=foo))", numResponses: "1"}, + compileSearchFilterTest{name: "andOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(objectclass=posixaccount))", numResponses: "3"}, + compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=ned))", numResponses: "3"}, + compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=foo))", numResponses: "4"}, + compileSearchFilterTest{name: "notAndOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(!(objectclass=posixgroup)))", numResponses: "3"}, + /* + compileSearchFilterTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings}, + compileSearchFilterTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings}, + compileSearchFilterTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings}, + compileSearchFilterTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual}, + compileSearchFilterTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual}, + compileSearchFilterTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch}, + */ +} + +///////////////////////// +func TestSearchFiltering(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + for _, i := range searchFilterTestFilters { + t.Log(i.name) + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", i.filterStr) + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "numResponses: "+i.numResponses) { + t.Errorf("ldapsearch failed - expected numResponses==%d: %v", i.numResponses, string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + } + quit <- true +} + +///////////////////////// +func TestSearchAttributes(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + filterString := "" + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "cn") + out, _ := cmd.CombinedOutput() + + if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { + t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out)) + } + if !strings.Contains(string(out), "cn: ned") { + t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out)) + } + if strings.Contains(string(out), "uidNumber") { + t.Errorf("ldapsearch failed - uidNumber attr should not be displayed: %v", string(out)) + } + if strings.Contains(string(out), "accountstatus") { + t.Errorf("ldapsearch failed - accountstatus attr should not be displayed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +///////////////////////// +func TestSearchScope(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.EnforceLDAP = true + s.QuitChannel(quit) + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "sub", "cn=trent") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'sub' scope failed - didn't find expected DN: %v", string(out)) + } + + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'one' scope failed - didn't find expected DN: %v", string(out)) + } + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent") + out, _ = cmd.CombinedOutput() + if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'one' scope failed - found unexpected DN: %v", string(out)) + } + + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "cn=trent,o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'base' scope failed - didn't find expected DN: %v", string(out)) + } + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent") + out, _ = cmd.CombinedOutput() + if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { + t.Errorf("ldapsearch 'base' scope failed - found unexpected DN: %v", string(out)) + } + + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} + +func TestSearchControls(t *testing.T) { + quit := make(chan bool) + done := make(chan bool) + go func() { + s := NewServer() + s.QuitChannel(quit) + s.SearchFunc("", searchControls{}) + s.BindFunc("", bindSimple{}) + if err := s.ListenAndServe(listenString); err != nil { + t.Errorf("s.ListenAndServe failed: %s", err.Error()) + } + }() + + serverBaseDN := "o=testers,c=test" + + go func() { + cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-e", "1.2.3.4.5") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") { + t.Errorf("ldapsearch with control failed: %v", string(out)) + } + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch with control failed: %v", string(out)) + } + if !strings.Contains(string(out), "numResponses: 2") { + t.Errorf("ldapsearch with control failed: %v", string(out)) + } + + cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", + "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test") + out, _ = cmd.CombinedOutput() + if strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") { + t.Errorf("ldapsearch without control failed: %v", string(out)) + } + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("ldapsearch without control failed: %v", string(out)) + } + if !strings.Contains(string(out), "numResponses: 1") { + t.Errorf("ldapsearch without control failed: %v", string(out)) + } + + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } + quit <- true +} diff --git a/server_test.go b/server_test.go index 9386a4a..7e813ec 100644 --- a/server_test.go +++ b/server_test.go @@ -61,7 +61,7 @@ func TestBindAnonFail(t *testing.T) { go func() { cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "ldap_bind: Inappropriate authentication (48)") { + if !strings.Contains(string(out), "ldap_bind: Invalid credentials (49)") { t.Errorf("ldapsearch failed: %v", string(out)) } done <- true @@ -186,7 +186,7 @@ func TestBindSSL(t *testing.T) { s := NewServer() s.QuitChannel(quit) s.BindFunc("", bindAnonOK{}) - if err := s.ListenAndServeTLS(listenString, "examples/cert_DONOTUSE.pem", "examples/key_DONOTUSE.pem"); err != nil { + if err := s.ListenAndServeTLS(listenString, "tests/cert_DONOTUSE.pem", "tests/key_DONOTUSE.pem"); err != nil { t.Errorf("s.ListenAndServeTLS failed: %s", err.Error()) } }() @@ -240,348 +240,6 @@ func TestBindPanic(t *testing.T) { } ///////////////////////// -func TestSearchSimpleOK(t *testing.T) { - quit := make(chan bool) - done := make(chan bool) - go func() { - s := NewServer() - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) - if err := s.ListenAndServe(listenString); err != nil { - t.Errorf("s.ListenAndServe failed: %s", err.Error()) - } - }() - - serverBaseDN := "o=testers,c=test" - - go func() { - cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test") - out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { - t.Errorf("ldapsearch failed: %v", string(out)) - } - if !strings.Contains(string(out), "uidNumber: 5000") { - t.Errorf("ldapsearch failed: %v", string(out)) - } - if !strings.Contains(string(out), "result: 0 Success") { - t.Errorf("ldapsearch failed: %v", string(out)) - } - if !strings.Contains(string(out), "numResponses: 4") { - t.Errorf("ldapsearch failed: %v", string(out)) - } - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } - quit <- true -} - -func TestSearchSizelimit(t *testing.T) { - quit := make(chan bool) - done := make(chan bool) - go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) - if err := s.ListenAndServe(listenString); err != nil { - t.Errorf("s.ListenAndServe failed: %s", err.Error()) - } - }() - - go func() { - cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "9") // effectively no limit for this test - out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "result: 0 Success") { - t.Errorf("ldapsearch failed: %v", string(out)) - } - if !strings.Contains(string(out), "numEntries: 3") { - t.Errorf("ldapsearch sizelimit failed - not enough entries: %v", string(out)) - } - - cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "2") - out, _ = cmd.CombinedOutput() - if !strings.Contains(string(out), "result: 0 Success") { - t.Errorf("ldapsearch failed: %v", string(out)) - } - if !strings.Contains(string(out), "numEntries: 2") { - t.Errorf("ldapsearch sizelimit failed - too many entries: %v", string(out)) - } - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } - quit <- true -} - -///////////////////////// -func TestBindSearchMulti(t *testing.T) { - quit := make(chan bool) - done := make(chan bool) - go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", bindSimple{}) - s.BindFunc("c=testz", bindSimple2{}) - s.SearchFunc("", searchSimple{}) - s.SearchFunc("c=testz", searchSimple2{}) - if err := s.ListenAndServe(listenString); err != nil { - t.Errorf("s.ListenAndServe failed: %s", err.Error()) - } - }() - - go func() { - cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test", - "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned") - out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "result: 0 Success") { - t.Errorf("error routing default bind/search functions: %v", string(out)) - } - if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { - t.Errorf("search default routing failed: %v", string(out)) - } - cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=testz", - "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger") - out, _ = cmd.CombinedOutput() - if !strings.Contains(string(out), "result: 0 Success") { - t.Errorf("error routing custom bind/search functions: %v", string(out)) - } - if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") { - t.Errorf("search custom routing failed: %v", string(out)) - } - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } - - quit <- true -} - -///////////////////////// -func TestSearchPanic(t *testing.T) { - quit := make(chan bool) - done := make(chan bool) - go func() { - s := NewServer() - s.QuitChannel(quit) - s.SearchFunc("", searchPanic{}) - s.BindFunc("", bindAnonOK{}) - if err := s.ListenAndServe(listenString); err != nil { - t.Errorf("s.ListenAndServe failed: %s", err.Error()) - } - }() - - go func() { - cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test") - out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "result: 1 Operations error") { - t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out)) - } - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } - quit <- true -} - -///////////////////////// -type compileSearchFilterTest struct { - name string - filterStr string - numResponses string -} - -var searchFilterTestFilters = []compileSearchFilterTest{ - compileSearchFilterTest{name: "equalityOk", filterStr: "(uid=ned)", numResponses: "2"}, - compileSearchFilterTest{name: "equalityNo", filterStr: "(uid=foo)", numResponses: "1"}, - compileSearchFilterTest{name: "equalityOk", filterStr: "(objectclass=posixaccount)", numResponses: "4"}, - compileSearchFilterTest{name: "presentEmptyOk", filterStr: "", numResponses: "4"}, - compileSearchFilterTest{name: "presentOk", filterStr: "(objectclass=*)", numResponses: "4"}, - compileSearchFilterTest{name: "presentOk", filterStr: "(description=*)", numResponses: "3"}, - compileSearchFilterTest{name: "presentNo", filterStr: "(foo=*)", numResponses: "1"}, - compileSearchFilterTest{name: "andOk", filterStr: "(&(uid=ned)(objectclass=posixaccount))", numResponses: "2"}, - compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(objectclass=posixgroup))", numResponses: "1"}, - compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(uid=trent))", numResponses: "1"}, - compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(uid=trent))", numResponses: "3"}, - compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(objectclass=posixaccount))", numResponses: "4"}, - compileSearchFilterTest{name: "orNo", filterStr: "(|(uid=foo)(objectclass=foo))", numResponses: "1"}, - compileSearchFilterTest{name: "andOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(objectclass=posixaccount))", numResponses: "3"}, - compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=ned))", numResponses: "3"}, - compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=foo))", numResponses: "4"}, - compileSearchFilterTest{name: "notAndOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(!(objectclass=posixgroup)))", numResponses: "3"}, - /* - compileSearchFilterTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings}, - compileSearchFilterTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings}, - compileSearchFilterTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings}, - compileSearchFilterTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual}, - compileSearchFilterTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual}, - compileSearchFilterTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch}, - */ -} - -///////////////////////// -func TestSearchFiltering(t *testing.T) { - quit := make(chan bool) - done := make(chan bool) - go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) - if err := s.ListenAndServe(listenString); err != nil { - t.Errorf("s.ListenAndServe failed: %s", err.Error()) - } - }() - - for _, i := range searchFilterTestFilters { - t.Log(i.name) - - go func() { - cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", i.filterStr) - out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "numResponses: "+i.numResponses) { - t.Errorf("ldapsearch failed - expected numResponses==%d: %v", i.numResponses, string(out)) - } - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } - } - quit <- true -} - -///////////////////////// -func TestSearchAttributes(t *testing.T) { - quit := make(chan bool) - done := make(chan bool) - go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) - if err := s.ListenAndServe(listenString); err != nil { - t.Errorf("s.ListenAndServe failed: %s", err.Error()) - } - }() - - go func() { - filterString := "" - cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "cn") - out, _ := cmd.CombinedOutput() - - if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { - t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out)) - } - if !strings.Contains(string(out), "cn: ned") { - t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out)) - } - if strings.Contains(string(out), "uidNumber") { - t.Errorf("ldapsearch failed - uidNumber attr should not be displayed: %v", string(out)) - } - if strings.Contains(string(out), "accountstatus") { - t.Errorf("ldapsearch failed - accountstatus attr should not be displayed: %v", string(out)) - } - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } - quit <- true -} - -///////////////////////// -func TestSearchScope(t *testing.T) { - quit := make(chan bool) - done := make(chan bool) - go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) - if err := s.ListenAndServe(listenString); err != nil { - t.Errorf("s.ListenAndServe failed: %s", err.Error()) - } - }() - - go func() { - cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "sub", "cn=trent") - out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { - t.Errorf("ldapsearch 'sub' scope failed - didn't find expected DN: %v", string(out)) - } - - cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent") - out, _ = cmd.CombinedOutput() - if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { - t.Errorf("ldapsearch 'one' scope failed - didn't find expected DN: %v", string(out)) - } - cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent") - out, _ = cmd.CombinedOutput() - if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { - t.Errorf("ldapsearch 'one' scope failed - found unexpected DN: %v", string(out)) - } - - cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", "cn=trent,o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent") - out, _ = cmd.CombinedOutput() - if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { - t.Errorf("ldapsearch 'base' scope failed - didn't find expected DN: %v", string(out)) - } - cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", - "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent") - out, _ = cmd.CombinedOutput() - if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") { - t.Errorf("ldapsearch 'base' scope failed - found unexpected DN: %v", string(out)) - } - - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } - quit <- true -} - -///////////////////////// type testStatsWriter struct { buffer *bytes.Buffer } @@ -625,7 +283,8 @@ func TestSearchStats(t *testing.T) { } stats := s.GetStats() - if stats.Conns != 1 || stats.Binds != 1 { + log.Println(stats) + if stats.Conns != 2 || stats.Binds != 1 { t.Errorf("Stats data missing or incorrect: %v", w.buffer.String()) } quit <- true @@ -635,7 +294,7 @@ func TestSearchStats(t *testing.T) { type bindAnonOK struct { } -func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { +func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "" && bindSimplePw == "" { return LDAPResultSuccess, nil } @@ -645,7 +304,7 @@ func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, er type bindSimple struct { } -func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { +func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" { return LDAPResultSuccess, nil } @@ -655,7 +314,7 @@ func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, er type bindSimple2 struct { } -func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { +func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" { return LDAPResultSuccess, nil } @@ -665,7 +324,7 @@ func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, e type bindPanic struct { } -func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { +func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { panic("test panic at the disco") return LDAPResultInvalidCredentials, nil } @@ -730,3 +389,22 @@ func (s searchPanic) Search(boundDN string, searchReq SearchRequest, conn net.Co panic("this is a test panic") return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil } + +type searchControls struct { +} + +func (s searchControls) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + entries := []*Entry{} + if len(searchReq.Controls) == 1 && searchReq.Controls[0].GetControlType() == "1.2.3.4.5" { + newEntry := &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{ + &EntryAttribute{"cn", []string{"hamburger"}}, + &EntryAttribute{"o", []string{"testers"}}, + &EntryAttribute{"uidNumber", []string{"5000"}}, + &EntryAttribute{"accountstatus", []string{"active"}}, + &EntryAttribute{"uid", []string{"hamburger"}}, + &EntryAttribute{"objectclass", []string{"posixaccount"}}, + }} + entries = append(entries, newEntry) + } + return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil +} diff --git a/tests/add.ldif b/tests/add.ldif new file mode 100644 index 0000000..f8cdf71 --- /dev/null +++ b/tests/add.ldif @@ -0,0 +1,6 @@ +dn: cn=Barbara Jensen,dc=example,dc=com +objectClass: person +cn: Barbara Jensen +sn: Jensen +mail: bjensen@example.com +uid: bjensen diff --git a/tests/add2.ldif b/tests/add2.ldif new file mode 100644 index 0000000..ccb71ad --- /dev/null +++ b/tests/add2.ldif @@ -0,0 +1,6 @@ +dn: cn=Big Bob,dc=example,dc=com +objectClass: person +cn: Big Bob +sn: Bob +mail: bob@example.com +uid: bob diff --git a/examples/cert_DONOTUSE.pem b/tests/cert_DONOTUSE.pem index ee14324..ee14324 100644 --- a/examples/cert_DONOTUSE.pem +++ b/tests/cert_DONOTUSE.pem diff --git a/examples/key_DONOTUSE.pem b/tests/key_DONOTUSE.pem index 7feaa11..7feaa11 100644 --- a/examples/key_DONOTUSE.pem +++ b/tests/key_DONOTUSE.pem diff --git a/tests/modify.ldif b/tests/modify.ldif new file mode 100644 index 0000000..ac969cc --- /dev/null +++ b/tests/modify.ldif @@ -0,0 +1,16 @@ +dn: cn=testy,dc=example,dc=com +changetype: modify +replace: mail +mail: modme@example.com +- +delete: manager +- +add: title +title: Grand Poobah +- +delete: description +- +delete: details +- +replace: fullname +fullname: Test Testerson diff --git a/tests/modify2.ldif b/tests/modify2.ldif new file mode 100644 index 0000000..794d7f4 --- /dev/null +++ b/tests/modify2.ldif @@ -0,0 +1,10 @@ +dn: cn=testo,dc=example,dc=com +changetype: modify +replace: mail +mail: modid@example.com +- +delete: manager +- +add: title +title: Other Poobah +- |