You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

amqtt

Package Overview
Dependencies
Maintainers
2
Versions
13
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

amqtt - pypi Package Compare versions

Comparing version
0.11.2
to
0.11.3
+47
amqtt/contrib/__init__.py
"""Module for contributed plugins."""
from dataclasses import asdict, is_dataclass
from typing import Any, TypeVar
from sqlalchemy import JSON, TypeDecorator
T = TypeVar("T")
class DataClassListJSON(TypeDecorator[list[dict[str, Any]]]):
impl = JSON
cache_ok = True
def __init__(self, dataclass_type: type[T]) -> None:
if not is_dataclass(dataclass_type):
msg = f"{dataclass_type} must be a dataclass type"
raise TypeError(msg)
self.dataclass_type = dataclass_type
super().__init__()
def process_bind_param(
self,
value: list[Any] | None, # Python -> DB
dialect: Any
) -> list[dict[str, Any]] | None:
if value is None:
return None
return [asdict(item) for item in value]
def process_result_value(
self,
value: list[dict[str, Any]] | None, # DB -> Python
dialect: Any
) -> list[Any] | None:
if value is None:
return None
return [self.dataclass_type(**item) for item in value]
def process_literal_param(self, value: Any, dialect: Any) -> Any:
# Required by SQLAlchemy, typically used for literal SQL rendering.
return value
@property
def python_type(self) -> type:
# Required by TypeEngine to indicate the expected Python type.
return list
"""Plugin to determine authentication of clients with DB storage."""
from dataclasses import dataclass
import click
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from .plugin import TopicAuthDBPlugin, UserAuthDBPlugin
class DBType(StrEnum):
"""Enumeration for supported relational databases."""
MARIA = "mariadb"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
SQLITE = "sqlite"
@dataclass
class DBInfo:
"""SQLAlchemy database information."""
connect_str: str
connect_port: int | None
_db_map = {
DBType.MARIA: DBInfo("mysql+aiomysql", 3306),
DBType.MYSQL: DBInfo("mysql+aiomysql", 3306),
DBType.POSTGRESQL: DBInfo("postgresql+asyncpg", 5432),
DBType.SQLITE: DBInfo("sqlite+aiosqlite", None)
}
def db_connection_str(db_type: DBType, db_username: str, db_host: str, db_port: int | None, db_filename: str) -> str:
"""Create sqlalchemy database connection string."""
db_info = _db_map[db_type]
if db_type == DBType.SQLITE:
return f"{db_info.connect_str}:///{db_filename}"
db_password = click.prompt("Enter the db password (press enter for none)", hide_input=True)
pwd = f":{db_password}" if db_password else ""
return f"{db_info.connect_str}://{db_username}:{pwd}@{db_host}:{db_port or db_info.connect_port}"
__all__ = ["DBType", "TopicAuthDBPlugin", "UserAuthDBPlugin", "db_connection_str"]
from collections.abc import Iterator
import logging
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from amqtt.contexts import Action
from amqtt.contrib.auth_db.models import AllowedTopic, Base, TopicAuth, UserAuth
from amqtt.errors import MQTTError
logger = logging.getLogger(__name__)
class UserManager:
def __init__(self, connection: str) -> None:
self._engine = create_async_engine(connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def db_sync(self) -> None:
"""Sync the database schema."""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@staticmethod
async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> UserAuth:
stmt = select(UserAuth).filter(UserAuth.username == username)
user_auth = await db_session.scalar(stmt)
if not user_auth:
msg = f"Username '{username}' doesn't exist."
logger.debug(msg)
raise MQTTError(msg)
return user_auth
async def get_user_auth(self, username: str) -> UserAuth | None:
"""Retrieve a user by username."""
async with self._db_session_maker() as db_session, db_session.begin():
try:
return await self._get_auth_or_raise(db_session, username)
except MQTTError:
return None
async def list_user_auths(self) -> Iterator[UserAuth]:
"""Return list of all clients."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(UserAuth).order_by(UserAuth.username)
users = await db_session.scalars(stmt)
if not users:
msg = "No users exist."
logger.info(msg)
raise MQTTError(msg)
return users
async def create_user_auth(self, username: str, plain_password: str) -> UserAuth | None:
"""Create a new user."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(UserAuth).filter(UserAuth.username == username)
user_auth = await db_session.scalar(stmt)
if user_auth:
msg = f"Username '{username}' already exists."
logger.info(msg)
raise MQTTError(msg)
user_auth = UserAuth(username=username)
user_auth.password = plain_password
db_session.add(user_auth)
await db_session.commit()
await db_session.flush()
return user_auth
async def delete_user_auth(self, username: str) -> UserAuth | None:
"""Delete a user."""
async with self._db_session_maker() as db_session, db_session.begin():
try:
user_auth = await self._get_auth_or_raise(db_session, username)
except MQTTError:
return None
await db_session.delete(user_auth)
await db_session.commit()
await db_session.flush()
return user_auth
async def update_user_auth_password(self, username: str, plain_password: str) -> UserAuth | None:
"""Change a user's password."""
async with self._db_session_maker() as db_session, db_session.begin():
user_auth = await self._get_auth_or_raise(db_session, username)
user_auth.password = plain_password
await db_session.commit()
await db_session.flush()
return user_auth
class TopicManager:
def __init__(self, connection: str) -> None:
self._engine = create_async_engine(connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def db_sync(self) -> None:
"""Sync the database schema."""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@staticmethod
async def _get_auth_or_raise(db_session: AsyncSession, username: str) -> TopicAuth:
stmt = select(TopicAuth).filter(TopicAuth.username == username)
topic_auth = await db_session.scalar(stmt)
if not topic_auth:
msg = f"Username '{username}' doesn't exist."
logger.debug(msg)
raise MQTTError(msg)
return topic_auth
@staticmethod
def _field_name(action: Action) -> str:
return f"{action}_acl"
async def create_topic_auth(self, username: str) -> TopicAuth | None:
"""Create a new user."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(TopicAuth).filter(TopicAuth.username == username)
topic_auth = await db_session.scalar(stmt)
if topic_auth:
msg = f"Username '{username}' already exists."
raise MQTTError(msg)
topic_auth = TopicAuth(username=username)
db_session.add(topic_auth)
await db_session.commit()
await db_session.flush()
return topic_auth
async def get_topic_auth(self, username: str) -> TopicAuth | None:
"""Retrieve a allowed topics by username."""
async with self._db_session_maker() as db_session, db_session.begin():
try:
return await self._get_auth_or_raise(db_session, username)
except MQTTError:
return None
async def list_topic_auths(self) -> Iterator[TopicAuth]:
"""Return list of all authorized clients."""
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(TopicAuth).order_by(TopicAuth.username)
topics = await db_session.scalars(stmt)
if not topics:
msg = "No topics exist."
logger.info(msg)
raise MQTTError(msg)
return topics
async def add_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
"""Add allowed topic from action for user."""
if action == Action.PUBLISH and topic.startswith("$"):
msg = "MQTT does not allow clients to publish to $ topics."
raise MQTTError(msg)
async with self._db_session_maker() as db_session, db_session.begin():
user_auth = await self._get_auth_or_raise(db_session, username)
topic_list = getattr(user_auth, self._field_name(action))
updated_list = [*topic_list, AllowedTopic(topic)]
setattr(user_auth, self._field_name(action), updated_list)
await db_session.commit()
await db_session.flush()
return updated_list
async def remove_allowed_topic(self, username: str, topic: str, action: Action) -> list[AllowedTopic] | None:
"""Remove topic from action for user."""
async with self._db_session_maker() as db_session, db_session.begin():
topic_auth = await self._get_auth_or_raise(db_session, username)
topic_list = topic_auth.get_topic_list(action)
if AllowedTopic(topic) not in topic_list:
msg = f"Client '{username}' doesn't have topic '{topic}' for action '{action}'."
logger.debug(msg)
raise MQTTError(msg)
updated_list = [allowed_topic for allowed_topic in topic_list if allowed_topic != AllowedTopic(topic)]
setattr(topic_auth, f"{action}_acl", updated_list)
await db_session.commit()
await db_session.flush()
return updated_list
from dataclasses import dataclass
import logging
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from sqlalchemy import String
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from amqtt.contexts import Action
from amqtt.contrib import DataClassListJSON
from amqtt.plugins import TopicMatcher
if TYPE_CHECKING:
from passlib.context import CryptContext
logger = logging.getLogger(__name__)
matcher = TopicMatcher()
@dataclass
class AllowedTopic:
topic: str
def __contains__(self, item: Union[str, "AllowedTopic"]) -> bool:
"""Determine `in`."""
return self.__eq__(item)
def __eq__(self, item: object) -> bool:
"""Determine `==` or `!=`."""
if isinstance(item, str):
return matcher.is_topic_allowed(item, self.topic)
if isinstance(item, AllowedTopic):
return item.topic == self.topic
msg = "AllowedTopic can only be compared to another AllowedTopic or string."
raise AttributeError(msg)
def __str__(self) -> str:
"""Display topic."""
return self.topic
def __repr__(self) -> str:
"""Display topic."""
return self.topic
class PasswordHasher:
"""singleton to initialize the CryptContext and then use it elsewhere in the code."""
_instance: Optional["PasswordHasher"] = None
def __init__(self) -> None:
if not hasattr(self, "_crypt_context"):
self._crypt_context: CryptContext | None = None
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "PasswordHasher":
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
@property
def crypt_context(self) -> "CryptContext":
if not self._crypt_context:
msg = "CryptContext is empty"
raise ValueError(msg)
return self._crypt_context
@crypt_context.setter
def crypt_context(self, value: "CryptContext") -> None:
self._crypt_context = value
class Base(DeclarativeBase):
pass
class UserAuth(Base):
__tablename__ = "user_auth"
id: Mapped[int] = mapped_column(primary_key=True)
username: Mapped[str] = mapped_column(String, unique=True)
_password_hash: Mapped[str] = mapped_column("password_hash", String(128))
publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
@hybrid_property
def password(self) -> None:
msg = "Password is write-only"
raise AttributeError(msg)
@password.inplace.setter # type: ignore[arg-type]
def _password_setter(self, plain_password: str) -> None:
self._password_hash = PasswordHasher().crypt_context.hash(plain_password)
def verify_password(self, plain_password: str) -> bool:
return bool(PasswordHasher().crypt_context.verify(plain_password, self._password_hash))
def __str__(self) -> str:
"""Display client id and password hash."""
return f"'{self.username}' with password hash: {self._password_hash}"
class TopicAuth(Base):
__tablename__ = "topic_auth"
id: Mapped[int] = mapped_column(primary_key=True)
username: Mapped[str] = mapped_column(String, unique=True)
publish_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
subscribe_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
receive_acl: Mapped[list[AllowedTopic]] = mapped_column(DataClassListJSON(AllowedTopic), default=list)
def get_topic_list(self, action: Action) -> list[AllowedTopic]:
return cast("list[AllowedTopic]", getattr(self, f"{action}_acl"))
def __str__(self) -> str:
"""Display client id and password hash."""
return f"""'{self.username}':
\tpublish: {self.publish_acl}, subscribe: {self.subscribe_acl}, receive: {self.receive_acl}
"""
from dataclasses import dataclass, field
import logging
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import create_async_engine
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.contrib.auth_db.managers import TopicManager, UserManager
from amqtt.contrib.auth_db.models import Base, PasswordHasher
from amqtt.errors import MQTTError
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
def default_hash_scheme() -> list[str]:
"""Create config dataclass defaults."""
return ["argon2", "bcrypt", "pbkdf2_sha256", "scrypt"]
class UserAuthDBPlugin(BaseAuthPlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
# access the singleton and set the proper crypt context
pwd_hasher = PasswordHasher()
pwd_hasher.crypt_context = CryptContext(schemes=self.config.hash_schemes, deprecated="auto")
self._user_manager = UserManager(self.config.connection)
self._engine = create_async_engine(f"{self.config.connection}")
async def on_broker_pre_start(self) -> None:
"""Sync the schema (if configured)."""
if not self.config.sync_schema:
return
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def authenticate(self, *, session: Session) -> bool | None:
"""Authenticate a client's session."""
if not session.username or not session.password:
return False
user_auth = await self._user_manager.get_user_auth(session.username)
if not user_auth:
return False
return bool(session.password) and user_auth.verify_password(session.password)
@dataclass
class Config:
"""Configuration for DB authentication."""
connection: str
"""SQLAlchemy connection string for the asyncio version of the database connector:
- `mysql+aiomysql://user:password@host:port/dbname`
- `postgresql+asyncpg://user:password@host:port/dbname`
- `sqlite+aiosqlite:///dbfilename.db`
"""
sync_schema: bool = False
"""Use SQLAlchemy to create / update the database schema."""
hash_schemes: list[str] = field(default_factory=default_hash_scheme)
"""list of hash schemes to use for passwords"""
class TopicAuthDBPlugin(BaseTopicPlugin):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self._topic_manager = TopicManager(self.config.connection)
self._engine = create_async_engine(f"{self.config.connection}")
async def on_broker_pre_start(self) -> None:
"""Sync the schema (if configured)."""
if not self.config.sync_schema:
return
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None:
if not session or not session.username or not topic:
return None
try:
topic_auth = await self._topic_manager.get_topic_auth(session.username)
topic_list = getattr(topic_auth, f"{action}_acl")
except MQTTError:
return False
return topic in topic_list
@dataclass
class Config:
"""Configuration for DB topic filtering."""
connection: str
"""SQLAlchemy connection string for the asyncio version of the database connector:
- `mysql+aiomysql://user:password@host:port/dbname`
- `postgresql+asyncpg://user:password@host:port/dbname`
- `sqlite+aiosqlite:///dbfilename.db`
"""
sync_schema: bool = False
"""Use SQLAlchemy to create / update the database schema."""
import asyncio
import contextlib
import logging
from pathlib import Path
from typing import Annotated
import typer
from amqtt.contexts import Action
from amqtt.contrib.auth_db import DBType, db_connection_str
from amqtt.contrib.auth_db.managers import TopicManager, UserManager
from amqtt.errors import MQTTError
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
topic_app = typer.Typer(no_args_is_help=True)
@topic_app.callback()
def main(
ctx: typer.Context,
db_type: Annotated[DBType, typer.Option("--db", "-d", help="db type", count=False)],
db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "",
db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0,
db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost",
db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db",
) -> None:
"""Command line interface to add / remove topic authorization.
Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database
password (if applicable).
If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.TopicManager` which provides
the underlying functionality to this command line interface.
"""
if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists():
pass
elif db_type == DBType.SQLITE and not Path(db_filename).exists():
logger.error(f"SQLite option could not find '{db_filename}'")
raise typer.Exit(code=1)
elif db_type != DBType.SQLITE and not db_username:
logger.error("DB access requires a username be provided.")
raise typer.Exit(code=1)
ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename}
@topic_app.command(name="sync")
def db_sync(ctx: typer.Context) -> None:
"""Create the table and schema for username and topic lists for subscribe, publish or receive.
Non-destructive if run multiple times. To clear the whole table, need to drop it manually.
"""
async def run_sync() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = UserManager(connect)
try:
await mgr.db_sync()
except MQTTError as me:
logger.critical("Could not sync schema on db.")
raise typer.Exit(code=1) from me
asyncio.run(run_sync())
logger.info("Success: database synced.")
@topic_app.command(name="list")
def list_clients(ctx: typer.Context) -> None:
"""List all Client IDs (in alphabetical order). Will also display the hashed passwords."""
async def run_list() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = TopicManager(connect)
user_count = 0
for user in await mgr.list_topic_auths():
user_count += 1
logger.info(user)
if not user_count:
logger.info("No client authorizations exist.")
asyncio.run(run_list())
@topic_app.command(name="add")
def add_topic_allowance(
ctx: typer.Context,
topic: Annotated[str, typer.Argument(help="list of topics", show_default=False)],
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client", show_default=False)],
action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow", show_default=False)]
) -> None:
"""Create a new user with a client id and password (prompted)."""
async def run_add() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = TopicManager(connect)
with contextlib.suppress(MQTTError):
await mgr.create_topic_auth(client_id)
topic_auth = await mgr.get_topic_auth(client_id)
if not topic_auth:
logger.info(f"Topic auth doesn't exist for '{client_id}'")
raise typer.Exit(code=1)
if topic in [allowed_topic.topic for allowed_topic in topic_auth.get_topic_list(action)]:
logger.info(f"Topic '{topic}' already exists for '{action}'.")
raise typer.Exit(1)
await mgr.add_allowed_topic(client_id, topic, action)
logger.info(f"Success: topic '{topic}' added to {action} for '{client_id}'")
asyncio.run(run_add())
@topic_app.command(name="rm")
def remove_topic_allowance(ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")],
action: Annotated[Action, typer.Option("--action", "-a", help="action for topic to allow")],
topic: Annotated[str, typer.Argument(help="list of topics")]
) -> None:
"""Remove a client from the authentication database."""
async def run_remove() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = TopicManager(connect)
topic_auth = await mgr.get_topic_auth(client_id)
if not topic_auth:
logger.info(f"client '{client_id}' doesn't exist.")
raise typer.Exit(1)
if topic not in getattr(topic_auth, f"{action}_acl"):
logger.info(f"Error: topic '{topic}' not in the {action} allow list for {client_id}.")
raise typer.Exit(1)
try:
await mgr.remove_allowed_topic(client_id, topic, action)
except MQTTError as me:
logger.info(f"'Error: could not remove '{topic}' for client '{client_id}'.")
raise typer.Exit(1) from me
logger.info(f"Success: removed topic '{topic}' from {action} for '{client_id}'")
asyncio.run(run_remove())
if __name__ == "__main__":
topic_app()
import asyncio
import logging
from pathlib import Path
from typing import Annotated
import click
import passlib
import typer
from amqtt.contrib.auth_db import DBType, db_connection_str
from amqtt.contrib.auth_db.managers import UserManager
from amqtt.errors import MQTTError
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
user_app = typer.Typer(no_args_is_help=True)
@user_app.callback()
def main(
ctx: typer.Context,
db_type: Annotated[DBType, typer.Option(..., "--db", "-d", help="db type", show_default=False)],
db_username: Annotated[str, typer.Option("--username", "-u", help="db username", show_default=False)] = "",
db_port: Annotated[int, typer.Option("--port", "-p", help="database port (defaults to db type)", show_default=False)] = 0,
db_host: Annotated[str, typer.Option("--host", "-h", help="database host")] = "localhost",
db_filename: Annotated[str, typer.Option("--file", "-f", help="database file name (sqlite only)")] = "auth.db",
) -> None:
"""Command line interface to list, create, remove and add clients.
Passwords are not allowed to be passed via the command line for security reasons. You will be prompted for database
password (if applicable) and the client id's password.
If you need to create users programmatically, see `amqtt.contrib.auth_db.managers.UserManager` which provides
the underlying functionality to this command line interface.
"""
if db_type == DBType.SQLITE and ctx.invoked_subcommand == "sync" and not Path(db_filename).exists():
pass
elif db_type == DBType.SQLITE and not Path(db_filename).exists():
logger.error(f"SQLite option could not find '{db_filename}'")
raise typer.Exit(code=1)
elif db_type != DBType.SQLITE and not db_username:
logger.error("DB access requires a username be provided.")
raise typer.Exit(code=1)
ctx.obj = {"type": db_type, "username": db_username, "host": db_host, "port": db_port, "filename": db_filename}
@user_app.command(name="sync")
def db_sync(ctx: typer.Context) -> None:
"""Create the table and schema for username and hashed password.
Non-destructive if run multiple times. To clear the whole table, need to drop it manually.
"""
async def run_sync() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = UserManager(connect)
try:
await mgr.db_sync()
except MQTTError as me:
logger.critical("Could not sync schema on db.")
raise typer.Exit(code=1) from me
asyncio.run(run_sync())
logger.info("Success: database synced.")
@user_app.command(name="list")
def list_user_auths(ctx: typer.Context) -> None:
"""List all Client IDs (in alphabetical order). Will also display the hashed passwords."""
async def run_list() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"], ctx.obj["filename"])
mgr = UserManager(connect)
user_count = 0
for user in await mgr.list_user_auths():
user_count += 1
logger.info(user)
if not user_count:
logger.info("No client authentications exist.")
asyncio.run(run_list())
@user_app.command(name="add")
def create_user_auth(
ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")],
) -> None:
"""Create a new user with a client id and password (prompted)."""
async def run_create() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = UserManager(connect)
client_password = click.prompt("Enter the client's password", hide_input=True)
if not client_password.strip():
logger.info("Error: client password cannot be empty.")
raise typer.Exit(1)
try:
user = await mgr.create_user_auth(client_id, client_password.strip())
except passlib.exc.MissingBackendError as mbe:
logger.info(f"Please install backend: {mbe}")
raise typer.Exit(code=1) from mbe
if not user:
logger.info(f"Error: could not create user: {client_id}")
raise typer.Exit(code=1)
logger.info(f"Success: created {user}")
asyncio.run(run_create())
@user_app.command(name="rm")
def remove_user_auth(ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the client to remove")]) -> None:
"""Remove a client from the authentication database."""
async def run_remove() -> None:
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = UserManager(connect)
user = await mgr.get_user_auth(client_id)
if not user:
logger.info(f"Error: client '{client_id}' does not exist.")
raise typer.Exit(1)
if not click.confirm(f"Please confirm the removal of '{client_id}'?"):
raise typer.Exit(0)
user = await mgr.delete_user_auth(client_id)
if not user:
logger.info(f"Error: client '{client_id}' does not exist.")
raise typer.Exit(1)
logger.info(f"Success: '{user.username}' was removed.")
asyncio.run(run_remove())
@user_app.command(name="pwd")
def change_password(
ctx: typer.Context,
client_id: Annotated[str, typer.Option("--client-id", "-c", help="id for the new client")],
) -> None:
"""Update a user's password (prompted)."""
async def run_password() -> None:
client_password = click.prompt("Enter the client's new password", hide_input=True)
if not client_password.strip():
logger.error("Error: client password cannot be empty.")
raise typer.Exit(1)
connect = db_connection_str(ctx.obj["type"], ctx.obj["username"], ctx.obj["host"], ctx.obj["port"],
ctx.obj["filename"])
mgr = UserManager(connect)
await mgr.update_user_auth_password(client_id, client_password.strip())
logger.info(f"Success: client '{client_id}' password updated.")
asyncio.run(run_password())
if __name__ == "__main__":
user_app()
from dataclasses import dataclass
from datetime import datetime, timedelta
try:
from datetime import UTC
except ImportError:
# support for python 3.10
from datetime import timezone
UTC = timezone.utc
from ipaddress import IPv4Address
import logging
from pathlib import Path
import re
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509 import Certificate, CertificateSigningRequest
from cryptography.x509.oid import NameOID
from amqtt.plugins.base import BaseAuthPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class UserAuthCertPlugin(BaseAuthPlugin):
"""Used a *signed* x509 certificate's `Subject AlternativeName` or `SAN` to verify client authentication.
Often used for IoT devices, this method provides the most secure form of identification. A root
certificate, often referenced as a CA certificate -- either issued by a known authority (such as LetsEncrypt)
or a self-signed certificate) is used to sign a private key and certificate for the server. Each device/client
also gets a unique private key and certificate signed by the same CA certificate; also included in the device
certificate is a 'SAN' or SubjectAlternativeName which is the device's unique identifier.
Since both server and device certificates are signed by the same CA certificate, the client can
verify the server's authenticity; and the server can verify the client's authenticity. And since
the device's certificate contains a x509 SAN, the server (with this plugin) can identify the device securely.
!!! note "URI and Client ID configuration"
`uri_domain` configuration must be set to the same uri used to generate the device credentials
when a device is connecting with private key and certificate, the `client_id` must
match the device id used to generate the device credentials.
Available ore three scripts to help with the key generation and certificate signing: `ca_creds`, `server_creds`
and `device_creds`.
!!! note "Configuring broker & client for using Self-signed root CA"
If using self-signed root credentials, the `cafile` configuration for both broker and client need to be
configured with `cafile` set to the `ca.crt`.
"""
async def authenticate(self, *, session: Session) -> bool | None:
"""Verify the client's session using the provided client's x509 certificate."""
if not session.ssl_object:
return False
der_cert = session.ssl_object.getpeercert(binary_form=True)
if der_cert:
cert = x509.load_der_x509_certificate(der_cert, backend=default_backend())
try:
san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
uris = san.value.get_values_for_type(x509.UniformResourceIdentifier)
if self.config.uri_domain not in uris[0]:
return False
pattern = rf"^spiffe://{re.escape(self.config.uri_domain)}/device/([^/]+)$"
match = re.match(pattern, uris[0])
if not match:
return False
return match.group(1) == session.client_id
except x509.ExtensionNotFound:
logger.warning("No SAN extension found.")
return False
@dataclass
class Config:
"""Configuration for the CertificateAuthPlugin."""
uri_domain: str
"""The domain that is expected as part of the device certificate's spiffe (e.g. test.amqtt.io)"""
def generate_root_creds(country: str, state: str, locality: str,
org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Generate CA key and certificate."""
# generate private key for the server
ca_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
# Create certificate subject and issuer (self-signed)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
x509.NameAttribute(NameOID.LOCALITY_NAME, locality),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name),
x509.NameAttribute(NameOID.COMMON_NAME, cn),
])
# 3. Build self-signed certificate
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(ca_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(UTC))
.not_valid_after(datetime.now(UTC) + timedelta(days=3650)) # 10 years
.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()),
critical=False,
)
.add_extension(
x509.KeyUsage(
key_cert_sign=True,
crl_sign=True,
digital_signature=False,
key_encipherment=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.sign(ca_key, hashes.SHA256())
)
return ca_key, cert
def generate_server_csr(country: str, org_name: str, cn: str) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
"""Generate server private key and server certificate-signing-request."""
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
csr = (
x509.CertificateSigningRequestBuilder()
.subject_name(x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name),
x509.NameAttribute(NameOID.COMMON_NAME, cn),
]))
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName(cn),
x509.IPAddress(IPv4Address("127.0.0.1")),
]),
critical=False,
)
.sign(key, hashes.SHA256())
)
return key, csr
def generate_device_csr(country: str, org_name: str, common_name: str,
uri_san: str, dns_san: str
) -> tuple[rsa.RSAPrivateKey, CertificateSigningRequest]:
"""Generate a device key and a csr."""
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
csr = (
x509.CertificateSigningRequestBuilder()
.subject_name(x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, org_name
),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]))
.add_extension(
x509.SubjectAlternativeName([
x509.UniformResourceIdentifier(uri_san),
x509.DNSName(dns_san),
]),
critical=False,
)
.sign(key, hashes.SHA256())
)
return key, csr
def sign_csr(csr: CertificateSigningRequest,
ca_key: rsa.RSAPrivateKey,
ca_cert: Certificate, validity_days: int = 365) -> Certificate:
"""Sign a csr with CA credentials."""
return (
x509.CertificateBuilder()
.subject_name(csr.subject)
.issuer_name(ca_cert.subject)
.public_key(csr.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(UTC))
.not_valid_after(datetime.now(UTC) + timedelta(days=validity_days))
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
.add_extension(
csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value,
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()), # type: ignore[arg-type]
critical=False,
)
.sign(ca_key, hashes.SHA256())
)
def load_ca(ca_key_fn: str, ca_crt_fn: str) -> tuple[rsa.RSAPrivateKey, Certificate]:
"""Load server key and certificate."""
with Path(ca_key_fn).open("rb") as f:
ca_key: rsa.RSAPrivateKey = serialization.load_pem_private_key(f.read(), password=None) # type: ignore[assignment]
with Path(ca_crt_fn).open("rb") as f:
ca_cert = x509.load_pem_x509_certificate(f.read())
return ca_key, ca_cert
def write_key_and_crt(key: rsa.RSAPrivateKey, crt: Certificate,
prefix: str, path: Path | None = None) -> None:
"""Create pem-encoded files for key and certificate."""
path = path or Path()
crt_fn = path / f"{prefix}.crt"
key_fn = path / f"{prefix}.key"
with crt_fn.open("wb") as f:
f.write(crt.public_bytes(serialization.Encoding.PEM))
with key_fn.open("wb") as f:
f.write(key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.TraditionalOpenSSL,
serialization.NoEncryption()
))
from dataclasses import dataclass
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
import logging
from typing import Any
from aiohttp import ClientResponse, ClientSession, FormData
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class ResponseMode(StrEnum):
STATUS = "status"
JSON = "json"
TEXT = "text"
class RequestMethod(StrEnum):
GET = "get"
POST = "post"
PUT = "put"
class ParamsMode(StrEnum):
JSON = "json"
FORM = "form"
class ACLError(Exception):
pass
HTTP_2xx_MIN = 200
HTTP_2xx_MAX = 299
HTTP_4xx_MIN = 400
HTTP_4xx_MAX = 499
@dataclass
class HttpConfig:
"""Configuration for the HTTP Auth & ACL Plugin."""
host: str
"""hostname of the server for the auth & acl check"""
port: int
"""port of the server for the auth & acl check"""
request_method: RequestMethod = RequestMethod.GET
"""send the request as a GET, POST or PUT"""
params_mode: ParamsMode = ParamsMode.JSON # see docs/plugins/http.md for additional details
"""send the request with `JSON` or `FORM` data. *additional details below*"""
response_mode: ResponseMode = ResponseMode.JSON # see docs/plugins/http.md for additional details
"""expected response from the auth/acl server. `STATUS` (code), `JSON`, or `TEXT`. *additional details below*"""
with_tls: bool = False
"""http or https"""
user_agent: str = "amqtt"
"""the 'User-Agent' header sent along with the request"""
superuser_uri: str | None = None
"""URI to verify if the user is a superuser (e.g. '/superuser'), `None` if superuser is not supported"""
timeout: int = 5
"""duration, in seconds, to wait for the HTTP server to respond"""
class AuthHttpPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.http = ClientSession(headers={"User-Agent": self.config.user_agent})
match self.config.request_method:
case RequestMethod.GET:
self.method = self.http.get
case RequestMethod.PUT:
self.method = self.http.put
case _:
self.method = self.http.post
async def on_broker_pre_shutdown(self) -> None:
await self.http.close()
@staticmethod
def _is_2xx(r: ClientResponse) -> bool:
return HTTP_2xx_MIN <= r.status <= HTTP_2xx_MAX
@staticmethod
def _is_4xx(r: ClientResponse) -> bool:
return HTTP_4xx_MIN <= r.status <= HTTP_4xx_MAX
def _get_params(self, payload: dict[str, Any]) -> dict[str, Any]:
match self.config.params_mode:
case ParamsMode.FORM:
match self.config.request_method:
case RequestMethod.GET:
kwargs = {"params": payload}
case _: # POST, PUT
d: Any = FormData(payload)
kwargs = {"data": d}
case _: # JSON
kwargs = {"json": payload}
return kwargs
async def _send_request(self, url: str, payload: dict[str, Any]) -> bool | None: # pylint: disable=R0911
kwargs = self._get_params(payload)
async with self.method(url, **kwargs) as r:
logger.debug(f"http request returned {r.status}")
match self.config.response_mode:
case ResponseMode.TEXT:
return self._is_2xx(r) and (await r.text()).lower() == "ok"
case ResponseMode.STATUS:
if self._is_2xx(r):
return True
if self._is_4xx(r):
return False
# any other code
return None
case _:
if not self._is_2xx(r):
return False
data: dict[str, Any] = await r.json()
data = {k.lower(): v for k, v in data.items()}
return data.get("ok", None)
def get_url(self, uri: str) -> str:
return f"{'https' if self.config.with_tls else 'http'}://{self.config.host}:{self.config.port}{uri}"
class UserAuthHttpPlugin(AuthHttpPlugin, BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
d = {"username": session.username, "password": session.password, "client_id": session.client_id}
return await self._send_request(self.get_url(self.config.user_uri), d)
@dataclass
class Config(HttpConfig):
"""Configuration for the HTTP Auth Plugin."""
user_uri: str = "/user"
"""URI of the auth check."""
class TopicAuthHttpPlugin(AuthHttpPlugin, BaseTopicPlugin):
async def topic_filtering(self, *,
session: Session | None = None,
topic: str | None = None,
action: Action | None = None) -> bool | None:
if not session:
return None
acc = 0
match action:
case Action.PUBLISH:
acc = 2
case Action.SUBSCRIBE:
acc = 4
case Action.RECEIVE:
acc = 1
d = {"username": session.username, "client_id": session.client_id, "topic": topic, "acc": acc}
return await self._send_request(self.get_url(self.config.topic_uri), d)
@dataclass
class Config(HttpConfig):
"""Configuration for the HTTP Topic Plugin."""
topic_uri: str = "/acl"
"""URI of the topic check."""
from dataclasses import dataclass
import logging
from typing import ClassVar
import jwt
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class Algorithms(StrEnum):
ES256 = "ES256"
ES256K = "ES256K"
ES384 = "ES384"
ES512 = "ES512"
ES521 = "ES521"
EdDSA = "EdDSA"
HS256 = "HS256"
HS384 = "HS384"
HS512 = "HS512"
PS256 = "PS256"
PS384 = "PS384"
PS512 = "PS512"
RS256 = "RS256"
RS384 = "RS384"
RS512 = "RS512"
class UserAuthJwtPlugin(BaseAuthPlugin):
async def authenticate(self, *, session: Session) -> bool | None:
if not session.username or not session.password:
return None
try:
decoded_payload = jwt.decode(session.password, self.config.secret_key, algorithms=["HS256"])
return bool(decoded_payload.get(self.config.user_claim, None) == session.username)
except jwt.ExpiredSignatureError:
logger.debug(f"jwt for '{session.username}' is expired")
return False
except jwt.InvalidTokenError:
logger.debug(f"jwt for '{session.username}' is invalid")
return False
@dataclass
class Config:
"""Configuration for the JWT user authentication."""
secret_key: str
"""Secret key to decrypt the token."""
user_claim: str
"""Payload key for user name."""
algorithm: str = "HS256"
"""Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""
class TopicAuthJwtPlugin(BaseTopicPlugin):
_topic_jwt_claims: ClassVar = {
Action.PUBLISH: "publish_claim",
Action.SUBSCRIBE: "subscribe_claim",
Action.RECEIVE: "receive_claim",
}
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.topic_matcher = TopicMatcher()
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None:
if not session or not topic or not action:
return None
if not session.password:
return None
try:
decoded_payload = jwt.decode(session.password.encode(), self.config.secret_key, algorithms=["HS256"])
claim = getattr(self.config, self._topic_jwt_claims[action])
return any(self.topic_matcher.is_topic_allowed(topic, a_filter) for a_filter in decoded_payload.get(claim, []))
except jwt.ExpiredSignatureError:
logger.debug(f"jwt for '{session.username}' is expired")
return False
except jwt.InvalidTokenError:
logger.debug(f"jwt for '{session.username}' is invalid")
return False
@dataclass
class Config:
"""Configuration for the JWT topic authorization."""
secret_key: str
"""Secret key to decrypt the token."""
publish_claim: str
"""Payload key for contains a list of permissible publish topics."""
subscribe_claim: str
"""Payload key for contains a list of permissible subscribe topics."""
receive_claim: str
"""Payload key for contains a list of permissible receive topics."""
algorithm: str = "HS256"
"""Algorithm to use for token encryption: 'ES256', 'ES256K', 'ES384', 'ES512', 'ES521', 'EdDSA', 'HS256',
'HS384', 'HS512', 'PS256', 'PS384', 'PS512', 'RS256', 'RS384', 'RS512'"""
from dataclasses import dataclass
import logging
from typing import ClassVar
import ldap
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.errors import PluginInitError
from amqtt.plugins import TopicMatcher
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
@dataclass
class LdapConfig:
"""Configuration for the LDAP Plugins."""
server: str
"""uri formatted server location. e.g `ldap://localhost:389`"""
base_dn: str
"""distinguished name (dn) of the ldap server. e.g. `dc=amqtt,dc=io`"""
user_attribute: str
"""attribute in ldap entry to match the username against"""
bind_dn: str
"""distinguished name (dn) of known, preferably read-only, user. e.g. `cn=admin,dc=amqtt,dc=io`"""
bind_password: str
"""password for known, preferably read-only, user"""
class AuthLdapPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.conn = ldap.initialize(self.config.server)
self.conn.protocol_version = ldap.VERSION3 # pylint: disable=E1101
try:
self.conn.simple_bind_s(self.config.bind_dn, self.config.bind_password)
except ldap.INVALID_CREDENTIALS as e: # pylint: disable=E1101
raise PluginInitError(self.__class__) from e
class UserAuthLdapPlugin(AuthLdapPlugin, BaseAuthPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
async def authenticate(self, *, session: Session) -> bool | None:
# use our initial creds to see if the user exists
search_filter = f"({self.config.user_attribute}={session.username})"
result = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, ["dn"]) # pylint: disable=E1101
if not result:
logger.debug(f"user not found: {session.username}")
return False
try:
# `search_s` responds with list of tuples: (dn, entry); first in list is our match
user_dn = result[0][0]
except IndexError:
return False
try:
user_conn = ldap.initialize(self.config.server)
user_conn.simple_bind_s(user_dn, session.password)
except ldap.INVALID_CREDENTIALS: # pylint: disable=E1101
logger.debug(f"invalid credentials for '{session.username}'")
return False
except ldap.LDAPError as e: # pylint: disable=E1101
logger.debug(f"LDAP error during user bind: {e}")
return False
return True
@dataclass
class Config(LdapConfig):
"""Configuration for the User Auth LDAP Plugin."""
class TopicAuthLdapPlugin(AuthLdapPlugin, BaseTopicPlugin):
"""Plugin to authenticate a user with an LDAP directory server."""
_action_attr_map: ClassVar = {
Action.PUBLISH: "publish_attribute",
Action.SUBSCRIBE: "subscribe_attribute",
Action.RECEIVE: "receive_attribute"
}
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self.topic_matcher = TopicMatcher()
async def topic_filtering(
self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool | None:
# if not provided needed criteria, can't properly evaluate topic filtering
if not session or not action or not topic:
return None
search_filter = f"({self.config.user_attribute}={session.username})"
attrs = [
"cn",
self.config.publish_attribute,
self.config.subscribe_attribute,
self.config.receive_attribute
]
results = self.conn.search_s(self.config.base_dn, ldap.SCOPE_SUBTREE, search_filter, attrs) # pylint: disable=E1101
if not results:
logger.debug(f"user not found: {session.username}")
return False
if len(results) > 1:
found_users = [dn for dn, _ in results]
logger.debug(f"multiple users found: {', '.join(found_users)}")
return False
dn, entry = results[0]
ldap_attribute = getattr(self.config, self._action_attr_map[action])
topic_filters = [t.decode("utf-8") for t in entry.get(ldap_attribute, [])]
logger.debug(f"DN: {dn} - {ldap_attribute}={topic_filters}")
return self.topic_matcher.are_topics_allowed(topic, topic_filters)
@dataclass
class Config(LdapConfig):
"""Configuration for the LDAPAuthPlugin."""
publish_attribute: str
"""LDAP attribute which contains a list of permissible publish topics."""
subscribe_attribute: str
"""LDAP attribute which contains a list of permissible subscribe topics."""
receive_attribute: str
"""LDAP attribute which contains a list of permissible receive topics."""
from dataclasses import dataclass
import logging
from pathlib import Path
from sqlalchemy import Boolean, Integer, LargeBinary, Result, String, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from amqtt.broker import BrokerContext, RetainedApplicationMessage
from amqtt.contrib import DataClassListJSON
from amqtt.errors import PluginError
from amqtt.plugins.base import BasePlugin
from amqtt.session import Session
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
pass
@dataclass
class RetainedMessage:
topic: str
data: str
qos: int
@dataclass
class Subscription:
topic: str
qos: int
class StoredSession(Base):
__tablename__ = "stored_sessions"
id: Mapped[int] = mapped_column(primary_key=True)
client_id: Mapped[str] = mapped_column(String)
clean_session: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
will_flag: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
will_message: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
will_qos: Mapped[int | None] = mapped_column(Integer, nullable=True, default=None)
will_retain: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=None)
will_topic: Mapped[str | None] = mapped_column(String, nullable=True, default=None)
keep_alive: Mapped[int] = mapped_column(Integer, default=0)
retained: Mapped[list[RetainedMessage]] = mapped_column(DataClassListJSON(RetainedMessage), default=list)
subscriptions: Mapped[list[Subscription]] = mapped_column(DataClassListJSON(Subscription), default=list)
class StoredMessage(Base):
__tablename__ = "stored_messages"
id: Mapped[int] = mapped_column(primary_key=True)
topic: Mapped[str] = mapped_column(String)
data: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True, default=None)
qos: Mapped[int] = mapped_column(Integer, default=0)
class SessionDBPlugin(BasePlugin[BrokerContext]):
"""Plugin to store session information and retained topic messages in the event that the broker terminates abnormally.
Configuration:
- file *(string)* path & filename to store the session db. default: `amqtt.db`
- clear_on_shutdown *(bool)* if the broker shutdowns down normally, don't retain any information. default: `True`
"""
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
# bypass the `test_plugins_correct_has_attr` until it can be updated
if not hasattr(self.config, "file"):
logger.warning("`Config` is missing a `file` attribute")
return
self._engine = create_async_engine(f"sqlite+aiosqlite:///{self.config.file}")
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
@staticmethod
async def _get_or_create_session(db_session: AsyncSession, client_id: str) -> StoredSession:
stmt = select(StoredSession).filter(StoredSession.client_id == client_id)
stored_session = await db_session.scalar(stmt)
if stored_session is None:
stored_session = StoredSession(client_id=client_id)
db_session.add(stored_session)
await db_session.flush()
return stored_session
@staticmethod
async def _get_or_create_message(db_session: AsyncSession, topic: str) -> StoredMessage:
stmt = select(StoredMessage).filter(StoredMessage.topic == topic)
stored_message = await db_session.scalar(stmt)
if stored_message is None:
stored_message = StoredMessage(topic=topic)
db_session.add(stored_message)
await db_session.flush()
return stored_message
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
"""Search to see if session already exists."""
# if client id doesn't exist, create (can ignore if session is anonymous)
# update session information (will, clean_session, etc)
# don't store session information for clean or anonymous sessions
if client_session.clean_session in (None, True) or client_session.is_anonymous:
return
async with self._db_session_maker() as db_session, db_session.begin():
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.clean_session = client_session.clean_session
stored_session.will_flag = client_session.will_flag
stored_session.will_message = client_session.will_message # type: ignore[assignment]
stored_session.will_qos = client_session.will_qos
stored_session.will_retain = client_session.will_retain
stored_session.will_topic = client_session.will_topic
stored_session.keep_alive = client_session.keep_alive
await db_session.flush()
async def on_broker_client_subscribed(self, client_id: str, topic: str, qos: int) -> None:
"""Create/update subscription if clean session = false."""
session = self.context.get_session(client_id)
if not session:
logger.warning(f"'{client_id}' is subscribing but doesn't have a session")
return
if session.clean_session:
return
async with self._db_session_maker() as db_session, db_session.begin():
# stored sessions shouldn't need to be created here, but we'll use the same helper...
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.subscriptions = [*stored_session.subscriptions, Subscription(topic, qos)]
await db_session.flush()
async def on_broker_client_unsubscribed(self, client_id: str, topic: str) -> None:
"""Remove subscription if clean session = false."""
async def on_broker_retained_message(self, *, client_id: str | None, retained_message: RetainedApplicationMessage) -> None:
"""Update to retained messages.
if retained_message.data is None or '', the message is being cleared
"""
# if client_id is valid, the retained message is for a disconnected client
if client_id is not None:
async with self._db_session_maker() as db_session, db_session.begin():
# stored sessions shouldn't need to be created here, but we'll use the same helper...
stored_session = await self._get_or_create_session(db_session, client_id)
stored_session.retained = [*stored_session.retained, RetainedMessage(retained_message.topic,
retained_message.data.decode(),
retained_message.qos or 0)]
await db_session.flush()
return
async with self._db_session_maker() as db_session, db_session.begin():
# if the retained message has data, we need to store/update for the topic
if retained_message.data:
client_message = await self._get_or_create_message(db_session, retained_message.topic)
client_message.data = retained_message.data # type: ignore[assignment]
client_message.qos = retained_message.qos or 0
await db_session.flush()
return
# if there is no data, clear the stored message (if exists) for the topic
stmt = select(StoredMessage).filter(StoredMessage.topic == retained_message.topic)
topic_message = await db_session.scalar(stmt)
if topic_message is not None:
await db_session.delete(topic_message)
await db_session.flush()
return
async def on_broker_pre_start(self) -> None:
"""Initialize the database and db connection."""
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def on_broker_post_start(self) -> None:
"""Load subscriptions."""
if len(self.context.subscriptions) > 0:
msg = "SessionDBPlugin : broker shouldn't have any subscriptions yet"
raise PluginError(msg)
if len(list(self.context.sessions)) > 0:
msg = "SessionDBPlugin : broker shouldn't have any sessions yet"
raise PluginError(msg)
async with self._db_session_maker() as db_session, db_session.begin():
stmt = select(StoredSession)
stored_sessions = await db_session.execute(stmt)
restored_sessions = 0
for stored_session in stored_sessions.scalars():
await self.context.add_subscription(stored_session.client_id, None, None)
for subscription in stored_session.subscriptions:
await self.context.add_subscription(stored_session.client_id,
subscription.topic,
subscription.qos)
session = self.context.get_session(stored_session.client_id)
if not session:
continue
session.clean_session = stored_session.clean_session
session.will_flag = stored_session.will_flag
session.will_message = stored_session.will_message
session.will_qos = stored_session.will_qos
session.will_retain = stored_session.will_retain
session.will_topic = stored_session.will_topic
session.keep_alive = stored_session.keep_alive
for message in stored_session.retained:
retained_message = RetainedApplicationMessage(
source_session=None,
topic=message.topic,
data=message.data.encode(),
qos=message.qos
)
await session.retained_messages.put(retained_message)
restored_sessions += 1
stmt = select(StoredMessage)
stored_messages: Result[tuple[StoredMessage]] = await db_session.execute(stmt)
restored_messages = 0
retained_messages = self.context.retained_messages
for stored_message in stored_messages.scalars():
retained_messages[stored_message.topic] = (RetainedApplicationMessage(
source_session=None,
topic=stored_message.topic,
data=stored_message.data or b"",
qos=stored_message.qos
))
restored_messages += 1
logger.info(f"Retained messages restored: {restored_messages}")
logger.info(f"Restored {restored_sessions} sessions.")
async def on_broker_pre_shutdown(self) -> None:
"""Clean up the db connection."""
await self._engine.dispose()
async def on_broker_post_shutdown(self) -> None:
if self.config.clear_on_shutdown and self.config.file.exists():
self.config.file.unlink()
@dataclass
class Config:
"""Configuration variables."""
file: str | Path = "amqtt.db"
"""path & filename to store the sqlite session db."""
clear_on_shutdown: bool = True
"""if the broker shutdowns down normally, don't retain any information."""
def __post_init__(self) -> None:
"""Create `Path` from string path."""
if isinstance(self.file, str):
self.file = Path(self.file)
"""Module for the shadow state plugin."""
from .plugin import ShadowPlugin, ShadowTopicAuthPlugin
from .states import ShadowOperation
__all__ = ["ShadowOperation", "ShadowPlugin", "ShadowTopicAuthPlugin"]
from collections.abc import MutableMapping
from dataclasses import dataclass, fields, is_dataclass
import json
from typing import Any
from amqtt.contrib.shadows.states import MetaTimestamp, ShadowOperation, State, StateDocument
def asdict_no_none(obj: Any) -> Any:
"""Create dictionary from dataclass, but eliminate any key set to `None`."""
if is_dataclass(obj):
result = {}
for f in fields(obj):
value = getattr(obj, f.name)
if value is not None:
result[f.name] = asdict_no_none(value)
return result
if isinstance(obj, list):
return [asdict_no_none(item) for item in obj if item is not None]
if isinstance(obj, dict):
return {
key: asdict_no_none(value)
for key, value in obj.items()
if value is not None
}
return obj
def create_shadow_topic(device_id: str, shadow_name: str, message_op: "ShadowOperation") -> str:
"""Create a shadow topic for message type."""
return f"$shadow/{device_id}/{shadow_name}/{message_op}"
class ShadowMessage:
def to_message(self) -> bytes:
return json.dumps(asdict_no_none(self)).encode("utf-8")
@dataclass
class GetAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
metadata: State[MetaTimestamp]
timestamp: int
version: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_ACCEPT)
@dataclass
class GetRejectedMessage(ShadowMessage):
code: int
message: str
timestamp: int | None = None
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.GET_REJECT)
@dataclass
class UpdateAcceptedMessage(ShadowMessage):
state: State[dict[str, Any]]
metadata: State[MetaTimestamp]
timestamp: int
version: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_ACCEPT)
@dataclass
class UpdateRejectedMessage(ShadowMessage):
code: int
message: str
timestamp: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_REJECT)
@dataclass
class UpdateDeltaMessage(ShadowMessage):
state: MutableMapping[str, Any]
metadata: MutableMapping[str, Any]
timestamp: int
version: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DELTA)
class UpdateIotaMessage(UpdateDeltaMessage):
"""Same format, corollary name."""
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_IOTA)
@dataclass
class UpdateDocumentMessage(ShadowMessage):
previous: StateDocument
current: StateDocument
timestamp: int
@staticmethod
def topic(device_id: str, shadow_name: str) -> str:
return create_shadow_topic(device_id, shadow_name, ShadowOperation.UPDATE_DOCUMENTS)
from collections.abc import Sequence
from dataclasses import asdict
import logging
import time
from typing import Any, Optional
import uuid
from sqlalchemy import JSON, CheckConstraint, Integer, String, UniqueConstraint, desc, event, func, select
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, Session, make_transient, mapped_column
from amqtt.contrib.shadows.states import StateDocument
logger = logging.getLogger(__name__)
class ShadowUpdateError(Exception):
def __init__(self, message: str = "updating an existing Shadow is not allowed") -> None:
super().__init__(message)
class ShadowBase(DeclarativeBase):
pass
async def sync_shadow_base(connection: AsyncConnection) -> None:
"""Create tables and table schemas."""
await connection.run_sync(ShadowBase.metadata.create_all)
def default_state_document() -> dict[str, Any]:
"""Create a default (empty) state document, factory for model field."""
return asdict(StateDocument())
class Shadow(ShadowBase):
__tablename__ = "shadows_shadow"
id: Mapped[str | None] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
device_id: Mapped[str] = mapped_column(String(128), nullable=False)
name: Mapped[str] = mapped_column(String(128), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
_state: Mapped[dict[str, Any]] = mapped_column("state", JSON, nullable=False, default=dict)
created_at: Mapped[int] = mapped_column(Integer, default=lambda: int(time.time()), nullable=False)
__table_args__ = (
CheckConstraint("version > 0", name="check_quantity_positive"),
UniqueConstraint("device_id", "name", "version", name="uq_device_id_name_version"),
)
@property
def state(self) -> StateDocument:
if not self._state:
return StateDocument()
return StateDocument.from_dict(self._state)
@state.setter
def state(self, value: StateDocument) -> None:
self._state = asdict(value)
@classmethod
async def latest_version(cls, session: AsyncSession, device_id: str, name: str) -> Optional["Shadow"]:
"""Get the latest version of the shadow associated with the device and name."""
stmt = (
select(cls).where(
cls.device_id == device_id,
cls.name == name
).order_by(desc(cls.version)).limit(1)
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
@classmethod
async def all(cls, session: AsyncSession, device_id: str, name: str) -> Sequence["Shadow"]:
"""Return a list of all shadows associated with the device and name."""
stmt = (
select(cls).where(
cls.device_id == device_id,
cls.name == name
).order_by(desc(cls.version)))
result = await session.execute(stmt)
return result.scalars().all()
@event.listens_for(Shadow, "before_insert")
def assign_incremental_version(_: Mapper[Any], connection: Session, target: "Shadow") -> None:
"""Get the latest version of the state document."""
stmt = (
select(func.max(Shadow.version))
.where(
Shadow.device_id == target.device_id,
Shadow.name == target.name
)
)
result = connection.execute(stmt).scalar_one_or_none()
target.version = (result or 0) + 1
@event.listens_for(Shadow, "before_update")
def prevent_update(_mapper: Mapper[Any], _session: Session, _instance: "Shadow") -> None:
"""Prevent existing shadow from being updated."""
raise ShadowUpdateError
@event.listens_for(Session, "before_flush")
def convert_update_to_insert(session: Session, _flush_context: object, _instances: object | None) -> None:
"""Force a shadow to insert a new version, instead of updating an existing."""
# Make a copy of the dirty set so we can safely mutate the session
dirty = list(session.dirty)
for obj in dirty:
if not session.is_modified(obj, include_collections=False):
continue # skip unchanged
# You can scope this to a particular class
if not isinstance(obj, Shadow):
continue
# Clone logic: convert update into insert
session.expunge(obj) # remove from session
make_transient(obj) # remove identity and history
obj.id = "" # clear primary key
obj.version += 1 # bump version or modify fields
session.add(obj) # re-add as new object
_listener_example = '''#
# @event.listens_for(Shadow, "before_insert")
# def convert_state_document_to_json(_1: Mapper[Any], _2: Session, target: "Shadow") -> None:
# """Listen for insertion and convert state document to json."""
# if not isinstance(target.state, StateDocument):
# msg = "'state' field needs to be a StateDocument"
# raise TypeError(msg)
#
# target.state = target.state.to_dict()
'''
from collections import defaultdict
from dataclasses import dataclass, field
import json
import re
from typing import Any
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from amqtt.broker import BrokerContext
from amqtt.contexts import Action
from amqtt.contrib.shadows.messages import (
GetAcceptedMessage,
GetRejectedMessage,
UpdateAcceptedMessage,
UpdateDeltaMessage,
UpdateDocumentMessage,
UpdateIotaMessage,
)
from amqtt.contrib.shadows.models import Shadow, sync_shadow_base
from amqtt.contrib.shadows.states import (
ShadowOperation,
StateDocument,
calculate_delta_update,
calculate_iota_update,
)
from amqtt.plugins.base import BasePlugin, BaseTopicPlugin
from amqtt.session import ApplicationMessage, Session
shadow_topic_re = re.compile(r"^\$shadow/(?P<client_id>[a-zA-Z0-9_-]+?)/(?P<shadow_name>[a-zA-Z0-9_-]+?)/(?P<request>get|update)")
DeviceID = str
ShadowName = str
@dataclass
class ShadowTopic:
device_id: DeviceID
name: ShadowName
message_op: ShadowOperation
def shadow_dict() -> dict[DeviceID, dict[ShadowName, StateDocument]]:
"""Nested defaultdict for shadow cache."""
return defaultdict(shadow_dict) # type: ignore[arg-type]
class ShadowPlugin(BasePlugin[BrokerContext]):
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
self._shadows: dict[DeviceID, dict[ShadowName, StateDocument]] = defaultdict(dict)
self._engine = create_async_engine(self.config.connection)
self._db_session_maker = async_sessionmaker(self._engine, expire_on_commit=False)
async def on_broker_pre_start(self) -> None:
"""Sync the schema."""
async with self._engine.begin() as conn:
await sync_shadow_base(conn)
@staticmethod
def shadow_topic_match(topic: str) -> ShadowTopic | None:
"""Check if topic matches the shadow topic format."""
# pattern is "$shadow/<username>/<shadow_name>/get, update, etc
match = shadow_topic_re.search(topic)
if match:
groups = match.groupdict()
return ShadowTopic(groups["client_id"], groups["shadow_name"], ShadowOperation(groups["request"]))
return None
async def _handle_get(self, st: ShadowTopic) -> None:
"""Send 'accepted."""
async with self._db_session_maker() as db_session, db_session.begin():
shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
if not shadow:
reject_msg = GetRejectedMessage(
code=404,
message="shadow not found",
)
await self.context.broadcast_message(reject_msg.topic(st.device_id, st.name), reject_msg.to_message())
return
accept_msg = GetAcceptedMessage(
state=shadow.state.state,
metadata=shadow.state.metadata,
timestamp=shadow.created_at,
version=shadow.version
)
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
async def _handle_update(self, st: ShadowTopic, update: dict[str, Any]) -> None:
async with self._db_session_maker() as db_session, db_session.begin():
shadow = await Shadow.latest_version(db_session, st.device_id, st.name)
if not shadow:
shadow = Shadow(device_id=st.device_id, name=st.name)
state_update = StateDocument.from_dict(update)
prev_state = shadow.state or StateDocument()
prev_state.version = shadow.version or 0 # only required when generating shadow messages
prev_state.timestamp = shadow.created_at or 0 # only required when generating shadow messages
next_state = prev_state + state_update
shadow.state = next_state
db_session.add(shadow)
await db_session.commit()
next_state.version = shadow.version
next_state.timestamp = shadow.created_at
accept_msg = UpdateAcceptedMessage(
state=next_state.state,
metadata=next_state.metadata,
timestamp=123,
version=1
)
await self.context.broadcast_message(accept_msg.topic(st.device_id, st.name), accept_msg.to_message())
delta_msg = UpdateDeltaMessage(
state=calculate_delta_update(next_state.state.desired, next_state.state.reported),
metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
version=shadow.version,
timestamp=shadow.created_at
)
await self.context.broadcast_message(delta_msg.topic(st.device_id, st.name), delta_msg.to_message())
iota_msg = UpdateIotaMessage(
state=calculate_iota_update(next_state.state.desired, next_state.state.reported),
metadata=calculate_delta_update(next_state.metadata.desired, next_state.metadata.reported),
version=shadow.version,
timestamp=shadow.created_at
)
await self.context.broadcast_message(iota_msg.topic(st.device_id, st.name), iota_msg.to_message())
doc_msg = UpdateDocumentMessage(
previous=prev_state,
current=next_state,
timestamp=shadow.created_at
)
await self.context.broadcast_message(doc_msg.topic(st.device_id, st.name), doc_msg.to_message())
async def on_broker_message_received(self, *, client_id: str, message: ApplicationMessage) -> None:
"""Process a message that was received from a client."""
topic = message.topic
if not topic.startswith("$shadow"): # this is less overhead than do the full regular expression match
return
if not (shadow_topic := self.shadow_topic_match(topic)):
return
match shadow_topic.message_op:
case ShadowOperation.GET:
await self._handle_get(shadow_topic)
case ShadowOperation.UPDATE:
await self._handle_update(shadow_topic, json.loads(message.data.decode("utf-8")))
@dataclass
class Config:
"""Configuration for shadow plugin."""
connection: str
"""SQLAlchemy connection string for the asyncio version of the database connector:
- `mysql+aiomysql://user:password@host:port/dbname`
- `postgresql+asyncpg://user:password@host:port/dbname`
- `sqlite+aiosqlite:///dbfilename.db`
"""
class ShadowTopicAuthPlugin(BaseTopicPlugin):
async def topic_filtering(self, *,
session: Session | None = None,
topic: str | None = None,
action: Action | None = None) -> bool | None:
session = session or Session()
if not topic:
return False
shadow_topic = ShadowPlugin.shadow_topic_match(topic)
if not shadow_topic:
return False
return shadow_topic.device_id == session.username or session.username in self.config.superusers
@dataclass
class Config:
"""Configuration for only allowing devices access to their own shadow topics."""
superusers: list[str] = field(default_factory=list)
"""A list of one or more usernames that can write to any device topic,
primarily for the central app sending updates to devices."""
from collections import Counter
from collections.abc import MutableMapping
from dataclasses import dataclass, field
try:
from enum import StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
import time
from typing import Any, Generic, TypeVar
from mergedeep import merge
C = TypeVar("C", bound=Any)
class StateError(Exception):
def __init__(self, msg: str = "'state' field is required") -> None:
super().__init__(msg)
@dataclass
class MetaTimestamp:
timestamp: int = 0
def __eq__(self, other: object) -> bool:
"""Compare timestamps."""
if isinstance(other, int):
return self.timestamp == other
if isinstance(other, self.__class__):
return self.timestamp == other.timestamp
msg = "needs to be int or MetaTimestamp"
raise ValueError(msg)
# numeric operations to make this dataclass transparent
def __abs__(self) -> int:
"""Absolute timestamp."""
return self.timestamp
def __add__(self, other: int) -> int:
"""Add to a timestamp."""
return self.timestamp + other
def __sub__(self, other: int) -> int:
"""Subtract from a timestamp."""
return self.timestamp - other
def __mul__(self, other: int) -> int:
"""Multiply a timestamp."""
return self.timestamp * other
def __float__(self) -> float:
"""Convert timestamp to float."""
return float(self.timestamp)
def __int__(self) -> int:
"""Convert timestamp to int."""
return int(self.timestamp)
def __lt__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp < other
def __le__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp <= other
def __gt__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp > other
def __ge__(self, other: int) -> bool:
"""Compare timestamp."""
return self.timestamp >= other
def create_metadata(state: MutableMapping[str, Any], timestamp: int) -> dict[str, Any]:
"""Create metadata (timestamps) for each of the keys in 'state'."""
metadata: dict[str, Any] = {}
for key, value in state.items():
if isinstance(value, dict):
metadata[key] = create_metadata(value, timestamp)
elif value is None:
metadata[key] = None
else:
metadata[key] = MetaTimestamp(timestamp)
return metadata
def calculate_delta_update(desired: MutableMapping[str, Any],
reported: MutableMapping[str, Any],
depth: bool = True,
exclude_nones: bool = True,
ordered_lists: bool = True) -> dict[str, Any]:
"""Calculate state differences between desired and reported."""
diff_dict = {}
for key, value in desired.items():
if value is None and exclude_nones:
continue
# if the desired has an element that the reported does not...
if key not in reported:
diff_dict[key] = value
# if the desired has an element that's a list, but the list is
elif isinstance(value, list) and not ordered_lists:
if Counter(value) != Counter(reported[key]):
diff_dict[key] = value
elif isinstance(value, dict) and depth:
# recurse, report when there is a difference
obj_diff = calculate_delta_update(value, reported[key])
if obj_diff:
diff_dict[key] = obj_diff
elif value != reported[key]:
diff_dict[key] = value
return diff_dict
def calculate_iota_update(desired: MutableMapping[str, Any], reported: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
"""Calculate state differences between desired and reported (including missing keys)."""
delta = calculate_delta_update(desired, reported, depth=False, exclude_nones=False)
for key in reported:
if key not in desired:
delta[key] = None
return delta
@dataclass
class State(Generic[C]):
desired: MutableMapping[str, C] = field(default_factory=dict)
reported: MutableMapping[str, C] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: dict[str, C]) -> "State[C]":
"""Create state from dictionary."""
return cls(
desired=data.get("desired", {}),
reported=data.get("reported", {})
)
def __bool__(self) -> bool:
"""Determine if state is empty."""
return bool(self.desired) or bool(self.reported)
def __add__(self, other: "State[C]") -> "State[C]":
"""Merge states together."""
return State(
desired=merge({}, self.desired, other.desired),
reported=merge({}, self.reported, other.reported)
)
@dataclass
class StateDocument:
state: State[dict[str, Any]] = field(default_factory=State)
metadata: State[MetaTimestamp] = field(default_factory=State)
version: int | None = None # only required when generating shadow messages
timestamp: int | None = None # only required when generating shadow messages
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "StateDocument":
"""Create state document from dictionary."""
now = int(time.time())
if data and "state" not in data:
raise StateError
state = State.from_dict(data.get("state", {}))
metadata = State(
desired=create_metadata(state.desired, now),
reported=create_metadata(state.reported, now)
)
return cls(state=state, metadata=metadata)
def __post_init__(self) -> None:
"""Initialize meta data if not provided."""
now = int(time.time())
if not self.metadata:
self.metadata = State(
desired=create_metadata(self.state.desired, now),
reported=create_metadata(self.state.reported, now),
)
def __add__(self, other: "StateDocument") -> "StateDocument":
"""Merge two state documents together."""
return StateDocument(
state=self.state + other.state,
metadata=self.metadata + other.metadata
)
class ShadowOperation(StrEnum):
GET = "get"
UPDATE = "update"
GET_ACCEPT = "get/accepted"
GET_REJECT = "get/rejected"
UPDATE_ACCEPT = "update/accepted"
UPDATE_REJECT = "update/rejected"
UPDATE_DOCUMENTS = "update/documents"
UPDATE_DELTA = "update/delta"
UPDATE_IOTA = "update/iota"
import logging
from pathlib import Path
import sys
import typer
logger = logging.getLogger(__name__)
app = typer.Typer(add_completion=False, rich_markup_mode=None)
def main() -> None:
"""Run the cli for `ca_creds`."""
app()
@app.command()
def ca_creds(
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
state: str = typer.Option(..., "--state", help="x509 'state_or_province_name' attribute"),
locality: str = typer.Option(..., "--locality", help="x509 'locality_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
) -> None:
"""Generate a self-signed key and certificate to be used as the root CA, with a key size of 2048 and a 1-year expiration."""
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
try:
from amqtt.contrib.cert import generate_root_creds, write_key_and_crt # pylint: disable=import-outside-toplevel
except ImportError:
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
logger.critical(msg)
sys.exit(1)
ca_key, ca_crt = generate_root_creds(country=country, state=state, locality=locality, org_name=org_name, cn=cn)
write_key_and_crt(ca_key, ca_crt, "ca", Path(output_dir))
if __name__ == "__main__":
main()
import logging
from pathlib import Path
import sys
import typer
logger = logging.getLogger(__name__)
app = typer.Typer(add_completion=False, rich_markup_mode=None)
def main() -> None:
"""Run the `device_creds` cli."""
app()
@app.command()
def device_creds( # pylint: disable=too-many-locals
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
device_id: str = typer.Option(..., "--device-id", help="device id for the SAN"),
uri: str = typer.Option(..., "--uri", help="domain name for device SAN"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="root key filename used for signing."),
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="root cert filename used for signing."),
) -> None:
"""Generate a key and certificate for each device in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa: E501
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
try:
from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel
generate_device_csr,
load_ca,
sign_csr,
write_key_and_crt,
)
except ImportError:
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
logger.critical(msg)
sys.exit(1)
ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn)
uri_san = f"spiffe://{uri}/device/{device_id}"
dns_san = f"{device_id}.local"
device_key, device_csr = generate_device_csr(
country=country,
org_name=org_name,
common_name=device_id,
uri_san=uri_san,
dns_san=dns_san
)
device_crt = sign_csr(device_csr, ca_key, ca_crt)
write_key_and_crt(device_key, device_crt, device_id, Path(output_dir))
logger.info(f"✅ Created: {device_id}.crt and {device_id}.key")
if __name__ == "__main__":
main()
import logging
import sys
from amqtt.errors import MQTTError
logger = logging.getLogger(__name__)
def main() -> None:
"""Run the auth db cli."""
try:
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
except ImportError:
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
sys.exit(1)
from amqtt.contrib.auth_db.topic_mgr_cli import topic_app # pylint: disable=import-outside-toplevel
try:
topic_app()
except ModuleNotFoundError as mnfe:
logger.critical(f"Please install database-specific dependencies: {mnfe}")
sys.exit(1)
except ValueError as ve:
if "greenlet" in f"{ve}":
logger.critical("Please install database-specific dependencies: 'greenlet'")
sys.exit(1)
logger.critical(f"Unknown error: {ve}")
sys.exit(1)
except MQTTError as me:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()
import logging
import sys
from amqtt.errors import MQTTError
logger = logging.getLogger(__name__)
def main() -> None:
"""Run the auth db cli."""
try:
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
except ImportError:
logger.critical("optional 'contrib' library is missing, please install: `pip install amqtt[contrib]`")
sys.exit(1)
from amqtt.contrib.auth_db.user_mgr_cli import user_app # pylint: disable=import-outside-toplevel
try:
user_app()
except ModuleNotFoundError as mnfe:
logger.critical(f"Please install database-specific dependencies: {mnfe}")
sys.exit(1)
except ValueError as ve:
if "greenlet" in f"{ve}":
logger.critical("Please install database-specific dependencies: 'greenlet'")
sys.exit(1)
logger.critical(f"Unknown error: {ve}")
sys.exit(1)
except MQTTError as me:
logger.critical(f"could not execute command: {me}")
sys.exit(1)
if __name__ == "__main__":
main()
import logging
from pathlib import Path
import sys
import typer
logger = logging.getLogger(__name__)
app = typer.Typer(add_completion=False, rich_markup_mode=None)
def main() -> None:
"""Run the `server_creds` cli."""
app()
@app.command()
def server_creds(
country: str = typer.Option(..., "--country", help="x509 'country_name' attribute"),
org_name: str = typer.Option(..., "--org-name", help="x509 'organization_name' attribute"),
cn: str = typer.Option(..., "--cn", help="x509 'common_name' attribute"),
output_dir: str = typer.Option(Path.cwd().absolute(), "--output-dir", help="output directory"),
ca_key_fn: str = typer.Option("ca.key", "--ca-key", help="server key output filename."),
ca_crt_fn: str = typer.Option("ca.crt", "--ca-crt", help="server cert output filename."),
) -> None:
"""Generate a key and certificate for the broker in pem format, signed by the provided CA credentials. With a key size of 2048 and a 1-year expiration.""" # noqa : E501
formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=formatter)
try:
from amqtt.contrib.cert import ( # pylint: disable=import-outside-toplevel
generate_server_csr,
load_ca,
sign_csr,
write_key_and_crt,
)
except ImportError:
msg = "Requires installation of the optional 'contrib' package: `pip install amqtt[contrib]`"
logger.critical(msg)
sys.exit(1)
ca_key, ca_crt = load_ca(ca_key_fn, ca_crt_fn)
server_key, server_csr = generate_server_csr(country=country, org_name=org_name, cn=cn)
server_crt = sign_csr(server_csr, ca_key, ca_crt)
write_key_and_crt(server_key, server_crt, "server", Path(output_dir))
if __name__ == "__main__":
main()
+3
-0

@@ -7,2 +7,5 @@ #------- Package & Cache Files -------

*.pem
*.crt
*.key
*.patch

@@ -9,0 +12,0 @@ #------- Environment Files -------

+1
-1
"""INIT."""
__version__ = "0.11.2"
__version__ = "0.11.3"

@@ -6,2 +6,4 @@ from abc import ABC, abstractmethod

import logging
import ssl
from typing import cast

@@ -57,2 +59,7 @@ from websockets import ConnectionClosed

@abstractmethod
def get_ssl_info(self) -> ssl.SSLObject | None:
"""Return peer certificate information (if available) used to establish a TLS session."""
raise NotImplementedError
@abstractmethod
async def close(self) -> None:

@@ -126,2 +133,5 @@ """Close the protocol connection."""

def get_ssl_info(self) -> ssl.SSLObject | None:
return cast("ssl.SSLObject", self._protocol.transport.get_extra_info("ssl_object"))
async def close(self) -> None:

@@ -176,2 +186,5 @@ await self._protocol.close()

def get_ssl_info(self) -> ssl.SSLObject | None:
return cast("ssl.SSLObject", self._writer.get_extra_info("ssl_object"))
async def close(self) -> None:

@@ -211,2 +224,5 @@ if not self.is_closed:

def get_ssl_info(self) -> ssl.SSLObject | None:
return None
def __init__(self, buffer: bytes = b"") -> None:

@@ -213,0 +229,0 @@ self._stream = io.BytesIO(buffer)

@@ -5,8 +5,8 @@ import asyncio

from collections.abc import Generator
import copy
from functools import partial
import logging
from pathlib import Path
from math import floor
import re
import ssl
import time
from typing import Any, ClassVar, TypeAlias

@@ -26,7 +26,7 @@

)
from amqtt.contexts import Action, BaseContext
from amqtt.contexts import Action, BaseContext, BrokerConfig, ListenerConfig, ListenerType
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError
from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import format_client_message, gen_client_id, read_yaml_config
from amqtt.utils import format_client_message, gen_client_id

@@ -38,9 +38,4 @@ from .events import BrokerEvents

_CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]
_BROADCAST: TypeAlias = dict[str, Session | str | bytes | bytearray | int | None]
_defaults = read_yaml_config(Path(__file__).parent / "scripts/default_broker.yaml")
# Default port numbers

@@ -63,2 +58,4 @@ DEFAULT_PORTS = {"tcp": 1883, "ws": 8883}

class Server:
"""Used to encapsulate the server associated with a listener. Allows broker to interact with the connection lifecycle."""
def __init__(

@@ -101,18 +98,32 @@ self,

class ExternalServer(Server):
"""For external listeners, the connection lifecycle is handled by that implementation so these are no-ops."""
def __init__(self) -> None:
super().__init__("aiohttp", None) # type: ignore[arg-type]
async def acquire_connection(self) -> None:
pass
def release_connection(self) -> None:
pass
async def close_instance(self) -> None:
pass
class BrokerContext(BaseContext):
"""BrokerContext is used as the context passed to plugins interacting with the broker.
"""Used to provide the server's context as well as public methods for accessing internal state."""
It act as an adapter to broker services from plugins developed for HBMQTT broker.
"""
def __init__(self, broker: "Broker") -> None:
super().__init__()
self.config: _CONFIG_LISTENER | None = None
self.config: BrokerConfig | None = None
self._broker_instance = broker
async def broadcast_message(self, topic: str, data: bytes, qos: int | None = None) -> None:
"""Send message to all client sessions subscribing to `topic`."""
await self._broker_instance.internal_message_broadcast(topic, data, qos)
def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
self._broker_instance.retain_message(None, topic_name, data, qos)
async def retain_message(self, topic_name: str, data: bytes | bytearray, qos: int | None = None) -> None:
await self._broker_instance.retain_message(None, topic_name, data, qos)

@@ -124,2 +135,6 @@ @property

def get_session(self, client_id: str) -> Session | None:
"""Return the session associated with `client_id`, if it exists."""
return self._broker_instance.sessions.get(client_id, (None, None))[0]
@property

@@ -133,3 +148,17 @@ def retained_messages(self) -> dict[str, RetainedApplicationMessage]:

async def add_subscription(self, client_id: str, topic: str | None, qos: int | None) -> None:
"""Create a topic subscription for the given `client_id`.
If a client session doesn't exist for `client_id`, create a disconnected session.
If `topic` and `qos` are both `None`, only create the client session.
"""
if client_id not in self._broker_instance.sessions:
broker_handler, session = self._broker_instance.create_offline_session(client_id)
self._broker_instance._sessions[client_id] = (session, broker_handler) # noqa: SLF001
if topic is not None and qos is not None:
session, _ = self._broker_instance.sessions[client_id]
await self._broker_instance.add_subscription((topic, qos), session)
class Broker:

@@ -139,3 +168,3 @@ """MQTT 3.1.1 compliant broker implementation.

Args:
config: dictionary of configuration options (see [broker configuration](broker_config.md)).
config: `BrokerConfig` or dictionary of equivalent structure options (see [broker configuration](broker_config.md)).
loop: asyncio loop. defaults to `asyncio.new_event_loop()`.

@@ -145,3 +174,5 @@ plugin_namespace: plugin namespace to use when loading plugin entry_points. defaults to `amqtt.broker.plugins`.

Raises:
BrokerError, ParserError, PluginError
BrokerError: problem with broker configuration
PluginImportError: if importing a plugin from configuration
PluginInitError: if initialization plugin fails

@@ -162,3 +193,3 @@ """

self,
config: _CONFIG_LISTENER | None = None,
config: BrokerConfig | dict[str, Any] | None = None,
loop: asyncio.AbstractEventLoop | None = None,

@@ -169,11 +200,11 @@ plugin_namespace: str | None = None,

self.logger = logging.getLogger(__name__)
self.config = copy.deepcopy(_defaults or {})
if config is not None:
# if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
if ("auth" in config or "topic-check" in config) and "plugins" not in config:
# set to None so that the config isn't updated with the new-style default plugin list
config["plugins"] = None # type: ignore[assignment]
self.config.update(config)
self._build_listeners_config(self.config)
if isinstance(config, dict):
self.config = BrokerConfig.from_dict(config)
else:
self.config = config or BrokerConfig()
# listeners are populated from default within BrokerConfig
self.listeners_config = self.config.listeners
self._loop = loop or asyncio.get_running_loop()

@@ -196,2 +227,5 @@ self._servers: dict[str, Server] = {}

# Task for session monitor
self._session_monitor_task: asyncio.Task[Any] | None = None
# Initialize plugins manager

@@ -204,22 +238,2 @@

def _build_listeners_config(self, broker_config: _CONFIG_LISTENER) -> None:
self.listeners_config = {}
try:
listeners_config = broker_config.get("listeners")
if not isinstance(listeners_config, dict):
msg = "Listener config not found or invalid"
raise BrokerError(msg)
defaults = listeners_config.get("default")
if defaults is None:
msg = "Listener config has not default included or is invalid"
raise BrokerError(msg)
for listener_name, listener_conf in listeners_config.items():
config = defaults.copy()
config.update(listener_conf)
self.listeners_config[listener_name] = config
except KeyError as ke:
msg = f"Listener config not found or invalid: {ke}"
raise BrokerError(msg) from ke
def _init_states(self) -> None:

@@ -261,2 +275,3 @@ self.transitions = Machine(states=Broker.states, initial="new")

self._broadcast_task = asyncio.ensure_future(self._broadcast_loop())
self._session_monitor_task = asyncio.create_task(self._session_monitor())
self.logger.debug("Broker started")

@@ -279,14 +294,23 @@ except Exception as e:

try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
# for listeners which are external, don't need to create a server
if listener.type == ListenerType.EXTERNAL:
instance = await self._create_server_instance(listener_name, listener["type"], address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
# broker still needs to associate a new connection to the listener
self.logger.info(f"External listener exists for '{listener_name}' ")
self._servers[listener_name] = ExternalServer()
else:
# for tcp and websockets, start servers to listen for inbound connections
try:
address, port = self._split_bindaddr_port(listener["bind"], DEFAULT_PORTS[listener["type"]])
except ValueError as e:
msg = f"Invalid port value in bind value: {listener['bind']}"
raise BrokerError(msg) from e
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
instance = await self._create_server_instance(listener_name, listener.type, address, port, ssl_context)
self._servers[listener_name] = Server(listener_name, instance, max_connections)
def _create_ssl_context(self, listener: dict[str, Any]) -> ssl.SSLContext:
self.logger.info(f"Listener '{listener_name}' bind to {listener['bind']} (max_connections={max_connections})")
@staticmethod
def _create_ssl_context(listener: ListenerConfig) -> ssl.SSLContext:
"""Create an SSL context for a listener."""

@@ -313,3 +337,3 @@ try:

listener_name: str,
listener_type: str,
listener_type: ListenerType,
address: str | None,

@@ -320,21 +344,52 @@ port: int,

"""Create a server instance for a listener."""
if listener_type == "tcp":
return await asyncio.start_server(
partial(self.stream_connected, listener_name=listener_name),
address,
port,
reuse_address=True,
ssl=ssl_context,
)
if listener_type == "ws":
return await websockets.serve(
partial(self.ws_connected, listener_name=listener_name),
address,
port,
ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")],
)
msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg)
match listener_type:
case ListenerType.TCP:
return await asyncio.start_server(
partial(self.stream_connected, listener_name=listener_name),
address,
port,
reuse_address=True,
ssl=ssl_context,
)
case ListenerType.WS:
return await websockets.serve(
partial(self.ws_connected, listener_name=listener_name),
address,
port,
ssl=ssl_context,
subprotocols=[websockets.Subprotocol("mqtt")],
)
case _:
msg = f"Unsupported listener type: {listener_type}"
raise BrokerError(msg)
async def _session_monitor(self) -> None:
self.logger.info("Starting session expiration monitor.")
while True:
session_count_before = len(self._sessions)
# clean or anonymous sessions don't retain messages (or subscriptions); the session can be filtered out
sessions_to_remove = [client_id for client_id, (session, _) in self._sessions.items()
if session.transitions.state == "disconnected" and (session.is_anonymous or session.clean_session)]
# if session expiration is enabled, check to see if any of the sessions are disconnected and past expiration
if self.config.session_expiry_interval is not None:
retain_after = floor(time.time() - self.config.session_expiry_interval)
sessions_to_remove += [client_id for client_id, (session, _) in self._sessions.items()
if session.transitions.state == "disconnected" and
session.last_disconnect_time and
session.last_disconnect_time < retain_after]
for client_id in sessions_to_remove:
await self._cleanup_session(client_id)
if session_count_before > (session_count_after := len(self._sessions)):
self.logger.debug(f"Expired {session_count_before - session_count_after} sessions")
await asyncio.sleep(1)
async def shutdown(self) -> None:

@@ -357,2 +412,4 @@ """Stop broker instance."""

await self._shutdown_broadcast_loop()
if self._session_monitor_task:
self._session_monitor_task.cancel()

@@ -393,2 +450,6 @@ for server in self._servers.values():

async def external_connected(self, reader: ReaderAdapter, writer: WriterAdapter, listener_name: str) -> None:
"""Engage the broker in handling the data stream to/from an established connection."""
await self._client_connected(listener_name, reader, writer)
async def _client_connected(self, listener_name: str, reader: ReaderAdapter, writer: WriterAdapter) -> None:

@@ -465,3 +526,13 @@ """Handle a new client connection."""

self.logger.debug(f"Found old session {self._sessions[client_session.client_id]!r}")
client_session, _ = self._sessions[client_session.client_id]
# even though the session previously existed, the new connection can bring updated configuration and credentials
existing_client_session, _ = self._sessions[client_session.client_id]
existing_client_session.will_flag = client_session.will_flag
existing_client_session.will_message = client_session.will_message
existing_client_session.will_topic = client_session.will_topic
existing_client_session.will_qos = client_session.will_qos
existing_client_session.keep_alive = client_session.keep_alive
existing_client_session.username = client_session.username
existing_client_session.password = client_session.password
client_session = existing_client_session
client_session.parent = 1

@@ -478,2 +549,10 @@ else:

def create_offline_session(self, client_id: str) -> tuple[BrokerProtocolHandler, Session]:
session = Session()
session.client_id = client_id
bph = BrokerProtocolHandler(self.plugins_manager, session)
session.transitions.disconnect()
return bph, session
async def _handle_client_session(

@@ -530,6 +609,4 @@ self,

for topic in self._subscriptions:
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
await self._publish_retained_messages_for_subscription((topic, QOS_0), client_session)
await self._client_message_loop(client_session, handler)

@@ -565,3 +642,2 @@

if subscribe_waiter in done:

@@ -621,3 +697,3 @@ await self._handle_subscription(client_session, handler, subscribe_waiter)

if client_session.will_retain:
self.retain_message(
await self.retain_message(
client_session,

@@ -637,3 +713,2 @@ client_session.will_topic,

async def _handle_subscription(

@@ -648,3 +723,3 @@ self,

subscriptions = subscribe_waiter.result()
return_codes = [await self._add_subscription(subscription, client_session) for subscription in subscriptions.topics]
return_codes = [await self.add_subscription(subscription, client_session) for subscription in subscriptions.topics]
await handler.mqtt_acknowledge_subscription(subscriptions.packet_id, return_codes)

@@ -689,2 +764,9 @@ for index, subscription in enumerate(subscriptions.topics):

# notify of a message's receipt, even if a client isn't necessarily allowed to send it
await self.plugins_manager.fire_event(
BrokerEvents.MESSAGE_RECEIVED,
client_id=client_session.client_id,
message=app_message,
)
if app_message is None:

@@ -711,6 +793,7 @@ self.logger.debug("app_message was empty!")

if not permitted:
self.logger.info(f"{client_session.client_id} forbidden TOPIC {app_message.topic} sent in PUBLISH message.")
self.logger.info(f"{client_session.client_id} not allowed to publish to TOPIC {app_message.topic}.")
else:
# notify that a received message is valid and is allowed to be distributed to other clients
await self.plugins_manager.fire_event(
BrokerEvents.MESSAGE_RECEIVED,
BrokerEvents.MESSAGE_BROADCAST,
client_id=client_session.client_id,

@@ -721,3 +804,3 @@ message=app_message,

if app_message.publish_packet and app_message.publish_packet.retain_flag:
self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
await self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
return True

@@ -735,6 +818,7 @@

await handler.stop()
except Exception:
# a failure in stopping a handler shouldn't cause the broker to fail
except asyncio.QueueEmpty:
self.logger.exception("Failed to stop handler")
async def _authenticate(self, session: Session, _: dict[str, Any]) -> bool:
async def _authenticate(self, session: Session, _: ListenerConfig) -> bool:
"""Call the authenticate method on registered plugins to test user authentication.

@@ -752,3 +836,3 @@

results = [ result for _, result in returns.items() if result is not None] if returns else []
results = [result for _, result in returns.items() if result is not None] if returns else []
if len(results) < 1:

@@ -767,3 +851,3 @@ self.logger.debug("Authentication failed: no plugin responded with a boolean")

def retain_message(
async def retain_message(
self,

@@ -779,8 +863,21 @@ source_session: Session | None,

self._retained_messages[topic_name] = RetainedApplicationMessage(source_session, topic_name, data, qos)
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=None,
retained_message=self._retained_messages[topic_name])
# [MQTT-3.3.1-10]
elif topic_name in self._retained_messages:
self.logger.debug(f"Clearing retained messages for topic '{topic_name}'")
cleared_message = self._retained_messages[topic_name]
cleared_message.data = b""
await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=None,
retained_message=cleared_message)
del self._retained_messages[topic_name]
async def _add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
async def add_subscription(self, subscription: tuple[str, int], session: Session) -> int:
topic_filter, qos = subscription

@@ -889,3 +986,4 @@ if "#" in topic_filter and not topic_filter.endswith("#"):

self.logger.info(f"Task has been cancelled: {task}")
except Exception:
# if a task fails, don't want it to cause the broker to fail
except Exception: # pylint: disable=W0718
self.logger.exception(f"Task failed and will be skipped: {task}")

@@ -929,2 +1027,8 @@

sendable = await self._topic_filtering(target_session, topic=broadcast["topic"], action=Action.RECEIVE)
if not sendable:
self.logger.info(
f"{target_session.client_id} not allowed to receive messages from TOPIC {broadcast['topic']}.")
continue
# Retain all messages which cannot be broadcasted, due to the session not being connected

@@ -972,2 +1076,6 @@ # but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]

await self.plugins_manager.fire_event(BrokerEvents.RETAINED_MESSAGE,
client_id=target_session.client_id,
retained_message=retained_message)
if self.logger.isEnabledFor(logging.DEBUG):

@@ -974,0 +1082,0 @@ self.logger.debug(f"target_session.retained_messages={target_session.retained_messages.qsize()}")

@@ -5,6 +5,4 @@ import asyncio

import contextlib
import copy
from functools import wraps
import logging
from pathlib import Path
import ssl

@@ -23,3 +21,3 @@ from typing import TYPE_CHECKING, Any, TypeAlias, cast

)
from amqtt.contexts import BaseContext
from amqtt.contexts import BaseContext, ClientConfig
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError

@@ -31,3 +29,3 @@ from amqtt.mqtt.connack import CONNECTION_ACCEPTED

from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session
from amqtt.utils import gen_client_id, read_yaml_config
from amqtt.utils import gen_client_id

@@ -37,5 +35,3 @@ if TYPE_CHECKING:

_defaults: dict[str, Any] | None = read_yaml_config(Path(__file__).parent / "scripts/default_client.yaml")
class ClientContext(BaseContext):

@@ -49,3 +45,3 @@ """ClientContext is used as the context passed to plugins interacting with the client.

super().__init__()
self.config = None
self.config: ClientConfig | None = None

@@ -87,22 +83,23 @@

class MQTTClient:
"""MQTT client implementation.
"""MQTT client implementation, providing an API for connecting to a broker and send/receive messages using the MQTT protocol.
MQTTClient instances provides API for connecting to a broker and send/receive
messages using the MQTT protocol.
Args:
client_id: MQTT client ID to use when connecting to the broker. If none,
it will be generated randomly by `amqtt.utils.gen_client_id`
config: dictionary of configuration options (see [client configuration](client_config.md)).
config: `ClientConfig` or dictionary of equivalent structure options (see [client configuration](client_config.md)).
Raises:
PluginError
PluginImportError: if importing a plugin from configuration fails
PluginInitError: if initialization plugin fails
"""
def __init__(self, client_id: str | None = None, config: dict[str, Any] | None = None) -> None:
def __init__(self, client_id: str | None = None, config: ClientConfig | dict[str, Any] | None = None) -> None:
self.logger = logging.getLogger(__name__)
self.config = copy.deepcopy(_defaults or {})
if config is not None:
self.config.update(config)
if isinstance(config, dict):
self.config = ClientConfig.from_dict(config)
else:
self.config = config or ClientConfig()
self.client_id = client_id if client_id is not None else gen_client_id()

@@ -155,3 +152,3 @@

Raises:
ClientError, ConnectError
ConnectError: could not connect to broker

@@ -169,3 +166,4 @@ """

raise ConnectError(msg) from e
except Exception as e:
# no matter the failure mode, still try to reconnect
except Exception as e: # pylint: disable=W0718
self.logger.warning(f"Connection failed: {e!r}")

@@ -244,3 +242,4 @@ if not self.config.get("auto_reconnect", False):

raise ConnectError(msg) from e
except Exception as e:
# no matter the failure mode, still try to reconnect
except Exception as e: # pylint: disable=W0718
self.logger.warning(f"Reconnection attempt failed: {e!r}")

@@ -393,2 +392,3 @@ self.logger.debug("", exc_info=True)

asyncio.TimeoutError: if timeout occurs before a message is delivered
ClientError: if client is not connected

@@ -437,10 +437,6 @@ """

self.session.username = (
self.session.username
if self.session.username
else (str(uri_attributes.username) if uri_attributes.username else None)
self.session.username or (str(uri_attributes.username) if uri_attributes.username else None)
)
self.session.password = (
self.session.password
if self.session.password
else (str(uri_attributes.password) if uri_attributes.password else None)
self.session.password or (str(uri_attributes.password) if uri_attributes.password else None)
)

@@ -476,11 +472,11 @@ self.session.remote_address = str(uri_attributes.hostname) if uri_attributes.hostname else None

ssl.Purpose.SERVER_AUTH,
cafile=self.session.cafile,
capath=self.session.capath,
cadata=self.session.cadata,
cafile=self.session.cafile
)
if "certfile" in self.config:
sc.load_verify_locations(cafile=self.config["certfile"])
if "check_hostname" in self.config and isinstance(self.config["check_hostname"], bool):
sc.check_hostname = self.config["check_hostname"]
if self.config.connection.certfile and self.config.connection.keyfile:
sc.load_cert_chain(certfile=self.config.connection.certfile, keyfile=self.config.connection.keyfile)
if self.config.connection.cafile:
sc.load_verify_locations(cafile=self.config.connection.cafile)
if self.config.check_hostname is not None:
sc.check_hostname = self.config.check_hostname
sc.verify_mode = ssl.CERT_REQUIRED

