#
# SchoolTool - common information systems platform for school administration
# Copyright (c) 2007 Shuttleworth Foundation
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
"""
LDAP authentication plugin.
"""
import datetime
import ldap
import ldap.filter
import pprint
import re

from zope.authentication.interfaces import ILoginPassword
from zope.authentication.interfaces import IAuthenticatedGroup, IEveryoneGroup
from zope.component import getUtility, queryUtility, queryAdapter, adapter
from zope.interface import implements, implementer
from zope.session.interfaces import ISession
from zope.security.checker import ProxyFactory
from zope.security.proxy import removeSecurityProxy

from schooltool.app.interfaces import ISchoolToolApplication
from schooltool.app.security import Principal
from schooltool.app.security import PersonContainerAuthenticationPlugin
from schooltool.contact.interfaces import IContact
from schooltool.group.interfaces import IGroupContainer
from schooltool.ldap.interfaces import ILDAPClient, ILDAPPersonsClient
from schooltool.ldap.interfaces import ILDAPPersonsConfig, ILDAPConfig
from schooltool.ldap.interfaces import ISchoolToolLDAPObject
from schooltool.ldap.interfaces import ISchoolToolLDAPObjectPart
from schooltool.ldap.config import LDAP_SCOPES, decode_ldap_query
from schooltool.person.interfaces import IPersonFactory
from schooltool.person.interfaces import IPasswordWriter
from schooltool.schoolyear.interfaces import ISchoolYearContainer

class BadLDAPCredentials(ValueError):
    pass


class LDAPServerDown(Exception):
    pass


class LDAPNotConnected(Exception):
    pass


class LDAPResult(object):

    __result = None
    __keys = None
    dn = None # Distinguished name

    def __init__(self, result):
        for k, v in result[1].items():
            setattr(self, k, v)
        self.__result = dict(result[1])
        self.__keys = dict([(k.lower(), k) for k in self.__result])
        self.dn = result[0]

    def __getattr__(self, name):
        iname = name.lower()
        if iname in self.__keys:
            return self.__result[self.__keys[iname]]
        try:
            return object.__getattribute__(self, name)
        except AttributeError:
            return None

    def __str__(self):
        return pprint.pformat(self.__result)

    def __repr__(self):
        attrs = ('objectclass', 'cn', 'o', 'ou', 'title')
        values = [getattr(self, a, None) for a in attrs]
        return "LDAPResult(%s)" % ', '.join(
            ['%s=%r' % (a, v)
             for (a, v) in zip(attrs, values) if v])


_bind_default = object()


