amqtt
Advanced tools
| """Module for contributed plugins.""" | ||
| from dataclasses import asdict, is_dataclass | ||
| from typing import Any, TypeVar | ||
| from sqlalchemy import JSON, TypeDecorator | ||
| T = TypeVar("T") | ||
| class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]): | ||
| impl = JSON | ||
| cache_ok = True | ||
| def __init__(self, dataclass_type: type[T]) -> None: | ||
| if not is_dataclass(dataclass_type): | ||
| msg = f"{dataclass_type} must be a dataclass type" | ||
| raise TypeError(msg) | ||
| self.dataclass_type = dataclass_type | ||
| super().__init__() | ||
| def process_bind_param( | ||
| self, | ||
| value: list[Any] | None, # Python -> DB | ||
| dialect: Any | ||
| ) -> list[dict[str, Any]] | None: | ||
| if value is None: | ||
| return None | ||
| return [asdict(item) for item in value] | ||
| def process_result_value( | ||
| self, | ||
| value: list[dict[str, Any]] | None, # DB -> Python | ||
| dialect: Any | ||
| ) -> list[Any] | None: | ||
| if value is None: | ||
| return None | ||
| return [self.dataclass_type(**item) for item in value] | ||
| def process_literal_param(self, value: Any, dialect: Any) -> Any: | ||
| # Required by SQLAlchemy, typically used for literal SQL rendering. | ||
| return value | ||
| @property | ||
| def python_type(self) -> type: | ||
| # Required by TypeEngine to indicate the expected Python type. | ||
| return list |
| """Plugin to determine authentication of clients with DB storage.""" | ||
| from dataclasses import dataclass | ||
| import click | ||
| try: | ||
| from enum import StrEnum | ||
| except ImportError: | ||
| # support for python 3.10 | ||
| from enum import Enum | ||
| class StrEnum(str, Enum): # type: ignore[no-redef] | ||
| pass | ||
| from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin | ||
| class DBType(StrEnum): | ||
| """Enumeration for supported relational databases.""" | ||
| MARIA = "mariadb" | ||
| MYSQL = "mysql" | ||
| POSTGRESQL = "postgresql" | ||
| SQLITE = "sqlite" | ||
| @dataclass | ||
| class DBInfo: | ||
| """SQLAlchemy database information.""" | ||
| connect_str: str | ||
| connect_port: int | None | ||
| _db_map = { | ||
| DBType.MARIA: DBInfo("mysql+aiomysql", 3306), | ||
| DBType.MYSQL: DBInfo("mysql+aiomysql", 3306), | ||
| DBType.POSTGRESQL: DBInfo("postgresql+asyncpg", 5432), | ||
| DBType.SQLITE: DBInfo("sqlite+aiosqlite", None) | ||
| } | ||
| def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str: | ||
| """Create sqlalchemy database connection string.""" | ||
| db_info = _db_map[db_type] | ||
| if db_type == DBType.SQLITE: | ||
| return f"{db_info.connect_str}:///{db_filename}" | ||
| db_password = click.prompt("Enter the db password (press enter for none)", hide_input=True) | ||
| pwd = f":{db_password}" if db_password else "" | ||
| return f"{db_info.connect_str}://{db_username}:{pwd}@{db_host}:{db_port or db_info.connect_port}" | ||
| __all__ = ["DBType", "TopicAuthDBPlugin", "UserAuthDBPlugin", "db_connection_str"] |
| from collections.abc import Iterator | ||
| import logging | ||
| from sqlalchemy import select | ||
| from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine | ||
| from amqtt.contexts import Action | ||
| from amqtt.contrib.auth_db.models import AllowedTopic, Base, TopicAuth, UserAuth | ||
| from amqtt.errors import MQTTError | ||
| logger = logging.getLogger(__name__) | ||
| class UserManager: | ||
| def __init__(self, connection: str) -> None: | ||
| self._engine = create_async_engine(connection) | ||
| self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) | ||
| async def db_sync(self) -> None: | ||
| """Sync the database schema.""" | ||
| async with self._engine.begin() as conn: | ||
| await conn.run_sync(Base.metadata.create_all) | ||
| @staticmethod | ||
| async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> UserAuth: | ||
| stmt = select(UserAuth).filter(UserAuth.username == username) | ||
| user_auth = await db_session.scalar(stmt) | ||
| if not user_auth: | ||
| msg = f"Username '{username}' doesn't exist." | ||
| logger.debug(msg) | ||
| raise MQTTError(msg) | ||
| return user_auth | ||
| async def get_user_auth(self, username: str) -> UserAuth | None: | ||
| """Retrieve a user by username.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| try: | ||
| return await self._get_auth_or_raise(db_session, username) | ||
| except MQTTError: | ||
| return None | ||
| async def list_user_auths(self) -> Iterator[UserAuth]: | ||
| """Return list of all clients.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| stmt = select(UserAuth).order_by(UserAuth.username) | ||
| users = await db_session.scalars(stmt) | ||
| if not users: | ||
| msg = "No users exist." | ||
| logger.info(msg) | ||
| raise MQTTError(msg) | ||
| return users | ||
| async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None: | ||
| """Create a new user.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| stmt = select(UserAuth).filter(UserAuth.username == username) | ||
| user_auth = await db_session.scalar(stmt) | ||
| if user_auth: | ||
| msg = f"Username '{username}' already exists." | ||
| logger.info(msg) | ||
| raise MQTTError(msg) | ||
| user_auth = UserAuth(username=username) | ||
| user_auth.password = plain_password | ||
| db_session.add(user_auth) | ||
| await db_session.commit() | ||
| await db_session.flush() | ||
| return user_auth | ||
| async def delete_user_auth(self, username: str) -> UserAuth | None: | ||
| """Delete a user.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| try: | ||
| user_auth = await self._get_auth_or_raise(db_session, username) | ||
| except MQTTError: | ||
| return None | ||
| await db_session.delete(user_auth) | ||
| await db_session.commit() | ||
| await db_session.flush() | ||
| return user_auth | ||
| async def update_user_auth_password(self, username: str, plain_password: str) -> UserAuth | None: | ||
| """Change a user's password.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| user_auth = await self._get_auth_or_raise(db_session, username) | ||
| user_auth.password = plain_password | ||
| await db_session.commit() | ||
| await db_session.flush() | ||
| return user_auth | ||
| class TopicManager: | ||
| def __init__(self, connection: str) -> None: | ||
| self._engine = create_async_engine(connection) | ||
| self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) | ||
| async def db_sync(self) -> None: | ||
| """Sync the database schema.""" | ||
| async with self._engine.begin() as conn: | ||
| await conn.run_sync(Base.metadata.create_all) | ||
| @staticmethod | ||
| async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> TopicAuth: | ||
| stmt = select(TopicAuth).filter(TopicAuth.username == username) | ||
| topic_auth = await db_session.scalar(stmt) | ||
| if not topic_auth: | ||
| msg = f"Username '{username}' doesn't exist." | ||
| logger.debug(msg) | ||
| raise MQTTError(msg) | ||
| return topic_auth | ||
| @staticmethod | ||
| def _field_name(action: Action) -> str: | ||
| return f"{action}_acl" | ||
| async def create_topic_auth(self, username: str) -> TopicAuth | None: | ||
| """Create a new user.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| stmt = select(TopicAuth).filter(TopicAuth.username == username) | ||
| topic_auth = await db_session.scalar(stmt) | ||
| if topic_auth: | ||
| msg = f"Username '{username}' already exists." | ||
| raise MQTTError(msg) | ||
| topic_auth = TopicAuth(username=username) | ||
| db_session.add(topic_auth) | ||
| await db_session.commit() | ||
| await db_session.flush() | ||
| return topic_auth | ||
| async def get_topic_auth(self, username: str) -> TopicAuth | None: | ||
| """Retrieve a allowed topics by username.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| try: | ||
| return await self._get_auth_or_raise(db_session, username) | ||
| except MQTTError: | ||
| return None | ||
| async def list_topic_auths(self) -> Iterator[TopicAuth]: | ||
| """Return list of all authorized clients.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| stmt = select(TopicAuth).order_by(TopicAuth.username) | ||
| topics = await db_session.scalars(stmt) | ||
| if not topics: | ||
| msg = "No topics exist." | ||
| logger.info(msg) | ||
| raise MQTTError(msg) | ||
| return topics | ||
| async def add_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None: | ||
| """Add allowed topic from action for user.""" | ||
| if action == Action.PUBLISH and topic.startswith("$"): | ||
| msg = "MQTT does not allow clients to publish to $ topics." | ||
| raise MQTTError(msg) | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| user_auth = await self._get_auth_or_raise(db_session, username) | ||
| topic_list = getattr(user_auth, self._field_name(action)) | ||
| updated_list = [*topic_list, AllowedTopic(topic)] | ||
| setattr(user_auth, self._field_name(action), updated_list) | ||
| await db_session.commit() | ||
| await db_session.flush() | ||
| return updated_list | ||
| async def remove_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None: | ||
| """Remove topic from action for user.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| topic_auth = await self._get_auth_or_raise(db_session, username) | ||
| topic_list = topic_auth.get_topic_list(action) | ||
| if AllowedTopic(topic) not in topic_list: | ||
| msg = f"Client '{username}' doesn't have topic '{topic}' for action '{action}'." | ||
| logger.debug(msg) | ||
| raise MQTTError(msg) | ||
| updated_list = [allowed_topic for allowed_topic in topic_list if allowed_topic != AllowedTopic(topic)] | ||
| setattr(topic_auth, f"{action}_acl", updated_list) | ||
| await db_session.commit() | ||
| await db_session.flush() | ||
| return updated_list |
| from dataclasses import dataclass | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | ||
| from sqlalchemy import String | ||
| from sqlalchemy.ext.hybrid import hybrid_property | ||
| from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | ||
| from amqtt.contexts import Action | ||
| from amqtt.contrib import DataClassListJSON | ||
| from amqtt.plugins import TopicMatcher | ||
| if TYPE_CHECKING: | ||
| from passlib.context import CryptContext | ||
| logger = logging.getLogger(__name__) | ||
| matcher = TopicMatcher() | ||
| @dataclass | ||
| class AllowedTopic: | ||
| topic: str | ||
| def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool: | ||
| """Determine `in`.""" | ||
| return self.__eq__(item) | ||
| def __eq__(self, item: object) -> bool: | ||
| """Determine `==` or `!=`.""" | ||
| if isinstance(item, str): | ||
| return matcher.is_topic_allowed(item, self.topic) | ||
| if isinstance(item, AllowedTopic): | ||
| return item.topic == self.topic | ||
| msg = "AllowedTopic can only be compared to another AllowedTopic or string." | ||
| raise AttributeError(msg) | ||
| def __str__(self) -> str: | ||
| """Display topic.""" | ||
| return self.topic | ||
| def __repr__(self) -> str: | ||
| """Display topic.""" | ||
| return self.topic | ||
| class PasswordHasher: | ||
| """singleton to initialize the CryptContext and then use it elsewhere in the code.""" | ||
| _instance: Optional["PasswordHasher"] = None | ||
| def __init__(self) -> None: | ||
| if not hasattr(self, "_crypt_context"): | ||
| self._crypt_context: CryptContext | None = None | ||
| def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "PasswordHasher": | ||
| if cls._instance is None: | ||
| cls._instance = super().__new__(cls, *args, **kwargs) | ||
| return cls._instance | ||
| @property | ||
| def crypt_context(self) -> "CryptContext": | ||
| if not self._crypt_context: | ||
| msg = "CryptContext is empty" | ||
| raise ValueError(msg) | ||
| return self._crypt_context | ||
| @crypt_context.setter | ||
| def crypt_context(self, value: "CryptContext") -> None: | ||
| self._crypt_context = value | ||
| class Base(DeclarativeBase): | ||
| pass | ||
| class UserAuth(Base): | ||
| __tablename__ = "user_auth" | ||
| id: Mapped[int] = mapped_column(primary_key=True) | ||
| username: Mapped[str] = mapped_column(String, unique=True) | ||
| _password_hash: Mapped[str] = mapped_column("password_hash", String(128)) | ||
| publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) | ||
| subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) | ||
| receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) | ||
| @hybrid_property | ||
| def password(self) -> None: | ||
| msg = "Password is write-only" | ||
| raise AttributeError(msg) | ||
| @password.inplace.setter # type: ignore[arg-type] | ||
| def _password_setter(self, plain_password: str) -> None: | ||
| self._password_hash = PasswordHasher().crypt_context.hash(plain_password) | ||
| def verify_password(self, plain_password: str) -> bool: | ||
| return bool(PasswordHasher().crypt_context.verify(plain_password, self._password_hash)) | ||
| def __str__(self) -> str: | ||
| """Display client id and password hash.""" | ||
| return f"'{self.username}' with password hash: {self._password_hash}" | ||
| class TopicAuth(Base): | ||
| __tablename__ = "topic_auth" | ||
| id: Mapped[int] = mapped_column(primary_key=True) | ||
| username: Mapped[str] = mapped_column(String, unique=True) | ||
| publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) | ||
| subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) | ||
| receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list) | ||
| def get_topic_list(self, action: Action) -> list[AllowedTopic]: | ||
| return cast("list[AllowedTopic]", getattr(self, f"{action}_acl")) | ||
| def __str__(self) -> str: | ||
| """Display client id and password hash.""" | ||
| return f"""'{self.username}': | ||
| \tpublish: {self.publish_acl}, subscribe: {self.subscribe_acl}, receive: {self.receive_acl} | ||
| """ |
| from dataclasses import dataclass, field | ||
| import logging | ||
| from passlib.context import CryptContext | ||
| from sqlalchemy.ext.asyncio import create_async_engine | ||
| from amqtt.broker import BrokerContext | ||
| from amqtt.contexts import Action | ||
| from amqtt.contrib.auth_db.managers import TopicManager, UserManager | ||
| from amqtt.contrib.auth_db.models import Base, PasswordHasher | ||
| from amqtt.errors import MQTTError | ||
| from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin | ||
| from amqtt.session import Session | ||
| logger = logging.getLogger(__name__) | ||
| def default_hash_scheme() -> list[str]: | ||
| """Create config dataclass defaults.""" | ||
| return ["argon2", "bcrypt", "pbkdf2_sha256", "scrypt"] | ||
| class UserAuthDBPlugin(BaseAuthPlugin): | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| # access the singleton and set the proper crypt context | ||
| pwd_hasher = PasswordHasher() | ||
| pwd_hasher.crypt_context = CryptContext(schemes=self.config.hash_schemes, deprecated="auto") | ||
| self._user_manager = UserManager(self.config.connection) | ||
| self._engine = create_async_engine(f"{self.config.connection}") | ||
| async def on_broker_pre_start(self) -> None: | ||
| """Sync the schema (if configured).""" | ||
| if not self.config.sync_schema: | ||
| return | ||
| async with self._engine.begin() as conn: | ||
| await conn.run_sync(Base.metadata.create_all) | ||
| async def authenticate(self, *, session: Session) -> bool | None: | ||
| """Authenticate a client's session.""" | ||
| if not session.username or not session.password: | ||
| return False | ||
| user_auth = await self._user_manager.get_user_auth(session.username) | ||
| if not user_auth: | ||
| return False | ||
| return bool(session.password) and user_auth.verify_password(session.password) | ||
| @dataclass | ||
| class Config: | ||
| """Configuration for DB authentication.""" | ||
| connection: str | ||
| """SQLAlchemy connection string for the asyncio version of the database connector: | ||
| - `mysql+aiomysql://user:password@host:port/dbname` | ||
| - `postgresql+asyncpg://user:password@host:port/dbname` | ||
| - `sqlite+aiosqlite:///dbfilename.db` | ||
| """ | ||
| sync_schema: bool = False | ||
| """Use SQLAlchemy to create / update the database schema.""" | ||
| hash_schemes: list[str] = field(default_factory=default_hash_scheme) | ||
| """list of hash schemes to use for passwords""" | ||
| class TopicAuthDBPlugin(BaseTopicPlugin): | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| self._topic_manager = TopicManager(self.config.connection) | ||
| self._engine = create_async_engine(f"{self.config.connection}") | ||
| async def on_broker_pre_start(self) -> None: | ||
| """Sync the schema (if configured).""" | ||
| if not self.config.sync_schema: | ||
| return | ||
| async with self._engine.begin() as conn: | ||
| await conn.run_sync(Base.metadata.create_all) | ||
| async def topic_filtering( | ||
| self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None | ||
| ) -> bool | None: | ||
| if not session or not session.username or not topic: | ||
| return None | ||
| try: | ||
| topic_auth = await self._topic_manager.get_topic_auth(session.username) | ||
| topic_list = getattr(topic_auth, f"{action}_acl") | ||
| except MQTTError: | ||
| return False | ||
| return topic in topic_list | ||
| @dataclass | ||
| class Config: | ||
| """Configuration for DB topic filtering.""" | ||
| connection: str | ||
| """SQLAlchemy connection string for the asyncio version of the database connector: | ||
| - `mysql+aiomysql://user:password@host:port/dbname` | ||
| - `postgresql+asyncpg://user:password@host:port/dbname` | ||
| - `sqlite+aiosqlite:///dbfilename.db` | ||
| """ | ||
| sync_schema: bool = False | ||
| """Use SQLAlchemy to create / update the database schema.""" |
| import asyncio | ||
| import contextlib | ||
| import logging | ||
| from pathlib import Path | ||
| from typing import Annotated | ||
| import typer | ||
| from amqtt.contexts import Action | ||
| from amqtt.contrib.auth_db import DBType, db_connection_str | ||
| from amqtt.contrib.auth_db.managers import TopicManager, UserManager | ||
| from amqtt.errors import MQTTError | ||
| logging.basicConfig(level=logging.INFO, format="%(message)s") | ||
| logger = logging.getLogger(__name__) | ||
| topic_app = typer.Typer(no_args_is_help=True) | ||
| @topic_app.callback() | ||
| def main( | ||
| ctx: typer.Context, | ||
| db_type: Annotated[DBType, typer.Option("--db", "-d", help="db type", count=False)], | ||
| db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "", | ||
| db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0, | ||
| db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost", | ||
| db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db", | ||
| ) -> None: | ||
| """Command line interface to add / remove topic authorization. | ||
| Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database | ||
| password (if applicable). | ||
| If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.TopicManager` which provides | ||
| the underlying functionality to this command line interface. | ||
| """ | ||
| if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists(): | ||
| pass | ||
| elif db_type == DBType.SQLITE and not Path(db_filename).exists(): | ||
| logger.error(f"SQLite option could not find '{db_filename}'") | ||
| raise typer.Exit(code=1) | ||
| elif db_type != DBType.SQLITE and not db_username: | ||
| logger.error("DB access requires a username be provided.") | ||
| raise typer.Exit(code=1) | ||
| ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename} | ||
| @topic_app.command(name="sync") | ||
| def db_sync(ctx: typer.Context) -> None: | ||
| """Create the table and schema for username and topic lists for subscribe, publish or receive. | ||
| Non-destructive if run multiple times. To clear the whole table, need to drop it manually. | ||
| """ | ||
| async def run_sync() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"]) | ||
| mgr = UserManager(connect) | ||
| try: | ||
| await mgr.db_sync() | ||
| except MQTTError as me: | ||
| logger.critical("Could not sync schema on db.") | ||
| raise typer.Exit(code=1) from me | ||
| asyncio.run(run_sync()) | ||
| logger.info("Success: database synced.") | ||
| @topic_app.command(name="list") | ||
| def list_clients(ctx: typer.Context) -> None: | ||
| """List all Client IDs (in alphabetical order). Will also display the hashed passwords.""" | ||
| async def run_list() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"]) | ||
| mgr = TopicManager(connect) | ||
| user_count = 0 | ||
| for user in await mgr.list_topic_auths(): | ||
| user_count += 1 | ||
| logger.info(user) | ||
| if not user_count: | ||
| logger.info("No client authorizations exist.") | ||
| asyncio.run(run_list()) | ||
| @topic_app.command(name="add") | ||
| def add_topic_allowance( | ||
| ctx: typer.Context, | ||
| topic: Annotated[str, typer.Argument(help="list of topics", show_default=False)], | ||
| client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client", show_default=False)], | ||
| action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow", show_default=False)] | ||
| ) -> None: | ||
| """Create a new user with a client id and password (prompted).""" | ||
| async def run_add() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], | ||
| ctx.obj["filename"]) | ||
| mgr = TopicManager(connect) | ||
| with contextlib.suppress(MQTTError): | ||
| await mgr.create_topic_auth(client_id) | ||
| topic_auth = await mgr.get_topic_auth(client_id) | ||
| if not topic_auth: | ||
| logger.info(f"Topic auth doesn't exist for '{client_id}'") | ||
| raise typer.Exit(code=1) | ||
| if topic in [allowed_topic.topic for allowed_topic in topic_auth.get_topic_list(action)]: | ||
| logger.info(f"Topic '{topic}' already exists for '{action}'.") | ||
| raise typer.Exit(1) | ||
| await mgr.add_allowed_topic(client_id, topic, action) | ||
| logger.info(f"Success: topic '{topic}' added to {action} for '{client_id}'") | ||
| asyncio.run(run_add()) | ||
| @topic_app.command(name="rm") | ||
| def remove_topic_allowance(ctx: typer.Context, | ||
| client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")], | ||
| action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow")], | ||
| topic: Annotated[str, typer.Argument(help="list of topics")] | ||
| ) -> None: | ||
| """Remove a client from the authentication database.""" | ||
| async def run_remove() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], | ||
| ctx.obj["filename"]) | ||
| mgr = TopicManager(connect) | ||
| topic_auth = await mgr.get_topic_auth(client_id) | ||
| if not topic_auth: | ||
| logger.info(f"client '{client_id}' doesn't exist.") | ||
| raise typer.Exit(1) | ||
| if topic not in getattr(topic_auth, f"{action}_acl"): | ||
| logger.info(f"Error: topic '{topic}' not in the {action} allow list for {client_id}.") | ||
| raise typer.Exit(1) | ||
| try: | ||
| await mgr.remove_allowed_topic(client_id, topic, action) | ||
| except MQTTError as me: | ||
| logger.info(f"'Error: could not remove '{topic}' for client '{client_id}'.") | ||
| raise typer.Exit(1) from me | ||
| logger.info(f"Success: removed topic '{topic}' from {action} for '{client_id}'") | ||
| asyncio.run(run_remove()) | ||
| if __name__ == "__main__": | ||
| topic_app() |
| import asyncio | ||
| import logging | ||
| from pathlib import Path | ||
| from typing import Annotated | ||
| import click | ||
| import passlib | ||
| import typer | ||
| from amqtt.contrib.auth_db import DBType, db_connection_str | ||
| from amqtt.contrib.auth_db.managers import UserManager | ||
| from amqtt.errors import MQTTError | ||
| logging.basicConfig(level=logging.INFO, format="%(message)s") | ||
| logger = logging.getLogger(__name__) | ||
| user_app = typer.Typer(no_args_is_help=True) | ||
| @user_app.callback() | ||
| def main( | ||
| ctx: typer.Context, | ||
| db_type: Annotated[DBType, typer.Option(..., "--db", "-d", help="db type", show_default=False)], | ||
| db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "", | ||
| db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0, | ||
| db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost", | ||
| db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db", | ||
| ) -> None: | ||
| """Command line interface to list, create, remove and add clients. | ||
| Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database | ||
| password (if applicable) and the client id's password. | ||
| If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.UserManager` which provides | ||
| the underlying functionality to this command line interface. | ||
| """ | ||
| if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists(): | ||
| pass | ||
| elif db_type == DBType.SQLITE and not Path(db_filename).exists(): | ||
| logger.error(f"SQLite option could not find '{db_filename}'") | ||
| raise typer.Exit(code=1) | ||
| elif db_type != DBType.SQLITE and not db_username: | ||
| logger.error("DB access requires a username be provided.") | ||
| raise typer.Exit(code=1) | ||
| ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename} | ||
| @user_app.command(name="sync") | ||
| def db_sync(ctx: typer.Context) -> None: | ||
| """Create the table and schema for username and hashed password. | ||
| Non-destructive if run multiple times. To clear the whole table, need to drop it manually. | ||
| """ | ||
| async def run_sync() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"]) | ||
| mgr = UserManager(connect) | ||
| try: | ||
| await mgr.db_sync() | ||
| except MQTTError as me: | ||
| logger.critical("Could not sync schema on db.") | ||
| raise typer.Exit(code=1) from me | ||
| asyncio.run(run_sync()) | ||
| logger.info("Success: database synced.") | ||
| @user_app.command(name="list") | ||
| def list_user_auths(ctx: typer.Context) -> None: | ||
| """List all Client IDs (in alphabetical order). Will also display the hashed passwords.""" | ||
| async def run_list() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"]) | ||
| mgr = UserManager(connect) | ||
| user_count = 0 | ||
| for user in await mgr.list_user_auths(): | ||
| user_count += 1 | ||
| logger.info(user) | ||
| if not user_count: | ||
| logger.info("No client authentications exist.") | ||
| asyncio.run(run_list()) | ||
| @user_app.command(name="add") | ||
| def create_user_auth( | ||
| ctx: typer.Context, | ||
| client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")], | ||
| ) -> None: | ||
| """Create a new user with a client id and password (prompted).""" | ||
| async def run_create() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], | ||
| ctx.obj["filename"]) | ||
| mgr = UserManager(connect) | ||
| client_password = click.prompt("Enter the client's password", hide_input=True) | ||
| if not client_password.strip(): | ||
| logger.info("Error: client password cannot be empty.") | ||
| raise typer.Exit(1) | ||
| try: | ||
| user = await mgr.create_user_auth(client_id, client_password.strip()) | ||
| except passlib.exc.MissingBackendError as mbe: | ||
| logger.info(f"Please install backend: {mbe}") | ||
| raise typer.Exit(code=1) from mbe | ||
| if not user: | ||
| logger.info(f"Error: could not create user: {client_id}") | ||
| raise typer.Exit(code=1) | ||
| logger.info(f"Success: created {user}") | ||
| asyncio.run(run_create()) | ||
| @user_app.command(name="rm") | ||
| def remove_user_auth(ctx: typer.Context, | ||
| client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")]) -> None: | ||
| """Remove a client from the authentication database.""" | ||
| async def run_remove() -> None: | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], | ||
| ctx.obj["filename"]) | ||
| mgr = UserManager(connect) | ||
| user = await mgr.get_user_auth(client_id) | ||
| if not user: | ||
| logger.info(f"Error: client '{client_id}' does not exist.") | ||
| raise typer.Exit(1) | ||
| if not click.confirm(f"Please confirm the removal of '{client_id}'?"): | ||
| raise typer.Exit(0) | ||
| user = await mgr.delete_user_auth(client_id) | ||
| if not user: | ||
| logger.info(f"Error: client '{client_id}' does not exist.") | ||
| raise typer.Exit(1) | ||
| logger.info(f"Success: '{user.username}' was removed.") | ||
| asyncio.run(run_remove()) | ||
| @user_app.command(name="pwd") | ||
| def change_password( | ||
| ctx: typer.Context, | ||
| client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")], | ||
| ) -> None: | ||
| """Update a user's password (prompted).""" | ||
| async def run_password() -> None: | ||
| client_password = click.prompt("Enter the client's new password", hide_input=True) | ||
| if not client_password.strip(): | ||
| logger.error("Error: client password cannot be empty.") | ||
| raise typer.Exit(1) | ||
| connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], | ||
| ctx.obj["filename"]) | ||
| mgr = UserManager(connect) | ||
| await mgr.update_user_auth_password(client_id, client_password.strip()) | ||
| logger.info(f"Success: client '{client_id}' password updated.") | ||
| asyncio.run(run_password()) | ||
| if __name__ == "__main__": | ||
| user_app() |
| from dataclasses import dataclass | ||
| from datetime import datetime, timedelta | ||
| try: | ||
| from datetime import UTC | ||
| except ImportError: | ||
| # support for python 3.10 | ||
| from datetime import timezone | ||
| UTC = timezone.utc | ||
| from ipaddress import IPv4Address | ||
| import logging | ||
| from pathlib import Path | ||
| import re | ||
| from cryptography import x509 | ||
| from cryptography.hazmat.backends import default_backend | ||
| from cryptography.hazmat.primitives import hashes, serialization | ||
| from cryptography.hazmat.primitives.asymmetric import rsa | ||
| from cryptography.x509 import Certificate, CertificateSigningRequest | ||
| from cryptography.x509.oid import NameOID | ||
| from amqtt.plugins.base import BaseAuthPlugin | ||
| from amqtt.session import Session | ||
| logger = logging.getLogger(__name__) | ||
| class UserAuthCertPlugin(BaseAuthPlugin): | ||
| """Used a *signed* x509 certificate's `Subject AlternativeName` or `SAN` to verify client authentication. | ||
| Often used for IoT devices, this method provides the most secure form of identification. A root | ||
| certificate, often referenced as a CA certificate -- either issued by a known authority (such as LetsEncrypt) | ||
| or a self-signed certificate) is used to sign a private key and certificate for the server. Each device/client | ||
| also gets a unique private key and certificate signed by the same CA certificate; also included in the device | ||
| certificate is a 'SAN' or SubjectAlternativeName which is the device's unique identifier. | ||
| Since both server and device certificates are signed by the same CA certificate, the client can | ||
| verify the server's authenticity; and the server can verify the client's authenticity. And since | ||
| the device's certificate contains a x509 SAN, the server (with this plugin) can identify the device securely. | ||
| !!! note "URI and Client ID configuration" | ||
| `uri_domain` configuration must be set to the same uri used to generate the device credentials | ||
| when a device is connecting with private key and certificate, the `client_id` must | ||
| match the device id used to generate the device credentials. | ||
| Available ore three scripts to help with the key generation and certificate signing: `ca_creds`, `server_creds` | ||
| and `device_creds`. | ||
| !!! note "Configuring broker & client for using Self-signed root CA" | ||
| If using self-signed root credentials, the `cafile` configuration for both broker and client need to be | ||
| configured with `cafile` set to the `ca.crt`. | ||
| """ | ||
| async def authenticate(self, *, session: Session) -> bool | None: | ||
| """Verify the client's session using the provided client's x509 certificate.""" | ||
| if not session.ssl_object: | ||
| return False | ||
| der_cert = session.ssl_object.getpeercert(binary_form=True) | ||
| if der_cert: | ||
| cert = x509.load_der_x509_certificate(der_cert, backend=default_backend()) | ||
| try: | ||
| san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName) | ||
| uris = san.value.get_values_for_type(x509.UniformResourceIdentifier) | ||
| if self.config.uri_domain not in uris[0]: | ||
| return False | ||
| pattern = rf"^spiffe://{re.escape(self.config.uri_domain)}/device/([^/]+)$" | ||
| match = re.match(pattern, uris[0]) | ||
| if not match: | ||
| return False | ||
| return match.group(1) == session.client_id | ||
| except x509.ExtensionNotFound: | ||
| logger.warning("No SAN extension found.") | ||
| return False | ||
| @dataclass | ||
| class Config: | ||
| """Configuration for the CertificateAuthPlugin.""" | ||
| uri_domain: str | ||
| """The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)""" | ||
| def generate_root_creds(country: str, state: str, locality: str, | ||
| org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]: | ||
| """Generate CA key and certificate.""" | ||
| # generate private key for the server | ||
| ca_key = rsa.generate_private_key( | ||
| public_exponent=65537, | ||
| key_size=4096, | ||
| ) | ||
| # Create certificate subject and issuer (self-signed) | ||
| subject = issuer = x509.Name([ | ||
| x509.NameAttribute(NameOID.COUNTRY_NAME, country), | ||
| x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state), | ||
| x509.NameAttribute(NameOID.LOCALITY_NAME, locality), | ||
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name), | ||
| x509.NameAttribute(NameOID.COMMON_NAME, cn), | ||
| ]) | ||
| # 3. Build self-signed certificate | ||
| cert = ( | ||
| x509.CertificateBuilder() | ||
| .subject_name(subject) | ||
| .issuer_name(issuer) | ||
| .public_key(ca_key.public_key()) | ||
| .serial_number(x509.random_serial_number()) | ||
| .not_valid_before(datetime.now(UTC)) | ||
| .not_valid_after(datetime.now(UTC) + timedelta(days=3650)) # 10 years | ||
| .add_extension( | ||
| x509.BasicConstraints(ca=True, path_length=None), | ||
| critical=True, | ||
| ) | ||
| .add_extension( | ||
| x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()), | ||
| critical=False, | ||
| ) | ||
| .add_extension( | ||
| x509.KeyUsage( | ||
| key_cert_sign=True, | ||
| crl_sign=True, | ||
| digital_signature=False, | ||
| key_encipherment=False, | ||
| content_commitment=False, | ||
| data_encipherment=False, | ||
| key_agreement=False, | ||
| encipher_only=False, | ||
| decipher_only=False, | ||
| ), | ||
| critical=True, | ||
| ) | ||
| .sign(ca_key, hashes.SHA256()) | ||
| ) | ||
| return ca_key, cert | ||
| def generate_server_csr(country: str, org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]: | ||
| """Generate server private key and server certificate-signing-request.""" | ||
| key = rsa.generate_private_key(public_exponent=65537, key_size=2048) | ||
| csr = ( | ||
| x509.CertificateSigningRequestBuilder() | ||
| .subject_name(x509.Name([ | ||
| x509.NameAttribute(NameOID.COUNTRY_NAME, country), | ||
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name), | ||
| x509.NameAttribute(NameOID.COMMON_NAME, cn), | ||
| ])) | ||
| .add_extension( | ||
| x509.SubjectAlternativeName([ | ||
| x509.DNSName(cn), | ||
| x509.IPAddress(IPv4Address("127.0.0.1")), | ||
| ]), | ||
| critical=False, | ||
| ) | ||
| .sign(key, hashes.SHA256()) | ||
| ) | ||
| return key, csr | ||
| def generate_device_csr(country: str, org_name: str, common_name: str, | ||
| uri_san: str, dns_san: str | ||
| ) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]: | ||
| """Generate a device key and a csr.""" | ||
| key = rsa.generate_private_key(public_exponent=65537, key_size=2048) | ||
| csr = ( | ||
| x509.CertificateSigningRequestBuilder() | ||
| .subject_name(x509.Name([ | ||
| x509.NameAttribute(NameOID.COUNTRY_NAME, country), | ||
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name | ||
| ), | ||
| x509.NameAttribute(NameOID.COMMON_NAME, common_name), | ||
| ])) | ||
| .add_extension( | ||
| x509.SubjectAlternativeName([ | ||
| x509.UniformResourceIdentifier(uri_san), | ||
| x509.DNSName(dns_san), | ||
| ]), | ||
| critical=False, | ||
| ) | ||
| .sign(key, hashes.SHA256()) | ||
| ) | ||
| return key, csr | ||
| def sign_csr(csr: CertificateSigningRequest, | ||
| ca_key: rsa.RSAPrivateKey, | ||
| ca_cert: Certificate, validity_days: int = 365) -> Certificate: | ||
| """Sign a csr with CA credentials.""" | ||
| return ( | ||
| x509.CertificateBuilder() | ||
| .subject_name(csr.subject) | ||
| .issuer_name(ca_cert.subject) | ||
| .public_key(csr.public_key()) | ||
| .serial_number(x509.random_serial_number()) | ||
| .not_valid_before(datetime.now(UTC)) | ||
| .not_valid_after(datetime.now(UTC) + timedelta(days=validity_days)) | ||
| .add_extension( | ||
| x509.BasicConstraints(ca=False, path_length=None), | ||
| critical=True, | ||
| ) | ||
| .add_extension( | ||
| csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value, | ||
| critical=False, | ||
| ) | ||
| .add_extension( | ||
| x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()), # type: ignore[arg-type] | ||
| critical=False, | ||
| ) | ||
| .sign(ca_key, hashes.SHA256()) | ||
| ) | ||
| def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]: | ||
| """Load server key and certificate.""" | ||
| with Path(ca_key_fn).open("rb") as f: | ||
| ca_key: rsa.RSAPrivateKey = serialization.load_pem_private_key(f.read(), password=None) # type: ignore[assignment] | ||
| with Path(ca_crt_fn).open("rb") as f: | ||
| ca_cert = x509.load_pem_x509_certificate(f.read()) | ||
| return ca_key, ca_cert | ||
| def write_key_and_crt(key: rsa.RSAPrivateKey, crt: Certificate, | ||
| prefix: str, path: Path | None = None) -> None: | ||
| """Create pem-encoded files for key and certificate.""" | ||
| path = path or Path() | ||
| crt_fn = path / f"{prefix}.crt" | ||
| key_fn = path / f"{prefix}.key" | ||
| with crt_fn.open("wb") as f: | ||
| f.write(crt.public_bytes(serialization.Encoding.PEM)) | ||
| with key_fn.open("wb") as f: | ||
| f.write(key.private_bytes( | ||
| serialization.Encoding.PEM, | ||
| serialization.PrivateFormat.TraditionalOpenSSL, | ||
| serialization.NoEncryption() | ||
| )) |
| from dataclasses import dataclass | ||
| try: | ||
| from enum import StrEnum | ||
| except ImportError: | ||
| # support for python 3.10 | ||
| from enum import Enum | ||
| class StrEnum(str, Enum): # type: ignore[no-redef] | ||
| pass | ||
| import logging | ||
| from typing import Any | ||
| from aiohttp import ClientResponse, ClientSession, FormData | ||
| from amqtt.broker import BrokerContext | ||
| from amqtt.contexts import Action | ||
| from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin | ||
| from amqtt.session import Session | ||
| logger = logging.getLogger(__name__) | ||
| class ResponseMode(StrEnum): | ||
| STATUS = "status" | ||
| JSON = "json" | ||
| TEXT = "text" | ||
| class RequestMethod(StrEnum): | ||
| GET = "get" | ||
| POST = "post" | ||
| PUT = "put" | ||
| class ParamsMode(StrEnum): | ||
| JSON = "json" | ||
| FORM = "form" | ||
| class ACLError(Exception): | ||
| pass | ||
| HTTP_2xx_MIN = 200 | ||
| HTTP_2xx_MAX = 299 | ||
| HTTP_4xx_MIN = 400 | ||
| HTTP_4xx_MAX = 499 | ||
| @dataclass | ||
| class HttpConfig: | ||
| """Configuration for the HTTP Auth & ACL Plugin.""" | ||
| host: str | ||
| """hostname of the server for the auth & acl check""" | ||
| port: int | ||
| """port of the server for the auth & acl check""" | ||
| request_method: RequestMethod = RequestMethod.GET | ||
| """send the request as a GET, POST or PUT""" | ||
| params_mode: ParamsMode = ParamsMode.JSON # see docs/plugins/http.md for additional details | ||
| """send the request with `JSON` or `FORM` data. *additional details below*""" | ||
| response_mode: ResponseMode = ResponseMode.JSON # see docs/plugins/http.md for additional details | ||
| """expected response from the auth/acl server. `STATUS` (code), `JSON`, or `TEXT`. *additional details below*""" | ||
| with_tls: bool = False | ||
| """http or https""" | ||
| user_agent: str = "amqtt" | ||
| """the 'User-Agent' header sent along with the request""" | ||
| superuser_uri: str | None = None | ||
| """URI to verify if the user is a superuser (e.g. '/superuser'), `None` if superuser is not supported""" | ||
| timeout: int = 5 | ||
| """duration, in seconds, to wait for the HTTP server to respond""" | ||
| class AuthHttpPlugin(BasePlugin[BrokerContext]): | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| self.http = ClientSession(headers={"User-Agent": self.config.user_agent}) | ||
| match self.config.request_method: | ||
| case RequestMethod.GET: | ||
| self.method = self.http.get | ||
| case RequestMethod.PUT: | ||
| self.method = self.http.put | ||
| case _: | ||
| self.method = self.http.post | ||
| async def on_broker_pre_shutdown(self) -> None: | ||
| await self.http.close() | ||
| @staticmethod | ||
| def _is_2xx(r: ClientResponse) -> bool: | ||
| return HTTP_2xx_MIN <= r.status <= HTTP_2xx_MAX | ||
| @staticmethod | ||
| def _is_4xx(r: ClientResponse) -> bool: | ||
| return HTTP_4xx_MIN <= r.status <= HTTP_4xx_MAX | ||
| def _get_params(self, payload: dict[str, Any]) -> dict[str, Any]: | ||
| match self.config.params_mode: | ||
| case ParamsMode.FORM: | ||
| match self.config.request_method: | ||
| case RequestMethod.GET: | ||
| kwargs = {"params": payload} | ||
| case _: # POST, PUT | ||
| d: Any = FormData(payload) | ||
| kwargs = {"data": d} | ||
| case _: # JSON | ||
| kwargs = {"json": payload} | ||
| return kwargs | ||
| async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911 | ||
| kwargs = self._get_params(payload) | ||
| async with self.method(url, **kwargs) as r: | ||
| logger.debug(f"http request returned {r.status}") | ||
| match self.config.response_mode: | ||
| case ResponseMode.TEXT: | ||
| return self._is_2xx(r) and (await r.text()).lower() == "ok" | ||
| case ResponseMode.STATUS: | ||
| if self._is_2xx(r): | ||
| return True | ||
| if self._is_4xx(r): | ||
| return False | ||
| # any other code | ||
| return None | ||
| case _: | ||
| if not self._is_2xx(r): | ||
| return False | ||
| data: dict[str, Any] = await r.json() | ||
| data = {k.lower(): v for k, v in data.items()} | ||
| return data.get("ok", None) | ||
| def get_url(self, uri: str) -> str: | ||
| return f"{'https' if self.config.with_tls else 'http'}://{self.config.host}:{self.config.port}{uri}" | ||
| class UserAuthHttpPlugin(AuthHttpPlugin, BaseAuthPlugin): | ||
| async def authenticate(self, *, session: Session) -> bool | None: | ||
| d = {"username": session.username, "password": session.password, "client_id": session.client_id} | ||
| return await self._send_request(self.get_url(self.config.user_uri), d) | ||
| @dataclass | ||
| class Config(HttpConfig): | ||
| """Configuration for the HTTP Auth Plugin.""" | ||
| user_uri: str = "/user" | ||
| """URI of the auth check.""" | ||
| class TopicAuthHttpPlugin(AuthHttpPlugin, BaseTopicPlugin): | ||
| async def topic_filtering(self, *, | ||
| session: Session | None = None, | ||
| topic: str | None = None, | ||
| action: Action | None = None) -> bool | None: | ||
| if not session: | ||
| return None | ||
| acc = 0 | ||
| match action: | ||
| case Action.PUBLISH: | ||
| acc = 2 | ||
| case Action.SUBSCRIBE: | ||
| acc = 4 | ||
| case Action.RECEIVE: | ||
| acc = 1 | ||
| d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc} | ||
| return await self._send_request(self.get_url(self.config.topic_uri), d) | ||
| @dataclass | ||
| class Config(HttpConfig): | ||
| """Configuration for the HTTP Topic Plugin.""" | ||
| topic_uri: str = "/acl" | ||
| """URI of the topic check.""" |
| from dataclasses import dataclass | ||
| import logging | ||
| from typing import ClassVar | ||
| import jwt | ||
| try: | ||
| from enum import StrEnum | ||
| except ImportError: | ||
| # support for python 3.10 | ||
| from enum import Enum | ||
| class StrEnum(str, Enum): # type: ignore[no-redef] | ||
| pass | ||
| from amqtt.broker import BrokerContext | ||
| from amqtt.contexts import Action | ||
| from amqtt.plugins import TopicMatcher | ||
| from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin | ||
| from amqtt.session import Session | ||
| logger = logging.getLogger(__name__) | ||
| class Algorithms(StrEnum): | ||
| ES256 = "ES256" | ||
| ES256K = "ES256K" | ||
| ES384 = "ES384" | ||
| ES512 = "ES512" | ||
| ES521 = "ES521" | ||
| EdDSA = "EdDSA" | ||
| HS256 = "HS256" | ||
| HS384 = "HS384" | ||
| HS512 = "HS512" | ||
| PS256 = "PS256" | ||
| PS384 = "PS384" | ||
| PS512 = "PS512" | ||
| RS256 = "RS256" | ||
| RS384 = "RS384" | ||
| RS512 = "RS512" | ||
| class UserAuthJwtPlugin(BaseAuthPlugin): | ||
| async def authenticate(self, *, session: Session) -> bool | None: | ||
| if not session.username or not session.password: | ||
| return None | ||
| try: | ||
| decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"]) | ||
| return bool(decoded_payload.get(self.config.user_claim, None) == session.username) | ||
| except jwt.ExpiredSignatureError: | ||
| logger.debug(f"jwt for '{session.username}' is expired") | ||
| return False | ||
| except jwt.InvalidTokenError: | ||
| logger.debug(f"jwt for '{session.username}' is invalid") | ||
| return False | ||
| @dataclass | ||
| class Config: | ||
| """Configuration for the JWT user authentication.""" | ||
| secret_key: str | ||
| """Secret key to decrypt the token.""" | ||
| user_claim: str | ||
| """Payload key for user name.""" | ||
| algorithm: str = "HS256" | ||
| """Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256', | ||
| 'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'""" | ||
| class TopicAuthJwtPlugin(BaseTopicPlugin): | ||
| _topic_jwt_claims: ClassVar = { | ||
| Action.PUBLISH: "publish_claim", | ||
| Action.SUBSCRIBE: "subscribe_claim", | ||
| Action.RECEIVE: "receive_claim", | ||
| } | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| self.topic_matcher = TopicMatcher() | ||
| async def topic_filtering( | ||
| self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None | ||
| ) -> bool | None: | ||
| if not session or not topic or not action: | ||
| return None | ||
| if not session.password: | ||
| return None | ||
| try: | ||
| decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"]) | ||
| claim = getattr(self.config, self._topic_jwt_claims[action]) | ||
| return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, [])) | ||
| except jwt.ExpiredSignatureError: | ||
| logger.debug(f"jwt for '{session.username}' is expired") | ||
| return False | ||
| except jwt.InvalidTokenError: | ||
| logger.debug(f"jwt for '{session.username}' is invalid") | ||
| return False | ||
| @dataclass | ||
| class Config: | ||
| """Configuration for the JWT topic authorization.""" | ||
| secret_key: str | ||
| """Secret key to decrypt the token.""" | ||
| publish_claim: str | ||
| """Payload key for contains a list of permissible publish topics.""" | ||
| subscribe_claim: str | ||
| """Payload key for contains a list of permissible subscribe topics.""" | ||
| receive_claim: str | ||
| """Payload key for contains a list of permissible receive topics.""" | ||
| algorithm: str = "HS256" | ||
| """Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256', | ||
| 'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'""" |
| from dataclasses import dataclass | ||
| import logging | ||
| from typing import ClassVar | ||
| import ldap | ||
| from amqtt.broker import BrokerContext | ||
| from amqtt.contexts import Action | ||
| from amqtt.errors import PluginInitError | ||
| from amqtt.plugins import TopicMatcher | ||
| from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin | ||
| from amqtt.session import Session | ||
| logger = logging.getLogger(__name__) | ||
| @dataclass | ||
| class LdapConfig: | ||
| """Configuration for the LDAP Plugins.""" | ||
| server: str | ||
| """uri formatted server location. e.g `ldap://localhost:389`""" | ||
| base_dn: str | ||
| """distinguished name (dn) of the ldap server. e.g. `dc=amqtt,dc=io`""" | ||
| user_attribute: str | ||
| """attribute in ldap entry to match the username against""" | ||
| bind_dn: str | ||
| """distinguished name (dn) of known, preferably read-only, user. e.g. `cn=admin,dc=amqtt,dc=io`""" | ||
| bind_password: str | ||
| """password for known, preferably read-only, user""" | ||
| class AuthLdapPlugin(BasePlugin[BrokerContext]): | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| self.conn = ldap.initialize(self.config.server) | ||
| self.conn.protocol_version = ldap.VERSION3 # pylint: disable=E1101 | ||
| try: | ||
| self.conn.simple_bind_s(self.config.bind_dn, self.config.bind_password) | ||
| except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101 | ||
| raise PluginInitError(self.__class__) from e | ||
| class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin): | ||
| """Plugin to authenticate a user with an LDAP directory server.""" | ||
| async def authenticate(self, *, session: Session) -> bool | None: | ||
| # use our initial creds to see if the user exists | ||
| search_filter = f"({self.config.user_attribute}={session.username})" | ||
| result = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, ["dn"]) # pylint: disable=E1101 | ||
| if not result: | ||
| logger.debug(f"user not found: {session.username}") | ||
| return False | ||
| try: | ||
| # `search_s` responds with list of tuples: (dn, entry); first in list is our match | ||
| user_dn = result[0][0] | ||
| except IndexError: | ||
| return False | ||
| try: | ||
| user_conn = ldap.initialize(self.config.server) | ||
| user_conn.simple_bind_s(user_dn, session.password) | ||
| except ldap.INVALID_CREDENTIALS: # pylint: disable=E1101 | ||
| logger.debug(f"invalid credentials for '{session.username}'") | ||
| return False | ||
| except ldap.LDAPError as e: # pylint: disable=E1101 | ||
| logger.debug(f"LDAP error during user bind: {e}") | ||
| return False | ||
| return True | ||
| @dataclass | ||
| class Config(LdapConfig): | ||
| """Configuration for the User Auth LDAP Plugin.""" | ||
| class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin): | ||
| """Plugin to authenticate a user with an LDAP directory server.""" | ||
| _action_attr_map: ClassVar = { | ||
| Action.PUBLISH: "publish_attribute", | ||
| Action.SUBSCRIBE: "subscribe_attribute", | ||
| Action.RECEIVE: "receive_attribute" | ||
| } | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| self.topic_matcher = TopicMatcher() | ||
| async def topic_filtering( | ||
| self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None | ||
| ) -> bool | None: | ||
| # if not provided needed criteria, can't properly evaluate topic filtering | ||
| if not session or not action or not topic: | ||
| return None | ||
| search_filter = f"({self.config.user_attribute}={session.username})" | ||
| attrs = [ | ||
| "cn", | ||
| self.config.publish_attribute, | ||
| self.config.subscribe_attribute, | ||
| self.config.receive_attribute | ||
| ] | ||
| results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101 | ||
| if not results: | ||
| logger.debug(f"user not found: {session.username}") | ||
| return False | ||
| if len(results) > 1: | ||
| found_users = [dn for dn, _ in results] | ||
| logger.debug(f"multiple users found: {', '.join(found_users)}") | ||
| return False | ||
| dn, entry = results[0] | ||
| ldap_attribute = getattr(self.config, self._action_attr_map[action]) | ||
| topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])] | ||
| logger.debug(f"DN: {dn} - {ldap_attribute}={topic_filters}") | ||
| return self.topic_matcher.are_topics_allowed(topic, topic_filters) | ||
| @dataclass | ||
| class Config(LdapConfig): | ||
| """Configuration for the LDAPAuthPlugin.""" | ||
| publish_attribute: str | ||
| """LDAP attribute which contains a list of permissible publish topics.""" | ||
| subscribe_attribute: str | ||
| """LDAP attribute which contains a list of permissible subscribe topics.""" | ||
| receive_attribute: str | ||
| """LDAP attribute which contains a list of permissible receive topics.""" |
| from dataclasses import dataclass | ||
| import logging | ||
| from pathlib import Path | ||
| from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select | ||
| from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine | ||
| from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | ||
| from amqtt.broker import BrokerContext, RetainedApplicationMessage | ||
| from amqtt.contrib import DataClassListJSON | ||
| from amqtt.errors import PluginError | ||
| from amqtt.plugins.base import BasePlugin | ||
| from amqtt.session import Session | ||
| logger = logging.getLogger(__name__) | ||
| class Base(DeclarativeBase): | ||
| pass | ||
| @dataclass | ||
| class RetainedMessage: | ||
| topic: str | ||
| data: str | ||
| qos: int | ||
| @dataclass | ||
| class Subscription: | ||
| topic: str | ||
| qos: int | ||
| class StoredSession(Base): | ||
| __tablename__ = "stored_sessions" | ||
| id: Mapped[int] = mapped_column(primary_key=True) | ||
| client_id: Mapped[str] = mapped_column(String) | ||
| clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True) | ||
| will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false") | ||
| will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None) | ||
| will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None) | ||
| will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None) | ||
| will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None) | ||
| keep_alive: Mapped[int] = mapped_column(Integer, default=0) | ||
| retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list) | ||
| subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list) | ||
| class StoredMessage(Base): | ||
| __tablename__ = "stored_messages" | ||
| id: Mapped[int] = mapped_column(primary_key=True) | ||
| topic: Mapped[str] = mapped_column(String) | ||
| data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None) | ||
| qos: Mapped[int] = mapped_column(Integer, default=0) | ||
| class SessionDBPlugin(BasePlugin[BrokerContext]): | ||
| """Plugin to store session information and retained topic messages in the event that the broker terminates abnormally. | ||
| Configuration: | ||
| - file *(string)* path & filename to store the session db. default: `amqtt.db` | ||
| - clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True` | ||
| """ | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| # bypass the `test_plugins_correct_has_attr` until it can be updated | ||
| if not hasattr(self.config, "file"): | ||
| logger.warning("`Config` is missing a `file` attribute") | ||
| return | ||
| self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}") | ||
| self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) | ||
| @staticmethod | ||
| async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession: | ||
| stmt = select(StoredSession).filter(StoredSession.client_id == client_id) | ||
| stored_session = await db_session.scalar(stmt) | ||
| if stored_session is None: | ||
| stored_session = StoredSession(client_id=client_id) | ||
| db_session.add(stored_session) | ||
| await db_session.flush() | ||
| return stored_session | ||
| @staticmethod | ||
| async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage: | ||
| stmt = select(StoredMessage).filter(StoredMessage.topic == topic) | ||
| stored_message = await db_session.scalar(stmt) | ||
| if stored_message is None: | ||
| stored_message = StoredMessage(topic=topic) | ||
| db_session.add(stored_message) | ||
| await db_session.flush() | ||
| return stored_message | ||
| async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None: | ||
| """Search to see if session already exists.""" | ||
| # if client id doesn't exist, create (can ignore if session is anonymous) | ||
| # update session information (will, clean_session, etc) | ||
| # don't store session information for clean or anonymous sessions | ||
| if client_session.clean_session in (None, True) or client_session.is_anonymous: | ||
| return | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| stored_session = await self._get_or_create_session(db_session, client_id) | ||
| stored_session.clean_session = client_session.clean_session | ||
| stored_session.will_flag = client_session.will_flag | ||
| stored_session.will_message = client_session.will_message # type: ignore[assignment] | ||
| stored_session.will_qos = client_session.will_qos | ||
| stored_session.will_retain = client_session.will_retain | ||
| stored_session.will_topic = client_session.will_topic | ||
| stored_session.keep_alive = client_session.keep_alive | ||
| await db_session.flush() | ||
| async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None: | ||
| """Create/update subscription if clean session = false.""" | ||
| session = self.context.get_session(client_id) | ||
| if not session: | ||
| logger.warning(f"'{client_id}' is subscribing but doesn't have a session") | ||
| return | ||
| if session.clean_session: | ||
| return | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| # stored sessions shouldn't need to be created here, but we'll use the same helper... | ||
| stored_session = await self._get_or_create_session(db_session, client_id) | ||
| stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)] | ||
| await db_session.flush() | ||
| async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None: | ||
| """Remove subscription if clean session = false.""" | ||
| async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None: | ||
| """Update to retained messages. | ||
| if retained_message.data is None or '', the message is being cleared | ||
| """ | ||
| # if client_id is valid, the retained message is for a disconnected client | ||
| if client_id is not None: | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| # stored sessions shouldn't need to be created here, but we'll use the same helper... | ||
| stored_session = await self._get_or_create_session(db_session, client_id) | ||
| stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic, | ||
| retained_message.data.decode(), | ||
| retained_message.qos or 0)] | ||
| await db_session.flush() | ||
| return | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| # if the retained message has data, we need to store/update for the topic | ||
| if retained_message.data: | ||
| client_message = await self._get_or_create_message(db_session, retained_message.topic) | ||
| client_message.data = retained_message.data # type: ignore[assignment] | ||
| client_message.qos = retained_message.qos or 0 | ||
| await db_session.flush() | ||
| return | ||
| # if there is no data, clear the stored message (if exists) for the topic | ||
| stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic) | ||
| topic_message = await db_session.scalar(stmt) | ||
| if topic_message is not None: | ||
| await db_session.delete(topic_message) | ||
| await db_session.flush() | ||
| return | ||
| async def on_broker_pre_start(self) -> None: | ||
| """Initialize the database and db connection.""" | ||
| async with self._engine.begin() as conn: | ||
| await conn.run_sync(Base.metadata.create_all) | ||
| async def on_broker_post_start(self) -> None: | ||
| """Load subscriptions.""" | ||
| if len(self.context.subscriptions) > 0: | ||
| msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet" | ||
| raise PluginError(msg) | ||
| if len(list(self.context.sessions)) > 0: | ||
| msg = "SessionDBPlugin : broker shouldn't have any sessions yet" | ||
| raise PluginError(msg) | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| stmt = select(StoredSession) | ||
| stored_sessions = await db_session.execute(stmt) | ||
| restored_sessions = 0 | ||
| for stored_session in stored_sessions.scalars(): | ||
| await self.context.add_subscription(stored_session.client_id, None, None) | ||
| for subscription in stored_session.subscriptions: | ||
| await self.context.add_subscription(stored_session.client_id, | ||
| subscription.topic, | ||
| subscription.qos) | ||
| session = self.context.get_session(stored_session.client_id) | ||
| if not session: | ||
| continue | ||
| session.clean_session = stored_session.clean_session | ||
| session.will_flag = stored_session.will_flag | ||
| session.will_message = stored_session.will_message | ||
| session.will_qos = stored_session.will_qos | ||
| session.will_retain = stored_session.will_retain | ||
| session.will_topic = stored_session.will_topic | ||
| session.keep_alive = stored_session.keep_alive | ||
| for message in stored_session.retained: | ||
| retained_message = RetainedApplicationMessage( | ||
| source_session=None, | ||
| topic=message.topic, | ||
| data=message.data.encode(), | ||
| qos=message.qos | ||
| ) | ||
| await session.retained_messages.put(retained_message) | ||
| restored_sessions += 1 | ||
| stmt = select(StoredMessage) | ||
| stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt) | ||
| restored_messages = 0 | ||
| retained_messages = self.context.retained_messages | ||
| for stored_message in stored_messages.scalars(): | ||
| retained_messages[stored_message.topic] = (RetainedApplicationMessage( | ||
| source_session=None, | ||
| topic=stored_message.topic, | ||
| data=stored_message.data or b"", | ||
| qos=stored_message.qos | ||
| )) | ||
| restored_messages += 1 | ||
| logger.info(f"Retained messages restored: {restored_messages}") | ||
| logger.info(f"Restored {restored_sessions} sessions.") | ||
| async def on_broker_pre_shutdown(self) -> None: | ||
| """Clean up the db connection.""" | ||
| await self._engine.dispose() | ||
| async def on_broker_post_shutdown(self) -> None: | ||
| if self.config.clear_on_shutdown and self.config.file.exists(): | ||
| self.config.file.unlink() | ||
| @dataclass | ||
| class Config: | ||
| """Configuration variables.""" | ||
| file: str | Path = "amqtt.db" | ||
| """path & filename to store the sqlite session db.""" | ||
| clear_on_shutdown: bool = True | ||
| """if the broker shutdowns down normally, don't retain any information.""" | ||
| def __post_init__(self) -> None: | ||
| """Create `Path` from string path.""" | ||
| if isinstance(self.file, str): | ||
| self.file = Path(self.file) |
| """Module for the shadow state plugin.""" | ||
| from .plugin import ShadowPlugin, ShadowTopicAuthPlugin | ||
| from .states import ShadowOperation | ||
| __all__ = ["ShadowOperation", "ShadowPlugin", "ShadowTopicAuthPlugin"] |
| from collections.abc import MutableMapping | ||
| from dataclasses import dataclass, fields, is_dataclass | ||
| import json | ||
| from typing import Any | ||
| from amqtt.contrib.shadows.states import MetaTimestamp, ShadowOperation, State, StateDocument | ||
| def asdict_no_none(obj: Any) -> Any: | ||
| """Create dictionary from dataclass, but eliminate any key set to `None`.""" | ||
| if is_dataclass(obj): | ||
| result = {} | ||
| for f in fields(obj): | ||
| value = getattr(obj, f.name) | ||
| if value is not None: | ||
| result[f.name] = asdict_no_none(value) | ||
| return result | ||
| if isinstance(obj, list): | ||
| return [asdict_no_none(item) for item in obj if item is not None] | ||
| if isinstance(obj, dict): | ||
| return { | ||
| key: asdict_no_none(value) | ||
| for key, value in obj.items() | ||
| if value is not None | ||
| } | ||
| return obj | ||
| def create_shadow_topic(device_id: str, shadow_name: str, message_op: "ShadowOperation") -> str: | ||
| """Create a shadow topic for message type.""" | ||
| return f"$shadow/{device_id}/{shadow_name}/{message_op}" | ||
| class ShadowMessage: | ||
| def to_message(self) -> bytes: | ||
| return json.dumps(asdict_no_none(self)).encode("utf-8") | ||
| @dataclass | ||
| class GetAcceptedMessage(ShadowMessage): | ||
| state: State[dict[str, Any]] | ||
| metadata: State[MetaTimestamp] | ||
| timestamp: int | ||
| version: int | ||
| @staticmethod | ||
| def topic(device_id: str, shadow_name: str) -> str: | ||
| return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT) | ||
| @dataclass | ||
| class GetRejectedMessage(ShadowMessage): | ||
| code: int | ||
| message: str | ||
| timestamp: int | None = None | ||
| @staticmethod | ||
| def topic(device_id: str, shadow_name: str) -> str: | ||
| return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT) | ||
| @dataclass | ||
| class UpdateAcceptedMessage(ShadowMessage): | ||
| state: State[dict[str, Any]] | ||
| metadata: State[MetaTimestamp] | ||
| timestamp: int | ||
| version: int | ||
| @staticmethod | ||
| def topic(device_id: str, shadow_name: str) -> str: | ||
| return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_ACCEPT) | ||
| @dataclass | ||
| class UpdateRejectedMessage(ShadowMessage): | ||
| code: int | ||
| message: str | ||
| timestamp: int | ||
| @staticmethod | ||
| def topic(device_id: str, shadow_name: str) -> str: | ||
| return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT) | ||
| @dataclass | ||
| class UpdateDeltaMessage(ShadowMessage): | ||
| state: MutableMapping[str, Any] | ||
| metadata: MutableMapping[str, Any] | ||
| timestamp: int | ||
| version: int | ||
| @staticmethod | ||
| def topic(device_id: str, shadow_name: str) -> str: | ||
| return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA) | ||
| class UpdateIotaMessage(UpdateDeltaMessage): | ||
| """Same format, corollary name.""" | ||
| @staticmethod | ||
| def topic(device_id: str, shadow_name: str) -> str: | ||
| return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA) | ||
| @dataclass | ||
| class UpdateDocumentMessage(ShadowMessage): | ||
| previous: StateDocument | ||
| current: StateDocument | ||
| timestamp: int | ||
| @staticmethod | ||
| def topic(device_id: str, shadow_name: str) -> str: | ||
| return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DOCUMENTS) |
| from collections.abc import Sequence | ||
| from dataclasses import asdict | ||
| import logging | ||
| import time | ||
| from typing import Any, Optional | ||
| import uuid | ||
| from sqlalchemy import JSON, CheckConstraint, Integer, String, UniqueConstraint, desc, event, func, select | ||
| from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession | ||
| from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, Session, make_transient, mapped_column | ||
| from amqtt.contrib.shadows.states import StateDocument | ||
| logger = logging.getLogger(__name__) | ||
| class ShadowUpdateError(Exception): | ||
| def __init__(self, message: str = "updating an existing Shadow is not allowed") -> None: | ||
| super().__init__(message) | ||
| class ShadowBase(DeclarativeBase): | ||
| pass | ||
| async def sync_shadow_base(connection: AsyncConnection) -> None: | ||
| """Create tables and table schemas.""" | ||
| await connection.run_sync(ShadowBase.metadata.create_all) | ||
| def default_state_document() -> dict[str, Any]: | ||
| """Create a default (empty) state document, factory for model field.""" | ||
| return asdict(StateDocument()) | ||
| class Shadow(ShadowBase): | ||
| __tablename__ = "shadows_shadow" | ||
| id: Mapped[str | None] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) | ||
| device_id: Mapped[str] = mapped_column(String(128), nullable=False) | ||
| name: Mapped[str] = mapped_column(String(128), nullable=False) | ||
| version: Mapped[int] = mapped_column(Integer, nullable=False) | ||
| _state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict) | ||
| created_at: Mapped[int] = mapped_column(Integer, default=lambda: int(time.time()), nullable=False) | ||
| __table_args__ = ( | ||
| CheckConstraint("version > 0", name="check_quantity_positive"), | ||
| UniqueConstraint("device_id", "name", "version", name="uq_device_id_name_version"), | ||
| ) | ||
| @property | ||
| def state(self) -> StateDocument: | ||
| if not self._state: | ||
| return StateDocument() | ||
| return StateDocument.from_dict(self._state) | ||
| @state.setter | ||
| def state(self, value: StateDocument) -> None: | ||
| self._state = asdict(value) | ||
| @classmethod | ||
| async def latest_version(cls, session: AsyncSession, device_id: str, name: str) -> Optional["Shadow"]: | ||
| """Get the latest version of the shadow associated with the device and name.""" | ||
| stmt = ( | ||
| select(cls).where( | ||
| cls.device_id == device_id, | ||
| cls.name == name | ||
| ).order_by(desc(cls.version)).limit(1) | ||
| ) | ||
| result = await session.execute(stmt) | ||
| return result.scalar_one_or_none() | ||
| @classmethod | ||
| async def all(cls, session: AsyncSession, device_id: str, name: str) -> Sequence["Shadow"]: | ||
| """Return a list of all shadows associated with the device and name.""" | ||
| stmt = ( | ||
| select(cls).where( | ||
| cls.device_id == device_id, | ||
| cls.name == name | ||
| ).order_by(desc(cls.version))) | ||
| result = await session.execute(stmt) | ||
| return result.scalars().all() | ||
| @event.listens_for(Shadow, "before_insert") | ||
| def assign_incremental_version(_: Mapper[Any], connection: Session, target: "Shadow") -> None: | ||
| """Get the latest version of the state document.""" | ||
| stmt = ( | ||
| select(func.max(Shadow.version)) | ||
| .where( | ||
| Shadow.device_id == target.device_id, | ||
| Shadow.name == target.name | ||
| ) | ||
| ) | ||
| result = connection.execute(stmt).scalar_one_or_none() | ||
| target.version = (result or 0) + 1 | ||
| @event.listens_for(Shadow, "before_update") | ||
| def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow") -> None: | ||
| """Prevent existing shadow from being updated.""" | ||
| raise ShadowUpdateError | ||
| @event.listens_for(Session, "before_flush") | ||
| def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None: | ||
| """Force a shadow to insert a new version, instead of updating an existing.""" | ||
| # Make a copy of the dirty set so we can safely mutate the session | ||
| dirty = list(session.dirty) | ||
| for obj in dirty: | ||
| if not session.is_modified(obj, include_collections=False): | ||
| continue # skip unchanged | ||
| # You can scope this to a particular class | ||
| if not isinstance(obj, Shadow): | ||
| continue | ||
| # Clone logic: convert update into insert | ||
| session.expunge(obj) # remove from session | ||
| make_transient(obj) # remove identity and history | ||
| obj.id = "" # clear primary key | ||
| obj.version += 1 # bump version or modify fields | ||
| session.add(obj) # re-add as new object | ||
| _listener_example = '''# | ||
| # @event.listens_for(Shadow, "before_insert") | ||
| # def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None: | ||
| # """Listen for insertion and convert state document to json.""" | ||
| # if not isinstance(target.state, StateDocument): | ||
| # msg = "'state' field needs to be a StateDocument" | ||
| # raise TypeError(msg) | ||
| # | ||
| # target.state = target.state.to_dict() | ||
| ''' |
| from collections import defaultdict | ||
| from dataclasses import dataclass, field | ||
| import json | ||
| import re | ||
| from typing import Any | ||
| from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine | ||
| from amqtt.broker import BrokerContext | ||
| from amqtt.contexts import Action | ||
| from amqtt.contrib.shadows.messages import ( | ||
| GetAcceptedMessage, | ||
| GetRejectedMessage, | ||
| UpdateAcceptedMessage, | ||
| UpdateDeltaMessage, | ||
| UpdateDocumentMessage, | ||
| UpdateIotaMessage, | ||
| ) | ||
| from amqtt.contrib.shadows.models import Shadow, sync_shadow_base | ||
| from amqtt.contrib.shadows.states import ( | ||
| ShadowOperation, | ||
| StateDocument, | ||
| calculate_delta_update, | ||
| calculate_iota_update, | ||
| ) | ||
| from amqtt.plugins.base import BasePlugin, BaseTopicPlugin | ||
| from amqtt.session import ApplicationMessage, Session | ||
| shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)") | ||
| DeviceID = str | ||
| ShadowName = str | ||
| @dataclass | ||
| class ShadowTopic: | ||
| device_id: DeviceID | ||
| name: ShadowName | ||
| message_op: ShadowOperation | ||
| def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]: | ||
| """Nested defaultdict for shadow cache.""" | ||
| return defaultdict(shadow_dict) # type: ignore[arg-type] | ||
| class ShadowPlugin(BasePlugin[BrokerContext]): | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| self._shadows: dict[DeviceID, dict[ShadowName, StateDocument]] = defaultdict(dict) | ||
| self._engine = create_async_engine(self.config.connection) | ||
| self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False) | ||
| async def on_broker_pre_start(self) -> None: | ||
| """Sync the schema.""" | ||
| async with self._engine.begin() as conn: | ||
| await sync_shadow_base(conn) | ||
| @staticmethod | ||
| def shadow_topic_match(topic: str) -> ShadowTopic | None: | ||
| """Check if topic matches the shadow topic format.""" | ||
| # pattern is "$shadow/<username>/<shadow_name>/get, update, etc | ||
| match = shadow_topic_re.search(topic) | ||
| if match: | ||
| groups = match.groupdict() | ||
| return ShadowTopic(groups["client_id"], groups["shadow_name"], ShadowOperation(groups["request"])) | ||
| return None | ||
| async def _handle_get(self, st: ShadowTopic) -> None: | ||
| """Send 'accepted.""" | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| shadow = await Shadow.latest_version(db_session, st.device_id, st.name) | ||
| if not shadow: | ||
| reject_msg = GetRejectedMessage( | ||
| code=404, | ||
| message="shadow not found", | ||
| ) | ||
| await self.context.broadcast_message(reject_msg.topic(st.device_id, st.name), reject_msg.to_message()) | ||
| return | ||
| accept_msg = GetAcceptedMessage( | ||
| state=shadow.state.state, | ||
| metadata=shadow.state.metadata, | ||
| timestamp=shadow.created_at, | ||
| version=shadow.version | ||
| ) | ||
| await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message()) | ||
| async def _handle_update(self, st: ShadowTopic, update: dict[str, Any]) -> None: | ||
| async with self._db_session_maker() as db_session, db_session.begin(): | ||
| shadow = await Shadow.latest_version(db_session, st.device_id, st.name) | ||
| if not shadow: | ||
| shadow = Shadow(device_id=st.device_id, name=st.name) | ||
| state_update = StateDocument.from_dict(update) | ||
| prev_state = shadow.state or StateDocument() | ||
| prev_state.version = shadow.version or 0 # only required when generating shadow messages | ||
| prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages | ||
| next_state = prev_state + state_update | ||
| shadow.state = next_state | ||
| db_session.add(shadow) | ||
| await db_session.commit() | ||
| next_state.version = shadow.version | ||
| next_state.timestamp = shadow.created_at | ||
| accept_msg = UpdateAcceptedMessage( | ||
| state=next_state.state, | ||
| metadata=next_state.metadata, | ||
| timestamp=123, | ||
| version=1 | ||
| ) | ||
| await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message()) | ||
| delta_msg = UpdateDeltaMessage( | ||
| state=calculate_delta_update(next_state.state.desired, next_state.state.reported), | ||
| metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported), | ||
| version=shadow.version, | ||
| timestamp=shadow.created_at | ||
| ) | ||
| await self.context.broadcast_message(delta_msg.topic(st.device_id, st.name), delta_msg.to_message()) | ||
| iota_msg = UpdateIotaMessage( | ||
| state=calculate_iota_update(next_state.state.desired, next_state.state.reported), | ||
| metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported), | ||
| version=shadow.version, | ||
| timestamp=shadow.created_at | ||
| ) | ||
| await self.context.broadcast_message(iota_msg.topic(st.device_id, st.name), iota_msg.to_message()) | ||
| doc_msg = UpdateDocumentMessage( | ||
| previous=prev_state, | ||
| current=next_state, | ||
| timestamp=shadow.created_at | ||
| ) | ||
| await self.context.broadcast_message(doc_msg.topic(st.device_id, st.name), doc_msg.to_message()) | ||
| async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None: | ||
| """Process a message that was received from a client.""" | ||
| topic = message.topic | ||
| if not topic.startswith("$shadow"): # this is less overhead than do the full regular expression match | ||
| return | ||
| if not (shadow_topic := self.shadow_topic_match(topic)): | ||
| return | ||
| match shadow_topic.message_op: | ||
| case ShadowOperation.GET: | ||
| await self._handle_get(shadow_topic) | ||
| case ShadowOperation.UPDATE: | ||
| await self._handle_update(shadow_topic, json.loads(message.data.decode("utf-8"))) | ||
| @dataclass | ||
| class Config: | ||
| """Configuration for shadow plugin.""" | ||
| connection: str | ||
| """SQLAlchemy connection string for the asyncio version of the database connector: | ||
| - `mysql+aiomysql://user:password@host:port/dbname` | ||
| - `postgresql+asyncpg://user:password@host:port/dbname` | ||
| - `sqlite+aiosqlite:///dbfilename.db` | ||
| """ | ||
| class ShadowTopicAuthPlugin(BaseTopicPlugin): | ||
| async def topic_filtering(self, *, | ||
| session: Session | None = None, | ||
| topic: str | None = None, | ||
| action: Action | None = None) -> bool | None: | ||
| session = session or Session() | ||
| if not topic: | ||
| return False | ||
| shadow_topic = ShadowPlugin.shadow_topic_match(topic) | ||
| if not shadow_topic: | ||
| return False | ||
| return shadow_topic.device_id == session.username or session.username in self.config.superusers | ||
| @dataclass | ||
| class Config: | ||
| """Configuration for only allowing devices access to their own shadow topics.""" | ||
| superusers: list[str] = field(default_factory=list) | ||
| """A list of one or more usernames that can write to any device topic, | ||
| primarily for the central app sending updates to devices.""" |
| from collections import Counter | ||
| from collections.abc import MutableMapping | ||
| from dataclasses import dataclass, field | ||
| try: | ||
| from enum import StrEnum | ||
| except ImportError: | ||
| # support for python 3.10 | ||
| from enum import Enum | ||
| class StrEnum(str, Enum): # type: ignore[no-redef] | ||
| pass | ||
| import time | ||
| from typing import Any, Generic, TypeVar | ||
| from mergedeep import merge | ||
| C = TypeVar("C", bound=Any) | ||
| class StateError(Exception): | ||
| def __init__(self, msg: str = "'state' field is required") -> None: | ||
| super().__init__(msg) | ||
| @dataclass | ||
| class MetaTimestamp: | ||
| timestamp: int = 0 | ||
| def __eq__(self, other: object) -> bool: | ||
| """Compare timestamps.""" | ||
| if isinstance(other, int): | ||
| return self.timestamp == other | ||
| if isinstance(other, self.__class__): | ||
| return self.timestamp == other.timestamp | ||
| msg = "needs to be int or MetaTimestamp" | ||
| raise ValueError(msg) | ||
| # numeric operations to make this dataclass transparent | ||
| def __abs__(self) -> int: | ||
| """Absolute timestamp.""" | ||
| return self.timestamp | ||
| def __add__(self, other: int) -> int: | ||
| """Add to a timestamp.""" | ||
| return self.timestamp + other | ||
| def __sub__(self, other: int) -> int: | ||
| """Subtract from a timestamp.""" | ||
| return self.timestamp - other | ||
| def __mul__(self, other: int) -> int: | ||
| """Multiply a timestamp.""" | ||
| return self.timestamp * other | ||
| def __float__(self) -> float: | ||
| """Convert timestamp to float.""" | ||
| return float(self.timestamp) | ||
| def __int__(self) -> int: | ||
| """Convert timestamp to int.""" | ||
| return int(self.timestamp) | ||
| def __lt__(self, other: int) -> bool: | ||
| """Compare timestamp.""" | ||
| return self.timestamp < other | ||
| def __le__(self, other: int) -> bool: | ||
| """Compare timestamp.""" | ||
| return self.timestamp <= other | ||
| def __gt__(self, other: int) -> bool: | ||
| """Compare timestamp.""" | ||
| return self.timestamp > other | ||
| def __ge__(self, other: int) -> bool: | ||
| """Compare timestamp.""" | ||
| return self.timestamp >= other | ||
| def create_metadata(state: MutableMapping[str, Any], timestamp: int) -> dict[str, Any]: | ||
| """Create metadata (timestamps) for each of the keys in 'state'.""" | ||
| metadata: dict[str, Any] = {} | ||
| for key, value in state.items(): | ||
| if isinstance(value, dict): | ||
| metadata[key] = create_metadata(value, timestamp) | ||
| elif value is None: | ||
| metadata[key] = None | ||
| else: | ||
| metadata[key] = MetaTimestamp(timestamp) | ||
| return metadata | ||
| def calculate_delta_update(desired: MutableMapping[str, Any], | ||
| reported: MutableMapping[str, Any], | ||
| depth: bool = True, | ||
| exclude_nones: bool = True, | ||
| ordered_lists: bool = True) -> dict[str, Any]: | ||
| """Calculate state differences between desired and reported.""" | ||
| diff_dict = {} | ||
| for key, value in desired.items(): | ||
| if value is None and exclude_nones: | ||
| continue | ||
| # if the desired has an element that the reported does not... | ||
| if key not in reported: | ||
| diff_dict[key] = value | ||
| # if the desired has an element that's a list, but the list is | ||
| elif isinstance(value, list) and not ordered_lists: | ||
| if Counter(value) != Counter(reported[key]): | ||
| diff_dict[key] = value | ||
| elif isinstance(value, dict) and depth: | ||
| # recurse, report when there is a difference | ||
| obj_diff = calculate_delta_update(value, reported[key]) | ||
| if obj_diff: | ||
| diff_dict[key] = obj_diff | ||
| elif value != reported[key]: | ||
| diff_dict[key] = value | ||
| return diff_dict | ||
| def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMapping[str, Any]) -> MutableMapping[str, Any]: | ||
| """Calculate state differences between desired and reported (including missing keys).""" | ||
| delta = calculate_delta_update(desired, reported, depth=False, exclude_nones=False) | ||
| for key in reported: | ||
| if key not in desired: | ||
| delta[key] = None | ||
| return delta | ||
| @dataclass | ||
| class State(Generic[C]): | ||
| desired: MutableMapping[str, C] = field(default_factory=dict) | ||
| reported: MutableMapping[str, C] = field(default_factory=dict) | ||
| @classmethod | ||
| def from_dict(cls, data: dict[str, C]) -> "State[C]": | ||
| """Create state from dictionary.""" | ||
| return cls( | ||
| desired=data.get("desired", {}), | ||
| reported=data.get("reported", {}) | ||
| ) | ||
| def __bool__(self) -> bool: | ||
| """Determine if state is empty.""" | ||
| return bool(self.desired) or bool(self.reported) | ||
| def __add__(self, other: "State[C]") -> "State[C]": | ||
| """Merge states together.""" | ||
| return State( | ||
| desired=merge({}, self.desired, other.desired), | ||
| reported=merge({}, self.reported, other.reported) | ||
| ) | ||
| @dataclass | ||
| class StateDocument: | ||
| state: State[dict[str, Any]] = field(default_factory=State) | ||
| metadata: State[MetaTimestamp] = field(default_factory=State) | ||
| version: int | None = None # only required when generating shadow messages | ||
| timestamp: int | None = None # only required when generating shadow messages | ||
| @classmethod | ||
| def from_dict(cls, data: dict[str, Any]) -> "StateDocument": | ||
| """Create state document from dictionary.""" | ||
| now = int(time.time()) | ||
| if data and "state" not in data: | ||
| raise StateError | ||
| state = State.from_dict(data.get("state", {})) | ||
| metadata = State( | ||
| desired=create_metadata(state.desired, now), | ||
| reported=create_metadata(state.reported, now) | ||
| ) | ||
| return cls(state=state, metadata=metadata) | ||
| def __post_init__(self) -> None: | ||
| """Initialize meta data if not provided.""" | ||
| now = int(time.time()) | ||
| if not self.metadata: | ||
| self.metadata = State( | ||
| desired=create_metadata(self.state.desired, now), | ||
| reported=create_metadata(self.state.reported, now), | ||
| ) | ||
| def __add__(self, other: "StateDocument") -> "StateDocument": | ||
| """Merge two state documents together.""" | ||
| return StateDocument( | ||
| state=self.state + other.state, | ||
| metadata=self.metadata + other.metadata | ||
| ) | ||
| class ShadowOperation(StrEnum): | ||
| GET = "get" | ||
| UPDATE = "update" | ||
| GET_ACCEPT = "get/accepted" | ||
| GET_REJECT = "get/rejected" | ||
| UPDATE_ACCEPT = "update/accepted" | ||
| UPDATE_REJECT = "update/rejected" | ||
| UPDATE_DOCUMENTS = "update/documents" | ||
| UPDATE_DELTA = "update/delta" | ||
| UPDATE_IOTA = "update/iota" |
| import logging | ||
| from pathlib import Path | ||
| import sys | ||
| import typer | ||
| logger = logging.getLogger(__name__) | ||
| app = typer.Typer(add_completion=False, rich_markup_mode=None) | ||
| def main() -> None: | ||
| """Run the cli for `ca_creds`.""" | ||
| app() | ||
| @app.command() | ||
| def ca_creds( | ||
| country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"), | ||
| state: str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"), | ||
| locality: str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"), | ||
| org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"), | ||
| cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"), | ||
| output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"), | ||
| ) -> None: | ||
| """Generate a self-signed key and certificate to be used as the root CA, with a key size of 2048 and a 1-year expiration.""" | ||
| formatter = "[%(asctime)s] :: %(levelname)s - %(message)s" | ||
| logging.basicConfig(level=logging.INFO, format=formatter) | ||
| try: | ||
| from amqtt.contrib.cert import generate_root_creds, write_key_and_crt # pylint: disable=import-outside-toplevel | ||
| except ImportError: | ||
| msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`" | ||
| logger.critical(msg) | ||
| sys.exit(1) | ||
| ca_key, ca_crt = generate_root_creds(country=country, state=state, locality=locality, org_name=org_name, cn=cn) | ||
| write_key_and_crt(ca_key, ca_crt, "ca", Path(output_dir)) | ||
| if __name__ == "__main__": | ||
| main() |
| import logging | ||
| from pathlib import Path | ||
| import sys | ||
| import typer | ||
| logger = logging.getLogger(__name__) | ||
| app = typer.Typer(add_completion=False, rich_markup_mode=None) | ||
| def main() -> None: | ||
| """Run the `device_creds` cli.""" | ||
| app() | ||
| @app.command() | ||
| def device_creds( # pylint: disable=too-many-locals | ||
| country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"), | ||
| org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"), | ||
| device_id: str = typer.Option(..., "--device-id", help="device id for the SAN"), | ||
| uri: str = typer.Option(..., "--uri", help="domain name for device SAN"), | ||
| output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"), | ||
| ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="root key filename used for signing."), | ||
| ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="root cert filename used for signing."), | ||
| ) -> None: | ||
| """Generate a key and certificate for each device in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa: E501 | ||
| formatter = "[%(asctime)s] :: %(levelname)s - %(message)s" | ||
| logging.basicConfig(level=logging.INFO, format=formatter) | ||
| try: | ||
| from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel | ||
| generate_device_csr, | ||
| load_ca, | ||
| sign_csr, | ||
| write_key_and_crt, | ||
| ) | ||
| except ImportError: | ||
| msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`" | ||
| logger.critical(msg) | ||
| sys.exit(1) | ||
| ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn) | ||
| uri_san = f"spiffe://{uri}/device/{device_id}" | ||
| dns_san = f"{device_id}.local" | ||
| device_key, device_csr = generate_device_csr( | ||
| country=country, | ||
| org_name=org_name, | ||
| common_name=device_id, | ||
| uri_san=uri_san, | ||
| dns_san=dns_san | ||
| ) | ||
| device_crt = sign_csr(device_csr, ca_key, ca_crt) | ||
| write_key_and_crt(device_key, device_crt, device_id, Path(output_dir)) | ||
| logger.info(f"✅ Created: {device_id}.crt and {device_id}.key") | ||
| if __name__ == "__main__": | ||
| main() |
| import logging | ||
| import sys | ||
| from amqtt.errors import MQTTError | ||
| logger = logging.getLogger(__name__) | ||
| def main() -> None: | ||
| """Run the auth db cli.""" | ||
| try: | ||
| from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel | ||
| except ImportError: | ||
| logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`") | ||
| sys.exit(1) | ||
| from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel | ||
| try: | ||
| topic_app() | ||
| except ModuleNotFoundError as mnfe: | ||
| logger.critical(f"Please install database-specific dependencies: {mnfe}") | ||
| sys.exit(1) | ||
| except ValueError as ve: | ||
| if "greenlet" in f"{ve}": | ||
| logger.critical("Please install database-specific dependencies: 'greenlet'") | ||
| sys.exit(1) | ||
| logger.critical(f"Unknown error: {ve}") | ||
| sys.exit(1) | ||
| except MQTTError as me: | ||
| logger.critical(f"could not execute command: {me}") | ||
| sys.exit(1) | ||
| if __name__ == "__main__": | ||
| main() |
| import logging | ||
| import sys | ||
| from amqtt.errors import MQTTError | ||
| logger = logging.getLogger(__name__) | ||
| def main() -> None: | ||
| """Run the auth db cli.""" | ||
| try: | ||
| from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel | ||
| except ImportError: | ||
| logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`") | ||
| sys.exit(1) | ||
| from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel | ||
| try: | ||
| user_app() | ||
| except ModuleNotFoundError as mnfe: | ||
| logger.critical(f"Please install database-specific dependencies: {mnfe}") | ||
| sys.exit(1) | ||
| except ValueError as ve: | ||
| if "greenlet" in f"{ve}": | ||
| logger.critical("Please install database-specific dependencies: 'greenlet'") | ||
| sys.exit(1) | ||
| logger.critical(f"Unknown error: {ve}") | ||
| sys.exit(1) | ||
| except MQTTError as me: | ||
| logger.critical(f"could not execute command: {me}") | ||
| sys.exit(1) | ||
| if __name__ == "__main__": | ||
| main() |
| import logging | ||
| from pathlib import Path | ||
| import sys | ||
| import typer | ||
| logger = logging.getLogger(__name__) | ||
| app = typer.Typer(add_completion=False, rich_markup_mode=None) | ||
| def main() -> None: | ||
| """Run the `server_creds` cli.""" | ||
| app() | ||
| @app.command() | ||
| def server_creds( | ||
| country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"), | ||
| org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"), | ||
| cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"), | ||
| output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"), | ||
| ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="server key output filename."), | ||
| ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."), | ||
| ) -> None: | ||
| """Generate a key and certificate for the broker in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa : E501 | ||
| formatter = "[%(asctime)s] :: %(levelname)s - %(message)s" | ||
| logging.basicConfig(level=logging.INFO, format=formatter) | ||
| try: | ||
| from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel | ||
| generate_server_csr, | ||
| load_ca, | ||
| sign_csr, | ||
| write_key_and_crt, | ||
| ) | ||
| except ImportError: | ||
| msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`" | ||
| logger.critical(msg) | ||
| sys.exit(1) | ||
| ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn) | ||
| server_key, server_csr = generate_server_csr(country=country, org_name=org_name, cn=cn) | ||
| server_crt = sign_csr(server_csr, ca_key, ca_crt) | ||
| write_key_and_crt(server_key, server_crt, "server", Path(output_dir)) | ||
| if __name__ == "__main__": | ||
| main() |
+3
-0
@@ -7,2 +7,5 @@ #------- Package & Cache Files ------- | ||
| *.pem | ||
| *.crt | ||
| *.key | ||
| *.patch | ||
@@ -9,0 +12,0 @@ #------- Environment Files ------- |
| """INIT.""" | ||
| __version__ = "0.11.2" | ||
| __version__ = "0.11.3" |
+16
-0
@@ -6,2 +6,4 @@ from abc import ABC, abstractmethod | ||
| import logging | ||
| import ssl | ||
| from typing import cast | ||
@@ -57,2 +59,7 @@ from websockets import ConnectionClosed | ||
| @abstractmethod | ||
| def get_ssl_info(self) -> ssl.SSLObject | None: | ||
| """Return peer certificate information (if available) used to establish a TLS session.""" | ||
| raise NotImplementedError | ||
| @abstractmethod | ||
| async def close(self) -> None: | ||
@@ -126,2 +133,5 @@ """Close the protocol connection.""" | ||
| def get_ssl_info(self) -> ssl.SSLObject | None: | ||
| return cast("ssl.SSLObject", self._protocol.transport.get_extra_info("ssl_object")) | ||
| async def close(self) -> None: | ||
@@ -176,2 +186,5 @@ await self._protocol.close() | ||
| def get_ssl_info(self) -> ssl.SSLObject | None: | ||
| return cast("ssl.SSLObject", self._writer.get_extra_info("ssl_object")) | ||
| async def close(self) -> None: | ||
@@ -211,2 +224,5 @@ if not self.is_closed: | ||
| def get_ssl_info(self) -> ssl.SSLObject | None: | ||
| return None | ||
| def __init__(self, buffer: bytes = b"") -> None: | ||
@@ -213,0 +229,0 @@ self._stream = io.BytesIO(buffer) |
+200
-92
@@ -5,8 +5,8 @@ import asyncio | ||
| from collections.abc import Generator | ||
| import copy | ||
| from functools import partial | ||
| import logging | ||
| from pathlib import Path | ||
| from math import floor | ||
| import re | ||
| import ssl | ||
| import time | ||
| from typing import Any, ClassVar, TypeAlias | ||
@@ -26,7 +26,7 @@ | ||
| ) | ||
| from amqtt.contexts import Action, BaseContext | ||
| from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType | ||
| from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError | ||
| from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler | ||
| from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session | ||
| from amqtt.utils import format_client_message, gen_client_id, read_yaml_config | ||
| from amqtt.utils import format_client_message, gen_client_id | ||
@@ -38,9 +38,4 @@ from .events import BrokerEvents | ||
| _CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]] | ||
| _BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None] | ||
| _defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml") | ||
| # Default port numbers | ||
@@ -63,2 +58,4 @@ DEFAULT_PORTS = {"tcp": 1883, "ws": 8883} | ||
| class Server: | ||
| """Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle.""" | ||
| def __init__( | ||
@@ -101,18 +98,32 @@ self, | ||
| class ExternalServer(Server): | ||
| """For external listeners, the connection lifecycle is handled by that implementation so these are no-ops.""" | ||
| def __init__(self) -> None: | ||
| super().__init__("aiohttp", None) # type: ignore[arg-type] | ||
| async def acquire_connection(self) -> None: | ||
| pass | ||
| def release_connection(self) -> None: | ||
| pass | ||
| async def close_instance(self) -> None: | ||
| pass | ||
| class BrokerContext(BaseContext): | ||
| """BrokerContext is used as the context passed to plugins interacting with the broker. | ||
| """Used to provide the server's context as well as public methods for accessing internal state.""" | ||
| It act as an adapter to broker services from plugins developed for HBMQTT broker. | ||
| """ | ||
| def __init__(self, broker: "Broker") -> None: | ||
| super().__init__() | ||
| self.config: _CONFIG_LISTENER | None = None | ||
| self.config: BrokerConfig | None = None | ||
| self._broker_instance = broker | ||
| async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None: | ||
| """Send message to all client sessions subscribing to `topic`.""" | ||
| await self._broker_instance.internal_message_broadcast(topic, data, qos) | ||
| def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None: | ||
| self._broker_instance.retain_message(None, topic_name, data, qos) | ||
| async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None: | ||
| await self._broker_instance.retain_message(None, topic_name, data, qos) | ||
@@ -124,2 +135,6 @@ @property | ||
| def get_session(self, client_id: str) -> Session | None: | ||
| """Return the session associated with `client_id`, if it exists.""" | ||
| return self._broker_instance.sessions.get(client_id, (None, None))[0] | ||
| @property | ||
@@ -133,3 +148,17 @@ def retained_messages(self) -> dict[str, RetainedApplicationMessage]: | ||
| async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None: | ||
| """Create a topic subscription for the given `client_id`. | ||
| If a client session doesn't exist for `client_id`, create a disconnected session. | ||
| If `topic` and `qos` are both `None`, only create the client session. | ||
| """ | ||
| if client_id not in self._broker_instance.sessions: | ||
| broker_handler, session = self._broker_instance.create_offline_session(client_id) | ||
| self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001 | ||
| if topic is not None and qos is not None: | ||
| session, _ = self._broker_instance.sessions[client_id] | ||
| await self._broker_instance.add_subscription((topic, qos), session) | ||
| class Broker: | ||
@@ -139,3 +168,3 @@ """MQTT 3.1.1 compliant broker implementation. | ||
| Args: | ||
| config: dictionary of configuration options (see [broker configuration](broker_config.md)). | ||
| config: `BrokerConfig` or dictionary of equivalent structure options (see [broker configuration](broker_config.md)). | ||
| loop: asyncio loop. defaults to `asyncio.new_event_loop()`. | ||
@@ -145,3 +174,5 @@ plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`. | ||
| Raises: | ||
| BrokerError, ParserError, PluginError | ||
| BrokerError: problem with broker configuration | ||
| PluginImportError: if importing a plugin from configuration | ||
| PluginInitError: if initialization plugin fails | ||
@@ -162,3 +193,3 @@ """ | ||
| self, | ||
| config: _CONFIG_LISTENER | None = None, | ||
| config: BrokerConfig | dict[str, Any] | None = None, | ||
| loop: asyncio.AbstractEventLoop | None = None, | ||
@@ -169,11 +200,11 @@ plugin_namespace: str | None = None, | ||
| self.logger = logging.getLogger(__name__) | ||
| self.config = copy.deepcopy(_defaults or {}) | ||
| if config is not None: | ||
| # if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config | ||
| if ("auth" in config or "topic-check" in config) and "plugins" not in config: | ||
| # set to None so that the config isn't updated with the new-style default plugin list | ||
| config["plugins"] = None # type: ignore[assignment] | ||
| self.config.update(config) | ||
| self._build_listeners_config(self.config) | ||
| if isinstance(config, dict): | ||
| self.config = BrokerConfig.from_dict(config) | ||
| else: | ||
| self.config = config or BrokerConfig() | ||
| # listeners are populated from default within BrokerConfig | ||
| self.listeners_config = self.config.listeners | ||
| self._loop = loop or asyncio.get_running_loop() | ||
@@ -196,2 +227,5 @@ self._servers: dict[str, Server] = {} | ||
| # Task for session monitor | ||
| self._session_monitor_task: asyncio.Task[Any] | None = None | ||
| # Initialize plugins manager | ||
@@ -204,22 +238,2 @@ | ||
| def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None: | ||
| self.listeners_config = {} | ||
| try: | ||
| listeners_config = broker_config.get("listeners") | ||
| if not isinstance(listeners_config, dict): | ||
| msg = "Listener config not found or invalid" | ||
| raise BrokerError(msg) | ||
| defaults = listeners_config.get("default") | ||
| if defaults is None: | ||
| msg = "Listener config has not default included or is invalid" | ||
| raise BrokerError(msg) | ||
| for listener_name, listener_conf in listeners_config.items(): | ||
| config = defaults.copy() | ||
| config.update(listener_conf) | ||
| self.listeners_config[listener_name] = config | ||
| except KeyError as ke: | ||
| msg = f"Listener config not found or invalid: {ke}" | ||
| raise BrokerError(msg) from ke | ||
| def _init_states(self) -> None: | ||
@@ -261,2 +275,3 @@ self.transitions = Machine(states=Broker.states, initial="new") | ||
| self._broadcast_task = asyncio.ensure_future(self._broadcast_loop()) | ||
| self._session_monitor_task = asyncio.create_task(self._session_monitor()) | ||
| self.logger.debug("Broker started") | ||
@@ -279,14 +294,23 @@ except Exception as e: | ||
| try: | ||
| address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) | ||
| except ValueError as e: | ||
| msg = f"Invalid port value in bind value: {listener['bind']}" | ||
| raise BrokerError(msg) from e | ||
| # for listeners which are external, don't need to create a server | ||
| if listener.type == ListenerType.EXTERNAL: | ||
| instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context) | ||
| self._servers[listener_name] = Server(listener_name, instance, max_connections) | ||
| # broker still needs to associate a new connection to the listener | ||
| self.logger.info(f"External listener exists for '{listener_name}' ") | ||
| self._servers[listener_name] = ExternalServer() | ||
| else: | ||
| # for tcp and websockets, start servers to listen for inbound connections | ||
| try: | ||
| address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]]) | ||
| except ValueError as e: | ||
| msg = f"Invalid port value in bind value: {listener['bind']}" | ||
| raise BrokerError(msg) from e | ||
| self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") | ||
| instance = await self._create_server_instance(listener_name, listener.type, address, port, ssl_context) | ||
| self._servers[listener_name] = Server(listener_name, instance, max_connections) | ||
| def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext: | ||
| self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})") | ||
| @staticmethod | ||
| def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext: | ||
| """Create an SSL context for a listener.""" | ||
@@ -313,3 +337,3 @@ try: | ||
| listener_name: str, | ||
| listener_type: str, | ||
| listener_type: ListenerType, | ||
| address: str | None, | ||
@@ -320,21 +344,52 @@ port: int, | ||
| """Create a server instance for a listener.""" | ||
| if listener_type == "tcp": | ||
| return await asyncio.start_server( | ||
| partial(self.stream_connected, listener_name=listener_name), | ||
| address, | ||
| port, | ||
| reuse_address=True, | ||
| ssl=ssl_context, | ||
| ) | ||
| if listener_type == "ws": | ||
| return await websockets.serve( | ||
| partial(self.ws_connected, listener_name=listener_name), | ||
| address, | ||
| port, | ||
| ssl=ssl_context, | ||
| subprotocols=[websockets.Subprotocol("mqtt")], | ||
| ) | ||
| msg = f"Unsupported listener type: {listener_type}" | ||
| raise BrokerError(msg) | ||
| match listener_type: | ||
| case ListenerType.TCP: | ||
| return await asyncio.start_server( | ||
| partial(self.stream_connected, listener_name=listener_name), | ||
| address, | ||
| port, | ||
| reuse_address=True, | ||
| ssl=ssl_context, | ||
| ) | ||
| case ListenerType.WS: | ||
| return await websockets.serve( | ||
| partial(self.ws_connected, listener_name=listener_name), | ||
| address, | ||
| port, | ||
| ssl=ssl_context, | ||
| subprotocols=[websockets.Subprotocol("mqtt")], | ||
| ) | ||
| case _: | ||
| msg = f"Unsupported listener type: {listener_type}" | ||
| raise BrokerError(msg) | ||
| async def _session_monitor(self) -> None: | ||
| self.logger.info("Starting session expiration monitor.") | ||
| while True: | ||
| session_count_before = len(self._sessions) | ||
| # clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out | ||
| sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items() | ||
| if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)] | ||
| # if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration | ||
| if self.config.session_expiry_interval is not None: | ||
| retain_after = floor(time.time() - self.config.session_expiry_interval) | ||
| sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items() | ||
| if session.transitions.state == "disconnected" and | ||
| session.last_disconnect_time and | ||
| session.last_disconnect_time < retain_after] | ||
| for client_id in sessions_to_remove: | ||
| await self._cleanup_session(client_id) | ||
| if session_count_before > (session_count_after := len(self._sessions)): | ||
| self.logger.debug(f"Expired {session_count_before - session_count_after} sessions") | ||
| await asyncio.sleep(1) | ||
| async def shutdown(self) -> None: | ||
@@ -357,2 +412,4 @@ """Stop broker instance.""" | ||
| await self._shutdown_broadcast_loop() | ||
| if self._session_monitor_task: | ||
| self._session_monitor_task.cancel() | ||
@@ -393,2 +450,6 @@ for server in self._servers.values(): | ||
| async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None: | ||
| """Engage the broker in handling the data stream to/from an established connection.""" | ||
| await self._client_connected(listener_name, reader, writer) | ||
| async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None: | ||
@@ -465,3 +526,13 @@ """Handle a new client connection.""" | ||
| self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}") | ||
| client_session, _ = self._sessions[client_session.client_id] | ||
| # even though the session previously existed, the new connection can bring updated configuration and credentials | ||
| existing_client_session, _ = self._sessions[client_session.client_id] | ||
| existing_client_session.will_flag = client_session.will_flag | ||
| existing_client_session.will_message = client_session.will_message | ||
| existing_client_session.will_topic = client_session.will_topic | ||
| existing_client_session.will_qos = client_session.will_qos | ||
| existing_client_session.keep_alive = client_session.keep_alive | ||
| existing_client_session.username = client_session.username | ||
| existing_client_session.password = client_session.password | ||
| client_session = existing_client_session | ||
| client_session.parent = 1 | ||
@@ -478,2 +549,10 @@ else: | ||
| def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]: | ||
| session = Session() | ||
| session.client_id = client_id | ||
| bph = BrokerProtocolHandler(self.plugins_manager, session) | ||
| session.transitions.disconnect() | ||
| return bph, session | ||
| async def _handle_client_session( | ||
@@ -530,6 +609,4 @@ self, | ||
| for topic in self._subscriptions: | ||
| await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session) | ||
| await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session) | ||
| await self._client_message_loop(client_session, handler) | ||
@@ -565,3 +642,2 @@ | ||
| if subscribe_waiter in done: | ||
@@ -621,3 +697,3 @@ await self._handle_subscription(client_session, handler, subscribe_waiter) | ||
| if client_session.will_retain: | ||
| self.retain_message( | ||
| await self.retain_message( | ||
| client_session, | ||
@@ -637,3 +713,2 @@ client_session.will_topic, | ||
| async def _handle_subscription( | ||
@@ -648,3 +723,3 @@ self, | ||
| subscriptions = subscribe_waiter.result() | ||
| return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics] | ||
| return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics] | ||
| await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes) | ||
@@ -689,2 +764,9 @@ for index, subscription in enumerate(subscriptions.topics): | ||
| # notify of a message's receipt, even if a client isn't necessarily allowed to send it | ||
| await self.plugins_manager.fire_event( | ||
| BrokerEvents.MESSAGE_RECEIVED, | ||
| client_id=client_session.client_id, | ||
| message=app_message, | ||
| ) | ||
| if app_message is None: | ||
@@ -711,6 +793,7 @@ self.logger.debug("app_message was empty!") | ||
| if not permitted: | ||
| self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.") | ||
| self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.") | ||
| else: | ||
| # notify that a received message is valid and is allowed to be distributed to other clients | ||
| await self.plugins_manager.fire_event( | ||
| BrokerEvents.MESSAGE_RECEIVED, | ||
| BrokerEvents.MESSAGE_BROADCAST, | ||
| client_id=client_session.client_id, | ||
@@ -721,3 +804,3 @@ message=app_message, | ||
| if app_message.publish_packet and app_message.publish_packet.retain_flag: | ||
| self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) | ||
| await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos) | ||
| return True | ||
@@ -735,6 +818,7 @@ | ||
| await handler.stop() | ||
| except Exception: | ||
| # a failure in stopping a handler shouldn't cause the broker to fail | ||
| except asyncio.QueueEmpty: | ||
| self.logger.exception("Failed to stop handler") | ||
| async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool: | ||
| async def _authenticate(self, session: Session, _: ListenerConfig) -> bool: | ||
| """Call the authenticate method on registered plugins to test user authentication. | ||
@@ -752,3 +836,3 @@ | ||
| results = [ result for _, result in returns.items() if result is not None] if returns else [] | ||
| results = [result for _, result in returns.items() if result is not None] if returns else [] | ||
| if len(results) < 1: | ||
@@ -767,3 +851,3 @@ self.logger.debug("Authentication failed: no plugin responded with a boolean") | ||
| def retain_message( | ||
| async def retain_message( | ||
| self, | ||
@@ -779,8 +863,21 @@ source_session: Session | None, | ||
| self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos) | ||
| await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, | ||
| client_id=None, | ||
| retained_message=self._retained_messages[topic_name]) | ||
| # [MQTT-3.3.1-10] | ||
| elif topic_name in self._retained_messages: | ||
| self.logger.debug(f"Clearing retained messages for topic '{topic_name}'") | ||
| cleared_message = self._retained_messages[topic_name] | ||
| cleared_message.data = b"" | ||
| await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, | ||
| client_id=None, | ||
| retained_message=cleared_message) | ||
| del self._retained_messages[topic_name] | ||
| async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int: | ||
| async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int: | ||
| topic_filter, qos = subscription | ||
@@ -889,3 +986,4 @@ if "#" in topic_filter and not topic_filter.endswith("#"): | ||
| self.logger.info(f"Task has been cancelled: {task}") | ||
| except Exception: | ||
| # if a task fails, don't want it to cause the broker to fail | ||
| except Exception: # pylint: disable=W0718 | ||
| self.logger.exception(f"Task failed and will be skipped: {task}") | ||
@@ -929,2 +1027,8 @@ | ||
| sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE) | ||
| if not sendable: | ||
| self.logger.info( | ||
| f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.") | ||
| continue | ||
| # Retain all messages which cannot be broadcasted, due to the session not being connected | ||
@@ -972,2 +1076,6 @@ # but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4] | ||
| await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE, | ||
| client_id=target_session.client_id, | ||
| retained_message=retained_message) | ||
| if self.logger.isEnabledFor(logging.DEBUG): | ||
@@ -974,0 +1082,0 @@ self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}") |
+44
-43
@@ -5,6 +5,4 @@ import asyncio | ||
| import contextlib | ||
| import copy | ||
| from functools import wraps | ||
| import logging | ||
| from pathlib import Path | ||
| import ssl | ||
@@ -23,3 +21,3 @@ from typing import TYPE_CHECKING, Any, TypeAlias, cast | ||
| ) | ||
| from amqtt.contexts import BaseContext | ||
| from amqtt.contexts import BaseContext, ClientConfig | ||
| from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError | ||
@@ -31,3 +29,3 @@ from amqtt.mqtt.connack import CONNECTION_ACCEPTED | ||
| from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session | ||
| from amqtt.utils import gen_client_id, read_yaml_config | ||
| from amqtt.utils import gen_client_id | ||
@@ -37,5 +35,3 @@ if TYPE_CHECKING: | ||
| _defaults: dict[str, Any] | None = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml") | ||
| class ClientContext(BaseContext): | ||
@@ -49,3 +45,3 @@ """ClientContext is used as the context passed to plugins interacting with the client. | ||
| super().__init__() | ||
| self.config = None | ||
| self.config: ClientConfig | None = None | ||
@@ -87,22 +83,23 @@ | ||
| class MQTTClient: | ||
| """MQTT client implementation. | ||
| """MQTT client implementation, providing an API for connecting to a broker and send/receive messages using the MQTT protocol. | ||
| MQTTClient instances provides API for connecting to a broker and send/receive | ||
| messages using the MQTT protocol. | ||
| Args: | ||
| client_id: MQTT client ID to use when connecting to the broker. If none, | ||
| it will be generated randomly by `amqtt.utils.gen_client_id` | ||
| config: dictionary of configuration options (see [client configuration](client_config.md)). | ||
| config: `ClientConfig` or dictionary of equivalent structure options (see [client configuration](client_config.md)). | ||
| Raises: | ||
| PluginError | ||
| PluginImportError: if importing a plugin from configuration fails | ||
| PluginInitError: if initialization plugin fails | ||
| """ | ||
| def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None: | ||
| def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None: | ||
| self.logger = logging.getLogger(__name__) | ||
| self.config = copy.deepcopy(_defaults or {}) | ||
| if config is not None: | ||
| self.config.update(config) | ||
| if isinstance(config, dict): | ||
| self.config = ClientConfig.from_dict(config) | ||
| else: | ||
| self.config = config or ClientConfig() | ||
| self.client_id = client_id if client_id is not None else gen_client_id() | ||
@@ -155,3 +152,3 @@ | ||
| Raises: | ||
| ClientError, ConnectError | ||
| ConnectError: could not connect to broker | ||
@@ -169,3 +166,4 @@ """ | ||
| raise ConnectError(msg) from e | ||
| except Exception as e: | ||
| # no matter the failure mode, still try to reconnect | ||
| except Exception as e: # pylint: disable=W0718 | ||
| self.logger.warning(f"Connection failed: {e!r}") | ||
@@ -244,3 +242,4 @@ if not self.config.get("auto_reconnect", False): | ||
| raise ConnectError(msg) from e | ||
| except Exception as e: | ||
| # no matter the failure mode, still try to reconnect | ||
| except Exception as e: # pylint: disable=W0718 | ||
| self.logger.warning(f"Reconnection attempt failed: {e!r}") | ||
@@ -393,2 +392,3 @@ self.logger.debug("", exc_info=True) | ||
| asyncio.TimeoutError: if timeout occurs before a message is delivered | ||
| ClientError: if client is not connected | ||
@@ -437,10 +437,6 @@ """ | ||
| self.session.username = ( | ||
| self.session.username | ||
| if self.session.username | ||
| else (str(uri_attributes.username) if uri_attributes.username else None) | ||
| self.session.username or (str(uri_attributes.username) if uri_attributes.username else None) | ||
| ) | ||
| self.session.password = ( | ||
| self.session.password | ||
| if self.session.password | ||
| else (str(uri_attributes.password) if uri_attributes.password else None) | ||
| self.session.password or (str(uri_attributes.password) if uri_attributes.password else None) | ||
| ) | ||
@@ -476,11 +472,11 @@ self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None | ||
| ssl.Purpose.SERVER_AUTH, | ||
| cafile=self.session.cafile, | ||
| capath=self.session.capath, | ||
| cadata=self.session.cadata, | ||
| cafile=self.session.cafile | ||
| ) | ||
| if "certfile" in self.config: | ||
| sc.load_verify_locations(cafile=self.config["certfile"]) | ||
| if "check_hostname" in self.config and isinstance(self.config["check_hostname"], bool): | ||
| sc.check_hostname = self.config["check_hostname"] | ||
| if self.config.connection.certfile and self.config.connection.keyfile: | ||
| sc.load_cert_chain(certfile=self.config.connection.certfile, keyfile=self.config.connection.keyfile) | ||
| if self.config.connection.cafile: | ||
| sc.load_verify_locations(cafile=self.config.connection.cafile) | ||
| if self.config.check_hostname is not None: | ||
| sc.check_hostname = self.config.check_hostname | ||
| sc.verify_mode = ssl.CERT_REQUIRED | ||
@@ -540,3 +536,3 @@ kwargs["ssl"] = sc | ||
| except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError) as e: | ||
| except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError, asyncio.TimeoutError) as e: | ||
| self.logger.debug(f"Connection failed : {self.session.broker_uri} [{e!r}]") | ||
@@ -597,7 +593,15 @@ self.session.transitions.disconnect() | ||
| """Initialize the MQTT session.""" | ||
| broker_conf = self.config.get("broker", {}).copy() | ||
| broker_conf.update( | ||
| {k: v for k, v in {"uri": uri, "cafile": cafile, "capath": capath, "cadata": cadata}.items() if v is not None}, | ||
| ) | ||
| broker_conf = self.config.get("connection", {}).copy() | ||
| if uri is not None: | ||
| broker_conf.uri = uri | ||
| if cleansession is not None: | ||
| self.config.cleansession = cleansession | ||
| if cafile is not None: | ||
| broker_conf.cafile = cafile | ||
| if capath is not None: | ||
| broker_conf.capath = capath | ||
| if cadata is not None: | ||
| broker_conf.cadata = cadata | ||
| if not broker_conf.get("uri"): | ||
@@ -610,2 +614,3 @@ msg = "Missing connection parameter 'uri'" | ||
| session.client_id = self.client_id | ||
| session.cafile = broker_conf.get("cafile") | ||
@@ -615,7 +620,3 @@ session.capath = broker_conf.get("capath") | ||
| if cleansession is not None: | ||
| broker_conf["cleansession"] = cleansession # noop? | ||
| session.clean_session = cleansession | ||
| else: | ||
| session.clean_session = self.config.get("cleansession", True) | ||
| session.clean_session = self.config.get("cleansession", True) | ||
@@ -622,0 +623,0 @@ session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"] |
@@ -145,6 +145,6 @@ import asyncio | ||
| def float_to_bytes_str(value: float, places:int=3) -> bytes: | ||
| def float_to_bytes_str(value: float, places: int = 3) -> bytes: | ||
| """Convert an float value to a bytes array containing the numeric character.""" | ||
| quant = Decimal(f"0.{''.join(['0' for i in range(places-1)])}1") | ||
| quant = Decimal(f"0.{''.join(['0' for i in range(places - 1)])}1") | ||
| rounded = Decimal(value).quantize(quant, rounding=ROUND_HALF_UP) | ||
| return str(rounded).encode("utf-8") |
+363
-6
@@ -1,19 +0,36 @@ | ||
| from enum import Enum | ||
| from dataclasses import dataclass, field, fields, replace | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
| import warnings | ||
| _LOGGER = logging.getLogger(__name__) | ||
| try: | ||
| from enum import Enum, StrEnum | ||
| except ImportError: | ||
| # support for python 3.10 | ||
| from enum import Enum | ||
| class StrEnum(str, Enum): # type: ignore[no-redef] | ||
| pass | ||
| from collections.abc import Iterator | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any, Literal | ||
| from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass | ||
| from amqtt.mqtt.constants import QOS_0, QOS_2 | ||
| if TYPE_CHECKING: | ||
| import asyncio | ||
| logger = logging.getLogger(__name__) | ||
| class BaseContext: | ||
| def __init__(self) -> None: | ||
| self.loop: asyncio.AbstractEventLoop | None = None | ||
| self.logger: logging.Logger = _LOGGER | ||
| self.config: dict[str, Any] | None = None | ||
| self.logger: logging.Logger = logging.getLogger(__name__) | ||
| # cleanup with a `Generic` type | ||
| self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None | ||
| class Action(Enum): | ||
| class Action(StrEnum): | ||
| """Actions issued by the broker.""" | ||
@@ -23,1 +40,341 @@ | ||
| PUBLISH = "publish" | ||
| RECEIVE = "receive" | ||
| class ListenerType(StrEnum): | ||
| """Types of mqtt listeners.""" | ||
| TCP = "tcp" | ||
| WS = "ws" | ||
| EXTERNAL = "external" | ||
| def __repr__(self) -> str: | ||
| """Display the string value, instead of the enum member.""" | ||
| return f'"{self.value!s}"' | ||
| class Dictable: | ||
| """Add dictionary methods to a dataclass.""" | ||
| def __getitem__(self, key: str) -> Any: | ||
| """Allow dict-style `[]` access to a dataclass.""" | ||
| return self.get(key) | ||
| def get(self, name: str, default: Any = None) -> Any: | ||
| """Allow dict-style access to a dataclass.""" | ||
| name = name.replace("-", "_") | ||
| if hasattr(self, name): | ||
| return getattr(self, name) | ||
| if default is not None: | ||
| return default | ||
| msg = f"'{name}' is not defined" | ||
| raise ValueError(msg) | ||
| def __contains__(self, name: str) -> bool: | ||
| """Provide dict-style 'in' check.""" | ||
| return getattr(self, name.replace("-", "_"), None) is not None | ||
| def __iter__(self) -> Iterator[Any]: | ||
| """Provide dict-style iteration.""" | ||
| for f in fields(self): # type: ignore[arg-type] | ||
| yield getattr(self, f.name) | ||
| def copy(self) -> dataclass: # type: ignore[valid-type] | ||
| """Return a copy of the dataclass.""" | ||
| return replace(self) # type: ignore[type-var] | ||
| @staticmethod | ||
| def _coerce_lists(value: list[Any] | dict[str, Any] | Any) -> list[dict[str, Any]]: | ||
| if isinstance(value, list): | ||
| return value # It's already a list of dicts | ||
| if isinstance(value, dict): | ||
| return [value] # Promote single dict to a list | ||
| msg = "Could not convert 'list' to 'list[dict[str, Any]]'" | ||
| raise ValueError(msg) | ||
| @dataclass | ||
| class ListenerConfig(Dictable): | ||
| """Structured configuration for a broker's listeners.""" | ||
| type: ListenerType = ListenerType.TCP | ||
| """Type of listener: `tcp` for 'mqtt' or `ws` for 'websocket' when specified in dictionary or yaml.'""" | ||
| bind: str | None = "0.0.0.0:1883" | ||
| """address and port for the listener to bind to""" | ||
| max_connections: int = 0 | ||
| """max number of connections allowed for this listener""" | ||
| ssl: bool = False | ||
| """secured by ssl""" | ||
| cafile: str | Path | None = None | ||
| """Path to a file of concatenated CA certificates in PEM format. See | ||
| [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.""" | ||
| capath: str | Path | None = None | ||
| """Path to a directory containing one or more CA certificates in PEM format, following the | ||
| [OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).""" | ||
| cadata: str | Path | None = None | ||
| """Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.""" | ||
| certfile: str | Path | None = None | ||
| """Full path to file in PEM format containing the server's certificate (as well as any number of CA | ||
| certificates needed to establish the certificate's authenticity.)""" | ||
| keyfile: str | Path | None = None | ||
| """Full path to file in PEM format containing the server's private key.""" | ||
| reader: str | None = None | ||
| writer: str | None = None | ||
| def __post_init__(self) -> None: | ||
| """Check config for errors and transform fields for easier use.""" | ||
| if (self.certfile is None) ^ (self.keyfile is None): | ||
| msg = "If specifying the 'certfile' or 'keyfile', both are required." | ||
| raise ValueError(msg) | ||
| for fn in ("cafile", "capath", "certfile", "keyfile"): | ||
| if isinstance(getattr(self, fn), str): | ||
| setattr(self, fn, Path(getattr(self, fn))) | ||
| if getattr(self, fn) and not getattr(self, fn).exists(): | ||
| msg = f"'{fn}' does not exist : {getattr(self, fn)}" | ||
| raise FileNotFoundError(msg) | ||
| def apply(self, other: "ListenerConfig") -> None: | ||
| """Apply the field from 'other', if 'self' field is default.""" | ||
| for f in fields(self): | ||
| if getattr(self, f.name) == f.default: | ||
| setattr(self, f.name, other[f.name]) | ||
| def default_listeners() -> dict[str, Any]: | ||
| """Create defaults for BrokerConfig.listeners.""" | ||
| return { | ||
| "default": ListenerConfig() | ||
| } | ||
| def default_broker_plugins() -> dict[str, Any]: | ||
| """Create defaults for BrokerConfig.plugins.""" | ||
| return { | ||
| "amqtt.plugins.logging_amqtt.EventLoggerPlugin": {}, | ||
| "amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}, | ||
| "amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True}, | ||
| "amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20} | ||
| } | ||
| @dataclass | ||
| class BrokerConfig(Dictable): | ||
| """Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary.""" | ||
| listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051 | ||
| """Network of listeners used by the services. a 'default' named listener is required; if another listener | ||
| does not set a value, the 'default' settings are applied. See | ||
| [`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information.""" | ||
| sys_interval: int | None = None | ||
| """*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../plugins/packaged_plugins.md#sys-topics) | ||
| for recommended configuration.*""" | ||
| timeout_disconnect_delay: int | None = 0 | ||
| """Client disconnect timeout without a keep-alive.""" | ||
| session_expiry_interval: int | None = None | ||
| """Seconds for an inactive session to be retained.""" | ||
| auth: dict[str, Any] | None = None | ||
| """*Deprecated field used to config EntryPoint-loaded plugins. See | ||
| [`AnonymousAuthPlugin`](../plugins/packaged_plugins.md#anonymous-auth-plugin) and | ||
| [`FileAuthPlugin`](../plugins/packaged_plugins.md#password-file-auth-plugin) for recommended configuration.*""" | ||
| topic_check: dict[str, Any] | None = None | ||
| """*Deprecated field used to config EntryPoint-loaded plugins. See | ||
| [`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and | ||
| [`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*""" | ||
| plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins) | ||
| """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin` | ||
| or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See | ||
| [custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available | ||
| to support legacy use cases.""" | ||
| def __post_init__(self) -> None: | ||
| """Check config for errors and transform fields for easier use.""" | ||
| if self.sys_interval is not None: | ||
| logger.warning("sys_interval is deprecated, use 'plugins' to define configuration") | ||
| if self.auth is not None or self.topic_check is not None: | ||
| logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration") | ||
| default_listener = self.listeners["default"] | ||
| for listener_name, listener in self.listeners.items(): | ||
| if listener_name == "default": | ||
| continue | ||
| listener.apply(default_listener) | ||
| if isinstance(self.plugins, list): | ||
| _plugins: dict[str, Any] = {} | ||
| for plugin in self.plugins: | ||
| # in case a plugin in a yaml file is listed without config map | ||
| if isinstance(plugin, str): | ||
| _plugins |= {plugin: {}} | ||
| continue | ||
| _plugins |= plugin | ||
| self.plugins = _plugins | ||
| @classmethod | ||
| def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig": | ||
| """Create a broker config from a dictionary.""" | ||
| if d is None: | ||
| return BrokerConfig() | ||
| # patch the incoming dictionary so it can be loaded correctly | ||
| if "topic-check" in d: | ||
| d["topic_check"] = d["topic-check"] | ||
| del d["topic-check"] | ||
| # identify EntryPoint plugin loading and prevent 'plugins' from getting defaults | ||
| if ("auth" in d or "topic-check" in d) and "plugins" not in d: | ||
| d["plugins"] = None | ||
| return dict_to_dataclass(data_class=BrokerConfig, | ||
| data=d, | ||
| config=DaciteConfig( | ||
| cast=[StrEnum, ListenerType], | ||
| strict=True, | ||
| type_hooks={list[dict[str, Any]]: cls._coerce_lists} | ||
| )) | ||
| @dataclass | ||
| class ConnectionConfig(Dictable): | ||
| """Properties for connecting to the broker.""" | ||
| uri: str | None = "mqtt://127.0.0.1:1883" | ||
| """URI of the broker""" | ||
| cafile: str | Path | None = None | ||
| """Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See | ||
| [Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info.""" | ||
| capath: str | Path | None = None | ||
| """Path to a directory containing one or more CA certificates in PEM format, following the | ||
| [OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/).""" | ||
| cadata: str | None = None | ||
| """The certificate to verify the broker's authenticity in an ASCII string format of one or more PEM-encoded | ||
| certificates or a bytes-like object of DER-encoded certificates.""" | ||
| certfile: str | Path | None = None | ||
| """Full path to file in PEM format containing the client's certificate (as well as any number of CA | ||
| certificates needed to establish the certificate's authenticity.)""" | ||
| keyfile: str | Path | None = None | ||
| """Full path to file in PEM format containing the client's private key associated with the certfile.""" | ||
| def __post__init__(self) -> None: | ||
| """Check config for errors and transform fields for easier use.""" | ||
| if (self.certfile is None) ^ (self.keyfile is None): | ||
| msg = "If specifying the 'certfile' or 'keyfile', both are required." | ||
| raise ValueError(msg) | ||
| for fn in ("cafile", "capath", "certfile", "keyfile"): | ||
| if isinstance(getattr(self, fn), str): | ||
| setattr(self, fn, Path(getattr(self, fn))) | ||
| @dataclass | ||
| class TopicConfig(Dictable): | ||
| """Configuration of how messages to specific topics are published. | ||
| The topic name is specified as the key in the dictionary of the `ClientConfig.topics. | ||
| """ | ||
| qos: int = 0 | ||
| """The quality of service associated with the publishing to this topic.""" | ||
| retain: bool = False | ||
| """Determines if the message should be retained by the topic it was published.""" | ||
| def __post__init__(self) -> None: | ||
| """Check config for errors and transform fields for easier use.""" | ||
| if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2): | ||
| msg = "Topic config: default QoS must be 0, 1 or 2." | ||
| raise ValueError(msg) | ||
| @dataclass | ||
| class WillConfig(Dictable): | ||
| """Configuration of the 'last will & testament' of the client upon improper disconnection.""" | ||
| topic: str | ||
| """The will message will be published to this topic.""" | ||
| message: str | ||
| """The contents of the message to be published.""" | ||
| qos: int | None = QOS_0 | ||
| """The quality of service associated with sending this message.""" | ||
| retain: bool | None = False | ||
| """Determines if the message should be retained by the topic it was published.""" | ||
| def __post__init__(self) -> None: | ||
| """Check config for errors and transform fields for easier use.""" | ||
| if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2): | ||
| msg = "Will config: default QoS must be 0, 1 or 2." | ||
| raise ValueError(msg) | ||
| def default_client_plugins() -> dict[str, Any]: | ||
| """Create defaults for `ClientConfig.plugins`.""" | ||
| return { | ||
| "amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {} | ||
| } | ||
| @dataclass | ||
| class ClientConfig(Dictable): | ||
| """Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary.""" | ||
| keep_alive: int | None = 10 | ||
| """Keep-alive timeout sent to the broker.""" | ||
| ping_delay: int | None = 1 | ||
| """Auto-ping delay before keep-alive timeout. Setting to 0 will disable which may lead to broker disconnection.""" | ||
| default_qos: int | None = QOS_0 | ||
| """Default QoS for messages published.""" | ||
| default_retain: bool | None = False | ||
| """Default retain value to messages published.""" | ||
| auto_reconnect: bool | None = True | ||
| """Enable or disable auto-reconnect if connection with the broker is interrupted.""" | ||
| connection_timeout: int | None = 60 | ||
| """The number of seconds before a connection times out""" | ||
| reconnect_retries: int | None = 2 | ||
| """Number of reconnection retry attempts. Negative value will cause client to reconnect indefinitely.""" | ||
| reconnect_max_interval: int | None = 10 | ||
| """Maximum seconds to wait before retrying a connection.""" | ||
| cleansession: bool | None = True | ||
| """Upon reconnect, should subscriptions be cleared. Can be overridden by `MQTTClient.connect`""" | ||
| topics: dict[str, TopicConfig] | None = field(default_factory=dict) | ||
| """Specify the topics and what flags should be set for messages published to them.""" | ||
| broker: ConnectionConfig | None = None | ||
| """*Deprecated* Configuration for connecting to the broker. Use `connection` field instead.""" | ||
| connection: ConnectionConfig = field(default_factory=ConnectionConfig) | ||
| """Configuration for connecting to the broker. See | ||
| [`ConnectionConfig`](client_config.md#amqtt.contexts.ConnectionConfig) for more information.""" | ||
| plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins) | ||
| """The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is | ||
| a dictionary of configuration options for that plugin. See [custom plugins](../plugins/custom_plugins.md) for | ||
| more information. `list[str | dict[str,Any]]` is deprecated but available to support legacy use cases.""" | ||
| check_hostname: bool | None = True | ||
| """If establishing a secure connection, should the hostname of the certificate be verified.""" | ||
| will: WillConfig | None = None | ||
| """Message, topic and flags that should be sent to if the client disconnects. See | ||
| [`WillConfig`](client_config.md#amqtt.contexts.WillConfig) for more information.""" | ||
| def __post_init__(self) -> None: | ||
| """Check config for errors and transform fields for easier use.""" | ||
| if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2): | ||
| msg = "Client config: default QoS must be 0, 1 or 2." | ||
| raise ValueError(msg) | ||
| if self.broker is not None: | ||
| warnings.warn("The 'broker' option is deprecated, please use 'connection' instead.", stacklevel=2) | ||
| self.connection = self.broker | ||
| if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile): | ||
| msg = "Connection key and certificate files are _both_ required." | ||
| raise ValueError(msg) | ||
| @classmethod | ||
| def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig": | ||
| """Create a client config from a dictionary.""" | ||
| if d is None: | ||
| return ClientConfig() | ||
| return dict_to_dataclass(data_class=ClientConfig, | ||
| data=d, | ||
| config=DaciteConfig( | ||
| cast=[StrEnum], | ||
| strict=True) | ||
| ) |
+2
-0
@@ -19,2 +19,3 @@ from typing import Any | ||
| class ZeroLengthReadError(NoDataError): | ||
@@ -24,2 +25,3 @@ def __init__(self) -> None: | ||
| class BrokerError(Exception): | ||
@@ -26,0 +28,0 @@ """Exceptions thrown by broker.""" |
+3
-1
@@ -6,3 +6,3 @@ try: | ||
| from enum import Enum | ||
| class StrEnum(str, Enum): #type: ignore[no-redef] | ||
| class StrEnum(str, Enum): # type: ignore[no-redef] | ||
| pass | ||
@@ -35,2 +35,4 @@ | ||
| CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed" | ||
| RETAINED_MESSAGE = "broker_retained_message" | ||
| MESSAGE_RECEIVED = "broker_message_received" | ||
| MESSAGE_BROADCAST = "broker_message_broadcast" |
@@ -170,3 +170,3 @@ from abc import ABC, abstractmethod | ||
| class MQTTPayload(Generic[_VH], ABC): | ||
| class MQTTPayload(ABC, Generic[_VH]): | ||
| """Abstract base class for MQTT payloads.""" | ||
@@ -173,0 +173,0 @@ |
@@ -34,2 +34,3 @@ import asyncio | ||
| class Subscription: | ||
@@ -239,2 +240,3 @@ def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None: | ||
| incoming_session.remote_port = remote_port | ||
| incoming_session.ssl_object = writer.get_ssl_info() | ||
@@ -241,0 +243,0 @@ incoming_session.keep_alive = max(connect.keep_alive, 0) |
@@ -22,2 +22,3 @@ import asyncio | ||
| class ClientProtocolHandler(ProtocolHandler["ClientContext"]): | ||
@@ -24,0 +25,0 @@ def __init__( |
@@ -7,9 +7,9 @@ import asyncio | ||
| # Fallback for Python < 3.12 | ||
| class InvalidStateError(Exception): # type: ignore[no-redef] | ||
| class InvalidStateError(Exception): # type: ignore[no-redef] | ||
| pass | ||
| class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818 | ||
| class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818 | ||
| pass | ||
| class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818 | ||
| class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818 | ||
| pass | ||
@@ -67,2 +67,3 @@ | ||
| class ProtocolHandler(Generic[C]): | ||
@@ -204,3 +205,3 @@ """Class implementing the MQTT communication protocol using asyncio features.""" | ||
| topic: str, | ||
| data: bytes | bytearray , | ||
| data: bytes | bytearray, | ||
| qos: int | None, | ||
@@ -541,3 +542,3 @@ retain: bool, | ||
| self.logger.debug(f"{self.session.client_id} No data available") | ||
| except Exception as e: # noqa: BLE001 | ||
| except Exception as e: # noqa: BLE001, pylint: disable=W0718 | ||
| self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}") | ||
@@ -544,0 +545,0 @@ break |
| """INIT.""" | ||
| import re | ||
| from typing import Any, Optional | ||
| class TopicMatcher: | ||
| _instance: Optional["TopicMatcher"] = None | ||
| def __init__(self) -> None: | ||
| if not hasattr(self, "_topic_filter_matchers"): | ||
| self._topic_filter_matchers: dict[str, re.Pattern[str]] = {} | ||
| def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "TopicMatcher": | ||
| if cls._instance is None: | ||
| cls._instance = super().__new__(cls, *args, **kwargs) | ||
| return cls._instance | ||
| def is_topic_allowed(self, topic: str, a_filter: str) -> bool: | ||
| if topic.startswith("$") and (a_filter.startswith(("+", "#"))): | ||
| return False | ||
| if "#" not in a_filter and "+" not in a_filter: | ||
| # if filter doesn't contain wildcard, return exact match | ||
| return a_filter == topic | ||
| # else use regex (re.compile is an expensive operation, store the matcher for future use) | ||
| if a_filter not in self._topic_filter_matchers: | ||
| self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter) | ||
| .replace("\\#", "?.*") | ||
| .replace("\\+", "[^/]*") | ||
| .lstrip("?")) | ||
| match_pattern = self._topic_filter_matchers[a_filter] | ||
| return bool(match_pattern.fullmatch(topic)) | ||
| def are_topics_allowed(self, topic: str, many_filters: list[str]) -> bool: | ||
| return any(self.is_topic_allowed(topic, a_filter) for a_filter in many_filters) |
@@ -40,5 +40,6 @@ from dataclasses import dataclass, field | ||
| class Config: | ||
| """Allow empty username.""" | ||
| """Configuration for AnonymousAuthPlugin.""" | ||
| allow_anonymous: bool = field(default=True) | ||
| """Allow all anonymous authentication (even with _no_ username).""" | ||
@@ -82,3 +83,3 @@ | ||
| self.context.logger.exception(f"Malformed password file '{password_file}'") | ||
| except Exception: | ||
| except OSError: | ||
| self.context.logger.exception(f"Unexpected error reading password file '{password_file}'") | ||
@@ -112,4 +113,5 @@ | ||
| class Config: | ||
| """Path to the properly encoded password file.""" | ||
| """Configuration for FileAuthPlugin.""" | ||
| password_file: str | Path | None = None | ||
| """Path to file with `username:password` pairs, one per line. All passwords are encoded using sha-512.""" |
+12
-12
| from dataclasses import dataclass, is_dataclass | ||
| from typing import Any, Generic, TypeVar, cast | ||
| from amqtt.contexts import Action, BaseContext | ||
| from amqtt.contexts import Action, BaseContext, BrokerConfig | ||
| from amqtt.session import Session | ||
@@ -49,3 +49,3 @@ | ||
| # Deprecated : supports entrypoint-style configs as well as dataclass configuration. | ||
| def _get_config_option(self, option_name: str, default: Any=None) -> Any: | ||
| def _get_config_option(self, option_name: str, default: Any = None) -> Any: | ||
| if not self.context.config: | ||
@@ -56,3 +56,3 @@ return default | ||
| # overloaded context.config for BasePlugin `Config` class, so ignoring static type check | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) | ||
| if option_name in self.context.config: | ||
@@ -80,9 +80,10 @@ return self.context.config[option_name] | ||
| def _get_config_option(self, option_name: str, default: Any=None) -> Any: | ||
| def _get_config_option(self, option_name: str, default: Any = None) -> Any: | ||
| if not self.context.config: | ||
| return default | ||
| if is_dataclass(self.context.config): | ||
| # overloaded context.config with either BrokerConfig or plugin's Config | ||
| if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig): | ||
| # overloaded context.config for BasePlugin `Config` class, so ignoring static type check | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) | ||
| if self.topic_config and option_name in self.topic_config: | ||
@@ -94,3 +95,3 @@ return self.topic_config[option_name] | ||
| self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None | ||
| ) -> bool: | ||
| ) -> bool | None: | ||
| """Logic for filtering out topics. | ||
@@ -104,3 +105,3 @@ | ||
| Returns: | ||
| bool: `True` if topic is allowed, `False` otherwise | ||
| bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined | ||
@@ -114,9 +115,9 @@ """ | ||
| def _get_config_option(self, option_name: str, default: Any=None) -> Any: | ||
| def _get_config_option(self, option_name: str, default: Any = None) -> Any: | ||
| if not self.context.config: | ||
| return default | ||
| if is_dataclass(self.context.config): | ||
| if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig): | ||
| # overloaded context.config for BasePlugin `Config` class, so ignoring static type check | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) | ||
| if self.auth_config and option_name in self.auth_config: | ||
@@ -134,3 +135,2 @@ return self.auth_config[option_name] | ||
| async def authenticate(self, *, session: Session) -> bool | None: | ||
@@ -137,0 +137,0 @@ """Logic for session authentication. |
@@ -11,2 +11,4 @@ __all__ = ["PluginManager", "get_plugin_manager"] | ||
| import logging | ||
| import sys | ||
| import traceback | ||
| from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast | ||
@@ -53,2 +55,3 @@ import warnings | ||
| class PluginManager(Generic[C]): | ||
@@ -100,6 +103,5 @@ """Wraps contextlib Entry point mechanism to provide a basic plugin system. | ||
| if "auth" in self.app_context.config: | ||
| if "auth" in self.app_context.config and self.app_context.config["auth"] is not None: | ||
| self.logger.warning("Loading plugins from config will ignore 'auth' section of config") | ||
| if "topic-check" in self.app_context.config: | ||
| if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None: | ||
| self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config") | ||
@@ -136,3 +138,3 @@ | ||
| DeprecationWarning, | ||
| stacklevel=2 | ||
| stacklevel=4 | ||
| ) | ||
@@ -152,3 +154,3 @@ | ||
| def _load_ep_plugins(self, namespace:str) -> None: | ||
| def _load_ep_plugins(self, namespace: str) -> None: | ||
| """Load plugins from `pyproject.toml` entrypoints. Deprecated.""" | ||
@@ -230,3 +232,3 @@ self.logger.debug(f"Loading plugins for namespace {namespace}") | ||
| try: | ||
| plugin_class: Any = import_string(plugin_path) | ||
| plugin_class: Any = import_string(plugin_path) | ||
| except ImportError as ep: | ||
@@ -300,2 +302,11 @@ msg = f"Plugin import failed: {plugin_path}" | ||
| def _clean_fired_events(self, future: asyncio.Future[Any]) -> None: | ||
| if self.logger.getEffectiveLevel() <= logging.DEBUG: | ||
| try: | ||
| future.result() | ||
| except asyncio.CancelledError: | ||
| self.logger.warning("fired event was cancelled") | ||
| # display plugin fault; don't allow it to cause a broker failure | ||
| except Exception as exc: # noqa: BLE001, pylint: disable=W0718 | ||
| traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr) | ||
| with contextlib.suppress(KeyError, ValueError): | ||
@@ -376,3 +387,3 @@ self._fired_events.remove(future) | ||
| return await self._map_plugin_method( | ||
| self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type] | ||
| self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type] | ||
@@ -379,0 +390,0 @@ async def map_plugin_topic( |
@@ -1,85 +0,11 @@ | ||
| import json | ||
| import sqlite3 | ||
| from typing import Any | ||
| import warnings | ||
| from amqtt.contexts import BaseContext | ||
| from amqtt.session import Session | ||
| from amqtt.broker import BrokerContext | ||
| from amqtt.plugins.base import BasePlugin | ||
| class SQLitePlugin: | ||
| def __init__(self, context: BaseContext) -> None: | ||
| self.context: BaseContext = context | ||
| self.conn: sqlite3.Connection | None = None | ||
| self.cursor: sqlite3.Cursor | None = None | ||
| self.db_file: str | None = None | ||
| self.persistence_config: dict[str, Any] | ||
| class SQLitePlugin(BasePlugin[BrokerContext]): | ||
| if ( | ||
| persistence_config := self.context.config.get("persistence") if self.context.config is not None else None | ||
| ) is not None: | ||
| self.persistence_config = persistence_config | ||
| self.init_db() | ||
| else: | ||
| self.context.logger.warning("'persistence' section not found in context configuration") | ||
| def init_db(self) -> None: | ||
| self.db_file = self.persistence_config.get("file") | ||
| if not self.db_file: | ||
| self.context.logger.warning("'file' persistence parameter not found") | ||
| else: | ||
| try: | ||
| self.conn = sqlite3.connect(self.db_file) | ||
| self.cursor = self.conn.cursor() | ||
| self.context.logger.info(f"Database file '{self.db_file}' opened") | ||
| except Exception: | ||
| self.context.logger.exception(f"Error while initializing database '{self.db_file}'") | ||
| if self.cursor: | ||
| self.cursor.execute( | ||
| "CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)", | ||
| ) | ||
| self.cursor.execute("PRAGMA table_info(session)") | ||
| columns = {col[1] for col in self.cursor.fetchall()} | ||
| required_columns = {"client_id", "data"} | ||
| if not required_columns.issubset(columns): | ||
| self.context.logger.error("Database schema for 'session' table is incompatible.") | ||
| async def save_session(self, session: Session) -> None: | ||
| if self.cursor and self.conn: | ||
| dump: str = json.dumps(session, default=str) | ||
| try: | ||
| self.cursor.execute( | ||
| "INSERT OR REPLACE INTO session (client_id, data) VALUES (?, ?)", | ||
| (session.client_id, dump), | ||
| ) | ||
| self.conn.commit() | ||
| except Exception: | ||
| self.context.logger.exception(f"Failed saving session '{session}'") | ||
| async def find_session(self, client_id: str) -> Session | None: | ||
| if self.cursor: | ||
| row = self.cursor.execute( | ||
| "SELECT data FROM session where client_id=?", | ||
| (client_id,), | ||
| ).fetchone() | ||
| return json.loads(row[0]) if row else None | ||
| return None | ||
| async def del_session(self, client_id: str) -> None: | ||
| if self.cursor and self.conn: | ||
| try: | ||
| exists = self.cursor.execute("SELECT 1 FROM session WHERE client_id=?", (client_id,)).fetchone() | ||
| if exists: | ||
| self.cursor.execute("DELETE FROM session where client_id=?", (client_id,)) | ||
| self.conn.commit() | ||
| except Exception: | ||
| self.context.logger.exception(f"Failed deleting session with client_id '{client_id}'") | ||
| async def on_broker_post_shutdown(self) -> None: | ||
| if self.conn: | ||
| try: | ||
| self.conn.close() | ||
| self.context.logger.info(f"Database file '{self.db_file}' closed") | ||
| except Exception: | ||
| self.context.logger.exception("Error closing database connection") | ||
| finally: | ||
| self.conn = None | ||
| def __init__(self, context: BrokerContext) -> None: | ||
| super().__init__(context) | ||
| warnings.warn("SQLitePlugin is deprecated, use amqtt.contrib.persistence.SessionDBPlugin", stacklevel=1) |
@@ -17,3 +17,3 @@ import asyncio | ||
| @runtime_checkable | ||
| class Buffer(Protocol): # type: ignore[no-redef] | ||
| class Buffer(Protocol): # type: ignore[no-redef] | ||
| def __buffer__(self, flags: int = ...) -> memoryview: | ||
@@ -79,3 +79,2 @@ """Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12.""" | ||
| def _clear_stats(self) -> None: | ||
@@ -117,3 +116,3 @@ """Initialize broker statistics data structures.""" | ||
| version = f"aMQTT version {amqtt.__version__}" | ||
| self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode()) | ||
| await self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode()) | ||
@@ -120,0 +119,0 @@ # Start $SYS topics management |
| from dataclasses import dataclass, field | ||
| from typing import Any | ||
| import warnings | ||
| from amqtt.contexts import Action, BaseContext | ||
| from amqtt.errors import PluginInitError | ||
| from amqtt.plugins.base import BaseTopicPlugin | ||
@@ -16,3 +18,3 @@ from amqtt.session import Session | ||
| self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None | ||
| ) -> bool: | ||
| ) -> bool | None: | ||
| filter_result = await super().topic_filtering(session=session, topic=topic, action=action) | ||
@@ -23,3 +25,3 @@ if filter_result: | ||
| return not (topic and topic in self._taboo) | ||
| return filter_result | ||
| return bool(filter_result) | ||
@@ -29,2 +31,12 @@ | ||
| def __init__(self, context: BaseContext) -> None: | ||
| super().__init__(context) | ||
| if self._get_config_option("acl", None): | ||
| warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1) | ||
| if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None): | ||
| msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included" | ||
| raise PluginInitError(msg) | ||
| @staticmethod | ||
@@ -52,3 +64,3 @@ def topic_ac(topic_requested: str, topic_allowed: str) -> bool: | ||
| self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None | ||
| ) -> bool: | ||
| ) -> bool | None: | ||
| filter_result = await super().topic_filtering(session=session, topic=topic, action=action) | ||
@@ -65,3 +77,3 @@ if not filter_result: | ||
| if not req_topic: | ||
| return False\ | ||
| return False | ||
@@ -72,9 +84,17 @@ username = session.username if session else None | ||
| acl: dict[str, Any] = {} | ||
| acl: dict[str, Any] | None = None | ||
| match action: | ||
| case Action.PUBLISH: | ||
| acl = self._get_config_option("publish-acl", {}) | ||
| acl = self._get_config_option("publish-acl", None) | ||
| case Action.SUBSCRIBE: | ||
| acl = self._get_config_option("acl", {}) | ||
| acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None)) | ||
| case Action.RECEIVE: | ||
| acl = self._get_config_option("receive-acl", None) | ||
| case _: | ||
| msg = "Received an invalid action type." | ||
| raise ValueError(msg) | ||
| if acl is None: | ||
| return True | ||
| allowed_topics = acl.get(username, []) | ||
@@ -81,0 +101,0 @@ if not allowed_topics: |
@@ -24,3 +24,3 @@ import asyncio | ||
| def _version(v:bool) -> None: | ||
| def _version(v: bool) -> None: | ||
| if v: | ||
@@ -45,5 +45,8 @@ typer.echo(f"{amqtt_version}") | ||
| formatter = "[%(asctime)s] :: %(levelname)s - %(message)s" | ||
| if debug: | ||
| formatter = "[%(asctime)s] %(name)s:%(lineno)d :: %(levelname)s - %(message)s" | ||
| level = logging.DEBUG if debug else logging.INFO | ||
| logging.basicConfig(level=level, format=formatter) | ||
| logging.getLogger("transitions").setLevel(logging.WARNING) | ||
| try: | ||
@@ -67,3 +70,3 @@ if config_file: | ||
| _ = loop.create_task(broker.start()) #noqa : RUF006 | ||
| _ = loop.create_task(broker.start()) # noqa : RUF006 | ||
| try: | ||
@@ -70,0 +73,0 @@ loop.run_forever() |
@@ -10,5 +10,5 @@ --- | ||
| reconnect_retries: 2 | ||
| broker: | ||
| connection: | ||
| uri: "mqtt://127.0.0.1" | ||
| plugins: | ||
| amqtt.plugins.logging_amqtt.PacketLoggerPlugin: |
@@ -55,3 +55,3 @@ import asyncio | ||
| yield line.encode(encoding="utf-8") | ||
| except Exception: | ||
| except (FileNotFoundError, OSError): | ||
| logger.exception(f"Failed to read file '{self.file}'") | ||
@@ -122,2 +122,3 @@ if self.lines: | ||
| app = typer.Typer(add_completion=False, rich_markup_mode=None) | ||
@@ -136,4 +137,5 @@ | ||
| @app.command() | ||
| def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 | ||
| def publisher_main( # pylint: disable=R0914,R0917 | ||
| url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"), | ||
@@ -161,3 +163,3 @@ config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"), | ||
| debug: bool = typer.Option(False, "-d", help="Enable debug messages"), | ||
| version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001 | ||
| version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001 | ||
| ) -> None: | ||
@@ -164,0 +166,0 @@ """Command-line MQTT client for publishing simple messages.""" |
@@ -103,3 +103,3 @@ import asyncio | ||
| def _version(v:bool) -> None: | ||
| def _version(v: bool) -> None: | ||
| if v: | ||
@@ -111,6 +111,6 @@ typer.echo(f"{amqtt_version}") | ||
| @app.command() | ||
| def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913 | ||
| def subscribe_main( # pylint: disable=R0914,R0917 | ||
| url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False), | ||
| config_file: str | None = typer.Option(None, "-c", help="Client configuration file"), | ||
| client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"), | ||
| client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"), | ||
| qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"), | ||
@@ -117,0 +117,0 @@ topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008 |
+28
-1
| from asyncio import Queue | ||
| from collections import OrderedDict | ||
| from typing import Any, ClassVar | ||
| import logging | ||
| from math import floor | ||
| import time | ||
| from typing import TYPE_CHECKING, Any, ClassVar | ||
@@ -13,3 +16,8 @@ from transitions import Machine | ||
| if TYPE_CHECKING: | ||
| import ssl | ||
| logger = logging.getLogger(__name__) | ||
| class ApplicationMessage: | ||
@@ -142,2 +150,5 @@ """ApplicationMessage and subclasses are used to store published message information flow. | ||
| self.parent: int = 0 | ||
| self.last_connect_time: int | None = None | ||
| self.ssl_object: ssl.SSLObject | None = None | ||
| self.last_disconnect_time: int | None = None | ||
@@ -166,2 +177,3 @@ # Used to store outgoing ApplicationMessage while publish protocol flows | ||
| ) | ||
| self.transitions.on_enter_connected(self._on_enter_connected) | ||
| self.transitions.add_transition( | ||
@@ -177,2 +189,3 @@ trigger="connect", | ||
| ) | ||
| self.transitions.on_enter_disconnected(self._on_enter_disconnected) | ||
| self.transitions.add_transition( | ||
@@ -189,2 +202,16 @@ trigger="disconnect", | ||
| def _on_enter_connected(self) -> None: | ||
| cur_time = floor(time.time()) | ||
| if self.last_disconnect_time is not None: | ||
| logger.debug(f"Session reconnected after {cur_time - self.last_disconnect_time} seconds.") | ||
| self.last_connect_time = cur_time | ||
| self.last_disconnect_time = None | ||
| def _on_enter_disconnected(self) -> None: | ||
| cur_time = floor(time.time()) | ||
| if self.last_connect_time is not None: | ||
| logger.debug(f"Session disconnected after {cur_time - self.last_connect_time} seconds.") | ||
| self.last_disconnect_time = cur_time | ||
| @property | ||
@@ -191,0 +218,0 @@ def next_packet_id(self) -> int: |
+26
-3
| Metadata-Version: 2.4 | ||
| Name: amqtt | ||
| Version: 0.11.2 | ||
| Version: 0.11.3 | ||
| Summary: Python's asyncio-native MQTT broker and client. | ||
@@ -29,2 +29,14 @@ Author: aMQTT Contributors | ||
| Requires-Dist: coveralls==4.0.1; extra == 'ci' | ||
| Provides-Extra: contrib | ||
| Requires-Dist: aiohttp>=3.12.13; extra == 'contrib' | ||
| Requires-Dist: aiosqlite>=0.21.0; extra == 'contrib' | ||
| Requires-Dist: argon2-cffi>=25.1.0; extra == 'contrib' | ||
| Requires-Dist: cryptography>=45.0.3; extra == 'contrib' | ||
| Requires-Dist: greenlet>=3.2.3; extra == 'contrib' | ||
| Requires-Dist: jsonschema>=4.25.0; extra == 'contrib' | ||
| Requires-Dist: mergedeep>=1.3.4; extra == 'contrib' | ||
| Requires-Dist: pyjwt>=2.10.1; extra == 'contrib' | ||
| Requires-Dist: pyopenssl>=25.1.0; extra == 'contrib' | ||
| Requires-Dist: python-ldap>=3.4.4; extra == 'contrib' | ||
| Requires-Dist: sqlalchemy[asyncio]>=2.0.41; extra == 'contrib' | ||
| Description-Content-Type: text/markdown | ||
@@ -48,7 +60,18 @@ | ||
| - Full set of [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) protocol specifications | ||
| - Communication over TCP and/or websocket, including support for SSL/TLS | ||
| - Communication over multiple TCP and/or websocket ports, including support for SSL/TLS | ||
| - Support QoS 0, QoS 1 and QoS 2 messages flow | ||
| - Client auto-reconnection on network lost | ||
| - Functionality expansion; plugins included: authentication and `$SYS` topic publishing | ||
| - Plugin framework for functionality expansion; included plugins: | ||
| - `$SYS` topic publishing | ||
| - AWS IOT-style shadow states | ||
| - x509 certificate authentication (including cli cert creation) | ||
| - Secure file-based password authentication | ||
| - Configuration-based topic authorization | ||
| - MySQL, Postgres & SQLite user and/or topic auth (including cli manager) | ||
| - External server (HTTP) user and/or topic auth | ||
| - LDAP user and/or topic auth | ||
| - JWT user and/or topic auth | ||
| - Fail over session persistence | ||
| ## Installation | ||
@@ -55,0 +78,0 @@ |
+58
-23
@@ -23,3 +23,3 @@ [build-system] | ||
| version = "0.11.2" | ||
| version = "0.11.3" | ||
| requires-python = ">=3.10.0" | ||
@@ -38,3 +38,3 @@ readme = "README.md" | ||
| "dacite>=1.9.2", | ||
| "psutil>=7.0.0", | ||
| "psutil>=7.0.0" | ||
| ] | ||
@@ -44,2 +44,4 @@ | ||
| dev = [ | ||
| "aiosqlite>=0.21.0", | ||
| "greenlet>=3.2.3", | ||
| "hatch>=1.14.1", | ||
@@ -52,5 +54,8 @@ "hypothesis>=6.130.8", | ||
| "psutil>=7.0.0", # https://pypi.org/project/psutil | ||
| "pyhamcrest>=2.1.0", | ||
| "pylint>=3.3.6", # https://pypi.org/project/pylint | ||
| "pyopenssl>=25.1.0", | ||
| "pytest-asyncio>=0.26.0", # https://pypi.org/project/pytest-asyncio | ||
| "pytest-cov>=6.1.0", # https://pypi.org/project/pytest-cov | ||
| "pytest-docker>=3.2.3", | ||
| "pytest-logdog>=0.1.0", # https://pypi.org/project/pytest-logdog | ||
@@ -61,2 +66,3 @@ "pytest-timeout>=2.3.1", # https://pypi.org/project/pytest-timeout | ||
| "setuptools>=78.1.0", | ||
| "sqlalchemy[mypy]>=2.0.41", | ||
| "types-mock>=5.2.0.20250306", # https://pypi.org/project/types-mock | ||
@@ -68,2 +74,3 @@ "types-PyYAML>=6.0.12.20250402", # https://pypi.org/project/types-PyYAML | ||
| docs = [ | ||
| "griffe>=1.11.1", | ||
| "markdown-callouts>=0.4", | ||
@@ -89,5 +96,17 @@ "markdown-exec>=1.8", | ||
| ci = ["coveralls==4.0.1"] | ||
| contrib = [ | ||
| "cryptography>=45.0.3", | ||
| "aiosqlite>=0.21.0", | ||
| "greenlet>=3.2.3", | ||
| "sqlalchemy[asyncio]>=2.0.41", | ||
| "argon2-cffi>=25.1.0", | ||
| "aiohttp>=3.12.13", | ||
| "pyjwt>=2.10.1", | ||
| "python-ldap>=3.4.4", | ||
| "mergedeep>=1.3.4", | ||
| "jsonschema>=4.25.0", | ||
| "pyopenssl>=25.1.0" | ||
| ] | ||
| [project.scripts] | ||
@@ -97,3 +116,9 @@ amqtt = "amqtt.scripts.broker_script:main" | ||
| amqtt_sub = "amqtt.scripts.sub_script:main" | ||
| ca_creds = "amqtt.scripts.ca_creds:main" | ||
| server_creds = "amqtt.scripts.server_creds:main" | ||
| device_creds = "amqtt.scripts.device_creds:main" | ||
| user_mgr = "amqtt.scripts.manage_users:main" | ||
| topic_mgr = "amqtt.scripts.manage_topics:main" | ||
| [tool.hatch.build] | ||
@@ -152,26 +177,36 @@ exclude = [ | ||
| [tool.ruff.lint] | ||
| preview = true | ||
| select = ["ALL"] | ||
| extend-select = [ | ||
| "UP", # pyupgrade | ||
| "D", # pydocstyle | ||
| "UP", # pyupgrade | ||
| "D", # pydocstyle, | ||
| ] | ||
| ignore = [ | ||
| "FBT001", # Checks for the use of boolean positional arguments in function definitions. | ||
| "FBT002", # Checks for the use of boolean positional arguments in function definitions. | ||
| "G004", # Logging statement uses f-string | ||
| "D100", # Missing docstring in public module | ||
| "D101", # Missing docstring in public class | ||
| "D102", # Missing docstring in public method | ||
| "D107", # Missing docstring in `__init__` | ||
| "D203", # Incorrect blank line before class (mutually exclusive D211) | ||
| "D213", # Multi-line summary second line (mutually exclusive D212) | ||
| "FIX002", # Checks for "TODO" comments. | ||
| "TD002", # TODO Missing author. | ||
| "TD003", # TODO Missing issue link for this TODO. | ||
| "ANN401", # Dynamically typed expressions (typing.Any) are disallowed | ||
| "ARG002", # Unused method argument | ||
| "PERF203",# try-except penalty within loops (3.10 only), | ||
| "COM812" # rule causes conflicts when used with the formatter | ||
| "FBT001", # Checks for the use of boolean positional arguments in function definitions. | ||
| "FBT002", # Checks for the use of boolean positional arguments in function definitions. | ||
| "G004", # Logging statement uses f-string | ||
| "D100", # Missing docstring in public module | ||
| "D101", # Missing docstring in public class | ||
| "D102", # Missing docstring in public method | ||
| "D107", # Missing docstring in `__init__` | ||
| "D203", # Incorrect blank line before class (mutually exclusive D211) | ||
| "D213", # Multi-line summary second line (mutually exclusive D212) | ||
| "FIX002", # Checks for "TODO" comments. | ||
| "TD002", # TODO Missing author. | ||
| "TD003", # TODO Missing issue link for this TODO. | ||
| "ANN401", # Dynamically typed expressions (typing.Any) are disallowed | ||
| "ARG002", # Unused method argument | ||
| "PERF203",# try-except penalty within loops (3.10 only), | ||
| "COM812", # rule causes conflicts when used with the formatter, | ||
| # ignore certain preview rules | ||
| "DOC", | ||
| "PLW", | ||
| "PLR", | ||
| "CPY", | ||
| "PLC", | ||
| "RUF052", | ||
| "B903" | ||
| ] | ||
@@ -209,3 +244,3 @@ | ||
| asyncio_mode = "auto" | ||
| timeout = 10 | ||
| timeout = 15 | ||
| asyncio_default_fixture_loop_scope = "function" | ||
@@ -244,2 +279,3 @@ #addopts = ["--tb=short", "--capture=tee-sys"] | ||
| max-line-length = 130 | ||
| ignore-paths = ["amqtt/plugins/persistence.py"] | ||
@@ -255,3 +291,2 @@ [tool.pylint.BASIC] | ||
| disable = [ | ||
| "broad-exception-caught", # TODO: improve later | ||
| "duplicate-code", | ||
@@ -258,0 +293,0 @@ "fixme", |
+13
-2
@@ -17,7 +17,18 @@ [](https://amqtt.readthedocs.io/en/latest/) | ||
| - Full set of [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) protocol specifications | ||
| - Communication over TCP and/or websocket, including support for SSL/TLS | ||
| - Communication over multiple TCP and/or websocket ports, including support for SSL/TLS | ||
| - Support QoS 0, QoS 1 and QoS 2 messages flow | ||
| - Client auto-reconnection on network lost | ||
| - Functionality expansion; plugins included: authentication and `$SYS` topic publishing | ||
| - Plugin framework for functionality expansion; included plugins: | ||
| - `$SYS` topic publishing | ||
| - AWS IOT-style shadow states | ||
| - x509 certificate authentication (including cli cert creation) | ||
| - Secure file-based password authentication | ||
| - Configuration-based topic authorization | ||
| - MySQL, Postgres & SQLite user and/or topic auth (including cli manager) | ||
| - External server (HTTP) user and/or topic auth | ||
| - LDAP user and/or topic auth | ||
| - JWT user and/or topic auth | ||
| - Fail over session persistence | ||
| ## Installation | ||
@@ -24,0 +35,0 @@ |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
425378
42%74
42.31%8260
42.98%