@@ -540,3 +536,3 @@ kwargs["ssl"] = sc

except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError) as e:
except (InvalidURI, InvalidHandshake, ProtocolHandlerError, ConnectionError, OSError, asyncio.TimeoutError) as e:
self.logger.debug(f"Connection failed : {self.session.broker_uri} [{e!r}]")

@@ -597,7 +593,15 @@ self.session.transitions.disconnect()

"""Initialize the MQTT session."""
broker_conf = self.config.get("broker", {}).copy()
broker_conf.update(
{k: v for k, v in {"uri": uri, "cafile": cafile, "capath": capath, "cadata": cadata}.items() if v is not None},
)
broker_conf = self.config.get("connection", {}).copy()
if uri is not None:
broker_conf.uri = uri
if cleansession is not None:
self.config.cleansession = cleansession
if cafile is not None:
broker_conf.cafile = cafile
if capath is not None:
broker_conf.capath = capath
if cadata is not None:
broker_conf.cadata = cadata
if not broker_conf.get("uri"):

@@ -610,2 +614,3 @@ msg = "Missing connection parameter 'uri'"

session.client_id = self.client_id
session.cafile = broker_conf.get("cafile")

@@ -615,7 +620,3 @@ session.capath = broker_conf.get("capath")

if cleansession is not None:
broker_conf["cleansession"] = cleansession # noop?
session.clean_session = cleansession
else:
session.clean_session = self.config.get("cleansession", True)
session.clean_session = self.config.get("cleansession", True)

