Source code for yagmail.async_core.aio

import asyncio
import socket
import ssl
from smtplib import (
    SMTPAuthenticationError,
    SMTPConnectError,
    SMTPDataError,
    SMTPException,
    SMTPHeloError,
    SMTPRecipientsRefused,
    SMTPSenderRefused,
    SMTPServerDisconnected,
)
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import yagmail
from yagmail.dkim import DKIM
from yagmail.headers import AddressInput


async def upgrade_to_tls(
    reader: asyncio.StreamReader,
    writer: asyncio.StreamWriter,
    ssl_context: ssl.SSLContext,
    server_hostname: str
) -> None:
    """Upgrade an existing stream connection to TLS, compatible with Python 3.8+."""
    if hasattr(writer, "start_tls"):
        await writer.start_tls(ssl_context, server_hostname=server_hostname)
    else:
        # Fallback for Python 3.8 - 3.10
        loop = asyncio.get_running_loop()
        protocol = writer.transport.get_protocol()
        new_transport = await loop.start_tls(
            writer.transport,
            protocol,
            ssl_context,
            server_side=False,
            server_hostname=server_hostname
        )
        writer._transport = new_transport  # type: ignore[attr-defined]
        reader._transport = new_transport  # type: ignore[attr-defined]
        if hasattr(protocol, "_replace_writer"):
            protocol._replace_writer(writer)  # type: ignore[attr-defined]