class LDAPClient(object):
    implements(ILDAPClient)

    uri = 'ldap://127.0.0.1:389'
    connection = None
    bind_dn = None
    bind_password = None
    timeout = -1

    def __init__(self, uri=None, dn=None, password=None, timeout=10):
        self.uri = uri and uri.strip() or uri
        self.bind_dn = dn
        self.bind_password = password
        self.timeout = timeout

    @property
    def configured(self):
        return bool(self.uri)

    def connect(self, dn=_bind_default, password=_bind_default):
        dn = dn if dn is not _bind_default else self.bind_dn
        password = password if password is not _bind_default else self.bind_password

        if not self.uri:
            raise LDAPServerDown(self.uri)

        try:
            new_connection = ldap.initialize(self.uri)
            new_connection.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3)
            new_connection.set_option(ldap.OPT_NETWORK_TIMEOUT, self.timeout)
        except ldap.LDAPError:
            raise

        if dn is not None:
            if isinstance(dn, unicode):
                dn = dn.encode('UTF-8')
            try:
                new_connection.bind_s(who=dn, cred=password)
            except (ldap.SERVER_DOWN, ldap.TIMEOUT):
                raise LDAPServerDown(self.uri)
            except ldap.INVALID_CREDENTIALS:
                raise BadLDAPCredentials(dn, '*' * len(password))

        self.close()
        self.connection = new_connection

    def search(self, query, scope=None, filter=None, **kw):
        """See ldap.LDAPObject.search_ext_s for more keyword options,
        like timeout.
        """
        if self.connection is None:
            self.connect()
        if isinstance(query, unicode):
            query.encode('UTF-8')
        base, _scope, _filter = decode_ldap_query(query)

        scope = scope if scope is not None else _scope
        scope = LDAP_SCOPES.get(scope, LDAP_SCOPES['sub'])

        if filter is None:
            filter = _filter
        elif (isinstance(filter, dict) and filter):
            items = sorted(filter.items())
            filter = ldap.filter.filter_format('%s=%s', items.pop())
            while items:
                filter = '&(%s)(%s)' % (
                    ldap.filter.filter_format('%s=%s', items.pop()), filter)
        if not filter:
            filter = 'objectClass=*'
        if not filter.startswith('('):
            filter = '(%s)' % filter

        result = self.connection.search_ext_s(
            base, scope, filterstr=filter, **kw)
        if result is None:
            return None
        result = [LDAPResult(r) for r in result]
        return result

    def close(self):
        if self.connection is None:
            return
        self.connection.unbind()
        self.connection = None

    def ping(self, uri=None):
        if uri is None:
            uri = self.uri
        now = datetime.datetime.utcnow()
        if not uri:
            raise LDAPServerDown(uri)

        try:
            test = ldap.initialize(uri)
        except ldap.LDAPError:
            raise LDAPServerDown(uri)
        test.set_option(ldap.OPT_PROTOCOL_VERSION, ldap.VERSION3)
        test.set_option(ldap.OPT_NETWORK_TIMEOUT, self.timeout)
        try:
            test.whoami_s()
        except (ldap.SERVER_DOWN, ldap.TIMEOUT):
            raise LDAPServerDown(uri)
        test.unbind()
        return datetime.datetime.utcnow() - now


class LDAPPersons(LDAPClient):
    implements(ILDAPPersonsClient)

    # List of (username attribute, query) tuples, like:
    # 'uid', 'dc=localhost?sub?(objectClass=inetOrgPerson')
    queries = ()

    # List of queries to get groups, like:
    # 'dc=localhost?sub?(objectClass=posixGroup')
    groupQueries = ()

    # List of (schooltool year, schooltool group name, posix group id)
    posixGroups = ()

    def __init__(self, *args, **kw):
        self.queries = tuple(kw.pop('queries', ()))
        self.groupQueries = tuple(kw.pop('groupQueries', ()))
        self.posixGroups = tuple(kw.pop('posixGroups', ()))
        LDAPClient.__init__(self, *args, **kw)

    @property
    def configured(self):
        configured = super(LDAPPersons, self).configured
        return configured and bool(self.queries)

    def find(self, username):
        username = ldap.filter.escape_filter_chars(username)
        for attr, query in self.queries:
            userfilter = ldap.filter.filter_format('%s=%s', (attr, username))
            base, scope, filter = decode_ldap_query(query)
            if filter is None:
                filter = userfilter
            else:
                filter = '&(%s)(%s)' % (filter, userfilter)
            users = self.search(base, scope, filter=filter)
            if len(users) == 1:
                yield users[0]

    def login(self, username, password):
        for user in self.find(username):
            self.connect(user.dn, password)
            return True
        return False

    def __iter__(self):
        for attr, query in self.queries:
            users = self.search(query)
            for user in users:
                yield user