@@ -622,0 +623,0 @@ session.keep_alive = self.config["keep_alive"] - self.config["ping_delay"]

@@ -145,6 +145,6 @@ import asyncio

def float_to_bytes_str(value: float, places:int=3) -> bytes:
def float_to_bytes_str(value: float, places: int = 3) -> bytes:
"""Convert an float value to a bytes array containing the numeric character."""
quant = Decimal(f"0.{''.join(['0' for i in range(places-1)])}1")
quant = Decimal(f"0.{''.join(['0' for i in range(places - 1)])}1")
rounded = Decimal(value).quantize(quant, rounding=ROUND_HALF_UP)
return str(rounded).encode("utf-8")

@@ -1,19 +0,36 @@

from enum import Enum
from dataclasses import dataclass, field, fields, replace
import logging
from typing import TYPE_CHECKING, Any
import warnings
_LOGGER = logging.getLogger(__name__)
try:
from enum import Enum, StrEnum
except ImportError:
# support for python 3.10
from enum import Enum
class StrEnum(str, Enum): # type: ignore[no-redef]
pass
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from dacite import Config as DaciteConfig, from_dict as dict_to_dataclass
from amqtt.mqtt.constants import QOS_0, QOS_2
if TYPE_CHECKING:
import asyncio
logger = logging.getLogger(__name__)
class BaseContext:
def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None
self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None
self.logger: logging.Logger = logging.getLogger(__name__)
# cleanup with a `Generic` type
self.config: ClientConfig | BrokerConfig | dict[str, Any] | None = None
class Action(Enum):
class Action(StrEnum):
"""Actions issued by the broker."""