class RawAsyncSMTP:
    """A raw asyncio-based SMTP protocol client."""

    def __init__(
        self,
        host: str = "smtp.gmail.com",
        port: int = 465,
        local_hostname: Optional[str] = None,
        timeout: float = 30.0,
    ):
        self.host = host
        self.port = port
        self.local_hostname = local_hostname or socket.getfqdn()
        self.timeout = timeout
        self.reader: Optional[asyncio.StreamReader] = None
        self.writer: Optional[asyncio.StreamWriter] = None

    async def connect(self, use_tls: bool = False, start_tls: bool = False) -> None:
        if use_tls:
            ssl_context = ssl.create_default_context()
            self.reader, self.writer = await asyncio.wait_for(
                asyncio.open_connection(self.host, self.port, ssl=ssl_context),
                timeout=self.timeout
            )
        else:
            self.reader, self.writer = await asyncio.wait_for(
                asyncio.open_connection(self.host, self.port),
                timeout=self.timeout
            )

        code, message = await self.read_response()
        if code != 220:
            raise SMTPConnectError(code, message)

        await self.ehlo()

        if start_tls or (not use_tls and "STARTTLS" in message):
            await self.starttls()

    async def read_response(self) -> Tuple[int, str]:
        if self.reader is None:
            raise SMTPServerDisconnected("Not connected")
        code = -1
        lines = []
        while True:
            try:
                line = await asyncio.wait_for(self.reader.readline(), timeout=self.timeout)
            except asyncio.TimeoutError:
                raise socket.timeout("SMTP connection timed out")
            if not line:
                raise SMTPServerDisconnected("Connection closed unexpectedly")
            line_str = line.decode("utf-8", errors="ignore").rstrip("\r\n")
            lines.append(line_str)
            if len(line_str) >= 3:
                try:
                    code = int(line_str[:3])
                except ValueError:
                    pass
                if len(line_str) >= 4 and line_str[3] == '-':
                    continue
                else:
                    break
            else:
                break
        return code, "\n".join(lines)

    async def ehlo(self) -> None:
        if self.writer is None:
            raise SMTPServerDisconnected("Not connected")
        self.writer.write(f"EHLO {self.local_hostname}\r\n".encode())
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 250:
            self.writer.write(f"HELO {self.local_hostname}\r\n".encode())
            await self.writer.drain()
            code, message = await self.read_response()
            if code != 250:
                raise SMTPHeloError(code, message)

    async def starttls(self, ssl_context: Optional[ssl.SSLContext] = None) -> None:
        if self.writer is None or self.reader is None:
            raise SMTPServerDisconnected("Not connected")
        self.writer.write(b"STARTTLS\r\n")
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 220:
            raise SMTPException(f"STARTTLS failed: {message}")

        if ssl_context is None:
            ssl_context = ssl.create_default_context()

        await upgrade_to_tls(self.reader, self.writer, ssl_context, self.host)
        await self.ehlo()

    async def login(self, user: str, password: str) -> None:
        if self.writer is None:
            raise SMTPServerDisconnected("Not connected")
        import base64

        auth_plain = base64.b64encode(f"\0{user}\0{password}".encode()).decode("utf-8")
        self.writer.write(f"AUTH PLAIN {auth_plain}\r\n".encode())
        await self.writer.drain()
        code, message = await self.read_response()

        if code == 235:
            return

        self.writer.write(b"AUTH LOGIN\r\n")
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 334:
            raise SMTPAuthenticationError(code, message)

        user_b64 = base64.b64encode(user.encode("utf-8")).decode("utf-8")
        self.writer.write(f"{user_b64}\r\n".encode())
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 334:
            raise SMTPAuthenticationError(code, message)

        pass_b64 = base64.b64encode(password.encode("utf-8")).decode("utf-8")
        self.writer.write(f"{pass_b64}\r\n".encode())
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 235:
            raise SMTPAuthenticationError(code, message)

    async def login_oauth2(self, user: str, auth_string: str) -> None:
        if self.writer is None:
            raise SMTPServerDisconnected("Not connected")
        self.writer.write(f"AUTH XOAUTH2 {auth_string}\r\n".encode())
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 235:
            if code == 334:
                self.writer.write(b"\r\n")
                await self.writer.drain()
                code, message = await self.read_response()
            raise SMTPAuthenticationError(code, message)

    async def sendmail(
        self, from_addr: str, to_addrs: Union[str, List[str]], msg: str
    ) -> Dict[str, Any]:
        if self.writer is None:
            raise SMTPServerDisconnected("Not connected")
        if isinstance(to_addrs, str):
            to_addrs = [to_addrs]

        self.writer.write(f"MAIL FROM:<{from_addr}>\r\n".encode())
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 250:
            raise SMTPSenderRefused(code, message.encode("utf-8"), from_addr)

        errors = {}
        success_count = 0
        for addr in to_addrs:
            self.writer.write(f"RCPT TO:<{addr}>\r\n".encode())
            await self.writer.drain()
            code, message = await self.read_response()
            if code not in (250, 251):
                errors[addr] = (code, message.encode("utf-8"))
            else:
                success_count += 1

        if success_count == 0:
            raise SMTPRecipientsRefused(errors)

        self.writer.write(b"DATA\r\n")
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 354:
            raise SMTPDataError(code, message)

        lines = msg.splitlines()
        body_lines = []
        for line in lines:
            if line.startswith("."):
                body_lines.append("." + line)
            else:
                body_lines.append(line)
        body = "\r\n".join(body_lines) + "\r\n.\r\n"

        self.writer.write(body.encode("utf-8"))
        await self.writer.drain()
        code, message = await self.read_response()
        if code != 250:
            raise SMTPDataError(code, message)

        return errors

    async def quit(self) -> None:
        if self.writer is not None:
            try:
                self.writer.write(b"QUIT\r\n")
                await self.writer.drain()
                await self.read_response()
            except Exception:
                pass
            finally:
                self.close()

    def close(self) -> None:
        if self.writer is not None:
            try:
                self.writer.close()
            except Exception:
                pass
            self.writer = None
            self.reader = None