class LDAPAuthenticationPlugin(PersonContainerAuthenticationPlugin):

    ldap_session_name = 'schooltool.ldap'
    ldap_params = None

    def newClient(self, **kw):
        app = ISchoolToolApplication(None)
        config = queryAdapter(app, ILDAPConfig, default=None)
        if config is None:
            return None
        client = queryAdapter(config, ILDAPPersonsClient, default=None)
        return client

    @property
    def enabled(self):
        client = self.newClient()
        if (client is None or not client.configured):
            return False
        if client is not None:
            try:
                client.ping()
            except (LDAPServerDown, ldap.LDAPError):
                return False
        return True

    def checkPlainTextPassword(self, username, password):
        # XXX: do not allow to log in with invalid schooltool usernames!
        client = self.newClient()
        if client is None or not username:
            return False
        try:
            password = password or ''
            return client.login(username, password)
        except (BadLDAPCredentials, LDAPServerDown, ldap.LDAPError):
            return False
        finally:
            client.close()
        return False

    def authenticate(self, request):
        """Identify a principal for request.

        Retrieves the username and password from the session.
        """
        session = ISession(request)[self.session_name]
        if 'username' in session:
            self.restorePOSTData(request)
            principal = None
            if self.enabled:
                principal = self.getPrincipal('sb.person.' + session['username'])
            if principal is None:
                return PersonContainerAuthenticationPlugin.getPrincipal(
                    self, 'sb.person.' + session['username'])
            else:
                return principal

        # Try HTTP basic too
        creds = ILoginPassword(request, None)
        if creds:
            login = creds.getLogin()
            password = creds.getPassword()
            if self.checkPlainTextPassword(login, password):
                principal = None
                if self.enabled:
                    principal = self.getPrincipal('sb.person.' + login)
                if principal is None:
                    return PersonContainerAuthenticationPlugin.getPrincipal(
                        self, 'sb.person.' + session['username'])
                else:
                    return principal

    def setCredentials(self, request, username, password):
        if not self.checkPlainTextPassword(username, password):
            # Fall-back to default person auth.
            return PersonContainerAuthenticationPlugin.setCredentials(
                self, request, username, password)

        session = ISession(request)[self.session_name]
        session['username'] = username
        principal = self.getPrincipal('sb.person.' + username, update_person=True)
        unproxied_person = removeSecurityProxy(principal._person)
        if (unproxied_person is not None and
            not unproxied_person.checkPassword(password)):
            password_writer = IPasswordWriter(unproxied_person, None)
            if password_writer is not None:
                password_writer.setPassword(password)

    def clearCredentials(self, request):
        session = ISession(request)[self.session_name]
        try:
            del session['username']
        except KeyError:
            pass

    def getPrincipal(self, id, update_person=False):
        """Get principal meta-data.

        Returns principals for groups and persons.
        """
        if not id.startswith(self.person_prefix):
            return None
        app = ISchoolToolApplication(None)
        username = id[len(self.person_prefix):]

        client = self.newClient()
        if client is None:
            return None

        try:
            results = list(client.find(username))
        except (LDAPServerDown, ldap.LDAPError):
            return None

        for result in results:
            ldap_person = LDAPPerson(result, client)
            ldap_person.update()

            if username not in app['persons']:
                ldap_person.create(username)

            person = app['persons'][username]
            principal = Principal(id, person.title,
                                  person=ProxyFactory(person))
            if update_person:
                ldap_person.update_groups(person)
                ldap_person.update_attrs(person)

            for group in person.groups:
                group_principal_id = self.group_prefix + group.__name__
                principal.groups.append(group_principal_id)
            authenticated = queryUtility(IAuthenticatedGroup)
            if authenticated:
                principal.groups.append(authenticated.id)
            everyone = queryUtility(IEveryoneGroup)
            if everyone:
                principal.groups.append(everyone.id)
            return principal
        return None


@adapter(ILDAPPersonsConfig)
@implementer(ILDAPPersonsClient)
def makeLDAPPersonsClient(config):
    tokenize = re.compile('\s')
    queries = config.queries or ''
    queries = [l.strip() for l in queries.splitlines()]
    queries = tuple([
            tuple([token.strip() for token in tokenize.split(line, 1)])
            for line in queries
            if line])


    groupQueries = config.groupQueries or ''
    groupQueries = filter(None, [l.strip() for l in groupQueries.splitlines()])

    posixGroups = config.posixGroups or ''
    groups = [l.strip() for l in posixGroups.strip().splitlines()]
    groups = tuple([
            tuple([i.strip() for i in line.split(',')])
            for line in groups
            if line])
    client = LDAPPersons(
        uri=config.uri, dn=config.bind_dn, password=config.bind_password,
        queries=queries, groupQueries=groupQueries, posixGroups=groups,
        timeout=config.timeout)
    return client