@@ -23,1 +40,341 @@

PUBLISH = "publish"
RECEIVE = "receive"
class ListenerType(StrEnum):
"""Types of mqtt listeners."""
TCP = "tcp"
WS = "ws"
EXTERNAL = "external"
def __repr__(self) -> str:
"""Display the string value, instead of the enum member."""
return f'"{self.value!s}"'
class Dictable:
"""Add dictionary methods to a dataclass."""
def __getitem__(self, key: str) -> Any:
"""Allow dict-style `[]` access to a dataclass."""
return self.get(key)
def get(self, name: str, default: Any = None) -> Any:
"""Allow dict-style access to a dataclass."""
name = name.replace("-", "_")
if hasattr(self, name):
return getattr(self, name)
if default is not None:
return default
msg = f"'{name}' is not defined"
raise ValueError(msg)
def __contains__(self, name: str) -> bool:
"""Provide dict-style 'in' check."""
return getattr(self, name.replace("-", "_"), None) is not None
def __iter__(self) -> Iterator[Any]:
"""Provide dict-style iteration."""
for f in fields(self): # type: ignore[arg-type]
yield getattr(self, f.name)
def copy(self) -> dataclass: # type: ignore[valid-type]
"""Return a copy of the dataclass."""
return replace(self) # type: ignore[type-var]
@staticmethod
def _coerce_lists(value: list[Any] | dict[str, Any] | Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return value # It's already a list of dicts
if isinstance(value, dict):
return [value] # Promote single dict to a list
msg = "Could not convert 'list' to 'list[dict[str, Any]]'"
raise ValueError(msg)
@dataclass
class ListenerConfig(Dictable):
"""Structured configuration for a broker's listeners."""
type: ListenerType = ListenerType.TCP
"""Type of listener: `tcp` for 'mqtt' or `ws` for 'websocket' when specified in dictionary or yaml.'"""
bind: str | None = "0.0.0.0:1883"
"""address and port for the listener to bind to"""
max_connections: int = 0
"""max number of connections allowed for this listener"""
ssl: bool = False
"""secured by ssl"""
cafile: str | Path | None = None
"""Path to a file of concatenated CA certificates in PEM format. See
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
capath: str | Path | None = None
"""Path to a directory containing one or more CA certificates in PEM format, following the
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
cadata: str | Path | None = None
"""Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates."""
certfile: str | Path | None = None
"""Full path to file in PEM format containing the server's certificate (as well as any number of CA
certificates needed to establish the certificate's authenticity.)"""
keyfile: str | Path | None = None
"""Full path to file in PEM format containing the server's private key."""
reader: str | None = None
writer: str | None = None
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg)
for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
if getattr(self, fn) and not getattr(self, fn).exists():
msg = f"'{fn}' does not exist : {getattr(self, fn)}"
raise FileNotFoundError(msg)
def apply(self, other: "ListenerConfig") -> None:
"""Apply the field from 'other', if 'self' field is default."""
for f in fields(self):
if getattr(self, f.name) == f.default:
setattr(self, f.name, other[f.name])
def default_listeners() -> dict[str, Any]:
"""Create defaults for BrokerConfig.listeners."""
return {
"default": ListenerConfig()
}
def default_broker_plugins() -> dict[str, Any]:
"""Create defaults for BrokerConfig.plugins."""
return {
"amqtt.plugins.logging_amqtt.EventLoggerPlugin": {},
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {},
"amqtt.plugins.authentication.AnonymousAuthPlugin": {"allow_anonymous": True},
"amqtt.plugins.sys.broker.BrokerSysPlugin": {"sys_interval": 20}
}
@dataclass
class BrokerConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
listeners: dict[Literal["default"] | str, ListenerConfig] = field(default_factory=default_listeners) # noqa: PYI051
"""Network of listeners used by the services. a 'default' named listener is required; if another listener
does not set a value, the 'default' settings are applied. See
[`ListenerConfig`](broker_config.md#amqtt.contexts.ListenerConfig) for more information."""
sys_interval: int | None = None
"""*Deprecated field to configure the `BrokerSysPlugin`. See [`BrokerSysPlugin`](../plugins/packaged_plugins.md#sys-topics)
for recommended configuration.*"""
timeout_disconnect_delay: int | None = 0
"""Client disconnect timeout without a keep-alive."""
session_expiry_interval: int | None = None
"""Seconds for an inactive session to be retained."""
auth: dict[str, Any] | None = None
"""*Deprecated field used to config EntryPoint-loaded plugins. See
[`AnonymousAuthPlugin`](../plugins/packaged_plugins.md#anonymous-auth-plugin) and
[`FileAuthPlugin`](../plugins/packaged_plugins.md#password-file-auth-plugin) for recommended configuration.*"""
topic_check: dict[str, Any] | None = None
"""*Deprecated field used to config EntryPoint-loaded plugins. See
[`TopicTabooPlugin`](../plugins/packaged_plugins.md#taboo-topic-plugin) and
[`TopicACLPlugin`](../plugins/packaged_plugins.md#acl-topic-plugin) for recommended configuration method.*"""
plugins: dict[str, Any] | list[str | dict[str, Any]] | None = field(default_factory=default_broker_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`, `BaseAuthPlugin`
or `BaseTopicPlugin`; the value is a dictionary of configuration options for that plugin. See
[custom plugins](../plugins/custom_plugins.md) for more information. `list[str | dict[str,Any]]` is deprecated but available
to support legacy use cases."""
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.sys_interval is not None:
logger.warning("sys_interval is deprecated, use 'plugins' to define configuration")
if self.auth is not None or self.topic_check is not None:
logger.warning("'auth' and 'topic-check' are deprecated, use 'plugins' to define configuration")
default_listener = self.listeners["default"]
for listener_name, listener in self.listeners.items():
if listener_name == "default":
continue
listener.apply(default_listener)
if isinstance(self.plugins, list):
_plugins: dict[str, Any] = {}
for plugin in self.plugins:
# in case a plugin in a yaml file is listed without config map
if isinstance(plugin, str):
_plugins |= {plugin: {}}
continue
_plugins |= plugin
self.plugins = _plugins
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> "BrokerConfig":
"""Create a broker config from a dictionary."""
if d is None:
return BrokerConfig()
# patch the incoming dictionary so it can be loaded correctly
if "topic-check" in d:
d["topic_check"] = d["topic-check"]
del d["topic-check"]
# identify EntryPoint plugin loading and prevent 'plugins' from getting defaults
if ("auth" in d or "topic-check" in d) and "plugins" not in d:
d["plugins"] = None
return dict_to_dataclass(data_class=BrokerConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum, ListenerType],
strict=True,
type_hooks={list[dict[str, Any]]: cls._coerce_lists}
))
@dataclass
class ConnectionConfig(Dictable):
"""Properties for connecting to the broker."""
uri: str | None = "mqtt://127.0.0.1:1883"
"""URI of the broker"""
cafile: str | Path | None = None
"""Path to a file of concatenated CA certificates in PEM format to verify broker's authenticity. See
[Certificates](https://docs.python.org/3/library/ssl.html#ssl-certificates) for more info."""
capath: str | Path | None = None
"""Path to a directory containing one or more CA certificates in PEM format, following the
[OpenSSL-specific layout](https://docs.openssl.org/master/man3/SSL_CTX_load_verify_locations/)."""
cadata: str | None = None
"""The certificate to verify the broker's authenticity in an ASCII string format of one or more PEM-encoded
certificates or a bytes-like object of DER-encoded certificates."""
certfile: str | Path | None = None
"""Full path to file in PEM format containing the client's certificate (as well as any number of CA
certificates needed to establish the certificate's authenticity.)"""
keyfile: str | Path | None = None
"""Full path to file in PEM format containing the client's private key associated with the certfile."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if (self.certfile is None) ^ (self.keyfile is None):
msg = "If specifying the 'certfile' or 'keyfile', both are required."
raise ValueError(msg)
for fn in ("cafile", "capath", "certfile", "keyfile"):
if isinstance(getattr(self, fn), str):
setattr(self, fn, Path(getattr(self, fn)))
@dataclass
class TopicConfig(Dictable):
"""Configuration of how messages to specific topics are published.
The topic name is specified as the key in the dictionary of the `ClientConfig.topics.
"""
qos: int = 0
"""The quality of service associated with the publishing to this topic."""
retain: bool = False
"""Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Topic config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
@dataclass
class WillConfig(Dictable):
"""Configuration of the 'last will & testament' of the client upon improper disconnection."""
topic: str
"""The will message will be published to this topic."""
message: str
"""The contents of the message to be published."""
qos: int | None = QOS_0
"""The quality of service associated with sending this message."""
retain: bool | None = False
"""Determines if the message should be retained by the topic it was published."""
def __post__init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.qos is not None and (self.qos < QOS_0 or self.qos > QOS_2):
msg = "Will config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
def default_client_plugins() -> dict[str, Any]:
"""Create defaults for `ClientConfig.plugins`."""
return {
"amqtt.plugins.logging_amqtt.PacketLoggerPlugin": {}
}
@dataclass
class ClientConfig(Dictable):
"""Structured configuration for a broker. Can be passed directly to `amqtt.broker.Broker` or created from a dictionary."""
keep_alive: int | None = 10
"""Keep-alive timeout sent to the broker."""
ping_delay: int | None = 1
"""Auto-ping delay before keep-alive timeout. Setting to 0 will disable which may lead to broker disconnection."""
default_qos: int | None = QOS_0
"""Default QoS for messages published."""
default_retain: bool | None = False
"""Default retain value to messages published."""
auto_reconnect: bool | None = True
"""Enable or disable auto-reconnect if connection with the broker is interrupted."""
connection_timeout: int | None = 60
"""The number of seconds before a connection times out"""
reconnect_retries: int | None = 2
"""Number of reconnection retry attempts. Negative value will cause client to reconnect indefinitely."""
reconnect_max_interval: int | None = 10
"""Maximum seconds to wait before retrying a connection."""
cleansession: bool | None = True
"""Upon reconnect, should subscriptions be cleared. Can be overridden by `MQTTClient.connect`"""
topics: dict[str, TopicConfig] | None = field(default_factory=dict)
"""Specify the topics and what flags should be set for messages published to them."""
broker: ConnectionConfig | None = None
"""*Deprecated* Configuration for connecting to the broker. Use `connection` field instead."""
connection: ConnectionConfig = field(default_factory=ConnectionConfig)
"""Configuration for connecting to the broker. See
[`ConnectionConfig`](client_config.md#amqtt.contexts.ConnectionConfig) for more information."""
plugins: dict[str, Any] | list[dict[str, Any]] | None = field(default_factory=default_client_plugins)
"""The dictionary has a key of the dotted-module path of a class derived from `BasePlugin`; the value is
a dictionary of configuration options for that plugin. See [custom plugins](../plugins/custom_plugins.md) for
more information. `list[str | dict[str,Any]]` is deprecated but available to support legacy use cases."""
check_hostname: bool | None = True
"""If establishing a secure connection, should the hostname of the certificate be verified."""
will: WillConfig | None = None
"""Message, topic and flags that should be sent to if the client disconnects. See
[`WillConfig`](client_config.md#amqtt.contexts.WillConfig) for more information."""
def __post_init__(self) -> None:
"""Check config for errors and transform fields for easier use."""
if self.default_qos is not None and (self.default_qos < QOS_0 or self.default_qos > QOS_2):
msg = "Client config: default QoS must be 0, 1 or 2."
raise ValueError(msg)
if self.broker is not None:
warnings.warn("The 'broker' option is deprecated, please use 'connection' instead.", stacklevel=2)
self.connection = self.broker
if bool(not self.connection.keyfile) ^ bool(not self.connection.certfile):
msg = "Connection key and certificate files are _both_ required."
raise ValueError(msg)
@classmethod
def from_dict(cls, d: dict[str, Any] | None) -> "ClientConfig":
"""Create a client config from a dictionary."""
if d is None:
return ClientConfig()
return dict_to_dataclass(data_class=ClientConfig,
data=d,
config=DaciteConfig(
cast=[StrEnum],
strict=True)
)

@@ -19,2 +19,3 @@ from typing import Any

class ZeroLengthReadError(NoDataError):

@@ -24,2 +25,3 @@ def __init__(self) -> None:

class BrokerError(Exception):

@@ -26,0 +28,0 @@ """Exceptions thrown by broker."""

@@ -6,3 +6,3 @@ try:

from enum import Enum
class StrEnum(str, Enum): #type: ignore[no-redef]
class StrEnum(str, Enum): # type: ignore[no-redef]
pass

@@ -35,2 +35,4 @@

CLIENT_UNSUBSCRIBED = "broker_client_unsubscribed"
RETAINED_MESSAGE = "broker_retained_message"
MESSAGE_RECEIVED = "broker_message_received"
MESSAGE_BROADCAST = "broker_message_broadcast"

@@ -170,3 +170,3 @@ from abc import ABC, abstractmethod

class MQTTPayload(Generic[_VH], ABC):
class MQTTPayload(ABC, Generic[_VH]):
"""Abstract base class for MQTT payloads."""

@@ -173,0 +173,0 @@

@@ -34,2 +34,3 @@ import asyncio

class Subscription:

@@ -239,2 +240,3 @@ def __init__(self, packet_id: int, topics: list[tuple[str, int]]) -> None:

incoming_session.remote_port = remote_port
incoming_session.ssl_object = writer.get_ssl_info()

@@ -241,0 +243,0 @@ incoming_session.keep_alive = max(connect.keep_alive, 0)

@@ -22,2 +22,3 @@ import asyncio

class ClientProtocolHandler(ProtocolHandler["ClientContext"]):

@@ -24,0 +25,0 @@ def __init__(

@@ -7,9 +7,9 @@ import asyncio

# Fallback for Python < 3.12
class InvalidStateError(Exception): # type: ignore[no-redef]
class InvalidStateError(Exception): # type: ignore[no-redef]
pass
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
class QueueFull(Exception): # type: ignore[no-redef] # noqa : N818
pass
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
class QueueShutDown(Exception): # type: ignore[no-redef] # noqa : N818
pass

@@ -67,2 +67,3 @@

class ProtocolHandler(Generic[C]):

@@ -204,3 +205,3 @@ """Class implementing the MQTT communication protocol using asyncio features."""

topic: str,
data: bytes | bytearray ,
data: bytes | bytearray,
qos: int | None,

@@ -541,3 +542,3 @@ retain: bool,

self.logger.debug(f"{self.session.client_id} No data available")
except Exception as e: # noqa: BLE001
except Exception as e: # noqa: BLE001, pylint: disable=W0718
self.logger.warning(f"{type(self).__name__} Unhandled exception in reader coro: {e!r}")

@@ -544,0 +545,0 @@ break

"""INIT."""
import re
from typing import Any, Optional
class TopicMatcher:
_instance: Optional["TopicMatcher"] = None
def __init__(self) -> None:
if not hasattr(self, "_topic_filter_matchers"):
self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]) -> "TopicMatcher":
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def is_topic_allowed(self, topic: str, a_filter: str) -> bool:
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
return False
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
# else use regex (re.compile is an expensive operation, store the matcher for future use)
if a_filter not in self._topic_filter_matchers:
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
.replace("\\#", "?.*")
.replace("\\+", "[^/]*")
.lstrip("?"))
match_pattern = self._topic_filter_matchers[a_filter]
return bool(match_pattern.fullmatch(topic))
def are_topics_allowed(self, topic: str, many_filters: list[str]) -> bool:
return any(self.is_topic_allowed(topic, a_filter) for a_filter in many_filters)

