ldap_service.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. """LDAP authentication service for BamBuddy (#794).
  2. Supports:
  3. - LDAP bind authentication (simple bind with user's credentials)
  4. - StartTLS, LDAPS, and plaintext connections
  5. - User search with configurable filter
  6. - Group membership resolution for role mapping
  7. """
  8. from __future__ import annotations
  9. import json
  10. import logging
  11. from dataclasses import dataclass
  12. from ldap3 import ALL, SUBTREE, Connection, Server, Tls
  13. logger = logging.getLogger(__name__)
  14. @dataclass
  15. class LDAPUserInfo:
  16. """User information retrieved from LDAP after successful authentication."""
  17. username: str
  18. email: str | None
  19. display_name: str | None
  20. groups: list[str] # List of group DNs the user belongs to
  21. @dataclass
  22. class LDAPConfig:
  23. """LDAP configuration parsed from settings."""
  24. server_url: str
  25. bind_dn: str
  26. bind_password: str
  27. search_base: str
  28. user_filter: str # e.g. "(sAMAccountName={username})"
  29. security: str # "none", "starttls", "ldaps"
  30. group_mapping: dict[str, str] # LDAP group DN -> BamBuddy group name
  31. auto_provision: bool
  32. ca_cert_path: str # Path to CA certificate file (empty = skip verification)
  33. default_group: str # Fallback BamBuddy group assigned when user has no mapped groups (empty = no fallback)
  34. def parse_ldap_config(settings: dict[str, str]) -> LDAPConfig | None:
  35. """Parse LDAP config from settings key-value pairs. Returns None if LDAP not enabled."""
  36. if settings.get("ldap_enabled", "false").lower() != "true":
  37. return None
  38. server_url = settings.get("ldap_server_url", "").strip()
  39. if not server_url:
  40. return None
  41. group_mapping_raw = settings.get("ldap_group_mapping", "")
  42. try:
  43. group_mapping = json.loads(group_mapping_raw) if group_mapping_raw else {}
  44. except json.JSONDecodeError:
  45. group_mapping = {}
  46. return LDAPConfig(
  47. server_url=server_url,
  48. bind_dn=settings.get("ldap_bind_dn", "").strip(),
  49. bind_password=settings.get("ldap_bind_password", ""),
  50. search_base=settings.get("ldap_search_base", "").strip(),
  51. user_filter=settings.get("ldap_user_filter", "(sAMAccountName={username})").strip(),
  52. security=settings.get("ldap_security", "starttls").strip(),
  53. group_mapping=group_mapping if isinstance(group_mapping, dict) else {},
  54. auto_provision=settings.get("ldap_auto_provision", "false").lower() == "true",
  55. ca_cert_path=settings.get("ldap_ca_cert_path", "").strip(),
  56. default_group=settings.get("ldap_default_group", "").strip(),
  57. )
  58. def _create_server(config: LDAPConfig) -> Server:
  59. """Create an ldap3 Server instance from config.
  60. Always uses TLS — either LDAPS (TLS from start) or StartTLS (upgrade after connect).
  61. Plaintext LDAP is not supported.
  62. """
  63. import ssl
  64. use_ssl = config.security == "ldaps" or config.server_url.startswith("ldaps://")
  65. if config.ca_cert_path:
  66. tls = Tls(validate=ssl.CERT_REQUIRED, ca_certs_file=config.ca_cert_path)
  67. else:
  68. tls = Tls(validate=ssl.CERT_NONE)
  69. return Server(config.server_url, use_ssl=use_ssl, tls=tls, get_info=ALL, connect_timeout=10)
  70. def authenticate_ldap_user(config: LDAPConfig, username: str, password: str) -> LDAPUserInfo | None:
  71. """Authenticate a user via LDAP bind.
  72. 1. Bind with service account to search for the user DN
  73. 2. Attempt bind with the user's DN and provided password
  74. 3. On success, retrieve user attributes and group memberships
  75. Returns LDAPUserInfo on success, None on failure.
  76. """
  77. if not password:
  78. return None
  79. server = _create_server(config)
  80. # Step 1: Service account bind + user search
  81. try:
  82. service_conn = Connection(
  83. server,
  84. user=config.bind_dn,
  85. password=config.bind_password,
  86. auto_bind=False,
  87. raise_exceptions=True,
  88. read_only=True,
  89. )
  90. service_conn.open()
  91. if config.security == "starttls" and not config.server_url.startswith("ldaps://"):
  92. service_conn.start_tls()
  93. service_conn.bind()
  94. except Exception as e:
  95. logger.warning("LDAP service account bind failed: %s", e)
  96. return None
  97. try:
  98. # Search for the user
  99. search_filter = config.user_filter.replace("{username}", _ldap_escape(username))
  100. service_conn.search(
  101. search_base=config.search_base,
  102. search_filter=search_filter,
  103. search_scope=SUBTREE,
  104. attributes=["*"],
  105. )
  106. if not service_conn.entries:
  107. logger.info("LDAP user not found: %s", username)
  108. return None
  109. user_entry = service_conn.entries[0]
  110. user_dn = str(user_entry.entry_dn)
  111. # Step 2: Bind as the user to verify password
  112. try:
  113. user_conn = Connection(
  114. server,
  115. user=user_dn,
  116. password=password,
  117. auto_bind=False,
  118. raise_exceptions=True,
  119. read_only=True,
  120. )
  121. user_conn.open()
  122. if config.security == "starttls" and not config.server_url.startswith("ldaps://"):
  123. user_conn.start_tls()
  124. user_conn.bind()
  125. user_conn.unbind()
  126. except Exception as e:
  127. logger.info("LDAP bind failed for user %s: %s", username, e)
  128. return None
  129. # Step 3: Extract user info
  130. email = str(user_entry.mail) if hasattr(user_entry, "mail") and user_entry.mail else None
  131. display_name = (
  132. str(user_entry.displayName) if hasattr(user_entry, "displayName") and user_entry.displayName else None
  133. )
  134. # Collect groups from memberOf attribute (Active Directory / groupOfNames)
  135. groups = (
  136. [str(g) for g in user_entry.memberOf] if hasattr(user_entry, "memberOf") and user_entry.memberOf else []
  137. )
  138. # Also search for POSIX groups (memberUid-based) using the service account
  139. canonical_username = username
  140. if hasattr(user_entry, "sAMAccountName") and user_entry.sAMAccountName:
  141. canonical_username = str(user_entry.sAMAccountName)
  142. elif hasattr(user_entry, "uid") and user_entry.uid:
  143. canonical_username = str(user_entry.uid)
  144. posix_filter = f"(&(objectClass=posixGroup)(memberUid={_ldap_escape(canonical_username)}))"
  145. service_conn.search(
  146. search_base=config.search_base,
  147. search_filter=posix_filter,
  148. search_scope=SUBTREE,
  149. attributes=["cn"],
  150. )
  151. for entry in service_conn.entries:
  152. groups.append(str(entry.entry_dn))
  153. # POSIX primary group: user's gidNumber matches a posixGroup's gidNumber.
  154. # Standard Unix semantics treat this as full group membership, so we need
  155. # to resolve it to a group DN alongside the memberUid results.
  156. if hasattr(user_entry, "gidNumber") and user_entry.gidNumber:
  157. primary_gid = str(user_entry.gidNumber)
  158. primary_filter = f"(&(objectClass=posixGroup)(gidNumber={_ldap_escape(primary_gid)}))"
  159. service_conn.search(
  160. search_base=config.search_base,
  161. search_filter=primary_filter,
  162. search_scope=SUBTREE,
  163. attributes=["cn"],
  164. )
  165. for entry in service_conn.entries:
  166. groups.append(str(entry.entry_dn))
  167. # Dedupe group DNs (user may be in a group via both memberUid and primary gidNumber).
  168. # Case-insensitive comparison — LDAP DNs are case-insensitive by spec.
  169. seen_lower: set[str] = set()
  170. deduped_groups: list[str] = []
  171. for g in groups:
  172. key = g.lower()
  173. if key not in seen_lower:
  174. seen_lower.add(key)
  175. deduped_groups.append(g)
  176. groups = deduped_groups
  177. logger.info(
  178. "LDAP authentication successful for user: %s (DN: %s, groups: %d)", canonical_username, user_dn, len(groups)
  179. )
  180. return LDAPUserInfo(
  181. username=canonical_username,
  182. email=email,
  183. display_name=display_name,
  184. groups=groups,
  185. )
  186. finally:
  187. service_conn.unbind()
  188. def resolve_group_mapping(ldap_groups: list[str], group_mapping: dict[str, str]) -> list[str]:
  189. """Map LDAP group DNs to BamBuddy group names.
  190. Returns list of BamBuddy group names that the user should be added to.
  191. Comparison is case-insensitive on the LDAP group DN.
  192. """
  193. if not group_mapping:
  194. return []
  195. # Build case-insensitive lookup
  196. mapping_lower = {k.lower(): v for k, v in group_mapping.items()}
  197. result = []
  198. for ldap_group in ldap_groups:
  199. bambuddy_group = mapping_lower.get(ldap_group.lower())
  200. if bambuddy_group:
  201. result.append(bambuddy_group)
  202. return result
  203. def test_ldap_connection(config: LDAPConfig) -> tuple[bool, str]:
  204. """Test LDAP connection and service account bind.
  205. Returns (success, message).
  206. """
  207. try:
  208. server = _create_server(config)
  209. conn = Connection(
  210. server,
  211. user=config.bind_dn,
  212. password=config.bind_password,
  213. auto_bind=False,
  214. raise_exceptions=True,
  215. read_only=True,
  216. )
  217. conn.open()
  218. if config.security == "starttls" and not config.server_url.startswith("ldaps://"):
  219. conn.start_tls()
  220. conn.bind()
  221. # Try a search to verify search base
  222. conn.search(
  223. search_base=config.search_base,
  224. search_filter="(objectClass=*)",
  225. search_scope=SUBTREE,
  226. size_limit=1,
  227. )
  228. conn.unbind()
  229. return True, "LDAP connection successful"
  230. except Exception as e:
  231. return False, f"LDAP connection failed: {e}"
  232. def _ldap_escape(value: str) -> str:
  233. """Escape special characters in LDAP search filter values (RFC 4515)."""
  234. replacements = {
  235. "\\": "\\5c",
  236. "*": "\\2a",
  237. "(": "\\28",
  238. ")": "\\29",
  239. "\x00": "\\00",
  240. }
  241. for char, escaped in replacements.items():
  242. value = value.replace(char, escaped)
  243. return value