@adapter(ILDAPConfig)
@implementer(ILDAPClient)
def makeLDAPClient(config):
    client = LDAPClient(uri=config.uri,
                        dn=config.bind_dn, password=config.bind_password)
    return client


class SchoolToolLDAPObjectPart(object):
    implements(ISchoolToolLDAPObjectPart)

    target = None
    data = None

    def __init__(self, target):
        self.target = target
        self.data = target.ldap_result

    def update(self):
        pass


class SchoolToolLDAPObject(object):
    implements(ISchoolToolLDAPObject)

    ldap_result = None
    client = None

    def __init__(self, ldap_result, client):
        self.ldap_result = ldap_result
        self.client = client

    def update(self):
        parts = filter(None, [
                queryAdapter(self, ISchoolToolLDAPObjectPart,
                             default=None, name=ldap_class.lower())
                for ldap_class in self.ldap_result.objectClass
                ])
        for part in parts:
            part.update()


class LDAPPerson(SchoolToolLDAPObject):

    person_attrs = None
    contact_attrs = None
    cached_groups = None
    group_member_id = None

    def update(self):
        self.person_attrs = {}
        self.contact_attrs = {}
        self.cached_groups = None
        super(LDAPPerson, self).update()

    def create(self, username):
        app = ISchoolToolApplication(None)
        factory = getUtility(IPersonFactory)
        person = factory(
            username,
            self.person_attrs.get('first_name'),
            self.person_attrs.get('last_name'))

        app['persons'][username] = person
        self.update_attrs(person)

    @property
    def groups(self):
        if self.cached_groups is not None:
            return self.cached_groups
        if self.group_member_id is None:
            return ()

        filter_groups = [gid.lower()
                         for _i, _i, gid in self.client.posixGroups]

        result = []
        my_id = self.group_member_id.lower()
        for query in self.client.groupQueries:
            candidates = [LDAPGroup(c, self.client)
                          for c in self.client.search(query)]
            for candidate in candidates:
                candidate.update()
                if (not candidate.group_id or
                    candidate.group_id.lower() not in filter_groups):
                    continue
                members = [m.lower() for m in candidate.member_ids]
                if my_id in members:
                    result.append(candidate)
        self.cached_groups = tuple(result)
        return self.cached_groups

    def update_attrs(self, person):
        for k, v in self.person_attrs.items():
            if getattr(person, k, None) != v:
                setattr(person, k, v)

        contact_info = IContact(person)

        for k, v in self.contact_attrs.items():
            if getattr(contact_info, k, None) != v:
                setattr(contact_info, k, v)

    def update_groups(self, person):
        my_groups = set([group.group_id.lower() for group in self.groups])
        by_year = {}
        for year, st_group, posix_id in self.client.posixGroups:
            if posix_id.lower() not in my_groups:
                continue
            if year not in by_year:
                by_year[year] = []
            by_year[year].append(st_group)

        app = ISchoolToolApplication(None)
        syc = ISchoolYearContainer(app)

        unproxied_person = removeSecurityProxy(person)

        for year_id, group_ids in by_year.items():
            if year_id == '':
                sy = syc.getActiveSchoolYear()
            else:
                sy = syc.get(year_id)
            if sy is None:
                continue
            gc = IGroupContainer(sy)
            for group_id in group_ids:
                group = gc.get(group_id)
                if group is None:
                    continue
                unproxied_group = removeSecurityProxy(group)
                if unproxied_group not in unproxied_person.groups:
                    person.groups.add(unproxied_group)


class PersonLDAPPersonPart(SchoolToolLDAPObjectPart):

    all_attrs = (
        'sn', # surname
        'cn', # common name
        'userPassword', 'telephoneNumber', 'seeAlso', 'description')

    def __init__(self, *args, **kw):
        SchoolToolLDAPObjectPart.__init__(self, *args, **kw)

    def oneline(self, t):
        if not t:
            return u''
        return u' '.join(s.decode('UTF-8') for s in t).strip()

    def update(self):
        super(PersonLDAPPersonPart, self).update()
        self.target.person_attrs.update({
            'first_name': self.oneline(self.data.cn),
            'last_name': self.oneline(self.data.sn),
            })
        self.target.contact_attrs.update({
            'work_phone': self.oneline(self.data.telephoneNumber) or None,
            })