@@ -40,5 +40,6 @@ from dataclasses import dataclass, field

class Config:
"""Allow empty username."""
"""Configuration for AnonymousAuthPlugin."""
allow_anonymous: bool = field(default=True)
"""Allow all anonymous authentication (even with _no_ username)."""

@@ -82,3 +83,3 @@

self.context.logger.exception(f"Malformed password file '{password_file}'")
except Exception:
except OSError:
self.context.logger.exception(f"Unexpected error reading password file '{password_file}'")

@@ -112,4 +113,5 @@

class Config:
"""Path to the properly encoded password file."""
"""Configuration for FileAuthPlugin."""
password_file: str | Path | None = None
"""Path to file with `username:password` pairs, one per line. All passwords are encoded using sha-512."""
from dataclasses import dataclass, is_dataclass
from typing import Any, Generic, TypeVar, cast
from amqtt.contexts import Action, BaseContext
from amqtt.contexts import Action, BaseContext, BrokerConfig
from amqtt.session import Session

@@ -49,3 +49,3 @@

# Deprecated : supports entrypoint-style configs as well as dataclass configuration.
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:

@@ -56,3 +56,3 @@ return default

# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if option_name in self.context.config:

@@ -80,9 +80,10 @@ return self.context.config[option_name]

def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
# overloaded context.config with either BrokerConfig or plugin's Config
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.topic_config and option_name in self.topic_config:

