From e45f83457931e08f9f6d5aec48f51fd390a01eb8 Mon Sep 17 00:00:00 2001 From: tmfkams Date: Sun, 19 Jan 2014 23:04:16 +0100 Subject: refactoring, fixes for go1.2, modify added --- .gitignore | 2 + README | 21 +- bind.go | 88 ++++---- conn.go | 428 ++++++++++++++++++------------------- control.go | 228 ++++++++++---------- debugging.go | 31 +++ examples/enterprise.ldif | 63 ++++++ examples/modify.go | 89 ++++++++ examples/search.go | 45 ++++ examples/slapd.conf | 67 ++++++ filter.go | 432 +++++++++++++++++++------------------- filter_test.go | 117 +++++------ ldap.go | 514 ++++++++++++++++++++++----------------------- ldap_test.go | 216 +++++++++---------- modify.go | 154 ++++++++++++++ search.go | 535 +++++++++++++++++++++++++++-------------------- 16 files changed, 1777 insertions(+), 1253 deletions(-) create mode 100644 .gitignore create mode 100644 debugging.go create mode 100644 examples/enterprise.ldif create mode 100644 examples/modify.go create mode 100644 examples/search.go create mode 100644 examples/slapd.conf create mode 100644 modify.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b33b5d8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +examples/modify +examples/search diff --git a/README b/README index 8a4256d..f91521d 100644 --- a/README +++ b/README @@ -1,4 +1,4 @@ -Basic LDAP v3 functionality for the GO programming language. +Basic LDAP v3 functionality for the GO programming language. Required Librarys: github.com/mmitton/asn1-ber @@ -9,19 +9,26 @@ Working: Searching for entries Compiling string filters to LDAP filters Paging Search Results - Mulitple internal goroutines to handle network traffic - Makes library goroutine safe - Can perform multiple search requests at the same time and return - the results to the proper goroutine. All requests are blocking - requests, so the goroutine does not need special handling + Modify Requests / Responses + +Examples: + search + modify Tests Implemented: Filter Compile / Decompile TODO: - Modify Requests / Responses Add Requests / Responses Delete Requests / Responses Modify DN Requests / Responses Compare Requests / Responses Implement Tests / Benchmarks + +This feature is disabled at the moment, because in some cases the "Search Request Done" packet +will be handled before the last "Search Request Entry". Looking for another solution: + Mulitple internal goroutines to handle network traffic + Makes library goroutine safe + Can perform multiple search requests at the same time and return + the results to the proper goroutine. All requests are blocking + requests, so the goroutine does not need special handling diff --git a/bind.go b/bind.go index b941f0c..30bc83f 100644 --- a/bind.go +++ b/bind.go @@ -6,50 +6,50 @@ package ldap import ( - "github.com/mmitton/asn1-ber" - "os" + "errors" + "github.com/tmfkams/asn1-ber" ) -func (l *Conn) Bind( username, password string ) *Error { - messageID := l.nextMessageID() - - packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" ) - packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) ) - bindRequest := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request" ) - bindRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, 3, "Version" ) ) - bindRequest.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, username, "User Name" ) ) - bindRequest.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, 0, password, "Password" ) ) - packet.AppendChild( bindRequest ) - - if l.Debug { - ber.PrintPacket( packet ) - } - - channel, err := l.sendMessage( packet ) - if err != nil { - return err - } - if channel == nil { - return NewError( ErrorNetwork, os.NewError( "Could not send message" ) ) - } - defer l.finishMessage( messageID ) - packet = <-channel - - if packet == nil { - return NewError( ErrorNetwork, os.NewError( "Could not retrieve response" ) ) - } - - if l.Debug { - if err := addLDAPDescriptions( packet ); err != nil { - return NewError( ErrorDebugging, err ) - } - ber.PrintPacket( packet ) - } - - result_code, result_description := getLDAPResultCode( packet ) - if result_code != 0 { - return NewError( result_code, os.NewError( result_description ) ) - } - - return nil +func (l *Conn) Bind(username, password string) *Error { + messageID := l.nextMessageID() + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID")) + bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, 3, "Version")) + bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, username, "User Name")) + bindRequest.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, 0, password, "Password")) + packet.AppendChild(bindRequest) + + if l.Debug { + ber.PrintPacket(packet) + } + + channel, err := l.sendMessage(packet) + if err != nil { + return err + } + if channel == nil { + return NewError(ErrorNetwork, errors.New("Could not send message")) + } + defer l.finishMessage(messageID) + packet = <-channel + + if packet == nil { + return NewError(ErrorNetwork, errors.New("Could not retrieve response")) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return NewError(ErrorDebugging, err.Err) + } + ber.PrintPacket(packet) + } + + result_code, result_description := getLDAPResultCode(packet) + if result_code != 0 { + return NewError(result_code, errors.New(result_description)) + } + + return nil } diff --git a/conn.go b/conn.go index 41e69fb..cfa8772 100644 --- a/conn.go +++ b/conn.go @@ -6,296 +6,284 @@ package ldap import ( - "github.com/mmitton/asn1-ber" - "crypto/tls" - "fmt" - "net" - "os" - "sync" + "crypto/tls" + "errors" + "github.com/tmfkams/asn1-ber" + "log" + "net" + "sync" ) // LDAP Connection type Conn struct { - conn net.Conn - isSSL bool - Debug bool - - chanResults map[ uint64 ] chan *ber.Packet - chanProcessMessage chan *messagePacket - chanMessageID chan uint64 - - closeLock sync.Mutex + conn net.Conn + isSSL bool + Debug debugging + chanResults map[uint64]chan *ber.Packet + chanProcessMessage chan *messagePacket + chanMessageID chan uint64 + closeLock sync.Mutex } // Dial connects to the given address on the given network using net.Dial // and then returns a new Conn for the connection. func Dial(network, addr string) (*Conn, *Error) { - c, err := net.Dial(network, "", addr) + c, err := net.Dial(network, addr) if err != nil { - return nil, NewError( ErrorNetwork, err ) + return nil, NewError(ErrorNetwork, err) } - conn := NewConn(c) - conn.start() + conn := NewConn(c) + conn.start() return conn, nil } // Dial connects to the given address on the given network using net.Dial // and then sets up SSL connection and returns a new Conn for the connection. func DialSSL(network, addr string) (*Conn, *Error) { - c, err := tls.Dial(network, "", addr, nil) + c, err := tls.Dial(network, addr, nil) if err != nil { - return nil, NewError( ErrorNetwork, err ) + return nil, NewError(ErrorNetwork, err) } - conn := NewConn(c) - conn.isSSL = true + conn := NewConn(c) + conn.isSSL = true - conn.start() + conn.start() return conn, nil } // Dial connects to the given address on the given network using net.Dial // and then starts a TLS session and returns a new Conn for the connection. func DialTLS(network, addr string) (*Conn, *Error) { - c, err := net.Dial(network, "", addr) + c, err := net.Dial(network, addr) if err != nil { - return nil, NewError( ErrorNetwork, err ) + return nil, NewError(ErrorNetwork, err) + } + conn := NewConn(c) + + if err := conn.startTLS(); err != nil { + conn.Close() + return nil, NewError(ErrorNetwork, err.Err) } - conn := NewConn(c) - - err = conn.startTLS() - if err != nil { - conn.Close() - return nil, NewError( ErrorNetwork, err ) - } - conn.start() + conn.start() return conn, nil } // NewConn returns a new Conn using conn for network I/O. func NewConn(conn net.Conn) *Conn { return &Conn{ - conn: conn, - isSSL: false, - Debug: false, - chanResults: map[uint64] chan *ber.Packet{}, - chanProcessMessage: make( chan *messagePacket ), - chanMessageID: make( chan uint64 ), + conn: conn, + isSSL: false, + Debug: false, + chanResults: map[uint64]chan *ber.Packet{}, + chanProcessMessage: make(chan *messagePacket), + chanMessageID: make(chan uint64), } } func (l *Conn) start() { - go l.reader() - go l.processMessages() + go l.reader() + go l.processMessages() } // Close closes the connection. func (l *Conn) Close() *Error { - l.closeLock.Lock() - defer l.closeLock.Unlock() - - l.sendProcessMessage( &messagePacket{ Op: MessageQuit } ) - - if l.conn != nil { - err := l.conn.Close() - if err != nil { - return NewError( ErrorNetwork, err ) - } - l.conn = nil - } + l.closeLock.Lock() + defer l.closeLock.Unlock() + + l.sendProcessMessage(&messagePacket{Op: MessageQuit}) + + if l.conn != nil { + err := l.conn.Close() + if err != nil { + return NewError(ErrorNetwork, err) + } + l.conn = nil + } return nil } // Returns the next available messageID func (l *Conn) nextMessageID() (messageID uint64) { - defer func() { if r := recover(); r != nil { messageID = 0 } }() - messageID = <-l.chanMessageID - return + defer func() { + if r := recover(); r != nil { + messageID = 0 + } + }() + messageID = <-l.chanMessageID + return } // StartTLS sends the command to start a TLS session and then creates a new TLS Client func (l *Conn) startTLS() *Error { - messageID := l.nextMessageID() - - if l.isSSL { - return NewError( ErrorNetwork, os.NewError( "Already encrypted" ) ) - } - - packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" ) - packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) ) - startTLS := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS" ) - startTLS.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command" ) ) - packet.AppendChild( startTLS ) - if l.Debug { - ber.PrintPacket( packet ) - } - - _, err := l.conn.Write( packet.Bytes() ) - if err != nil { - return NewError( ErrorNetwork, err ) - } - - packet, err = ber.ReadPacket( l.conn ) - if err != nil { - return NewError( ErrorNetwork, err ) - } - - if l.Debug { - if err := addLDAPDescriptions( packet ); err != nil { - return NewError( ErrorDebugging, err ) - } - ber.PrintPacket( packet ) - } - - if packet.Children[ 1 ].Children[ 0 ].Value.(uint64) == 0 { - conn := tls.Client( l.conn, nil ) - l.isSSL = true - l.conn = conn - } - - return nil + messageID := l.nextMessageID() + + if l.isSSL { + return NewError(ErrorNetwork, errors.New("Already encrypted")) + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID")) + startTLS := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") + startTLS.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) + packet.AppendChild(startTLS) + l.Debug.PrintPacket(packet) + + _, err := l.conn.Write(packet.Bytes()) + if err != nil { + return NewError(ErrorNetwork, err) + } + + packet, err = ber.ReadPacket(l.conn) + if err != nil { + return NewError(ErrorNetwork, err) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return NewError(ErrorDebugging, err.Err) + } + ber.PrintPacket(packet) + } + + if packet.Children[1].Children[0].Value.(uint64) == 0 { + conn := tls.Client(l.conn, nil) + l.isSSL = true + l.conn = conn + } + + return nil } const ( - MessageQuit = 0 - MessageRequest = 1 - MessageResponse = 2 - MessageFinish = 3 + MessageQuit = 0 + MessageRequest = 1 + MessageResponse = 2 + MessageFinish = 3 ) type messagePacket struct { - Op int - MessageID uint64 - Packet *ber.Packet - Channel chan *ber.Packet + Op int + MessageID uint64 + Packet *ber.Packet + Channel chan *ber.Packet } -func (l *Conn) sendMessage( p *ber.Packet ) (out chan *ber.Packet, err *Error) { - message_id := p.Children[ 0 ].Value.(uint64) - out = make(chan *ber.Packet) - - if l.chanProcessMessage == nil { - err = NewError( ErrorNetwork, os.NewError( "Connection closed" ) ) - return - } - message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out } - l.sendProcessMessage( message_packet ) - return +func (l *Conn) sendMessage(p *ber.Packet) (out chan *ber.Packet, err *Error) { + message_id := p.Children[0].Value.(uint64) + out = make(chan *ber.Packet) + + if l.chanProcessMessage == nil { + err = NewError(ErrorNetwork, errors.New("Connection closed")) + return + } + message_packet := &messagePacket{Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out} + l.sendProcessMessage(message_packet) + return } func (l *Conn) processMessages() { - defer l.closeAllChannels() - - var message_id uint64 = 1 - var message_packet *messagePacket - for { - select { - case l.chanMessageID <- message_id: - if l.conn == nil { - return - } - message_id++ - case message_packet = <-l.chanProcessMessage: - if l.conn == nil { - return - } - switch message_packet.Op { - case MessageQuit: - // Close all channels and quit - if l.Debug { - fmt.Printf( "Shutting down\n" ) - } - return - case MessageRequest: - // Add to message list and write to network - if l.Debug { - fmt.Printf( "Sending message %d\n", message_packet.MessageID ) - } - l.chanResults[ message_packet.MessageID ] = message_packet.Channel - buf := message_packet.Packet.Bytes() - for len( buf ) > 0 { - n, err := l.conn.Write( buf ) - if err != nil { - if l.Debug { - fmt.Printf( "Error Sending Message: %s\n", err.String() ) - } - return - } - if n == len( buf ) { - break - } - buf = buf[n:] - } - case MessageResponse: - // Pass back to waiting goroutine - if l.Debug { - fmt.Printf( "Receiving message %d\n", message_packet.MessageID ) - } - chanResult := l.chanResults[ message_packet.MessageID ] - if chanResult == nil { - fmt.Printf( "Unexpected Message Result: %d\n", message_id ) - ber.PrintPacket( message_packet.Packet ) - } else { - go func() { chanResult <- message_packet.Packet }() - // chanResult <- message_packet.Packet - } - case MessageFinish: - // Remove from message list - if l.Debug { - fmt.Printf( "Finished message %d\n", message_packet.MessageID ) - } - l.chanResults[ message_packet.MessageID ] = nil, false - } - } - } + defer l.closeAllChannels() + + var message_id uint64 = 1 + var message_packet *messagePacket + for { + select { + case l.chanMessageID <- message_id: + if l.conn == nil { + return + } + message_id++ + case message_packet = <-l.chanProcessMessage: + if l.conn == nil { + return + } + switch message_packet.Op { + case MessageQuit: + // Close all channels and quit + l.Debug.Printf("Shutting down\n") + return + case MessageRequest: + // Add to message list and write to network + l.Debug.Printf("Sending message %d\n", message_packet.MessageID) + l.chanResults[message_packet.MessageID] = message_packet.Channel + buf := message_packet.Packet.Bytes() + for len(buf) > 0 { + n, err := l.conn.Write(buf) + if err != nil { + l.Debug.Printf("Error Sending Message: %s\n", err.Error()) + return + } + if n == len(buf) { + break + } + buf = buf[n:] + } + case MessageResponse: + // Pass back to waiting goroutine + l.Debug.Printf("Receiving message %d\n", message_packet.MessageID) + if chanResult, ok := l.chanResults[message_packet.MessageID]; ok { + // If the "Search Result Done" is read before the + // "Search Result Entry" no Entry can be returned + // go func() { chanResult <- message_packet.Packet }() + chanResult <- message_packet.Packet + } else { + log.Printf("Unexpected Message Result: %d\n", message_id) + ber.PrintPacket(message_packet.Packet) + } + case MessageFinish: + // Remove from message list + l.Debug.Printf("Finished message %d\n", message_packet.MessageID) + l.chanResults[message_packet.MessageID] = nil + } + } + } } func (l *Conn) closeAllChannels() { -fmt.Printf( "closeAllChannels\n" ) - for MessageID, Channel := range l.chanResults { - if l.Debug { - fmt.Printf( "Closing channel for MessageID %d\n", MessageID ); - } - close( Channel ) - l.chanResults[ MessageID ] = nil, false - } - close( l.chanMessageID ) - l.chanMessageID = nil - - close( l.chanProcessMessage ) - l.chanProcessMessage = nil + log.Printf("closeAllChannels\n") + for messageID, channel := range l.chanResults { + if channel != nil { + l.Debug.Printf("Closing channel for MessageID %d\n", messageID) + close(channel) + l.chanResults[messageID] = nil + } + } + close(l.chanMessageID) + l.chanMessageID = nil + + close(l.chanProcessMessage) + l.chanProcessMessage = nil } -func (l *Conn) finishMessage( MessageID uint64 ) { - message_packet := &messagePacket{ Op: MessageFinish, MessageID: MessageID } - l.sendProcessMessage( message_packet ) +func (l *Conn) finishMessage(MessageID uint64) { + message_packet := &messagePacket{Op: MessageFinish, MessageID: MessageID} + l.sendProcessMessage(message_packet) } func (l *Conn) reader() { - defer l.Close() - for { - p, err := ber.ReadPacket( l.conn ) - if err != nil { - if l.Debug { - fmt.Printf( "ldap.reader: %s\n", err.String() ) - } - return - } - - addLDAPDescriptions( p ) - - message_id := p.Children[ 0 ].Value.(uint64) - message_packet := &messagePacket{ Op: MessageResponse, MessageID: message_id, Packet: p } - if l.chanProcessMessage != nil { - l.chanProcessMessage <- message_packet - } else { - fmt.Printf( "ldap.reader: Cannot return message\n" ) - return - } - } + defer l.Close() + for { + p, err := ber.ReadPacket(l.conn) + if err != nil { + l.Debug.Printf("ldap.reader: %s\n", err.Error()) + return + } + + addLDAPDescriptions(p) + + message_id := p.Children[0].Value.(uint64) + message_packet := &messagePacket{Op: MessageResponse, MessageID: message_id, Packet: p} + if l.chanProcessMessage != nil { + l.chanProcessMessage <- message_packet + } else { + log.Printf("ldap.reader: Cannot return message\n") + return + } + } } -func (l *Conn) sendProcessMessage( message *messagePacket ) { - if l.chanProcessMessage != nil { - go func() { l.chanProcessMessage <- message }() - } +func (l *Conn) sendProcessMessage(message *messagePacket) { + if l.chanProcessMessage != nil { + go func() { l.chanProcessMessage <- message }() + } } diff --git a/control.go b/control.go index af145b2..cada47a 100644 --- a/control.go +++ b/control.go @@ -6,152 +6,152 @@ package ldap import ( - "github.com/mmitton/asn1-ber" - "fmt" + "fmt" + "github.com/tmfkams/asn1-ber" ) const ( - ControlTypePaging = "1.2.840.113556.1.4.319" + ControlTypePaging = "1.2.840.113556.1.4.319" ) -var ControlTypeMap = map[ string ] string { - ControlTypePaging : "Paging", +var ControlTypeMap = map[string]string{ + ControlTypePaging: "Paging", } type Control interface { - GetControlType() string - Encode() *ber.Packet - String() string + GetControlType() string + Encode() *ber.Packet + String() string } type ControlString struct { - ControlType string - Criticality bool - ControlValue string + ControlType string + Criticality bool + ControlValue string } func (c *ControlString) GetControlType() string { - return c.ControlType + return c.ControlType } func (c *ControlString) Encode() (p *ber.Packet) { - p = ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control" ) - p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlType, "Control Type (" + ControlTypeMap[ c.ControlType ] + ")" ) ) - if c.Criticality { - p.AppendChild( ber.NewBoolean( ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, c.Criticality, "Criticality" ) ) - } - p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlValue, "Control Value" ) ) - return + p = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")")) + if c.Criticality { + p.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, c.Criticality, "Criticality")) + } + p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlValue, "Control Value")) + return } func (c *ControlString) String() string { - return fmt.Sprintf( "Control Type: %s (%q) Criticality: %s Control Value: %s", ControlTypeMap[ c.ControlType ], c.ControlType, c.Criticality, c.ControlValue ) + return fmt.Sprintf("Control Type: %s (%q) Criticality: %s Control Value: %s", ControlTypeMap[c.ControlType], c.ControlType, c.Criticality, c.ControlValue) } type ControlPaging struct { - PagingSize uint32 - Cookie []byte + PagingSize uint32 + Cookie []byte } func (c *ControlPaging) GetControlType() string { - return ControlTypePaging + return ControlTypePaging } func (c *ControlPaging) Encode() (p *ber.Packet) { - p = ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control" ) - p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, ControlTypePaging, "Control Type (" + ControlTypeMap[ ControlTypePaging ] + ")" ) ) + p = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")")) - p2 := ber.Encode( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, nil, "Control Value (Paging)" ) - seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value" ) - seq.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(c.PagingSize), "Paging Size" ) ) - cookie := ber.Encode( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, nil, "Cookie" ) - cookie.Value = c.Cookie - cookie.Data.Write( c.Cookie ) - seq.AppendChild( cookie ) - p2.AppendChild( seq ) + p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, nil, "Control Value (Paging)") + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value") + seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(c.PagingSize), "Paging Size")) + cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, nil, "Cookie") + cookie.Value = c.Cookie + cookie.Data.Write(c.Cookie) + seq.AppendChild(cookie) + p2.AppendChild(seq) - p.AppendChild( p2 ) - return + p.AppendChild(p2) + return } func (c *ControlPaging) String() string { - return fmt.Sprintf( - "Control Type: %s (%q) Criticality: %s PagingSize: %d Cookie: %q", - ControlTypeMap[ ControlTypePaging ], - ControlTypePaging, - false, - c.PagingSize, - c.Cookie ) -} - -func (c *ControlPaging) SetCookie( Cookie []byte ) { - c.Cookie = Cookie -} - -func FindControl( Controls []Control, ControlType string ) Control { - for _, c := range Controls { - if c.GetControlType() == ControlType { - return c - } - } - return nil -} - -func DecodeControl( p *ber.Packet ) Control { - ControlType := p.Children[ 0 ].Value.(string) - Criticality := false - - p.Children[ 0 ].Description = "Control Type (" + ControlTypeMap[ ControlType ] + ")" - value := p.Children[ 1 ] - if len( p.Children ) == 3 { - value = p.Children[ 2 ] - p.Children[ 1 ].Description = "Criticality" - Criticality = p.Children[ 1 ].Value.(bool) - } - - value.Description = "Control Value" - switch ControlType { - case ControlTypePaging: - value.Description += " (Paging)" - c := new( ControlPaging ) - if value.Value != nil { - value_children := ber.DecodePacket( value.Data.Bytes() ) - value.Data.Truncate( 0 ) - value.Value = nil - value.AppendChild( value_children ) - } - value = value.Children[ 0 ] - value.Description = "Search Control Value" - value.Children[ 0 ].Description = "Paging Size" - value.Children[ 1 ].Description = "Cookie" - c.PagingSize = uint32( value.Children[ 0 ].Value.(uint64) ) - c.Cookie = value.Children[ 1 ].Data.Bytes() - value.Children[ 1 ].Value = c.Cookie - return c - } - c := new( ControlString ) - c.ControlType = ControlType - c.Criticality = Criticality - c.ControlValue = value.Value.(string) - return c -} - -func NewControlString( ControlType string, Criticality bool, ControlValue string ) *ControlString { - return &ControlString{ - ControlType: ControlType, - Criticality: Criticality, - ControlValue: ControlValue, - } -} - -func NewControlPaging( PagingSize uint32 ) *ControlPaging { - return &ControlPaging{ PagingSize: PagingSize } -} - -func encodeControls( Controls []Control ) *ber.Packet { - p := ber.Encode( ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls" ) - for _, control := range Controls { - p.AppendChild( control.Encode() ) - } - return p + return fmt.Sprintf( + "Control Type: %s (%q) Criticality: %s PagingSize: %d Cookie: %q", + ControlTypeMap[ControlTypePaging], + ControlTypePaging, + false, + c.PagingSize, + c.Cookie) +} + +func (c *ControlPaging) SetCookie(Cookie []byte) { + c.Cookie = Cookie +} + +func FindControl(Controls []Control, ControlType string) Control { + for _, c := range Controls { + if c.GetControlType() == ControlType { + return c + } + } + return nil +} + +func DecodeControl(p *ber.Packet) Control { + ControlType := p.Children[0].Value.(string) + Criticality := false + + p.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" + value := p.Children[1] + if len(p.Children) == 3 { + value = p.Children[2] + p.Children[1].Description = "Criticality" + Criticality = p.Children[1].Value.(bool) + } + + value.Description = "Control Value" + switch ControlType { + case ControlTypePaging: + value.Description += " (Paging)" + c := new(ControlPaging) + if value.Value != nil { + value_children := ber.DecodePacket(value.Data.Bytes()) + value.Data.Truncate(0) + value.Value = nil + value.AppendChild(value_children) + } + value = value.Children[0] + value.Description = "Search Control Value" + value.Children[0].Description = "Paging Size" + value.Children[1].Description = "Cookie" + c.PagingSize = uint32(value.Children[0].Value.(uint64)) + c.Cookie = value.Children[1].Data.Bytes() + value.Children[1].Value = c.Cookie + return c + } + c := new(ControlString) + c.ControlType = ControlType + c.Criticality = Criticality + c.ControlValue = value.Value.(string) + return c +} + +func NewControlString(ControlType string, Criticality bool, ControlValue string) *ControlString { + return &ControlString{ + ControlType: ControlType, + Criticality: Criticality, + ControlValue: ControlValue, + } +} + +func NewControlPaging(PagingSize uint32) *ControlPaging { + return &ControlPaging{PagingSize: PagingSize} +} + +func encodeControls(Controls []Control) *ber.Packet { + p := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls") + for _, control := range Controls { + p.AppendChild(control.Encode()) + } + return p } diff --git a/debugging.go b/debugging.go new file mode 100644 index 0000000..cefbfad --- /dev/null +++ b/debugging.go @@ -0,0 +1,31 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// File contains debugging functionality +package ldap + +import ( + "github.com/tmfkams/asn1-ber" + "log" +) + +/* + debbuging type + - has a Printf method to write the debug output +*/ +type debugging bool + +// write debug output +func (debug debugging) Printf(format string, args ...interface{}) { + if debug { + // TODO: DEBUG prefix + log.Printf(format, args...) + } +} + +func (debug debugging) PrintPacket(packet *ber.Packet) { + if debug { + ber.PrintPacket(packet) + } +} diff --git a/examples/enterprise.ldif b/examples/enterprise.ldif new file mode 100644 index 0000000..f0ec28f --- /dev/null +++ b/examples/enterprise.ldif @@ -0,0 +1,63 @@ +dn: dc=enterprise,dc=org +objectClass: dcObject +objectClass: organization +o: acme + +dn: cn=admin,dc=enterprise,dc=org +objectClass: person +cn: admin +sn: admin +description: "LDAP Admin" + +dn: ou=crew,dc=enterprise,dc=org +ou: crew +objectClass: organizationalUnit + + +dn: cn=kirkj,ou=crew,dc=enterprise,dc=org +cn: kirkj +sn: Kirk +gn: James Tiberius +mail: james.kirk@enterprise.org +objectClass: inetOrgPerson + +dn: cn=spock,ou=crew,dc=enterprise,dc=org +cn: spock +sn: Spock +mail: spock@enterprise.org +objectClass: inetOrgPerson + +dn: cn=mccoyl,ou=crew,dc=enterprise,dc=org +cn: mccoyl +sn: McCoy +gn: Leonard +mail: leonard.mccoy@enterprise.org +objectClass: inetOrgPerson + +dn: cn=scottm,ou=crew,dc=enterprise,dc=org +cn: scottm +sn: Scott +gn: Montgomery +mail: Montgomery.scott@enterprise.org +objectClass: inetOrgPerson + +dn: cn=uhuran,ou=crew,dc=enterprise,dc=org +cn: uhuran +sn: Uhura +gn: Nyota +mail: nyota.uhura@enterprise.org +objectClass: inetOrgPerson + +dn: cn=suluh,ou=crew,dc=enterprise,dc=org +cn: suluh +sn: Sulu +gn: Hikaru +mail: hikaru.sulu@enterprise.org +objectClass: inetOrgPerson + +dn: cn=chekovp,ou=crew,dc=enterprise,dc=org +cn: chekovp +sn: Chekov +gn: pavel +mail: pavel.chekov@enterprise.org +objectClass: inetOrgPerson diff --git a/examples/modify.go b/examples/modify.go new file mode 100644 index 0000000..7af8e06 --- /dev/null +++ b/examples/modify.go @@ -0,0 +1,89 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// File contains a modify example +package main + +import ( + "errors" + "fmt" + "github.com/tmfkams/ldap" + "log" +) + +var ( + LdapServer string = "localhost" + LdapPort uint16 = 389 + BaseDN string = "dc=enterprise,dc=org" + BindDN string = "cn=admin,dc=enterprise,dc=org" + BindPW string = "enterprise" + Filter string = "(cn=kirkj)" +) + +func search(l *ldap.Conn, filter string, attributes []string) (*ldap.Entry, *ldap.Error) { + search := ldap.NewSearchRequest( + BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + filter, + attributes, + nil) + + sr, err := l.Search(search) + if err != nil { + log.Fatalf("ERROR: %s\n", err.String()) + return nil, err + } + + log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries)) + if len(sr.Entries) == 0 { + return nil, ldap.NewError(ldap.ErrorDebugging, errors.New(fmt.Sprintf("no entries found for: %s", filter))) + } + return sr.Entries[0], nil +} + +func main() { + l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", LdapServer, LdapPort)) + if err != nil { + log.Fatalf("ERROR: %s\n", err.String()) + } + defer l.Close() + // l.Debug = true + + l.Bind(BindDN, BindPW) + + log.Printf("The Search for Kirk ... %s\n", Filter) + entry, err := search(l, Filter, []string{}) + if err != nil { + log.Fatal("could not get entry") + } + entry.PrettyPrint(0) + + log.Printf("modify the mail address and add a description ... \n") + modify := ldap.NewModifyRequest(entry.DN) + modify.Add("description", []string{"Captain of the USS Enterprise"}) + modify.Replace("mail", []string{"captain@enterprise.org"}) + if err := l.Modify(modify); err != nil { + log.Fatalf("ERROR: %s\n", err.String()) + } + + entry, err = search(l, Filter, []string{}) + if err != nil { + log.Fatal("could not get entry") + } + entry.PrettyPrint(0) + + log.Printf("reset the entry ... \n") + modify = ldap.NewModifyRequest(entry.DN) + modify.Delete("description", []string{}) + modify.Replace("mail", []string{"james.kirk@enterprise.org"}) + if err := l.Modify(modify); err != nil { + log.Fatalf("ERROR: %s\n", err.String()) + } + + entry, err = search(l, Filter, []string{}) + if err != nil { + log.Fatal("could not get entry") + } + entry.PrettyPrint(0) +} diff --git a/examples/search.go b/examples/search.go new file mode 100644 index 0000000..b7d4943 --- /dev/null +++ b/examples/search.go @@ -0,0 +1,45 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// File contains a search example +package main + +import ( + "fmt" + "github.com/tmfkams/ldap" + "log" +) + +var ( + LdapServer string = "localhost" + LdapPort uint16 = 389 + BaseDN string = "dc=enterprise,dc=org" + Filter string = "(cn=kirkj)" + Attributes []string = []string{"mail"} +) + +func main() { + l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", LdapServer, LdapPort)) + if err != nil { + log.Fatalf("ERROR: %s\n", err.String()) + } + defer l.Close() + // l.Debug = true + + search := ldap.NewSearchRequest( + BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + Filter, + Attributes, + nil) + + sr, err := l.Search(search) + if err != nil { + log.Fatalf("ERROR: %s\n", err.String()) + return + } + + log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries)) + sr.PrettyPrint(0) +} diff --git a/examples/slapd.conf b/examples/slapd.conf new file mode 100644 index 0000000..5a66be0 --- /dev/null +++ b/examples/slapd.conf @@ -0,0 +1,67 @@ +# +# See slapd.conf(5) for details on configuration options. +# This file should NOT be world readable. +# +include /private/etc/openldap/schema/core.schema +include /private/etc/openldap/schema/cosine.schema +include /private/etc/openldap/schema/inetorgperson.schema + +# Define global ACLs to disable default read access. + +# Do not enable referrals until AFTER you have a working directory +# service AND an understanding of referrals. +#referral ldap://root.openldap.org + +pidfile /private/var/db/openldap/run/slapd.pid +argsfile /private/var/db/openldap/run/slapd.args + +# Load dynamic backend modules: +# modulepath /usr/libexec/openldap +# moduleload back_bdb.la +# moduleload back_hdb.la +# moduleload back_ldap.la + +# Sample security restrictions +# Require integrity protection (prevent hijacking) +# Require 112-bit (3DES or better) encryption for updates +# Require 63-bit encryption for simple bind +# security ssf=1 update_ssf=112 simple_bind=64 + +# Sample access control policy: +# Root DSE: allow anyone to read it +# Subschema (sub)entry DSE: allow anyone to read it +# Other DSEs: +# Allow self write access +# Allow authenticated users read access +# Allow anonymous users to authenticate +# Directives needed to implement policy: +# access to dn.base="" by * read +# access to dn.base="cn=Subschema" by * read +# access to * +# by self write +# by users read +# by anonymous auth +# +# if no access controls are present, the default policy +# allows anyone and everyone to read anything but restricts +# updates to rootdn. (e.g., "access to * by * read") +# +# rootdn can always read and write EVERYTHING! + +####################################################################### +# BDB database definitions +####################################################################### + +database bdb +suffix "dc=enterprise,dc=org" +rootdn "cn=admin,dc=enterprise,dc=org" +# Cleartext passwords, especially for the rootdn, should +# be avoid. See slappasswd(8) and slapd.conf(5) for details. +# Use of strong authentication encouraged. +rootpw {SSHA}laO00HsgszhK1O0Z5qR0/i/US69Osfeu +# The database directory MUST exist prior to running slapd AND +# should only be accessible by the slapd and slap tools. +# Mode 700 recommended. +directory /private/var/db/openldap/openldap-data +# Indices to maintain +index objectClass eq diff --git a/filter.go b/filter.go index 1b8fd1e..66c8afe 100644 --- a/filter.go +++ b/filter.go @@ -6,244 +6,244 @@ package ldap import ( + "errors" "fmt" - "os" - "github.com/mmitton/asn1-ber" + "github.com/tmfkams/asn1-ber" ) const ( - FilterAnd = 0 - FilterOr = 1 - FilterNot = 2 - FilterEqualityMatch = 3 - FilterSubstrings = 4 - FilterGreaterOrEqual = 5 - FilterLessOrEqual = 6 - FilterPresent = 7 - FilterApproxMatch = 8 - FilterExtensibleMatch = 9 + FilterAnd = 0 + FilterOr = 1 + FilterNot = 2 + FilterEqualityMatch = 3 + FilterSubstrings = 4 + FilterGreaterOrEqual = 5 + FilterLessOrEqual = 6 + FilterPresent = 7 + FilterApproxMatch = 8 + FilterExtensibleMatch = 9 ) -var FilterMap = map[ uint64 ] string { - FilterAnd : "And", - FilterOr : "Or", - FilterNot : "Not", - FilterEqualityMatch : "Equality Match", - FilterSubstrings : "Substrings", - FilterGreaterOrEqual : "Greater Or Equal", - FilterLessOrEqual : "Less Or Equal", - FilterPresent : "Present", - FilterApproxMatch : "Approx Match", - FilterExtensibleMatch : "Extensible Match", +var FilterMap = map[uint64]string{ + FilterAnd: "And", + FilterOr: "Or", + FilterNot: "Not", + FilterEqualityMatch: "Equality Match", + FilterSubstrings: "Substrings", + FilterGreaterOrEqual: "Greater Or Equal", + FilterLessOrEqual: "Less Or Equal", + FilterPresent: "Present", + FilterApproxMatch: "Approx Match", + FilterExtensibleMatch: "Extensible Match", } const ( - FilterSubstringsInitial = 0 - FilterSubstringsAny = 1 - FilterSubstringsFinal = 2 + FilterSubstringsInitial = 0 + FilterSubstringsAny = 1 + FilterSubstringsFinal = 2 ) -var FilterSubstringsMap = map[ uint64 ] string { - FilterSubstringsInitial : "Substrings Initial", - FilterSubstringsAny : "Substrings Any", - FilterSubstringsFinal : "Substrings Final", +var FilterSubstringsMap = map[uint64]string{ + FilterSubstringsInitial: "Substrings Initial", + FilterSubstringsAny: "Substrings Any", + FilterSubstringsFinal: "Substrings Final", } -func CompileFilter( filter string ) ( *ber.Packet, *Error ) { - if len( filter ) == 0 || filter[ 0 ] != '(' { - return nil, NewError( ErrorFilterCompile, os.NewError( "Filter does not start with an '('" ) ) - } - packet, pos, err := compileFilter( filter, 1 ) - if err != nil { - return nil, err - } - if pos != len( filter ) { - return nil, NewError( ErrorFilterCompile, os.NewError( "Finished compiling filter with extra at end.\n" + fmt.Sprint( filter[pos:] ) ) ) - } - return packet, nil +func CompileFilter(filter string) (*ber.Packet, *Error) { + if len(filter) == 0 || filter[0] != '(' { + return nil, NewError(ErrorFilterCompile, errors.New("Filter does not start with an '('")) + } + packet, pos, err := compileFilter(filter, 1) + if err != nil { + return nil, err + } + if pos != len(filter) { + return nil, NewError(ErrorFilterCompile, errors.New("Finished compiling filter with extra at end.\n"+fmt.Sprint(filter[pos:]))) + } + return packet, nil } -func DecompileFilter( packet *ber.Packet ) (ret string, err *Error) { - defer func() { - if r := recover(); r != nil { - err = NewError( ErrorFilterDecompile, os.NewError( "Error decompiling filter" ) ) - } - }() - ret = "(" - err = nil - child_str := "" +func DecompileFilter(packet *ber.Packet) (ret string, err *Error) { + defer func() { + if r := recover(); r != nil { + err = NewError(ErrorFilterDecompile, errors.New("Error decompiling filter")) + } + }() + ret = "(" + err = nil + child_str := "" - switch packet.Tag { - case FilterAnd: - ret += "&" - for _, child := range packet.Children { - child_str, err = DecompileFilter( child ) - if err != nil { - return - } - ret += child_str - } - case FilterOr: - ret += "|" - for _, child := range packet.Children { - child_str, err = DecompileFilter( child ) - if err != nil { - return - } - ret += child_str - } - case FilterNot: - ret += "!" - child_str, err = DecompileFilter( packet.Children[ 0 ] ) - if err != nil { - return - } - ret += child_str + switch packet.Tag { + case FilterAnd: + ret += "&" + for _, child := range packet.Children { + child_str, err = DecompileFilter(child) + if err != nil { + return + } + ret += child_str + } + case FilterOr: + ret += "|" + for _, child := range packet.Children { + child_str, err = DecompileFilter(child) + if err != nil { + return + } + ret += child_str + } + case FilterNot: + ret += "!" + child_str, err = DecompileFilter(packet.Children[0]) + if err != nil { + return + } + ret += child_str - case FilterSubstrings: - ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) - ret += "=" - switch packet.Children[ 1 ].Children[ 0 ].Tag { - case FilterSubstringsInitial: - ret += ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) + "*" - case FilterSubstringsAny: - ret += "*" + ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) + "*" - case FilterSubstringsFinal: - ret += "*" + ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) - } - case FilterEqualityMatch: - ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) - ret += "=" - ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) - case FilterGreaterOrEqual: - ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) - ret += ">=" - ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) - case FilterLessOrEqual: - ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) - ret += "<=" - ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) - case FilterPresent: - ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) - ret += "=*" - case FilterApproxMatch: - ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) - ret += "~=" - ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) - } + case FilterSubstrings: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "=" + switch packet.Children[1].Children[0].Tag { + case FilterSubstringsInitial: + ret += ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*" + case FilterSubstringsAny: + ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*" + case FilterSubstringsFinal: + ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + } + case FilterEqualityMatch: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "=" + ret += ber.DecodeString(packet.Children[1].Data.Bytes()) + case FilterGreaterOrEqual: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += ">=" + ret += ber.DecodeString(packet.Children[1].Data.Bytes()) + case FilterLessOrEqual: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "<=" + ret += ber.DecodeString(packet.Children[1].Data.Bytes()) + case FilterPresent: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "=*" + case FilterApproxMatch: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "~=" + ret += ber.DecodeString(packet.Children[1].Data.Bytes()) + } - ret += ")" - return + ret += ")" + return } -func compileFilterSet( filter string, pos int, parent *ber.Packet ) ( int, *Error ) { - for pos < len( filter ) && filter[ pos ] == '(' { - child, new_pos, err := compileFilter( filter, pos + 1 ) - if err != nil { - return pos, err - } - pos = new_pos - parent.AppendChild( child ) - } - if pos == len( filter ) { - return pos, NewError( ErrorFilterCompile, os.NewError( "Unexpected end of filter" ) ) - } +func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, *Error) { + for pos < len(filter) && filter[pos] == '(' { + child, new_pos, err := compileFilter(filter, pos+1) + if err != nil { + return pos, err + } + pos = new_pos + parent.AppendChild(child) + } + if pos == len(filter) { + return pos, NewError(ErrorFilterCompile, errors.New("Unexpected end of filter")) + } - return pos + 1, nil + return pos + 1, nil } -func compileFilter( filter string, pos int ) ( p *ber.Packet, new_pos int, err *Error ) { - defer func() { - if r := recover(); r != nil { - err = NewError( ErrorFilterCompile, os.NewError( "Error compiling filter" ) ) - } - }() - p = nil - new_pos = pos - err = nil +func compileFilter(filter string, pos int) (p *ber.Packet, new_pos int, err *Error) { + defer func() { + if r := recover(); r != nil { + err = NewError(ErrorFilterCompile, errors.New("Error compiling filter")) + } + }() + p = nil + new_pos = pos + err = nil - switch filter[pos] { - case '(': - p, new_pos, err = compileFilter( filter, pos + 1 ) - new_pos++ - return - case '&': - p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[ FilterAnd ] ) - new_pos, err = compileFilterSet( filter, pos + 1, p ) - return - case '|': - p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[ FilterOr ] ) - new_pos, err = compileFilterSet( filter, pos + 1, p ) - return - case '!': - p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[ FilterNot ] ) - var child *ber.Packet - child, new_pos, err = compileFilter( filter, pos + 1 ) - p.AppendChild( child ) - return - default: - attribute := "" - condition := "" - for new_pos < len( filter ) && filter[ new_pos ] != ')' { - switch { - case p != nil: - condition += fmt.Sprintf( "%c", filter[ new_pos ] ) - case filter[ new_pos ] == '=': - p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[ FilterEqualityMatch ] ) - case filter[ new_pos ] == '>' && filter[ new_pos + 1 ] == '=': - p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[ FilterGreaterOrEqual ] ) - new_pos++ - case filter[ new_pos ] == '<' && filter[ new_pos + 1 ] == '=': - p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[ FilterLessOrEqual ] ) - new_pos++ - case filter[ new_pos ] == '~' && filter[ new_pos + 1 ] == '=': - p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[ FilterLessOrEqual ] ) - new_pos++ - case p == nil: - attribute += fmt.Sprintf( "%c", filter[ new_pos ] ) - } - new_pos++ - } - if new_pos == len( filter ) { - err = NewError( ErrorFilterCompile, os.NewError( "Unexpected end of filter" ) ) - return - } - if p == nil { - err = NewError( ErrorFilterCompile, os.NewError( "Error parsing filter" ) ) - return - } - p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute" ) ) - switch { - case p.Tag == FilterEqualityMatch && condition == "*": - p.Tag = FilterPresent - p.Description = FilterMap[ uint64(p.Tag) ] - case p.Tag == FilterEqualityMatch && condition[ 0 ] == '*' && condition[ len( condition ) - 1 ] == '*': - // Any - p.Tag = FilterSubstrings - p.Description = FilterMap[ uint64(p.Tag) ] - seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" ) - seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsAny, condition[ 1 : len( condition ) - 1 ], "Any Substring" ) ) - p.AppendChild( seq ) - case p.Tag == FilterEqualityMatch && condition[ 0 ] == '*': - // Final - p.Tag = FilterSubstrings - p.Description = FilterMap[ uint64(p.Tag) ] - seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" ) - seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsFinal, condition[ 1: ], "Final Substring" ) ) - p.AppendChild( seq ) - case p.Tag == FilterEqualityMatch && condition[ len( condition ) - 1 ] == '*': - // Initial - p.Tag = FilterSubstrings - p.Description = FilterMap[ uint64(p.Tag) ] - seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" ) - seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsInitial, condition[ :len( condition ) - 1 ], "Initial Substring" ) ) - p.AppendChild( seq ) - default: - p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, condition, "Condition" ) ) - } - new_pos++ - return - } - err = NewError( ErrorFilterCompile, os.NewError( "Reached end of filter without closing parens" ) ) - return + switch filter[pos] { + case '(': + p, new_pos, err = compileFilter(filter, pos+1) + new_pos++ + return + case '&': + p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd]) + new_pos, err = compileFilterSet(filter, pos+1, p) + return + case '|': + p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr]) + new_pos, err = compileFilterSet(filter, pos+1, p) + return + case '!': + p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot]) + var child *ber.Packet + child, new_pos, err = compileFilter(filter, pos+1) + p.AppendChild(child) + return + default: + attribute := "" + condition := "" + for new_pos < len(filter) && filter[new_pos] != ')' { + switch { + case p != nil: + condition += fmt.Sprintf("%c", filter[new_pos]) + case filter[new_pos] == '=': + p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch]) + case filter[new_pos] == '>' && filter[new_pos+1] == '=': + p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual]) + new_pos++ + case filter[new_pos] == '<' && filter[new_pos+1] == '=': + p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual]) + new_pos++ + case filter[new_pos] == '~' && filter[new_pos+1] == '=': + p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual]) + new_pos++ + case p == nil: + attribute += fmt.Sprintf("%c", filter[new_pos]) + } + new_pos++ + } + if new_pos == len(filter) { + err = NewError(ErrorFilterCompile, errors.New("Unexpected end of filter")) + return + } + if p == nil { + err = NewError(ErrorFilterCompile, errors.New("Error parsing filter")) + return + } + p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute")) + switch { + case p.Tag == FilterEqualityMatch && condition == "*": + p.Tag = FilterPresent + p.Description = FilterMap[uint64(p.Tag)] + case p.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*': + // Any + p.Tag = FilterSubstrings + p.Description = FilterMap[uint64(p.Tag)] + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") + seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring")) + p.AppendChild(seq) + case p.Tag == FilterEqualityMatch && condition[0] == '*': + // Final + p.Tag = FilterSubstrings + p.Description = FilterMap[uint64(p.Tag)] + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") + seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsFinal, condition[1:], "Final Substring")) + p.AppendChild(seq) + case p.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*': + // Initial + p.Tag = FilterSubstrings + p.Description = FilterMap[uint64(p.Tag)] + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") + seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring")) + p.AppendChild(seq) + default: + p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, condition, "Condition")) + } + new_pos++ + return + } + err = NewError(ErrorFilterCompile, errors.New("Reached end of filter without closing parens")) + return } diff --git a/filter_test.go b/filter_test.go index 4b403e4..c56a937 100644 --- a/filter_test.go +++ b/filter_test.go @@ -1,78 +1,77 @@ package ldap import ( - "github.com/mmitton/asn1-ber" - "testing" + "github.com/tmfkams/asn1-ber" + "testing" ) type compile_test struct { - filter_str string - filter_type int + filter_str string + filter_type int } - -var test_filters = []compile_test { - compile_test{ filter_str: "(&(sn=Miller)(givenName=Bob))", filter_type: FilterAnd }, - compile_test{ filter_str: "(|(sn=Miller)(givenName=Bob))", filter_type: FilterOr }, - compile_test{ filter_str: "(!(sn=Miller))", filter_type: FilterNot }, - compile_test{ filter_str: "(sn=Miller)", filter_type: FilterEqualityMatch }, - compile_test{ filter_str: "(sn=Mill*)", filter_type: FilterSubstrings }, - compile_test{ filter_str: "(sn=*Mill)", filter_type: FilterSubstrings }, - compile_test{ filter_str: "(sn=*Mill*)", filter_type: FilterSubstrings }, - compile_test{ filter_str: "(sn>=Miller)", filter_type: FilterGreaterOrEqual }, - compile_test{ filter_str: "(sn<=Miller)", filter_type: FilterLessOrEqual }, - compile_test{ filter_str: "(sn=*)", filter_type: FilterPresent }, - compile_test{ filter_str: "(sn~=Miller)", filter_type: FilterApproxMatch }, - // compile_test{ filter_str: "()", filter_type: FilterExtensibleMatch }, +var test_filters = []compile_test{ + compile_test{filter_str: "(&(sn=Miller)(givenName=Bob))", filter_type: FilterAnd}, + compile_test{filter_str: "(|(sn=Miller)(givenName=Bob))", filter_type: FilterOr}, + compile_test{filter_str: "(!(sn=Miller))", filter_type: FilterNot}, + compile_test{filter_str: "(sn=Miller)", filter_type: FilterEqualityMatch}, + compile_test{filter_str: "(sn=Mill*)", filter_type: FilterSubstrings}, + compile_test{filter_str: "(sn=*Mill)", filter_type: FilterSubstrings}, + compile_test{filter_str: "(sn=*Mill*)", filter_type: FilterSubstrings}, + compile_test{filter_str: "(sn>=Miller)", filter_type: FilterGreaterOrEqual}, + compile_test{filter_str: "(sn<=Miller)", filter_type: FilterLessOrEqual}, + compile_test{filter_str: "(sn=*)", filter_type: FilterPresent}, + compile_test{filter_str: "(sn~=Miller)", filter_type: FilterApproxMatch}, + // compile_test{ filter_str: "()", filter_type: FilterExtensibleMatch }, } -func TestFilter( t *testing.T ) { - // Test Compiler and Decompiler - for _, i := range test_filters { - filter, err := CompileFilter( i.filter_str ) - if err != nil { - t.Errorf( "Problem compiling %s - %s", err.String() ) - } else if filter.Tag != uint8(i.filter_type) { - t.Errorf( "%q Expected %q got %q", i.filter_str, FilterMap[ uint64(i.filter_type) ], FilterMap[ uint64(filter.Tag) ] ) - } else { - o, err := DecompileFilter( filter ) - if err != nil { - t.Errorf( "Problem compiling %s - %s", i, err.String() ) - } else if i.filter_str != o { - t.Errorf( "%q expected, got %q", i.filter_str, o ) - } - } - } +func TestFilter(t *testing.T) { + // Test Compiler and Decompiler + for _, i := range test_filters { + filter, err := CompileFilter(i.filter_str) + if err != nil { + t.Errorf("Problem compiling %s - %s", err.String()) + } else if filter.Tag != uint8(i.filter_type) { + t.Errorf("%q Expected %q got %q", i.filter_str, FilterMap[uint64(i.filter_type)], FilterMap[uint64(filter.Tag)]) + } else { + o, err := DecompileFilter(filter) + if err != nil { + t.Errorf("Problem compiling %s - %s", i, err.String()) + } else if i.filter_str != o { + t.Errorf("%q expected, got %q", i.filter_str, o) + } + } + } } -func BenchmarkFilterCompile( b *testing.B ) { - b.StopTimer() - filters := make([]string, len( test_filters ) ) +func BenchmarkFilterCompile(b *testing.B) { + b.StopTimer() + filters := make([]string, len(test_filters)) - // Test Compiler and Decompiler - for idx, i := range test_filters { - filters[ idx ] = i.filter_str - } + // Test Compiler and Decompiler + for idx, i := range test_filters { + filters[idx] = i.filter_str + } - max_idx := len( filters ) - b.StartTimer() - for i := 0; i < b.N; i++ { - CompileFilter( filters[ i % max_idx ] ) - } + max_idx := len(filters) + b.StartTimer() + for i := 0; i < b.N; i++ { + CompileFilter(filters[i%max_idx]) + } } -func BenchmarkFilterDecompile( b *testing.B ) { - b.StopTimer() - filters := make([]*ber.Packet, len( test_filters ) ) +func BenchmarkFilterDecompile(b *testing.B) { + b.StopTimer() + filters := make([]*ber.Packet, len(test_filters)) - // Test Compiler and Decompiler - for idx, i := range test_filters { - filters[ idx ], _ = CompileFilter( i.filter_str ) - } + // Test Compiler and Decompiler + for idx, i := range test_filters { + filters[idx], _ = CompileFilter(i.filter_str) + } - max_idx := len( filters ) - b.StartTimer() - for i := 0; i < b.N; i++ { - DecompileFilter( filters[ i % max_idx ] ) - } + max_idx := len(filters) + b.StartTimer() + for i := 0; i < b.N; i++ { + DecompileFilter(filters[i%max_idx]) + } } diff --git a/ldap.go b/ldap.go index d331c81..71916f8 100644 --- a/ldap.go +++ b/ldap.go @@ -6,301 +6,301 @@ package ldap import ( - "github.com/mmitton/asn1-ber" - "fmt" - "io/ioutil" - "os" + "errors" + "fmt" + "github.com/tmfkams/asn1-ber" + "io/ioutil" ) // LDAP Application Codes const ( - ApplicationBindRequest = 0 - ApplicationBindResponse = 1 - ApplicationUnbindRequest = 2 - ApplicationSearchRequest = 3 - ApplicationSearchResultEntry = 4 - ApplicationSearchResultDone = 5 - ApplicationModifyRequest = 6 - ApplicationModifyResponse = 7 - ApplicationAddRequest = 8 - ApplicationAddResponse = 9 - ApplicationDelRequest = 10 - ApplicationDelResponse = 11 - ApplicationModifyDNRequest = 12 - ApplicationModifyDNResponse = 13 - ApplicationCompareRequest = 14 - ApplicationCompareResponse = 15 - ApplicationAbandonRequest = 16 - ApplicationSearchResultReference = 19 - ApplicationExtendedRequest = 23 - ApplicationExtendedResponse = 24 + ApplicationBindRequest = 0 + ApplicationBindResponse = 1 + ApplicationUnbindRequest = 2 + ApplicationSearchRequest = 3 + ApplicationSearchResultEntry = 4 + ApplicationSearchResultDone = 5 + ApplicationModifyRequest = 6 + ApplicationModifyResponse = 7 + ApplicationAddRequest = 8 + ApplicationAddResponse = 9 + ApplicationDelRequest = 10 + ApplicationDelResponse = 11 + ApplicationModifyDNRequest = 12 + ApplicationModifyDNResponse = 13 + ApplicationCompareRequest = 14 + ApplicationCompareResponse = 15 + ApplicationAbandonRequest = 16 + ApplicationSearchResultReference = 19 + ApplicationExtendedRequest = 23 + ApplicationExtendedResponse = 24 ) -var ApplicationMap = map[ uint8 ] string { - ApplicationBindRequest : "Bind Request", - ApplicationBindResponse : "Bind Response", - ApplicationUnbindRequest : "Unbind Request", - ApplicationSearchRequest : "Search Request", - ApplicationSearchResultEntry : "Search Result Entry", - ApplicationSearchResultDone : "Search Result Done", - ApplicationModifyRequest : "Modify Request", - ApplicationModifyResponse : "Modify Response", - ApplicationAddRequest : "Add Request", - ApplicationAddResponse : "Add Response", - ApplicationDelRequest : "Del Request", - ApplicationDelResponse : "Del Response", - ApplicationModifyDNRequest : "Modify DN Request", - ApplicationModifyDNResponse : "Modify DN Response", - ApplicationCompareRequest : "Compare Request", - ApplicationCompareResponse : "Compare Response", - ApplicationAbandonRequest : "Abandon Request", - ApplicationSearchResultReference : "Search Result Reference", - ApplicationExtendedRequest : "Extended Request", - ApplicationExtendedResponse : "Extended Response", +var ApplicationMap = map[uint8]string{ + ApplicationBindRequest: "Bind Request", + ApplicationBindResponse: "Bind Response", + ApplicationUnbindRequest: "Unbind Request", + ApplicationSearchRequest: "Search Request", + ApplicationSearchResultEntry: "Search Result Entry", + ApplicationSearchResultDone: "Search Result Done", + ApplicationModifyRequest: "Modify Request", + ApplicationModifyResponse: "Modify Response", + ApplicationAddRequest: "Add Request", + ApplicationAddResponse: "Add Response", + ApplicationDelRequest: "Del Request", + ApplicationDelResponse: "Del Response", + ApplicationModifyDNRequest: "Modify DN Request", + ApplicationModifyDNResponse: "Modify DN Response", + ApplicationCompareRequest: "Compare Request", + ApplicationCompareResponse: "Compare Response", + ApplicationAbandonRequest: "Abandon Request", + ApplicationSearchResultReference: "Search Result Reference", + ApplicationExtendedRequest: "Extended Request", + ApplicationExtendedResponse: "Extended Response", } // LDAP Result Codes const ( - LDAPResultSuccess = 0 - LDAPResultOperationsError = 1 - LDAPResultProtocolError = 2 - LDAPResultTimeLimitExceeded = 3 - LDAPResultSizeLimitExceeded = 4 - LDAPResultCompareFalse = 5 - LDAPResultCompareTrue = 6 - LDAPResultAuthMethodNotSupported = 7 - LDAPResultStrongAuthRequired = 8 - LDAPResultReferral = 10 - LDAPResultAdminLimitExceeded = 11 - LDAPResultUnavailableCriticalExtension = 12 - LDAPResultConfidentialityRequired = 13 - LDAPResultSaslBindInProgress = 14 - LDAPResultNoSuchAttribute = 16 - LDAPResultUndefinedAttributeType = 17 - LDAPResultInappropriateMatching = 18 - LDAPResultConstraintViolation = 19 - LDAPResultAttributeOrValueExists = 20 - LDAPResultInvalidAttributeSyntax = 21 - LDAPResultNoSuchObject = 32 - LDAPResultAliasProblem = 33 - LDAPResultInvalidDNSyntax = 34 - LDAPResultAliasDereferencingProblem = 36 - LDAPResultInappropriateAuthentication = 48 - LDAPResultInvalidCredentials = 49 - LDAPResultInsufficientAccessRights = 50 - LDAPResultBusy = 51 - LDAPResultUnavailable = 52 - LDAPResultUnwillingToPerform = 53 - LDAPResultLoopDetect = 54 - LDAPResultNamingViolation = 64 - LDAPResultObjectClassViolation = 65 - LDAPResultNotAllowedOnNonLeaf = 66 - LDAPResultNotAllowedOnRDN = 67 - LDAPResultEntryAlreadyExists = 68 - LDAPResultObjectClassModsProhibited = 69 - LDAPResultAffectsMultipleDSAs = 71 - LDAPResultOther = 80 + LDAPResultSuccess = 0 + LDAPResultOperationsError = 1 + LDAPResultProtocolError = 2 + LDAPResultTimeLimitExceeded = 3 + LDAPResultSizeLimitExceeded = 4 + LDAPResultCompareFalse = 5 + LDAPResultCompareTrue = 6 + LDAPResultAuthMethodNotSupported = 7 + LDAPResultStrongAuthRequired = 8 + LDAPResultReferral = 10 + LDAPResultAdminLimitExceeded = 11 + LDAPResultUnavailableCriticalExtension = 12 + LDAPResultConfidentialityRequired = 13 + LDAPResultSaslBindInProgress = 14 + LDAPResultNoSuchAttribute = 16 + LDAPResultUndefinedAttributeType = 17 + LDAPResultInappropriateMatching = 18 + LDAPResultConstraintViolation = 19 + LDAPResultAttributeOrValueExists = 20 + LDAPResultInvalidAttributeSyntax = 21 + LDAPResultNoSuchObject = 32 + LDAPResultAliasProblem = 33 + LDAPResultInvalidDNSyntax = 34 + LDAPResultAliasDereferencingProblem = 36 + LDAPResultInappropriateAuthentication = 48 + LDAPResultInvalidCredentials = 49 + LDAPResultInsufficientAccessRights = 50 + LDAPResultBusy = 51 + LDAPResultUnavailable = 52 + LDAPResultUnwillingToPerform = 53 + LDAPResultLoopDetect = 54 + LDAPResultNamingViolation = 64 + LDAPResultObjectClassViolation = 65 + LDAPResultNotAllowedOnNonLeaf = 66 + LDAPResultNotAllowedOnRDN = 67 + LDAPResultEntryAlreadyExists = 68 + LDAPResultObjectClassModsProhibited = 69 + LDAPResultAffectsMultipleDSAs = 71 + LDAPResultOther = 80 - ErrorNetwork = 200 - ErrorFilterCompile = 201 - ErrorFilterDecompile = 202 - ErrorDebugging = 203 + ErrorNetwork = 200 + ErrorFilterCompile = 201 + ErrorFilterDecompile = 202 + ErrorDebugging = 203 ) -var LDAPResultCodeMap = map[uint8] string { - LDAPResultSuccess : "Success", - LDAPResultOperationsError : "Operations Error", - LDAPResultProtocolError : "Protocol Error", - LDAPResultTimeLimitExceeded : "Time Limit Exceeded", - LDAPResultSizeLimitExceeded : "Size Limit Exceeded", - LDAPResultCompareFalse : "Compare False", - LDAPResultCompareTrue : "Compare True", - LDAPResultAuthMethodNotSupported : "Auth Method Not Supported", - LDAPResultStrongAuthRequired : "Strong Auth Required", - LDAPResultReferral : "Referral", - LDAPResultAdminLimitExceeded : "Admin Limit Exceeded", - LDAPResultUnavailableCriticalExtension : "Unavailable Critical Extension", - LDAPResultConfidentialityRequired : "Confidentiality Required", - LDAPResultSaslBindInProgress : "Sasl Bind In Progress", - LDAPResultNoSuchAttribute : "No Such Attribute", - LDAPResultUndefinedAttributeType : "Undefined Attribute Type", - LDAPResultInappropriateMatching : "Inappropriate Matching", - LDAPResultConstraintViolation : "Constraint Violation", - LDAPResultAttributeOrValueExists : "Attribute Or Value Exists", - LDAPResultInvalidAttributeSyntax : "Invalid Attribute Syntax", - LDAPResultNoSuchObject : "No Such Object", - LDAPResultAliasProblem : "Alias Problem", - LDAPResultInvalidDNSyntax : "Invalid DN Syntax", - LDAPResultAliasDereferencingProblem : "Alias Dereferencing Problem", - LDAPResultInappropriateAuthentication : "Inappropriate Authentication", - LDAPResultInvalidCredentials : "Invalid Credentials", - LDAPResultInsufficientAccessRights : "Insufficient Access Rights", - LDAPResultBusy : "Busy", - LDAPResultUnavailable : "Unavailable", - LDAPResultUnwillingToPerform : "Unwilling To Perform", - LDAPResultLoopDetect : "Loop Detect", - LDAPResultNamingViolation : "Naming Violation", - LDAPResultObjectClassViolation : "Object Class Violation", - LDAPResultNotAllowedOnNonLeaf : "Not Allowed On Non Leaf", - LDAPResultNotAllowedOnRDN : "Not Allowed On RDN", - LDAPResultEntryAlreadyExists : "Entry Already Exists", - LDAPResultObjectClassModsProhibited : "Object Class Mods Prohibited", - LDAPResultAffectsMultipleDSAs : "Affects Multiple DSAs", - LDAPResultOther : "Other", +var LDAPResultCodeMap = map[uint8]string{ + LDAPResultSuccess: "Success", + LDAPResultOperationsError: "Operations Error", + LDAPResultProtocolError: "Protocol Error", + LDAPResultTimeLimitExceeded: "Time Limit Exceeded", + LDAPResultSizeLimitExceeded: "Size Limit Exceeded", + LDAPResultCompareFalse: "Compare False", + LDAPResultCompareTrue: "Compare True", + LDAPResultAuthMethodNotSupported: "Auth Method Not Supported", + LDAPResultStrongAuthRequired: "Strong Auth Required", + LDAPResultReferral: "Referral", + LDAPResultAdminLimitExceeded: "Admin Limit Exceeded", + LDAPResultUnavailableCriticalExtension: "Unavailable Critical Extension", + LDAPResultConfidentialityRequired: "Confidentiality Required", + LDAPResultSaslBindInProgress: "Sasl Bind In Progress", + LDAPResultNoSuchAttribute: "No Such Attribute", + LDAPResultUndefinedAttributeType: "Undefined Attribute Type", + LDAPResultInappropriateMatching: "Inappropriate Matching", + LDAPResultConstraintViolation: "Constraint Violation", + LDAPResultAttributeOrValueExists: "Attribute Or Value Exists", + LDAPResultInvalidAttributeSyntax: "Invalid Attribute Syntax", + LDAPResultNoSuchObject: "No Such Object", + LDAPResultAliasProblem: "Alias Problem", + LDAPResultInvalidDNSyntax: "Invalid DN Syntax", + LDAPResultAliasDereferencingProblem: "Alias Dereferencing Problem", + LDAPResultInappropriateAuthentication: "Inappropriate Authentication", + LDAPResultInvalidCredentials: "Invalid Credentials", + LDAPResultInsufficientAccessRights: "Insufficient Access Rights", + LDAPResultBusy: "Busy", + LDAPResultUnavailable: "Unavailable", + LDAPResultUnwillingToPerform: "Unwilling To Perform", + LDAPResultLoopDetect: "Loop Detect", + LDAPResultNamingViolation: "Naming Violation", + LDAPResultObjectClassViolation: "Object Class Violation", + LDAPResultNotAllowedOnNonLeaf: "Not Allowed On Non Leaf", + LDAPResultNotAllowedOnRDN: "Not Allowed On RDN", + LDAPResultEntryAlreadyExists: "Entry Already Exists", + LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited", + LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs", + LDAPResultOther: "Other", } // Adds descriptions to an LDAP Response packet for debugging -func addLDAPDescriptions( packet *ber.Packet ) (err *Error) { - defer func() { - if r := recover(); r != nil { - err = NewError( ErrorDebugging, os.NewError( "Cannot process packet to add descriptions" ) ) - } - }() - packet.Description = "LDAP Response" - packet.Children[ 0 ].Description = "Message ID"; +func addLDAPDescriptions(packet *ber.Packet) (err *Error) { + defer func() { + if r := recover(); r != nil { + err = NewError(ErrorDebugging, errors.New("Cannot process packet to add descriptions")) + } + }() + packet.Description = "LDAP Response" + packet.Children[0].Description = "Message ID" - application := packet.Children[ 1 ].Tag - packet.Children[ 1 ].Description = ApplicationMap[ application ] + application := packet.Children[1].Tag + packet.Children[1].Description = ApplicationMap[application] - switch application { - case ApplicationBindRequest: - addRequestDescriptions( packet ) - case ApplicationBindResponse: - addDefaultLDAPResponseDescriptions( packet ) - case ApplicationUnbindRequest: - addRequestDescriptions( packet ) - case ApplicationSearchRequest: - addRequestDescriptions( packet ) - case ApplicationSearchResultEntry: - packet.Children[ 1 ].Children[ 0 ].Description = "Object Name" - packet.Children[ 1 ].Children[ 1 ].Description = "Attributes" - for _, child := range packet.Children[ 1 ].Children[ 1 ].Children { - child.Description = "Attribute" - child.Children[ 0 ].Description = "Attribute Name" - child.Children[ 1 ].Description = "Attribute Values" - for _, grandchild := range child.Children[ 1 ].Children { - grandchild.Description = "Attribute Value" - } - } - if len( packet.Children ) == 3 { - addControlDescriptions( packet.Children[ 2 ] ) - } - case ApplicationSearchResultDone: - addDefaultLDAPResponseDescriptions( packet ) - case ApplicationModifyRequest: - addRequestDescriptions( packet ) - case ApplicationModifyResponse: - case ApplicationAddRequest: - addRequestDescriptions( packet ) - case ApplicationAddResponse: - case ApplicationDelRequest: - addRequestDescriptions( packet ) - case ApplicationDelResponse: - case ApplicationModifyDNRequest: - addRequestDescriptions( packet ) - case ApplicationModifyDNResponse: - case ApplicationCompareRequest: - addRequestDescriptions( packet ) - case ApplicationCompareResponse: - case ApplicationAbandonRequest: - addRequestDescriptions( packet ) - case ApplicationSearchResultReference: - case ApplicationExtendedRequest: - addRequestDescriptions( packet ) - case ApplicationExtendedResponse: - } + switch application { + case ApplicationBindRequest: + addRequestDescriptions(packet) + case ApplicationBindResponse: + addDefaultLDAPResponseDescriptions(packet) + case ApplicationUnbindRequest: + addRequestDescriptions(packet) + case ApplicationSearchRequest: + addRequestDescriptions(packet) + case ApplicationSearchResultEntry: + packet.Children[1].Children[0].Description = "Object Name" + packet.Children[1].Children[1].Description = "Attributes" + for _, child := range packet.Children[1].Children[1].Children { + child.Description = "Attribute" + child.Children[0].Description = "Attribute Name" + child.Children[1].Description = "Attribute Values" + for _, grandchild := range child.Children[1].Children { + grandchild.Description = "Attribute Value" + } + } + if len(packet.Children) == 3 { + addControlDescriptions(packet.Children[2]) + } + case ApplicationSearchResultDone: + addDefaultLDAPResponseDescriptions(packet) + case ApplicationModifyRequest: + addRequestDescriptions(packet) + case ApplicationModifyResponse: + case ApplicationAddRequest: + addRequestDescriptions(packet) + case ApplicationAddResponse: + case ApplicationDelRequest: + addRequestDescriptions(packet) + case ApplicationDelResponse: + case ApplicationModifyDNRequest: + addRequestDescriptions(packet) + case ApplicationModifyDNResponse: + case ApplicationCompareRequest: + addRequestDescriptions(packet) + case ApplicationCompareResponse: + case ApplicationAbandonRequest: + addRequestDescriptions(packet) + case ApplicationSearchResultReference: + case ApplicationExtendedRequest: + addRequestDescriptions(packet) + case ApplicationExtendedResponse: + } - return nil + return nil } -func addControlDescriptions( packet *ber.Packet ) { - packet.Description = "Controls" - for _, child := range packet.Children { - child.Description = "Control" - child.Children[ 0 ].Description = "Control Type (" + ControlTypeMap[ child.Children[ 0 ].Value.(string) ] + ")" - value := child.Children[ 1 ] - if len( child.Children ) == 3 { - child.Children[ 1 ].Description = "Criticality" - value = child.Children[ 2 ] - } - value.Description = "Control Value" +func addControlDescriptions(packet *ber.Packet) { + packet.Description = "Controls" + for _, child := range packet.Children { + child.Description = "Control" + child.Children[0].Description = "Control Type (" + ControlTypeMap[child.Children[0].Value.(string)] + ")" + value := child.Children[1] + if len(child.Children) == 3 { + child.Children[1].Description = "Criticality" + value = child.Children[2] + } + value.Description = "Control Value" - switch child.Children[ 0 ].Value.(string) { - case ControlTypePaging: - value.Description += " (Paging)" - if value.Value != nil { - value_children := ber.DecodePacket( value.Data.Bytes() ) - value.Data.Truncate( 0 ) - value.Value = nil - value_children.Children[ 1 ].Value = value_children.Children[ 1 ].Data.Bytes() - value.AppendChild( value_children ) - } - value.Children[ 0 ].Description = "Real Search Control Value" - value.Children[ 0 ].Children[ 0 ].Description = "Paging Size" - value.Children[ 0 ].Children[ 1 ].Description = "Cookie" - } - } + switch child.Children[0].Value.(string) { + case ControlTypePaging: + value.Description += " (Paging)" + if value.Value != nil { + value_children := ber.DecodePacket(value.Data.Bytes()) + value.Data.Truncate(0) + value.Value = nil + value_children.Children[1].Value = value_children.Children[1].Data.Bytes() + value.AppendChild(value_children) + } + value.Children[0].Description = "Real Search Control Value" + value.Children[0].Children[0].Description = "Paging Size" + value.Children[0].Children[1].Description = "Cookie" + } + } } -func addRequestDescriptions( packet *ber.Packet ) { - packet.Description = "LDAP Request" - packet.Children[ 0 ].Description = "Message ID" - packet.Children[ 1 ].Description = ApplicationMap[ packet.Children[ 1 ].Tag ]; - if len( packet.Children ) == 3 { - addControlDescriptions( packet.Children[ 2 ] ) - } +func addRequestDescriptions(packet *ber.Packet) { + packet.Description = "LDAP Request" + packet.Children[0].Description = "Message ID" + packet.Children[1].Description = ApplicationMap[packet.Children[1].Tag] + if len(packet.Children) == 3 { + addControlDescriptions(packet.Children[2]) + } } -func addDefaultLDAPResponseDescriptions( packet *ber.Packet ) { - resultCode := packet.Children[ 1 ].Children[ 0 ].Value.(uint64) - packet.Children[ 1 ].Children[ 0 ].Description = "Result Code (" + LDAPResultCodeMap[ uint8(resultCode) ] + ")"; - packet.Children[ 1 ].Children[ 1 ].Description = "Matched DN"; - packet.Children[ 1 ].Children[ 2 ].Description = "Error Message"; - if len( packet.Children[ 1 ].Children ) > 3 { - packet.Children[ 1 ].Children[ 3 ].Description = "Referral"; - } - if len( packet.Children ) == 3 { - addControlDescriptions( packet.Children[ 2 ] ) - } +func addDefaultLDAPResponseDescriptions(packet *ber.Packet) { + resultCode := packet.Children[1].Children[0].Value.(uint64) + packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[uint8(resultCode)] + ")" + packet.Children[1].Children[1].Description = "Matched DN" + packet.Children[1].Children[2].Description = "Error Message" + if len(packet.Children[1].Children) > 3 { + packet.Children[1].Children[3].Description = "Referral" + } + if len(packet.Children) == 3 { + addControlDescriptions(packet.Children[2]) + } } -func DebugBinaryFile( FileName string ) *Error { - file, err := ioutil.ReadFile( FileName ) - if err != nil { - return NewError( ErrorDebugging, err ) - } - ber.PrintBytes( file, "" ) - packet := ber.DecodePacket( file ) - addLDAPDescriptions( packet ) - ber.PrintPacket( packet ) +func DebugBinaryFile(FileName string) *Error { + file, err := ioutil.ReadFile(FileName) + if err != nil { + return NewError(ErrorDebugging, err) + } + ber.PrintBytes(file, "") + packet := ber.DecodePacket(file) + addLDAPDescriptions(packet) + ber.PrintPacket(packet) - return nil + return nil } type Error struct { - Err os.Error - ResultCode uint8 + Err error + ResultCode uint8 } func (e *Error) String() string { - return fmt.Sprintf( "LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[ e.ResultCode ], e.Err.String() ) + return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error()) } -func NewError( ResultCode uint8, Err os.Error ) (* Error) { - return &Error{ ResultCode: ResultCode, Err: Err } +func NewError(ResultCode uint8, Err error) *Error { + return &Error{ResultCode: ResultCode, Err: Err} } -func getLDAPResultCode( p *ber.Packet ) ( code uint8, description string ) { - if len( p.Children ) >= 2 { - response := p.Children[ 1 ] - if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len( response.Children ) == 3 { - code = uint8(response.Children[ 0 ].Value.(uint64)) - description = response.Children[ 2 ].Value.(string) - return - } - } +func getLDAPResultCode(p *ber.Packet) (code uint8, description string) { + if len(p.Children) >= 2 { + response := p.Children[1] + if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 { + code = uint8(response.Children[0].Value.(uint64)) + description = response.Children[2].Value.(string) + return + } + } - code = ErrorNetwork - description = "Invalid packet format" - return + code = ErrorNetwork + description = "Invalid packet format" + return } diff --git a/ldap_test.go b/ldap_test.go index 708dde6..f21a8a6 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -1,125 +1,125 @@ package ldap import ( - "fmt" - "testing" + "fmt" + "testing" ) var ldap_server string = "ldap.itd.umich.edu" var ldap_port uint16 = 389 var base_dn string = "dc=umich,dc=edu" -var filter []string = []string { - "(cn=cis-fac)", - "(&(objectclass=rfc822mailgroup)(cn=*Computer*))", - "(&(objectclass=rfc822mailgroup)(cn=*Mathematics*))" } -var attributes []string = []string { - "cn", - "description" } - -func TestConnect( t *testing.T ) { - fmt.Printf( "TestConnect: starting...\n" ) - l, err := Dial( "tcp", fmt.Sprintf( "%s:%d", ldap_server, ldap_port ) ) - if err != nil { - t.Errorf( err.String() ) - return - } - defer l.Close() - fmt.Printf( "TestConnect: finished...\n" ) +var filter []string = []string{ + "(cn=cis-fac)", + "(&(objectclass=rfc822mailgroup)(cn=*Computer*))", + "(&(objectclass=rfc822mailgroup)(cn=*Mathematics*))"} +var attributes []string = []string{ + "cn", + "description"} + +func TestConnect(t *testing.T) { + fmt.Printf("TestConnect: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldap_server, ldap_port)) + if err != nil { + t.Errorf(err.String()) + return + } + defer l.Close() + fmt.Printf("TestConnect: finished...\n") } -func TestSearch( t *testing.T ) { - fmt.Printf( "TestSearch: starting...\n" ) - l, err := Dial( "tcp", fmt.Sprintf( "%s:%d", ldap_server, ldap_port ) ) - if err != nil { - t.Errorf( err.String() ) - return - } - defer l.Close() - - search_request := NewSearchRequest( - base_dn, - ScopeWholeSubtree, DerefAlways, 0, 0, false, - filter[0], - attributes, - nil ) - - sr, err := l.Search( search_request ) - if err != nil { - t.Errorf( err.String() ) - return - } - - fmt.Printf( "TestSearch: %s -> num of entries = %d\n", search_request.Filter, len( sr.Entries ) ) +func TestSearch(t *testing.T) { + fmt.Printf("TestSearch: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldap_server, ldap_port)) + if err != nil { + t.Errorf(err.String()) + return + } + defer l.Close() + + search_request := NewSearchRequest( + base_dn, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + sr, err := l.Search(search_request) + if err != nil { + t.Errorf(err.String()) + return + } + + fmt.Printf("TestSearch: %s -> num of entries = %d\n", search_request.Filter, len(sr.Entries)) } -func TestSearchWithPaging( t *testing.T ) { - fmt.Printf( "TestSearchWithPaging: starting...\n" ) - l, err := Dial( "tcp", fmt.Sprintf( "%s:%d", ldap_server, ldap_port ) ) - if err != nil { - t.Errorf( err.String() ) - return - } - defer l.Close() - - err = l.Bind( "", "" ) - if err != nil { - t.Errorf( err.String() ) - return - } - - search_request := NewSearchRequest( - base_dn, - ScopeWholeSubtree, DerefAlways, 0, 0, false, - filter[1], - attributes, - nil ) - sr, err := l.SearchWithPaging( search_request, 5 ) - if err != nil { - t.Errorf( err.String() ) - return - } - - fmt.Printf( "TestSearchWithPaging: %s -> num of entries = %d\n", search_request.Filter, len( sr.Entries ) ) +func TestSearchWithPaging(t *testing.T) { + fmt.Printf("TestSearchWithPaging: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldap_server, ldap_port)) + if err != nil { + t.Errorf(err.String()) + return + } + defer l.Close() + + err = l.Bind("", "") + if err != nil { + t.Errorf(err.String()) + return + } + + search_request := NewSearchRequest( + base_dn, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[1], + attributes, + nil) + sr, err := l.SearchWithPaging(search_request, 5) + if err != nil { + t.Errorf(err.String()) + return + } + + fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", search_request.Filter, len(sr.Entries)) } -func testMultiGoroutineSearch( t *testing.T, l *Conn, results chan *SearchResult, i int ) { - search_request := NewSearchRequest( - base_dn, - ScopeWholeSubtree, DerefAlways, 0, 0, false, - filter[i], - attributes, - nil ) - sr, err := l.Search( search_request ) - - if err != nil { - t.Errorf( err.String() ) - results <- nil - return - } - - results <- sr +func testMultiGoroutineSearch(t *testing.T, l *Conn, results chan *SearchResult, i int) { + search_request := NewSearchRequest( + base_dn, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[i], + attributes, + nil) + sr, err := l.Search(search_request) + + if err != nil { + t.Errorf(err.String()) + results <- nil + return + } + + results <- sr } -func TestMultiGoroutineSearch( t *testing.T ) { - fmt.Printf( "TestMultiGoroutineSearch: starting...\n" ) - l, err := Dial( "tcp", fmt.Sprintf( "%s:%d", ldap_server, ldap_port ) ) - if err != nil { - t.Errorf( err.String() ) - return - } - defer l.Close() - - results := make( []chan *SearchResult, len( filter ) ) - for i := range filter { - results[ i ] = make( chan *SearchResult ) - go testMultiGoroutineSearch( t, l, results[ i ], i ) - } - for i := range filter { - sr := <-results[ i ] - if sr == nil { - t.Errorf( "Did not receive results from goroutine for %q", filter[ i ] ) - } else { - fmt.Printf( "TestMultiGoroutineSearch(%d): %s -> num of entries = %d\n", i, filter[ i ], len( sr.Entries ) ) - } - } +func TestMultiGoroutineSearch(t *testing.T) { + fmt.Printf("TestMultiGoroutineSearch: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldap_server, ldap_port)) + if err != nil { + t.Errorf(err.String()) + return + } + defer l.Close() + + results := make([]chan *SearchResult, len(filter)) + for i := range filter { + results[i] = make(chan *SearchResult) + go testMultiGoroutineSearch(t, l, results[i], i) + } + for i := range filter { + sr := <-results[i] + if sr == nil { + t.Errorf("Did not receive results from goroutine for %q", filter[i]) + } else { + fmt.Printf("TestMultiGoroutineSearch(%d): %s -> num of entries = %d\n", i, filter[i], len(sr.Entries)) + } + } } diff --git a/modify.go b/modify.go new file mode 100644 index 0000000..ad49f94 --- /dev/null +++ b/modify.go @@ -0,0 +1,154 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// File contains Modify functionality +// +// https://tools.ietf.org/html/rfc4511 +// +// ModifyRequest ::= [APPLICATION 6] SEQUENCE { +// object LDAPDN, +// changes SEQUENCE OF change SEQUENCE { +// operation ENUMERATED { +// add (0), +// delete (1), +// replace (2), +// ... }, +// modification PartialAttribute } } +// +// PartialAttribute ::= SEQUENCE { +// type AttributeDescription, +// vals SET OF value AttributeValue } +// +// AttributeDescription ::= LDAPString +// -- Constrained to +// -- [RFC4512] +// +// AttributeValue ::= OCTET STRING +// +package ldap + +import ( + "errors" + "github.com/tmfkams/asn1-ber" + "log" +) + +const ( + AddAttribute = 0 + DeleteAttribute = 1 + ReplaceAttribute = 2 +) + +type PartialAttribute struct { + attrType string + attrVals []string +} + +func (p *PartialAttribute) encode() *ber.Packet { + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute") + seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, p.attrType, "Type")) + set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") + for _, value := range p.attrVals { + set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, value, "Vals")) + } + seq.AppendChild(set) + return seq +} + +type ModifyRequest struct { + dn string + addAttributes []PartialAttribute + deleteAttributes []PartialAttribute + replaceAttributes []PartialAttribute +} + +func (m *ModifyRequest) Add(attrType string, attrVals []string) { + m.addAttributes = append(m.addAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals}) +} + +func (m *ModifyRequest) Delete(attrType string, attrVals []string) { + m.deleteAttributes = append(m.deleteAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals}) +} + +func (m *ModifyRequest) Replace(attrType string, attrVals []string) { + m.replaceAttributes = append(m.replaceAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals}) +} + +func (m ModifyRequest) encode() *ber.Packet { + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, m.dn, "DN")) + changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes") + for _, attribute := range m.addAttributes { + change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") + change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(AddAttribute), "Operation")) + change.AppendChild(attribute.encode()) + changes.AppendChild(change) + } + for _, attribute := range m.deleteAttributes { + change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") + change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(DeleteAttribute), "Operation")) + change.AppendChild(attribute.encode()) + changes.AppendChild(change) + } + for _, attribute := range m.replaceAttributes { + change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") + change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(ReplaceAttribute), "Operation")) + change.AppendChild(attribute.encode()) + changes.AppendChild(change) + } + request.AppendChild(changes) + return request +} + +func NewModifyRequest( + dn string, +) *ModifyRequest { + return &ModifyRequest{ + dn: dn, + } +} + +func (l *Conn) Modify(modifyRequest *ModifyRequest) *Error { + messageID := l.nextMessageID() + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(modifyRequest.encode()) + + l.Debug.PrintPacket(packet) + + channel, err := l.sendMessage(packet) + if err != nil { + return err + } + if channel == nil { + return NewError(ErrorNetwork, errors.New("Could not send message")) + } + defer l.finishMessage(messageID) + + l.Debug.Printf("%d: waiting for response\n", messageID) + packet = <-channel + l.Debug.Printf("%d: got response %p\n", messageID, packet) + if packet == nil { + return NewError(ErrorNetwork, errors.New("Could not retrieve message")) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return NewError(ErrorDebugging, err.Err) + } + ber.PrintPacket(packet) + } + + if packet.Children[1].Tag == ApplicationModifyResponse { + resultCode, resultDescription := getLDAPResultCode(packet) + if resultCode != 0 { + return NewError(resultCode, errors.New(resultDescription)) + } + } else { + log.Printf("Unexpected Response: %d\n", packet.Children[1].Tag) + } + + l.Debug.Printf("%d: returning\n", messageID) + return nil +} diff --git a/search.go b/search.go index 83d1584..e92149f 100644 --- a/search.go +++ b/search.go @@ -1,269 +1,348 @@ // Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. - +// // File contains Search functionality +// +// https://tools.ietf.org/html/rfc4511 +// +// SearchRequest ::= [APPLICATION 3] SEQUENCE { +// baseObject LDAPDN, +// scope ENUMERATED { +// baseObject (0), +// singleLevel (1), +// wholeSubtree (2), +// ... }, +// derefAliases ENUMERATED { +// neverDerefAliases (0), +// derefInSearching (1), +// derefFindingBaseObj (2), +// derefAlways (3) }, +// sizeLimit INTEGER (0 .. maxInt), +// timeLimit INTEGER (0 .. maxInt), +// typesOnly BOOLEAN, +// filter Filter, +// attributes AttributeSelection } +// +// AttributeSelection ::= SEQUENCE OF selector LDAPString +// -- The LDAPString is constrained to +// -- in Section 4.5.1.8 +// +// Filter ::= CHOICE { +// and [0] SET SIZE (1..MAX) OF filter Filter, +// or [1] SET SIZE (1..MAX) OF filter Filter, +// not [2] Filter, +// equalityMatch [3] AttributeValueAssertion, +// substrings [4] SubstringFilter, +// greaterOrEqual [5] AttributeValueAssertion, +// lessOrEqual [6] AttributeValueAssertion, +// present [7] AttributeDescription, +// approxMatch [8] AttributeValueAssertion, +// extensibleMatch [9] MatchingRuleAssertion, +// ... } +// +// SubstringFilter ::= SEQUENCE { +// type AttributeDescription, +// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE { +// initial [0] AssertionValue, -- can occur at most once +// any [1] AssertionValue, +// final [2] AssertionValue } -- can occur at most once +// } +// +// MatchingRuleAssertion ::= SEQUENCE { +// matchingRule [1] MatchingRuleId OPTIONAL, +// type [2] AttributeDescription OPTIONAL, +// matchValue [3] AssertionValue, +// dnAttributes [4] BOOLEAN DEFAULT FALSE } +// +// package ldap import ( - "github.com/mmitton/asn1-ber" - "fmt" - "os" + "errors" + "fmt" + "github.com/tmfkams/asn1-ber" + "strings" ) const ( - ScopeBaseObject = 0 - ScopeSingleLevel = 1 - ScopeWholeSubtree = 2 + ScopeBaseObject = 0 + ScopeSingleLevel = 1 + ScopeWholeSubtree = 2 ) -var ScopeMap = map[ int ] string { - ScopeBaseObject : "Base Object", - ScopeSingleLevel : "Single Level", - ScopeWholeSubtree : "Whole Subtree", +var ScopeMap = map[int]string{ + ScopeBaseObject: "Base Object", + ScopeSingleLevel: "Single Level", + ScopeWholeSubtree: "Whole Subtree", } const ( - NeverDerefAliases = 0 - DerefInSearching = 1 - DerefFindingBaseObj = 2 - DerefAlways = 3 + NeverDerefAliases = 0 + DerefInSearching = 1 + DerefFindingBaseObj = 2 + DerefAlways = 3 ) -var DerefMap = map[ int ] string { - NeverDerefAliases : "NeverDerefAliases", - DerefInSearching : "DerefInSearching", - DerefFindingBaseObj : "DerefFindingBaseObj", - DerefAlways : "DerefAlways", +var DerefMap = map[int]string{ + NeverDerefAliases: "NeverDerefAliases", + DerefInSearching: "DerefInSearching", + DerefFindingBaseObj: "DerefFindingBaseObj", + DerefAlways: "DerefAlways", } type Entry struct { - DN string - Attributes []*EntryAttribute + DN string + Attributes []*EntryAttribute +} + +func (e *Entry) GetAttributeValues(Attribute string) []string { + for _, attr := range e.Attributes { + if attr.Name == Attribute { + return attr.Values + } + } + return []string{} +} + +func (e *Entry) GetAttributeValue(Attribute string) string { + values := e.GetAttributeValues(Attribute) + if len(values) == 0 { + return "" + } + return values[0] +} + +func (e *Entry) Print() { + fmt.Printf("DN: %s\n", e.DN) + for _, attr := range e.Attributes { + attr.Print() + } +} + +func (e *Entry) PrettyPrint(indent int) { + fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN) + for _, attr := range e.Attributes { + attr.PrettyPrint(indent + 2) + } } type EntryAttribute struct { - Name string - Values []string + Name string + Values []string } -type SearchResult struct { - Entries []*Entry - Referrals []string - Controls []Control +func (e *EntryAttribute) Print() { + fmt.Printf("%s: %s\n", e.Name, e.Values) } -func (e *Entry) GetAttributeValues( Attribute string ) []string { - for _, attr := range e.Attributes { - if attr.Name == Attribute { - return attr.Values - } - } +func (e *EntryAttribute) PrettyPrint(indent int) { + fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values) +} + +type SearchResult struct { + Entries []*Entry + Referrals []string + Controls []Control +} - return []string{ } +func (s *SearchResult) Print() { + for _, entry := range s.Entries { + entry.Print() + } } -func (e *Entry) GetAttributeValue( Attribute string ) string { - values := e.GetAttributeValues( Attribute ) - if len( values ) == 0 { - return "" - } - return values[ 0 ] +func (s *SearchResult) PrettyPrint(indent int) { + for _, entry := range s.Entries { + entry.PrettyPrint(indent) + } } type SearchRequest struct { - BaseDN string - Scope int - DerefAliases int - SizeLimit int - TimeLimit int - TypesOnly bool - Filter string - Attributes []string - Controls []Control + BaseDN string + Scope int + DerefAliases int + SizeLimit int + TimeLimit int + TypesOnly bool + Filter string + Attributes []string + Controls []Control +} + +func (s *SearchRequest) encode() (*ber.Packet, *Error) { + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, s.BaseDN, "Base DN")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(s.Scope), "Scope")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(s.DerefAliases), "Deref Aliases")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(s.SizeLimit), "Size Limit")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(s.TimeLimit), "Time Limit")) + request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, s.TypesOnly, "Types Only")) + // compile and encode filter + filterPacket, err := CompileFilter(s.Filter) + if err != nil { + return nil, err + } + request.AppendChild(filterPacket) + // encode attributes + attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") + for _, attribute := range s.Attributes { + attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute")) + } + request.AppendChild(attributesPacket) + return request, nil } func NewSearchRequest( - BaseDN string, - Scope, DerefAliases, SizeLimit, TimeLimit int, - TypesOnly bool, - Filter string, - Attributes []string, - Controls []Control, - ) (*SearchRequest) { - return &SearchRequest{ - BaseDN: BaseDN, - Scope: Scope, - DerefAliases: DerefAliases, - SizeLimit: SizeLimit, - TimeLimit: TimeLimit, - TypesOnly: TypesOnly, - Filter: Filter, - Attributes: Attributes, - Controls: Controls, - } + BaseDN string, + Scope, DerefAliases, SizeLimit, TimeLimit int, + TypesOnly bool, + Filter string, + Attributes []string, + Controls []Control, +) *SearchRequest { + return &SearchRequest{ + BaseDN: BaseDN, + Scope: Scope, + DerefAliases: DerefAliases, + SizeLimit: SizeLimit, + TimeLimit: TimeLimit, + TypesOnly: TypesOnly, + Filter: Filter, + Attributes: Attributes, + Controls: Controls, + } } -func (l *Conn) SearchWithPaging( SearchRequest *SearchRequest, PagingSize uint32 ) (*SearchResult, *Error) { - if SearchRequest.Controls == nil { - SearchRequest.Controls = make( []Control, 0 ) - } - - PagingControl := NewControlPaging( PagingSize ) - SearchRequest.Controls = append( SearchRequest.Controls, PagingControl ) - SearchResult := new( SearchResult ) - for { - result, err := l.Search( SearchRequest ) - if l.Debug { - fmt.Printf( "Looking for Paging Control...\n" ) - } - if err != nil { - return SearchResult, err - } - if result == nil { - return SearchResult, NewError( ErrorNetwork, os.NewError( "Packet not received" ) ) - } - - for _, entry := range result.Entries { - SearchResult.Entries = append( SearchResult.Entries, entry ) - } - for _, referral := range result.Referrals { - SearchResult.Referrals = append( SearchResult.Referrals, referral ) - } - for _, control := range result.Controls { - SearchResult.Controls = append( SearchResult.Controls, control ) - } - - if l.Debug { - fmt.Printf( "Looking for Paging Control...\n" ) - } - paging_result := FindControl( result.Controls, ControlTypePaging ) - if paging_result == nil { - PagingControl = nil - if l.Debug { - fmt.Printf( "Could not find paging control. Breaking...\n" ) - } - break - } - - cookie := paging_result.(*ControlPaging).Cookie - if len( cookie ) == 0 { - PagingControl = nil - if l.Debug { - fmt.Printf( "Could not find cookie. Breaking...\n" ) - } - break - } - PagingControl.SetCookie( cookie ) - } - - if PagingControl != nil { - if l.Debug { - fmt.Printf( "Abandoning Paging...\n" ) - } - PagingControl.PagingSize = 0 - l.Search( SearchRequest ) - } - - return SearchResult, nil +func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, *Error) { + if searchRequest.Controls == nil { + searchRequest.Controls = make([]Control, 0) + } + + pagingControl := NewControlPaging(pagingSize) + searchRequest.Controls = append(searchRequest.Controls, pagingControl) + searchResult := new(SearchResult) + for { + result, err := l.Search(searchRequest) + l.Debug.Printf("Looking for Paging Control...\n") + if err != nil { + return searchResult, err + } + if result == nil { + return searchResult, NewError(ErrorNetwork, errors.New("Packet not received")) + } + + for _, entry := range result.Entries { + searchResult.Entries = append(searchResult.Entries, entry) + } + for _, referral := range result.Referrals { + searchResult.Referrals = append(searchResult.Referrals, referral) + } + for _, control := range result.Controls { + searchResult.Controls = append(searchResult.Controls, control) + } + + l.Debug.Printf("Looking for Paging Control...\n") + pagingResult := FindControl(result.Controls, ControlTypePaging) + if pagingResult == nil { + pagingControl = nil + l.Debug.Printf("Could not find paging control. Breaking...\n") + break + } + + cookie := pagingResult.(*ControlPaging).Cookie + if len(cookie) == 0 { + pagingControl = nil + l.Debug.Printf("Could not find cookie. Breaking...\n") + break + } + pagingControl.SetCookie(cookie) + } + + if pagingControl != nil { + l.Debug.Printf("Abandoning Paging...\n") + pagingControl.PagingSize = 0 + l.Search(searchRequest) + } + + return searchResult, nil } -func (l *Conn) Search( SearchRequest *SearchRequest ) (*SearchResult, *Error) { - messageID := l.nextMessageID() - - packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" ) - packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) ) - searchRequest := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request" ) - searchRequest.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, SearchRequest.BaseDN, "Base DN" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(SearchRequest.Scope), "Scope" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(SearchRequest.DerefAliases), "Deref Aliases" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(SearchRequest.SizeLimit), "Size Limit" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(SearchRequest.TimeLimit), "Time Limit" ) ) - searchRequest.AppendChild( ber.NewBoolean( ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, SearchRequest.TypesOnly, "Types Only" ) ) - filterPacket, err := CompileFilter( SearchRequest.Filter ) - if err != nil { - return nil, err - } - searchRequest.AppendChild( filterPacket ) - attributesPacket := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes" ) - for _, attribute := range SearchRequest.Attributes { - attributesPacket.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute" ) ) - } - searchRequest.AppendChild( attributesPacket ) - packet.AppendChild( searchRequest ) - if SearchRequest.Controls != nil { - packet.AppendChild( encodeControls( SearchRequest.Controls ) ) - } - - if l.Debug { - ber.PrintPacket( packet ) - } - - channel, err := l.sendMessage( packet ) - if err != nil { - return nil, err - } - if channel == nil { - return nil, NewError( ErrorNetwork, os.NewError( "Could not send message" ) ) - } - defer l.finishMessage( messageID ) - - result := &SearchResult{ - Entries: make( []*Entry, 0 ), - Referrals: make( []string, 0 ), - Controls: make( []Control, 0 ) } - - foundSearchResultDone := false - for !foundSearchResultDone { - if l.Debug { - fmt.Printf( "%d: waiting for response\n", messageID ) - } - packet = <-channel - if l.Debug { - fmt.Printf( "%d: got response %p\n", messageID, packet ) - } - if packet == nil { - return nil, NewError( ErrorNetwork, os.NewError( "Could not retrieve message" ) ) - } - - if l.Debug { - if err := addLDAPDescriptions( packet ); err != nil { - return nil, NewError( ErrorDebugging, err ) - } - ber.PrintPacket( packet ) - } - - switch packet.Children[ 1 ].Tag { - case 4: - entry := new( Entry ) - entry.DN = packet.Children[ 1 ].Children[ 0 ].Value.(string) - for _, child := range packet.Children[ 1 ].Children[ 1 ].Children { - attr := new( EntryAttribute ) - attr.Name = child.Children[ 0 ].Value.(string) - for _, value := range child.Children[ 1 ].Children { - attr.Values = append( attr.Values, value.Value.(string) ) - } - entry.Attributes = append( entry.Attributes, attr ) - } - result.Entries = append( result.Entries, entry ) - case 5: - result_code, result_description := getLDAPResultCode( packet ) - if result_code != 0 { - return result, NewError( result_code, os.NewError( result_description ) ) - } - if len( packet.Children ) == 3 { - for _, child := range packet.Children[ 2 ].Children { - result.Controls = append( result.Controls, DecodeControl( child ) ) - } - } - foundSearchResultDone = true - case 19: - result.Referrals = append( result.Referrals, packet.Children[ 1 ].Children[ 0 ].Value.(string) ) - } - } - if l.Debug { - fmt.Printf( "%d: returning\n", messageID ) - } - - return result, nil +func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, *Error) { + messageID := l.nextMessageID() + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID")) + // encode search request + encodedSearchRequest, err := searchRequest.encode() + if err != nil { + return nil, err + } + packet.AppendChild(encodedSearchRequest) + // encode search controls + if searchRequest.Controls != nil { + packet.AppendChild(encodeControls(searchRequest.Controls)) + } + + l.Debug.PrintPacket(packet) + + channel, err := l.sendMessage(packet) + if err != nil { + return nil, err + } + if channel == nil { + return nil, NewError(ErrorNetwork, errors.New("Could not send message")) + } + defer l.finishMessage(messageID) + + result := &SearchResult{ + Entries: make([]*Entry, 0), + Referrals: make([]string, 0), + Controls: make([]Control, 0)} + + foundSearchResultDone := false + for !foundSearchResultDone { + l.Debug.Printf("%d: waiting for response\n", messageID) + packet = <-channel + l.Debug.Printf("%d: got response %p\n", messageID, packet) + if packet == nil { + return nil, NewError(ErrorNetwork, errors.New("Could not retrieve message")) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return nil, NewError(ErrorDebugging, err.Err) + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case 4: + entry := new(Entry) + entry.DN = packet.Children[1].Children[0].Value.(string) + for _, child := range packet.Children[1].Children[1].Children { + attr := new(EntryAttribute) + attr.Name = child.Children[0].Value.(string) + for _, value := range child.Children[1].Children { + attr.Values = append(attr.Values, value.Value.(string)) + } + entry.Attributes = append(entry.Attributes, attr) + } + result.Entries = append(result.Entries, entry) + case 5: + resultCode, resultDescription := getLDAPResultCode(packet) + if resultCode != 0 { + return result, NewError(resultCode, errors.New(resultDescription)) + } + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + result.Controls = append(result.Controls, DecodeControl(child)) + } + } + foundSearchResultDone = true + case 19: + result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string)) + } + } + l.Debug.Printf("%d: returning\n", messageID) + return result, nil } -- cgit v1.2.3