class PersonLDAPOrganizationalPersonPart(PersonLDAPPersonPart):

    all_attrs = PersonLDAPPersonPart.all_attrs + (
        'title',
        'x121Address', 'registeredAddress', 'destinationIndicator',
        'preferredDeliveryMethod', 'telexNumber', 'teletexTerminalIdentifier',
        'telephoneNumber',
        'internationaliSDNNumber',
        'facsimileTelephoneNumber',
        'street',
        'postOfficeBox', 'postalCode', 'postalAddress',
        'physicalDeliveryOfficeName',
        'ou', # organizational unit
        'st', # state or province
        'l', # town or city
        )

    def update(self):
        super(PersonLDAPOrganizationalPersonPart, self).update()
        self.target.person_attrs.update({
            })
        address = ' '.join(self.oneline(self.data.postalAddress).replace('$', ' ').split())
        address = address or self.oneline(self.data.street)
        self.target.person_attrs.update({
            'prefix': self.oneline(self.data.title) or None,
            })
        self.target.contact_attrs.update({
            'state': self.oneline(self.data.st) or None,
            'city': self.oneline(self.data.l) or None,
            'address_line_1': address,
            'postal_code': self.oneline(self.data.postalCode) or None,
            })


class PersonLDAPInetOrgPersonPart(PersonLDAPOrganizationalPersonPart):
    all_attrs = PersonLDAPOrganizationalPersonPart.all_attrs + (
        'audio', 'businessCategory', 'carLicense', 'departmentNumber',
        'displayName',
        'employeeNumber', 'employeeType',
        'givenName',
        'homePhone',
        'homePostalAddress',
        'initials',
        'jpegPhoto', # we could use this
        'labeledURI',
        'mail',
        'manager',
        'mobile',
        'o', # organization name
        'pager', 'photo',
        'roomNumber', 'secretary',
        'uid',
        'userCertificate', 'x500uniqueIdentifier',
        'preferredLanguage', # maybe we can use this
        'userSMIMECertificate', 'userPKCS12')

    def update(self):
        super(PersonLDAPInetOrgPersonPart, self).update()
        self.target.person_attrs.update({
            'username': (self.oneline(self.data.uid) or
                         self.target.person_attrs.get('username')),
            'first_name': (self.oneline(self.data.givenName) or
                           self.target.person_attrs.get('first_name')),
            })
        address = ' '.join(self.oneline(self.data.postalAddress).replace('$', ' ').split())
        address = address or self.oneline(self.data.street)
        self.target.contact_attrs.update({
            'home_phone': self.oneline(self.data.homePhone) or None,
            'mobile_phone': self.oneline(self.data.mobile) or None,
            'email': self.oneline(self.data.mail) or None,
            })


class PersonLDAPPosixAccountPart(SchoolToolLDAPObjectPart):

    @property
    def registered_groups(self):
        groups = dict([
            (posix_id, (year_id, group_id))
            for posix_id, year_id, group_id in self.target.client.posixGroups])
        return groups

    def update(self):
        super(PersonLDAPPosixAccountPart, self).update()
        self.target.group_member_id = self.data.uid and self.data.uid[0] or None


class LDAPGroup(SchoolToolLDAPObject):

    member_ids = None
    group_id = None

    def update(self):
        self.member_ids = []
        self.group_id = None
        super(LDAPGroup, self).update()

    def __repr__(self):
        return "%s(%r)" % (self.__class__.__name__,
                           self.group_id)


class LDAPPosixGroupPart(SchoolToolLDAPObjectPart):

    def update(self):
        super(LDAPPosixGroupPart, self).update()
        self.target.group_id = self.data.gidNumber[0] if self.data.gidNumber else ''
        self.target.member_ids += list(self.data.memberUID or ())