@@ -94,3 +95,3 @@ return self.topic_config[option_name]

self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
) -> bool | None:
"""Logic for filtering out topics.

@@ -104,3 +105,3 @@

Returns:
bool: `True` if topic is allowed, `False` otherwise
bool: `True` if topic is allowed, `False` otherwise. `None` if it can't be determined

@@ -114,9 +115,9 @@ """

def _get_config_option(self, option_name: str, default: Any=None) -> Any:
def _get_config_option(self, option_name: str, default: Any = None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
if is_dataclass(self.context.config) and not isinstance(self.context.config, BrokerConfig):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
return getattr(self.context.config, option_name.replace("-", "_"), default)
if self.auth_config and option_name in self.auth_config:

@@ -134,3 +135,2 @@ return self.auth_config[option_name]

async def authenticate(self, *, session: Session) -> bool | None:

@@ -137,0 +137,0 @@ """Logic for session authentication.

@@ -11,2 +11,4 @@ __all__ = ["PluginManager", "get_plugin_manager"]

import logging
import sys
import traceback
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast

@@ -53,2 +55,3 @@ import warnings

class PluginManager(Generic[C]):

@@ -100,6 +103,5 @@ """Wraps contextlib Entry point mechanism to provide a basic plugin system.

if "auth" in self.app_context.config:
if "auth" in self.app_context.config and self.app_context.config["auth"] is not None:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config:
if "topic-check" in self.app_context.config and self.app_context.config["topic-check"] is not None:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")

@@ -136,3 +138,3 @@

DeprecationWarning,
stacklevel=2
stacklevel=4
)