[docs] class AsyncClient(yagmail.Client): """ Asynchronous version of yagmail.Client. Provides non-blocking versions of login, send, send_unsent, and close using Python's built-in asyncio event loop and raw socket streams. """ def __init__( self, user: Optional[str] = None, password: Optional[Union[str, Dict[str, Any]]] = None, host: str = "smtp.gmail.com", port: Optional[Union[int, str]] = None, smtp_starttls: Optional[Union[bool, dict]] = None, smtp_ssl: bool = True, smtp_set_debuglevel: int = 0, smtp_skip_login: bool = False, encoding: str = "utf-8", oauth2_file: Optional[str] = None, soft_email_validation: bool = True, dkim: Optional[DKIM] = None, **kwargs: Any ): super().__init__( user=user, password=password, host=host, port=port, smtp_starttls=smtp_starttls, smtp_ssl=smtp_ssl, smtp_set_debuglevel=smtp_set_debuglevel, smtp_skip_login=smtp_skip_login, encoding=encoding, oauth2_file=oauth2_file, soft_email_validation=soft_email_validation, dkim=dkim, **kwargs ) self.is_closed = True self.smtp: Optional[RawAsyncSMTP] = None # type: ignore[assignment] @property def send_lock(self) -> asyncio.Lock: if not hasattr(self, "_send_lock"): self._send_lock = asyncio.Lock() return self._send_lock @property def login_lock(self) -> asyncio.Lock: if not hasattr(self, "_login_lock"): self._login_lock = asyncio.Lock() return self._login_lock
[docs] async def login(self) -> None: # type: ignore[override] """Connect and login to the SMTP server asynchronously.""" async with self.login_lock: if self.smtp is not None and not self.is_closed: return use_tls = str(self.port) == "465" self.smtp = RawAsyncSMTP( host=self.host, port=int(self.port), timeout=self.kwargs.get("timeout", 30.0) ) await self.smtp.connect(use_tls=use_tls, start_tls=bool(self.starttls)) if not self.smtp_skip_login: if self.oauth2_file is not None: if isinstance(self.credentials, dict): auth_string = self.get_oauth_string(self.user, self.credentials) await self.smtp.login_oauth2(self.user, auth_string) else: raise TypeError("OAuth2 credentials must be a dictionary") else: password = self.handle_password( self.user, self.credentials if isinstance(self.credentials, str) else None, ) await self.smtp.login(self.user, password) self.is_closed = False
[docs] async def send( # type: ignore[override] self, to: Optional[AddressInput] = None, subject: Optional[Union[str, List[str]]] = None, contents: Optional[Any] = None, attachments: Optional[Any] = None, cc: Optional[AddressInput] = None, bcc: Optional[AddressInput] = None, preview_only: bool = False, headers: Optional[Dict[str, str]] = None, prettify_html: bool = True, message_id: Optional[str] = None, group_messages: bool = True, ) -> Union[Tuple[List[str], str], Dict[str, Any], bool]: """Send an email asynchronously.""" recipients, msg_strings = self.prepare_send( to=to, subject=subject, contents=contents, attachments=attachments, cc=cc, bcc=bcc, headers=headers, prettify_html=prettify_html, message_id=message_id, group_messages=group_messages, ) if preview_only: return recipients, msg_strings async with self.send_lock: if self.smtp is None or self.is_closed: await self.login() return await self._attempt_send_async(recipients, msg_strings)
async def _attempt_send_async( self, recipients: List[str], msg_strings: str ) -> Union[Dict[str, Any], bool]: if self.smtp is None: raise SMTPServerDisconnected("Not connected") attempts = 0 while attempts < 3: try: result = await self.smtp.sendmail(self.user, recipients, msg_strings) self.log.info("Message sent to %s", recipients) self.num_mail_sent += 1 return result except SMTPServerDisconnected as e: self.log.error(e) self.is_closed = True attempts += 1 if attempts < 3: try: await self.login() except Exception as reconnect_err: self.log.error("Failed to reconnect during retry: %s", reconnect_err) await asyncio.sleep(attempts * 3) self.unsent.append((recipients, msg_strings)) return False
[docs] async def send_unsent(self) -> None: # type: ignore[override] """Attempt to send unsent emails asynchronously.""" async with self.send_lock: if self.smtp is None or self.is_closed: await self.login() unsent_copy = list(self.unsent) self.unsent.clear() for recipients, msg_strings in unsent_copy: await self._attempt_send_async(recipients, msg_strings)
[docs] async def close(self) -> None: # type: ignore[override] """Synchronous-like close method that raises error to match aioyagmail API.""" raise ValueError("Should be `async with` or use `await aclose()`")
[docs] async def aclose(self) -> None: """Close the SMTP connection asynchronously.""" async with self.send_lock: self.is_closed = True if self.smtp is not None: try: await self.smtp.quit() except Exception: pass finally: self.smtp = None
async def __aenter__(self) -> "AsyncClient": await self.login() return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: if not self.is_closed: await self.aclose() return False def __del__(self) -> None: try: if not self.is_closed: self.is_closed = True if self.smtp is not None: self.smtp.close() self.smtp = None except Exception: pass
# For backward compatibility AsyncSMTP = AsyncClient AIOSMTP = AsyncClient