@@ -152,3 +154,3 @@

def _load_ep_plugins(self, namespace:str) -> None:
def _load_ep_plugins(self, namespace: str) -> None:
"""Load plugins from `pyproject.toml` entrypoints. Deprecated."""

@@ -230,3 +232,3 @@ self.logger.debug(f"Loading plugins for namespace {namespace}")

try:
plugin_class: Any = import_string(plugin_path)
plugin_class: Any = import_string(plugin_path)
except ImportError as ep:

@@ -300,2 +302,11 @@ msg = f"Plugin import failed: {plugin_path}"

def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
if self.logger.getEffectiveLevel() <= logging.DEBUG:
try:
future.result()
except asyncio.CancelledError:
self.logger.warning("fired event was cancelled")
# display plugin fault; don't allow it to cause a broker failure
except Exception as exc: # noqa: BLE001, pylint: disable=W0718
traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
with contextlib.suppress(KeyError, ValueError):

@@ -376,3 +387,3 @@ self._fired_events.remove(future)

return await self._map_plugin_method(
self._auth_plugins, "authenticate", {"session": session }) # type: ignore[arg-type]
self._auth_plugins, "authenticate", {"session": session}) # type: ignore[arg-type]

@@ -379,0 +390,0 @@ async def map_plugin_topic(

@@ -1,85 +0,11 @@

import json
import sqlite3
from typing import Any
import warnings
from amqtt.contexts import BaseContext
from amqtt.session import Session
from amqtt.broker import BrokerContext
from amqtt.plugins.base import BasePlugin
class SQLitePlugin:
def __init__(self, context: BaseContext) -> None:
self.context: BaseContext = context
self.conn: sqlite3.Connection | None = None
self.cursor: sqlite3.Cursor | None = None
self.db_file: str | None = None
self.persistence_config: dict[str, Any]
class SQLitePlugin(BasePlugin[BrokerContext]):
if (
persistence_config := self.context.config.get("persistence") if self.context.config is not None else None
) is not None:
self.persistence_config = persistence_config
self.init_db()
else:
self.context.logger.warning("'persistence' section not found in context configuration")
def init_db(self) -> None:
self.db_file = self.persistence_config.get("file")
if not self.db_file:
self.context.logger.warning("'file' persistence parameter not found")
else:
try:
self.conn = sqlite3.connect(self.db_file)
self.cursor = self.conn.cursor()
self.context.logger.info(f"Database file '{self.db_file}' opened")
except Exception:
self.context.logger.exception(f"Error while initializing database '{self.db_file}'")
if self.cursor:
self.cursor.execute(
"CREATE TABLE IF NOT EXISTS session(client_id TEXT PRIMARY KEY, data BLOB)",
)
self.cursor.execute("PRAGMA table_info(session)")
columns = {col[1] for col in self.cursor.fetchall()}
required_columns = {"client_id", "data"}
if not required_columns.issubset(columns):
self.context.logger.error("Database schema for 'session' table is incompatible.")
async def save_session(self, session: Session) -> None:
if self.cursor and self.conn:
dump: str = json.dumps(session, default=str)
try:
self.cursor.execute(
"INSERT OR REPLACE INTO session (client_id, data) VALUES (?, ?)",
(session.client_id, dump),
)
self.conn.commit()
except Exception:
self.context.logger.exception(f"Failed saving session '{session}'")
async def find_session(self, client_id: str) -> Session | None:
if self.cursor:
row = self.cursor.execute(
"SELECT data FROM session where client_id=?",
(client_id,),
).fetchone()
return json.loads(row[0]) if row else None
return None
async def del_session(self, client_id: str) -> None:
if self.cursor and self.conn:
try:
exists = self.cursor.execute("SELECT 1 FROM session WHERE client_id=?", (client_id,)).fetchone()
if exists:
self.cursor.execute("DELETE FROM session where client_id=?", (client_id,))
self.conn.commit()
except Exception:
self.context.logger.exception(f"Failed deleting session with client_id '{client_id}'")
async def on_broker_post_shutdown(self) -> None:
if self.conn:
try:
self.conn.close()
self.context.logger.info(f"Database file '{self.db_file}' closed")
except Exception:
self.context.logger.exception("Error closing database connection")
finally:
self.conn = None
def __init__(self, context: BrokerContext) -> None:
super().__init__(context)
warnings.warn("SQLitePlugin is deprecated, use amqtt.contrib.persistence.SessionDBPlugin", stacklevel=1)

@@ -17,3 +17,3 @@ import asyncio

@runtime_checkable
class Buffer(Protocol): # type: ignore[no-redef]
class Buffer(Protocol): # type: ignore[no-redef]
def __buffer__(self, flags: int = ...) -> memoryview:

@@ -79,3 +79,2 @@ """Mimic the behavior of `collections.abc.Buffer` for python 3.10-3.12."""

def _clear_stats(self) -> None:

@@ -117,3 +116,3 @@ """Initialize broker statistics data structures."""

version = f"aMQTT version {amqtt.__version__}"
self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())
await self.context.retain_message(DOLLAR_SYS_ROOT + "version", version.encode())

@@ -120,0 +119,0 @@ # Start $SYS topics management

from dataclasses import dataclass, field
from typing import Any
import warnings
from amqtt.contexts import Action, BaseContext
from amqtt.errors import PluginInitError
from amqtt.plugins.base import BaseTopicPlugin

@@ -16,3 +18,3 @@ from amqtt.session import Session

self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
) -> bool | None:
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)

@@ -23,3 +25,3 @@ if filter_result:

return not (topic and topic in self._taboo)
return filter_result
return bool(filter_result)

@@ -29,2 +31,12 @@

def __init__(self, context: BaseContext) -> None:
super().__init__(context)
if self._get_config_option("acl", None):
warnings.warn("The 'acl' option is deprecated, please use 'subscribe-acl' instead.", stacklevel=1)
if self._get_config_option("acl", None) and self._get_config_option("subscribe-acl", None):
msg = "'acl' has been replaced with 'subscribe-acl'; only one may be included"
raise PluginInitError(msg)
@staticmethod

@@ -52,3 +64,3 @@ def topic_ac(topic_requested: str, topic_allowed: str) -> bool:

self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None
) -> bool:
) -> bool | None:
filter_result = await super().topic_filtering(session=session, topic=topic, action=action)

@@ -65,3 +77,3 @@ if not filter_result:

if not req_topic:
return False\
return False

@@ -72,9 +84,17 @@ username = session.username if session else None

acl: dict[str, Any] = {}
acl: dict[str, Any] | None = None
match action:
case Action.PUBLISH:
acl = self._get_config_option("publish-acl", {})
acl = self._get_config_option("publish-acl", None)
case Action.SUBSCRIBE:
acl = self._get_config_option("acl", {})
acl = self._get_config_option("subscribe-acl", self._get_config_option("acl", None))
case Action.RECEIVE:
acl = self._get_config_option("receive-acl", None)
case _:
msg = "Received an invalid action type."
raise ValueError(msg)
if acl is None:
return True
allowed_topics = acl.get(username, [])

@@ -81,0 +101,0 @@ if not allowed_topics:

@@ -24,3 +24,3 @@ import asyncio

def _version(v:bool) -> None:
def _version(v: bool) -> None:
if v:

@@ -45,5 +45,8 @@ typer.echo(f"{amqtt_version}")

formatter = "[%(asctime)s] :: %(levelname)s - %(message)s"
if debug:
formatter = "[%(asctime)s] %(name)s:%(lineno)d :: %(levelname)s - %(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(level=level, format=formatter)
logging.getLogger("transitions").setLevel(logging.WARNING)
try:

@@ -67,3 +70,3 @@ if config_file:

_ = loop.create_task(broker.start()) #noqa : RUF006
_ = loop.create_task(broker.start()) # noqa : RUF006
try:

@@ -70,0 +73,0 @@ loop.run_forever()

@@ -10,5 +10,5 @@ ---

reconnect_retries: 2
broker:
connection:
uri: "mqtt://127.0.0.1"
plugins:
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:

@@ -55,3 +55,3 @@ import asyncio

yield line.encode(encoding="utf-8")
except Exception:
except (FileNotFoundError, OSError):
logger.exception(f"Failed to read file '{self.file}'")

@@ -122,2 +122,3 @@ if self.lines:

app = typer.Typer(add_completion=False, rich_markup_mode=None)

@@ -136,4 +137,5 @@

@app.command()
def publisher_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def publisher_main( # pylint: disable=R0914,R0917
url: str | None = typer.Option(None, "--url", help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*"),

@@ -161,3 +163,3 @@ config_file: str | None = typer.Option(None, "-c", "--config-file", help="Client configuration file"),

debug: bool = typer.Option(False, "-d", help="Enable debug messages"),
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
version: bool = typer.Option(False, "--version", callback=_version, is_eager=True, help="Show version and exit"), # noqa : ARG001
) -> None:

@@ -164,0 +166,0 @@ """Command-line MQTT client for publishing simple messages."""

@@ -103,3 +103,3 @@ import asyncio

def _version(v:bool) -> None:
def _version(v: bool) -> None:
if v:

@@ -111,6 +111,6 @@ typer.echo(f"{amqtt_version}")

@app.command()
def subscribe_main( # pylint: disable=R0914,R0917 # noqa : PLR0913
def subscribe_main( # pylint: disable=R0914,R0917
url: str = typer.Option(None, help="Broker connection URL, *must conform to MQTT or URI scheme: `[mqtt(s)|ws(s)]://<username:password>@HOST:port`*", show_default=False),
config_file: str | None = typer.Option(None, "-c", help="Client configuration file"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
client_id: str | None = typer.Option(None, "-i", "--client-id", help="client identification for mqtt connection. *default: process id and the hostname of the client*"), max_count: int | None = typer.Option(None, "-n", help="Number of messages to read before ending *default: read indefinitely*"),
qos: int = typer.Option(0, "--qos", "-q", help="Quality of service (0, 1, or 2)"),

@@ -117,0 +117,0 @@ topics: list[str] = typer.Option(..., "-t", help="Topic filter to subscribe, can be used multiple times."), # noqa: B008

from asyncio import Queue
from collections import OrderedDict
from typing import Any, ClassVar
import logging
from math import floor
import time
from typing import TYPE_CHECKING, Any, ClassVar

@@ -13,3 +16,8 @@ from transitions import Machine

if TYPE_CHECKING:
import ssl
logger = logging.getLogger(__name__)
class ApplicationMessage:

@@ -142,2 +150,5 @@ """ApplicationMessage and subclasses are used to store published message information flow.

self.parent: int = 0
self.last_connect_time: int | None = None
self.ssl_object: ssl.SSLObject | None = None
self.last_disconnect_time: int | None = None

@@ -166,2 +177,3 @@ # Used to store outgoing ApplicationMessage while publish protocol flows

)
self.transitions.on_enter_connected(self._on_enter_connected)
self.transitions.add_transition(

@@ -177,2 +189,3 @@ trigger="connect",

)
self.transitions.on_enter_disconnected(self._on_enter_disconnected)
self.transitions.add_transition(

@@ -189,2 +202,16 @@ trigger="disconnect",

def _on_enter_connected(self) -> None:
cur_time = floor(time.time())
if self.last_disconnect_time is not None:
logger.debug(f"Session reconnected after {cur_time - self.last_disconnect_time} seconds.")
self.last_connect_time = cur_time
self.last_disconnect_time = None
def _on_enter_disconnected(self) -> None:
cur_time = floor(time.time())
if self.last_connect_time is not None:
logger.debug(f"Session disconnected after {cur_time - self.last_connect_time} seconds.")
self.last_disconnect_time = cur_time
@property

@@ -191,0 +218,0 @@ def next_packet_id(self) -> int:

Metadata-Version: 2.4
Name: amqtt
Version: 0.11.2
Version: 0.11.3
Summary: Python's asyncio-native MQTT broker and client.

@@ -29,2 +29,14 @@ Author: aMQTT Contributors

Requires-Dist: coveralls==4.0.1; extra == 'ci'
Provides-Extra: contrib
Requires-Dist: aiohttp>=3.12.13; extra == 'contrib'
Requires-Dist: aiosqlite>=0.21.0; extra == 'contrib'
Requires-Dist: argon2-cffi>=25.1.0; extra == 'contrib'
Requires-Dist: cryptography>=45.0.3; extra == 'contrib'
Requires-Dist: greenlet>=3.2.3; extra == 'contrib'
Requires-Dist: jsonschema>=4.25.0; extra == 'contrib'
Requires-Dist: mergedeep>=1.3.4; extra == 'contrib'
Requires-Dist: pyjwt>=2.10.1; extra == 'contrib'
Requires-Dist: pyopenssl>=25.1.0; extra == 'contrib'
Requires-Dist: python-ldap>=3.4.4; extra == 'contrib'
Requires-Dist: sqlalchemy[asyncio]>=2.0.41; extra == 'contrib'
Description-Content-Type: text/markdown

@@ -48,7 +60,18 @@

- Full set of [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) protocol specifications
- Communication over TCP and/or websocket, including support for SSL/TLS
- Communication over multiple TCP and/or websocket ports, including support for SSL/TLS
- Support QoS 0, QoS 1 and QoS 2 messages flow
- Client auto-reconnection on network lost
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
- Plugin framework for functionality expansion; included plugins:
- `$SYS` topic publishing
- AWS IOT-style shadow states
- x509 certificate authentication (including cli cert creation)
- Secure file-based password authentication
- Configuration-based topic authorization
- MySQL, Postgres & SQLite user and/or topic auth (including cli manager)
- External server (HTTP) user and/or topic auth
- LDAP user and/or topic auth
- JWT user and/or topic auth
- Fail over session persistence
## Installation

@@ -55,0 +78,0 @@

@@ -23,3 +23,3 @@ [build-system]

version = "0.11.2"
version = "0.11.3"
requires-python = ">=3.10.0"

@@ -38,3 +38,3 @@ readme = "README.md"

"dacite>=1.9.2",
"psutil>=7.0.0",
"psutil>=7.0.0"
]

@@ -44,2 +44,4 @@

dev = [
"aiosqlite>=0.21.0",
"greenlet>=3.2.3",
"hatch>=1.14.1",

@@ -52,5 +54,8 @@ "hypothesis>=6.130.8",

"psutil>=7.0.0", # https://pypi.org/project/psutil
"pyhamcrest>=2.1.0",
"pylint>=3.3.6", # https://pypi.org/project/pylint
"pyopenssl>=25.1.0",
"pytest-asyncio>=0.26.0", # https://pypi.org/project/pytest-asyncio
"pytest-cov>=6.1.0", # https://pypi.org/project/pytest-cov
"pytest-docker>=3.2.3",
"pytest-logdog>=0.1.0", # https://pypi.org/project/pytest-logdog

@@ -61,2 +66,3 @@ "pytest-timeout>=2.3.1", # https://pypi.org/project/pytest-timeout

"setuptools>=78.1.0",
"sqlalchemy[mypy]>=2.0.41",
"types-mock>=5.2.0.20250306", # https://pypi.org/project/types-mock

@@ -68,2 +74,3 @@ "types-PyYAML>=6.0.12.20250402", # https://pypi.org/project/types-PyYAML

docs = [
"griffe>=1.11.1",
"markdown-callouts>=0.4",

@@ -89,5 +96,17 @@ "markdown-exec>=1.8",

ci = ["coveralls==4.0.1"]
contrib = [
"cryptography>=45.0.3",
"aiosqlite>=0.21.0",
"greenlet>=3.2.3",
"sqlalchemy[asyncio]>=2.0.41",
"argon2-cffi>=25.1.0",
"aiohttp>=3.12.13",
"pyjwt>=2.10.1",
"python-ldap>=3.4.4",
"mergedeep>=1.3.4",
"jsonschema>=4.25.0",
"pyopenssl>=25.1.0"
]
[project.scripts]

@@ -97,3 +116,9 @@ amqtt = "amqtt.scripts.broker_script:main"

amqtt_sub = "amqtt.scripts.sub_script:main"
ca_creds = "amqtt.scripts.ca_creds:main"
server_creds = "amqtt.scripts.server_creds:main"
device_creds = "amqtt.scripts.device_creds:main"
user_mgr = "amqtt.scripts.manage_users:main"
topic_mgr = "amqtt.scripts.manage_topics:main"
[tool.hatch.build]

@@ -152,26 +177,36 @@ exclude = [

[tool.ruff.lint]
preview = true
select = ["ALL"]
extend-select = [
"UP", # pyupgrade
"D", # pydocstyle
"UP", # pyupgrade
"D", # pydocstyle,
]
ignore = [
"FBT001", # Checks for the use of boolean positional arguments in function definitions.
"FBT002", # Checks for the use of boolean positional arguments in function definitions.
"G004", # Logging statement uses f-string
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D107", # Missing docstring in `__init__`
"D203", # Incorrect blank line before class (mutually exclusive D211)
"D213", # Multi-line summary second line (mutually exclusive D212)
"FIX002", # Checks for "TODO" comments.
"TD002", # TODO Missing author.
"TD003", # TODO Missing issue link for this TODO.
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG002", # Unused method argument
"PERF203",# try-except penalty within loops (3.10 only),
"COM812" # rule causes conflicts when used with the formatter
"FBT001", # Checks for the use of boolean positional arguments in function definitions.
"FBT002", # Checks for the use of boolean positional arguments in function definitions.
"G004", # Logging statement uses f-string
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D107", # Missing docstring in `__init__`
"D203", # Incorrect blank line before class (mutually exclusive D211)
"D213", # Multi-line summary second line (mutually exclusive D212)
"FIX002", # Checks for "TODO" comments.
"TD002", # TODO Missing author.
"TD003", # TODO Missing issue link for this TODO.
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG002", # Unused method argument
"PERF203",# try-except penalty within loops (3.10 only),
"COM812", # rule causes conflicts when used with the formatter,
# ignore certain preview rules
"DOC",
"PLW",
"PLR",
"CPY",
"PLC",
"RUF052",
"B903"
]

@@ -209,3 +244,3 @@

asyncio_mode = "auto"
timeout = 10
timeout = 15
asyncio_default_fixture_loop_scope = "function"

@@ -244,2 +279,3 @@ #addopts = ["--tb=short", "--capture=tee-sys"]

max-line-length = 130
ignore-paths = ["amqtt/plugins/persistence.py"]

@@ -255,3 +291,2 @@ [tool.pylint.BASIC]

disable = [
"broad-exception-caught", # TODO: improve later
"duplicate-code",

@@ -258,0 +293,0 @@ "fixme",

@@ -17,7 +17,18 @@ [![MIT licensed](https://img.shields.io/github/license/Yakifo/amqtt?style=plastic)](https://amqtt.readthedocs.io/en/latest/)

- Full set of [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) protocol specifications
- Communication over TCP and/or websocket, including support for SSL/TLS
- Communication over multiple TCP and/or websocket ports, including support for SSL/TLS
- Support QoS 0, QoS 1 and QoS 2 messages flow
- Client auto-reconnection on network lost
- Functionality expansion; plugins included: authentication and `$SYS` topic publishing
- Plugin framework for functionality expansion; included plugins:
- `$SYS` topic publishing
- AWS IOT-style shadow states
- x509 certificate authentication (including cli cert creation)
- Secure file-based password authentication
- Configuration-based topic authorization
- MySQL, Postgres & SQLite user and/or topic auth (including cli manager)
- External server (HTTP) user and/or topic auth
- LDAP user and/or topic auth
- JWT user and/or topic auth
- Fail over session persistence
## Installation

@@ -24,0 +35,0 @@