You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

amqtt

Package Overview
Dependencies
Maintainers
2
Versions
13
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

amqtt - pypi Package Compare versions

Comparing version
0.11.1
to
0.11.2
+22
amqtt/contexts.py
from enum import Enum
import logging
from typing import TYPE_CHECKING, Any
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
import asyncio
class BaseContext:
def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None
self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None
class Action(Enum):
"""Actions issued by the broker."""
SUBSCRIBE = "subscribe"
PUBLISH = "publish"
+5
-0

@@ -37,1 +37,6 @@ #------- Package & Cache Files -------

coverage.xml
#----- generated files -----
*.log
*memray*
.coverage*
+1
-1
"""INIT."""
__version__ = "0.11.1"
__version__ = "0.11.2"

@@ -6,3 +6,2 @@ import asyncio

import copy
from enum import Enum
from functools import partial

@@ -27,2 +26,3 @@ import logging

)
from amqtt.contexts import Action, BaseContext
from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError

@@ -34,4 +34,5 @@ from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler

from .events import BrokerEvents
from .mqtt.constants import QOS_0, QOS_1, QOS_2
from .mqtt.disconnect import DisconnectPacket
from .plugins.manager import BaseContext, PluginManager
from .plugins.manager import PluginManager

@@ -50,9 +51,2 @@ _CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]]

class Action(Enum):
"""Actions issued by the broker."""
SUBSCRIBE = "subscribe"
PUBLISH = "publish"
class RetainedApplicationMessage(ApplicationMessage):

@@ -171,2 +165,6 @@ __slots__ = ("data", "qos", "source_session", "topic")

if config is not None:
# if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config
if ("auth" in config or "topic-check" in config) and "plugins" not in config:
# set to None so that the config isn't updated with the new-style default plugin list
config["plugins"] = None # type: ignore[assignment]
self.config.update(config)

@@ -182,2 +180,4 @@ self._build_listeners_config(self.config)

self._topic_filter_matchers: dict[str, re.Pattern[str]] = {}
# Broadcast queue for outgoing messages

@@ -447,2 +447,3 @@ self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()

client_session.client_id = gen_client_id()
client_session.parent = 0

@@ -501,9 +502,20 @@ # Get session from cache

await handler.mqtt_connack_authorize(authenticated)
await self.plugins_manager.fire_event(BrokerEvents.CLIENT_CONNECTED, client_id=client_session.client_id)
await self.plugins_manager.fire_event(BrokerEvents.CLIENT_CONNECTED,
client_id=client_session.client_id,
client_session=client_session)
self.logger.debug(f"{client_session.client_id} Start messages handling")
await handler.start()
# publish messages that were retained because the client session was disconnected
self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}")
await self._publish_session_retained_messages(client_session)
# if this is not a new session, there are subscriptions associated with them; publish any topic retained messages
self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.")
for topic in self._subscriptions:
await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session)
await self._client_message_loop(client_session, handler)

@@ -605,3 +617,5 @@

client_session.transitions.disconnect()
await self.plugins_manager.fire_event(BrokerEvents.CLIENT_DISCONNECTED, client_id=client_session.client_id)
await self.plugins_manager.fire_event(BrokerEvents.CLIENT_DISCONNECTED,
client_id=client_session.client_id,
client_session=client_session)

@@ -671,2 +685,7 @@

return False
if app_message.topic.startswith("$"):
self.logger.warning(
f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character."
)
return False

@@ -712,14 +731,17 @@ permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH)

returns = await self.plugins_manager.map_plugin_auth(session=session)
auth_result = True
if returns:
for plugin in returns:
res = returns[plugin]
if res is False:
auth_result = False
self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}")
else:
self.logger.debug(f"'{plugin.__class__}' plugin result: {res}")
# If all plugins returned True, authentication is success
return auth_result
results = [ result for _, result in returns.items() if result is not None] if returns else []
if len(results) < 1:
self.logger.debug("Authentication failed: no plugin responded with a boolean")
return False
if all(results):
self.logger.debug("Authentication succeeded")
return True
for plugin, result in returns.items():
self.logger.debug(f"Authentication '{plugin.__class__.__name__}' result: {result}")
return False
def retain_message(

@@ -780,9 +802,3 @@ self,

"""
topic_config = self.config.get("topic-check", {})
enabled = False
if isinstance(topic_config, dict):
enabled = topic_config.get("enabled", False)
if not enabled:
if not self.plugins_manager.is_topic_filtering_enabled():
return True

@@ -882,5 +898,2 @@

for k_filter, subscriptions in self._subscriptions.items():
if broadcast["topic"].startswith("$") and (k_filter.startswith(("+", "#"))):
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
continue

@@ -896,3 +909,8 @@ # Skip all subscriptions which do not match the topic

# Retain all messages which cannot be broadcasted, due to the session not being connected
if target_session.transitions.state != "connected":
# but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4]
# and, if a client used anonymous authentication, there is no expectation that messages should be retained
if (target_session.transitions.state != "connected"
and not target_session.clean_session
and qos in (QOS_1, QOS_2)
and not target_session.is_anonymous):
self.logger.debug(f"Session {target_session.client_id} is not connected, retaining message.")

@@ -902,2 +920,6 @@ await self._retain_broadcast_message(broadcast, qos, target_session)

# Only broadcast the message to connected clients
if target_session.transitions.state != "connected":
continue
self.logger.debug(

@@ -1003,7 +1025,17 @@ f"Broadcasting message from {format_client_message(session=broadcast['session'])}"

def _matches(self, topic: str, a_filter: str) -> bool:
if topic.startswith("$") and (a_filter.startswith(("+", "#"))):
self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #")
return False
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
# else use regex
match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?"))
# else use regex (re.compile is an expensive operation, store the matcher for future use)
if a_filter not in self._topic_filter_matchers:
self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter)
.replace("\\#", "?.*")
.replace("\\+", "[^/]*")
.lstrip("?"))
match_pattern = self._topic_filter_matchers[a_filter]
return bool(match_pattern.fullmatch(topic))

@@ -1010,0 +1042,0 @@

@@ -22,2 +22,3 @@ import asyncio

)
from amqtt.contexts import BaseContext
from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError

@@ -27,3 +28,3 @@ from amqtt.mqtt.connack import CONNECTION_ACCEPTED

from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.plugins.manager import PluginManager
from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session

@@ -602,3 +603,3 @@ from amqtt.utils import gen_client_id, read_yaml_config

if cleansession is not None:
broker_conf["cleansession"] = cleansession
broker_conf["cleansession"] = cleansession # noop?
session.clean_session = cleansession

@@ -605,0 +606,0 @@ else:

@@ -32,7 +32,12 @@ from typing import Any

class PluginImportError(PluginError):
def __init__(self, plugin: Any) -> None:
super().__init__(f"Plugin import failed: {plugin!r}")
"""Exceptions thrown when loading plugin."""
class PluginCoroError(PluginError):
"""Exceptions thrown when loading a plugin with a non-async call method."""
class PluginInitError(PluginError):
"""Exceptions thrown when initializing plugin."""
def __init__(self, plugin: Any) -> None:

@@ -39,0 +44,0 @@ super().__init__(f"Plugin init failed: {plugin!r}")

@@ -195,3 +195,3 @@ from asyncio import StreamReader

payload.client_id = gen_client_id()
# indicator to trow exception in case CLEAN_SESSION_FLAG is set to False
# indicator to throw exception in case CLEAN_SESSION_FLAG is set to False
payload.client_id_is_random = True

@@ -198,0 +198,0 @@

@@ -23,2 +23,3 @@ import asyncio

from amqtt.adapters import ReaderAdapter, WriterAdapter
from amqtt.contexts import BaseContext
from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError

@@ -61,3 +62,3 @@ from amqtt.events import MQTTEvents

from amqtt.mqtt.unsubscribe import UnsubscribePacket
from amqtt.plugins.manager import BaseContext, PluginManager
from amqtt.plugins.manager import PluginManager
from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session

@@ -524,4 +525,5 @@

elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket):
# TODO: why is this not like all other inside create_task?
await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet))
# q: why is this not like all other inside a create_task?
# a: the connection needs to be established before any other packet tasks for this new session are scheduled
await self.handle_connect(packet)
if task:

@@ -528,0 +530,0 @@ running_tasks.append(task)

@@ -15,3 +15,3 @@ import asyncio

super().__init__()
if "*" in topic_name:
if "#" in topic_name or "+" in topic_name:
msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters."

@@ -18,0 +18,0 @@ raise MQTTError(msg)

@@ -0,1 +1,2 @@

from dataclasses import dataclass, field
from pathlib import Path

@@ -6,2 +7,3 @@

from amqtt.broker import BrokerContext
from amqtt.contexts import BaseContext
from amqtt.plugins.base import BaseAuthPlugin

@@ -16,9 +18,15 @@ from amqtt.session import Session

def __init__(self, context: BaseContext) -> None:
super().__init__(context)
# Default to allowing anonymous
self._allow_anonymous = self._get_config_option("allow-anonymous", True) # noqa: FBT003
async def authenticate(self, *, session: Session) -> bool:
authenticated = await super().authenticate(session=session)
if authenticated:
# Default to allowing anonymous
allow_anonymous = self.auth_config.get("allow-anonymous", True) if isinstance(self.auth_config, dict) else True
if allow_anonymous:
if self._allow_anonymous:
self.context.logger.debug("Authentication success: config allows anonymous")
session.is_anonymous = True
return True

@@ -32,3 +40,9 @@

@dataclass
class Config:
"""Allow empty username."""
allow_anonymous: bool = field(default=True)
class FileAuthPlugin(BaseAuthPlugin):

@@ -44,3 +58,3 @@ """Authentication plugin based on a file-stored user database."""

"""Read the password file and populates the user dictionary."""
password_file = self.auth_config.get("password-file") if isinstance(self.auth_config, dict) else None
password_file = self._get_config_option("password-file", None)
if not password_file:

@@ -51,3 +65,6 @@ self.context.logger.warning("Configuration parameter 'password-file' not found")

try:
with Path(password_file).open(mode="r", encoding="utf-8") as file:
file = password_file
if isinstance(file, str):
file = Path(file)
with file.open(mode="r", encoding="utf-8") as file:
self.context.logger.debug(f"Reading user database from {password_file}")

@@ -95,1 +112,7 @@ for _line in file:

return False
@dataclass
class Config:
"""Path to the properly encoded password file."""
password_file: str | Path | None = None

@@ -1,5 +0,5 @@

from typing import Any, Generic, TypeVar
from dataclasses import dataclass, is_dataclass
from typing import Any, Generic, TypeVar, cast
from amqtt.broker import Action
from amqtt.plugins.manager import BaseContext
from amqtt.contexts import Action, BaseContext
from amqtt.session import Session

@@ -11,7 +11,28 @@

class BasePlugin(Generic[C]):
"""The base from which all plugins should inherit."""
"""The base from which all plugins should inherit.
Type Parameters
---------------
C:
A BaseContext: either BrokerContext or ClientContext, depending on plugin usage
Attributes
----------
context (C):
Information about the environment in which this plugin is executed. Modifying
the broker or client state should happen through methods available here.
config (self.Config):
An instance of the Config dataclass defined by the plugin (or an empty dataclass, if not
defined). If using entrypoint- or mixed-style configuration, use `_get_config_option()`
to access the variable.
"""
def __init__(self, context: C) -> None:
self.context: C = context
# since the PluginManager will hydrate the config from a plugin's `Config` class, this is a safe cast
self.config = cast("self.Config", context.config) # type: ignore[name-defined]
# Deprecated: included to support entrypoint-style configs. Replaced by dataclass Config class.
def _get_config_section(self, name: str) -> dict[str, Any] | None:

@@ -23,3 +44,3 @@

section_config: int | dict[str, Any] | None = self.context.config.get(name, None)
# mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check
# mypy has difficulty excluding int from `config`'s type, unless there's an explicit check
if isinstance(section_config, int):

@@ -29,2 +50,18 @@ return None

# Deprecated : supports entrypoint-style configs as well as dataclass configuration.
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if option_name in self.context.config:
return self.context.config[option_name]
return default
@dataclass
class Config:
"""Override to define the configuration and defaults for plugin."""
async def close(self) -> None:

@@ -41,5 +78,16 @@ """Override if plugin needs to clean up resources upon shutdown."""

self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check")
if self.topic_config is None:
if not bool(self.topic_config) and not is_dataclass(self.context.config):
self.context.logger.warning("'topic-check' section not found in context configuration")
def _get_config_option(self, option_name: str, default: Any=None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.topic_config and option_name in self.topic_config:
return self.topic_config[option_name]
return default
async def topic_filtering(

@@ -59,7 +107,3 @@ self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None

"""
if not self.topic_config:
# auth config section not found
self.context.logger.warning("'topic-check' section not found in context configuration")
return False
return True
return bool(self.topic_config) or is_dataclass(self.context.config)

@@ -70,2 +114,13 @@

def _get_config_option(self, option_name: str, default: Any=None) -> Any:
if not self.context.config:
return default
if is_dataclass(self.context.config):
# overloaded context.config for BasePlugin `Config` class, so ignoring static type check
return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable]
if self.auth_config and option_name in self.auth_config:
return self.auth_config[option_name]
return default
def __init__(self, context: BaseContext) -> None:

@@ -75,5 +130,7 @@ super().__init__(context)

self.auth_config: dict[str, Any] | None = self._get_config_section("auth")
if not self.auth_config:
if not bool(self.auth_config) and not is_dataclass(self.context.config):
# auth config section not found and Config dataclass not provided
self.context.logger.warning("'auth' section not found in context configuration")
async def authenticate(self, *, session: Session) -> bool | None:

@@ -90,6 +147,2 @@ """Logic for session authentication.

"""
if not self.auth_config:
# auth config section not found
self.context.logger.warning("'auth' section not found in context configuration")
return False
return True
return bool(self.auth_config) or is_dataclass(self.context.config)

@@ -6,2 +6,3 @@ from collections.abc import Callable, Coroutine

from amqtt.contexts import BaseContext
from amqtt.events import BrokerEvents

@@ -11,3 +12,2 @@ from amqtt.mqtt import MQTTPacket

from amqtt.plugins.base import BasePlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session

@@ -14,0 +14,0 @@

@@ -1,2 +0,2 @@

__all__ = ["BaseContext", "PluginManager", "get_plugin_manager"]
__all__ = ["PluginManager", "get_plugin_manager"]

@@ -11,15 +11,15 @@ import asyncio

import logging
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar
from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast
import warnings
from amqtt.errors import PluginImportError, PluginInitError
from dacite import Config as DaciteConfig, DaciteError, from_dict
from amqtt.contexts import Action, BaseContext
from amqtt.errors import PluginCoroError, PluginImportError, PluginInitError, PluginLoadError
from amqtt.events import BrokerEvents, Events, MQTTEvents
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
from amqtt.session import Session
from amqtt.utils import import_string
_LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING:
from amqtt.broker import Action
from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin
class Plugin(NamedTuple):

@@ -43,7 +43,7 @@ name: str

class BaseContext:
def __init__(self) -> None:
self.loop: asyncio.AbstractEventLoop | None = None
self.logger: logging.Logger = _LOGGER
self.config: dict[str, Any] | None = None
def safe_issubclass(sub_class: Any, super_class: Any) -> bool:
try:
return issubclass(sub_class, super_class)
except TypeError:
return False

@@ -75,2 +75,5 @@

self._event_plugin_callbacks: dict[str, list[AsyncFunc]] = defaultdict(list)
self._is_topic_filtering_enabled = False
self._is_auth_filtering_enabled = False
self._load_plugins(namespace)

@@ -84,12 +87,75 @@ self._fired_events: list[asyncio.Future[Any]] = []

def _load_plugins(self, namespace: str) -> None:
def _load_plugins(self, namespace: str | None = None) -> None:
"""Load plugins from entrypoint or config dictionary.
config style is now recommended; entrypoint has been deprecated
Example:
config = {
'listeners':...,
'plugins': {
'myproject.myfile.MyPlugin': {}
}
"""
if self.app_context.config and self.app_context.config.get("plugins", None) is not None:
# plugins loaded directly from config dictionary
if "auth" in self.app_context.config:
self.logger.warning("Loading plugins from config will ignore 'auth' section of config")
if "topic-check" in self.app_context.config:
self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config")
plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", [])
# if the config was generated from yaml, the plugins maybe a list instead of a dictionary; transform before loading
#
# plugins:
# - myproject.myfile.MyPlugin:
if isinstance(plugins_config, list):
plugins_info: dict[str, Any] = {}
for plugin_config in plugins_config:
if isinstance(plugin_config, str):
plugins_info.update({plugin_config: {}})
elif not isinstance(plugin_config, dict):
msg = "malformed 'plugins' configuration"
raise PluginLoadError(msg)
else:
plugins_info.update(plugin_config)
self._load_str_plugins(plugins_info)
elif isinstance(plugins_config, dict):
self._load_str_plugins(plugins_config)
else:
if not namespace:
msg = "Namespace needs to be provided for EntryPoint plugin definitions"
raise PluginLoadError(msg)
warnings.warn(
"Loading plugins from EntryPoints is deprecated and will be removed in a future version."
" Use `plugins` section of config instead.",
DeprecationWarning,
stacklevel=2
)
self._load_ep_plugins(namespace)
# for all the loaded plugins, find all event callbacks
for plugin in self._plugins:
for event in list(BrokerEvents) + list(MQTTEvents):
if awaitable := getattr(plugin, f"on_{event}", None):
if not iscoroutinefunction(awaitable):
msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'"
raise PluginImportError(msg)
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
self._event_plugin_callbacks[event].append(awaitable)
def _load_ep_plugins(self, namespace:str) -> None:
"""Load plugins from `pyproject.toml` entrypoints. Deprecated."""
self.logger.debug(f"Loading plugins for namespace {namespace}")
auth_filter_list = []
topic_filter_list = []
if self.app_context.config and "auth" in self.app_context.config:
auth_filter_list = self.app_context.config["auth"].get("plugins", [])
auth_filter_list = self.app_context.config["auth"].get("plugins", None)
if self.app_context.config and "topic-check" in self.app_context.config:
topic_filter_list = self.app_context.config["topic-check"].get("plugins", [])
topic_filter_list = self.app_context.config["topic-check"].get("plugins", None)

@@ -106,6 +172,8 @@ ep: EntryPoints | list[EntryPoint] = []

self._plugins.append(ep_plugin.object)
if ((not auth_filter_list or ep_plugin.name in auth_filter_list)
# maintain legacy behavior that if there is no list, use all auth plugins
if ((auth_filter_list is None or ep_plugin.name in auth_filter_list)
and hasattr(ep_plugin.object, "authenticate")):
self._auth_plugins.append(ep_plugin.object)
if ((not topic_filter_list or ep_plugin.name in topic_filter_list)
# maintain legacy behavior that if there is no list, use all topic plugins
if ((topic_filter_list is None or ep_plugin.name in topic_filter_list)
and hasattr(ep_plugin.object, "topic_filtering")):

@@ -115,12 +183,4 @@ self._topic_plugins.append(ep_plugin.object)

for plugin in self._plugins:
for event in list(BrokerEvents) + list(MQTTEvents):
if awaitable := getattr(plugin, f"on_{event}", None):
if not iscoroutinefunction(awaitable):
msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'"
raise PluginImportError(msg)
self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'")
self._event_plugin_callbacks[event].append(awaitable)
def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None:
"""Load plugins from `pyproject.toml` entrypoints. Deprecated."""
try:

@@ -145,5 +205,64 @@ self.logger.debug(f" Loading plugin {ep!s}")

def _load_str_plugins(self, plugins_info: dict[str, Any]) -> None:
self.logger.info("Loading plugins from config")
# legacy had a filtering 'enabled' flag, even if plugins were loaded/listed
self._is_topic_filtering_enabled = True
self._is_auth_filtering_enabled = True
for plugin_path, plugin_config in plugins_info.items():
plugin = self._load_str_plugin(plugin_path, plugin_config)
self._plugins.append(plugin)
# make sure that authenticate and topic filtering plugins have the appropriate async signature
if isinstance(plugin, BaseAuthPlugin):
if not iscoroutinefunction(plugin.authenticate):
msg = f"Auth plugin {plugin_path} has non-async authenticate method."
raise PluginCoroError(msg)
self._auth_plugins.append(plugin)
if isinstance(plugin, BaseTopicPlugin):
if not iscoroutinefunction(plugin.topic_filtering):
msg = f"Topic plugin {plugin_path} has non-async topic_filtering method."
raise PluginCoroError(msg)
self._topic_plugins.append(plugin)
def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]":
"""Load plugin from string dotted path: mymodule.myfile.MyPlugin."""
try:
plugin_class: Any = import_string(plugin_path)
except ImportError as ep:
msg = f"Plugin import failed: {plugin_path}"
raise PluginImportError(msg) from ep
if not safe_issubclass(plugin_class, BasePlugin):
msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'"
raise PluginLoadError(msg)
plugin_context = copy.copy(self.app_context)
plugin_context.logger = self.logger.getChild(plugin_class.__name__)
try:
# populate the config based on the inner dataclass called `Config`
# use `dacite` package to type check
plugin_context.config = from_dict(data_class=plugin_class.Config,
data=plugin_cfg or {},
config=DaciteConfig(strict=True))
except DaciteError as e:
raise PluginLoadError from e
except TypeError as e:
msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass."
raise PluginLoadError(msg) from e
try:
pc = plugin_class(plugin_context)
self.logger.debug(f"Loading plugin {plugin_path}")
return cast("BasePlugin[C]", pc)
except Exception as e:
self.logger.debug(f"Plugin init failed: {plugin_class.__name__}", exc_info=True)
raise PluginInitError(plugin_class) from e
def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]:
"""Get a plugin by its name from the plugins loaded for the current namespace.
Only used for testing purposes to verify plugin loading correctly.
:param name:

@@ -157,2 +276,8 @@ :return:

def is_topic_filtering_enabled(self) -> bool:
topic_config = self.app_context.config.get("topic-check", {}) if self.app_context.config else {}
if isinstance(topic_config, dict):
return topic_config.get("enabled", False) or self._is_topic_filtering_enabled
return False or self._is_topic_filtering_enabled
async def close(self) -> None:

@@ -176,2 +301,6 @@ """Free PluginManager resources and cancel pending event methods."""

def _clean_fired_events(self, future: asyncio.Future[Any]) -> None:
with contextlib.suppress(KeyError, ValueError):
self._fired_events.remove(future)
async def fire_event(self, event_name: Events, *, wait: bool = False, **method_kwargs: Any) -> None:

@@ -202,9 +331,4 @@ """Fire an event to plugins.

tasks.append(asyncio.ensure_future(coro_instance))
tasks[-1].add_done_callback(self._clean_fired_events)
def clean_fired_events(future: asyncio.Future[Any]) -> None:
with contextlib.suppress(KeyError, ValueError):
self._fired_events.remove(future)
tasks[-1].add_done_callback(clean_fired_events)
self._fired_events.extend(tasks)

@@ -211,0 +335,0 @@ if wait and tasks:

@@ -5,3 +5,3 @@ import json

from amqtt.plugins.manager import BaseContext
from amqtt.contexts import BaseContext
from amqtt.session import Session

@@ -8,0 +8,0 @@

import asyncio
from collections import deque # pylint: disable=C0412
from dataclasses import dataclass
from typing import Any, SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412

@@ -73,4 +74,7 @@

self._sys_handle: asyncio.Handle | None = None
self._sys_interval: int = 0
self._current_process = psutil.Process()
def _clear_stats(self) -> None:

@@ -116,10 +120,10 @@ """Initialize broker statistics data structures."""

try:
sys_interval: int = 0
x = self.context.config.get("sys_interval") if self.context.config is not None else None
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
sys_interval = int(x)
if sys_interval > 0:
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
self._sys_interval = self._get_config_option("sys_interval", None)
if isinstance(self._sys_interval, str | Buffer | SupportsInt | SupportsIndex):
self._sys_interval = int(self._sys_interval)
if self._sys_interval > 0:
self.context.logger.debug(f"Setup $SYS broadcasting every {self._sys_interval} seconds")
self._sys_handle = (
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
if self.context.loop is not None

@@ -131,3 +135,3 @@ else None

except KeyError:
pass
self.context.logger.debug("could not find 'sys_interval' key: {e!r}")
# 'sys_interval' config parameter not found

@@ -200,11 +204,5 @@

# Reschedule
sys_interval: int = 0
x = self.context.config.get("sys_interval") if self.context.config is not None else None
if isinstance(x, str | Buffer | SupportsInt | SupportsIndex):
sys_interval = int(x)
self.context.logger.debug("Broadcasting $SYS topics")
self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds")
self.context.logger.debug(f"Broadcast $SYS topics again in {self._sys_interval} seconds.")
self._sys_handle = (
self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics)
self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics)
if self.context.loop is not None

@@ -232,3 +230,3 @@ else None

async def on_broker_client_connected(self, client_id: str) -> None:
async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None:
"""Handle broker client connection."""

@@ -241,5 +239,11 @@ self._stats[STAT_CLIENTS_CONNECTED] += 1

async def on_broker_client_disconnected(self, client_id: str) -> None:
async def on_broker_client_disconnected(self, client_id: str, client_session: Session) -> None:
"""Handle broker client disconnection."""
self._stats[STAT_CLIENTS_CONNECTED] -= 1
self._stats[STAT_CLIENTS_DISCONNECTED] += 1
@dataclass
class Config:
"""Configuration struct for plugin."""
sys_interval: int = 20

@@ -0,6 +1,6 @@

from dataclasses import dataclass, field
from typing import Any
from amqtt.broker import Action
from amqtt.contexts import Action, BaseContext
from amqtt.plugins.base import BaseTopicPlugin
from amqtt.plugins.manager import BaseContext
from amqtt.session import Session

@@ -55,3 +55,3 @@

# hbmqtt and older amqtt do not support publish filtering
if action == Action.PUBLISH and self.topic_config is not None and "publish-acl" not in self.topic_config:
if action == Action.PUBLISH and not self._get_config_option("publish-acl", {}):
# maintain backward compatibility, assume permitted

@@ -62,3 +62,3 @@ return True

if not req_topic:
return False
return False\

@@ -70,6 +70,7 @@ username = session.username if session else None

acl: dict[str, Any] = {}
if self.topic_config is not None and action == Action.PUBLISH:
acl = self.topic_config.get("publish-acl", {})
elif self.topic_config is not None and action == Action.SUBSCRIBE:
acl = self.topic_config.get("acl", {})
match action:
case Action.PUBLISH:
acl = self._get_config_option("publish-acl", {})
case Action.SUBSCRIBE:
acl = self._get_config_option("acl", {})

@@ -81,1 +82,8 @@ allowed_topics = acl.get(username, [])

return any(self.topic_ac(req_topic, allowed_topic) for allowed_topic in allowed_topics)
@dataclass
class Config:
"""Mappings of username and list of approved topics."""
publish_acl: dict[str, list[str]] = field(default_factory=dict)
acl: dict[str, list[str]] = field(default_factory=dict)

@@ -6,8 +6,8 @@ ---

bind: 0.0.0.0:1883
sys_interval: 20
auth:
plugins:
- auth_anonymous
allow-anonymous: true
topic-check:
enabled: False
plugins:
amqtt.plugins.logging_amqtt.EventLoggerPlugin:
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:
amqtt.plugins.authentication.AnonymousAuthPlugin:
allow_anonymous: true
amqtt.plugins.sys.broker.BrokerSysPlugin:
sys_interval: 20

@@ -7,5 +7,8 @@ ---

auto_reconnect: true
cleansession: true
reconnect_max_interval: 10
reconnect_retries: 2
broker:
uri: "mqtt://127.0.0.1"
uri: "mqtt://127.0.0.1"
plugins:
amqtt.plugins.logging_amqtt.PacketLoggerPlugin:

@@ -148,3 +148,3 @@ from asyncio import Queue

# Stores messages retained for this session
# Stores messages retained for this session (specifically when the client is disconnected)
self.retained_messages: Queue[ApplicationMessage] = Queue()

@@ -155,2 +155,5 @@

# identify anonymous client sessions or clients which didn't identify themselves
self.is_anonymous: bool = False
def _init_states(self) -> None:

@@ -157,0 +160,0 @@ self.transitions = Machine(states=Session.states, initial="new")

from __future__ import annotations
from importlib import import_module
import logging

@@ -7,2 +8,3 @@ from pathlib import Path

import string
import sys
import typing

@@ -52,1 +54,37 @@ from typing import Any

return None
def cached_import(module_path: str, class_name: str | None = None) -> Any:
"""Return cached import of a class from a module path (or retrieve, cache and then return)."""
# Check whether module is loaded and fully initialized.
if not ((module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None))
and getattr(spec, "_initializing", False) is False):
module = import_module(module_path)
if class_name:
return getattr(module, class_name)
return module
def import_string(dotted_path: str) -> Any:
"""Import a dotted module path.
Returns:
attribute/class designated by the last name in the path
Raises:
ImportError (if the import failed)
"""
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError as err:
msg = f"{dotted_path} doesn't look like a module path"
raise ImportError(msg) from err
try:
return cached_import(module_path, class_name)
except AttributeError as err:
msg = f'Module "{module_path}" does not define a "{class_name}" attribute/class'
raise ImportError(msg) from err
Metadata-Version: 2.4
Name: amqtt
Version: 0.11.1
Version: 0.11.2
Summary: Python's asyncio-native MQTT broker and client.

@@ -20,2 +20,3 @@ Author: aMQTT Contributors

Requires-Python: >=3.10.0
Requires-Dist: dacite>=1.9.2
Requires-Dist: passlib==1.7.4

@@ -22,0 +23,0 @@ Requires-Dist: psutil>=7.0.0

@@ -23,3 +23,3 @@ [build-system]

version = "0.11.1"
version = "0.11.2"
requires-python = ">=3.10.0"

@@ -37,2 +37,3 @@ readme = "README.md"

"typer==0.15.4",
"dacite>=1.9.2",
"psutil>=7.0.0",

@@ -39,0 +40,0 @@ ]

try:
from datetime import UTC, datetime
except ImportError:
from datetime import datetime, timezone
UTC = timezone.utc
import logging
from pathlib import Path
import shutil
import subprocess
import warnings
import amqtt
logger = logging.getLogger(__name__)
def get_version() -> str:
"""Return the version of the amqtt package.
This function is deprecated. Use amqtt.__version__ instead.
"""
warnings.warn(
"amqtt.version.get_version() is deprecated, use amqtt.__version__ instead",
stacklevel=3, # Adjusted stack level to better reflect the caller
)
return amqtt.__version__
def get_git_changeset() -> str | None:
"""Return a numeric identifier of the latest git changeset.
The result is the UTC timestamp of the changeset in YYYYMMDDHHMMSS format.
This value isn't guaranteed to be unique, but collisions are very unlikely,
so it's sufficient for generating the development version numbers.
"""
# Define the repository directory (two levels above the current script)
repo_dir = Path(__file__).resolve().parent.parent
# Ensure the directory exists and is valid
if not repo_dir.is_dir():
logger.error(f"Invalid directory: {repo_dir} is not a valid directory")
return None
# Use the system's PATH to locate 'git', or define the full path if necessary
git_path = "git" # Assuming git is available in the system PATH
# Ensure 'git' is executable and available
if not shutil.which(git_path):
logger.error(f"{git_path} is not found in the system PATH.")
return None
# Call git log to get the latest changeset timestamp
try:
with subprocess.Popen( # noqa: S603
[git_path, "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=repo_dir,
universal_newlines=True,
) as git_log:
timestamp_str, stderr = git_log.communicate()
if git_log.returncode != 0:
logger.error(f"Git command failed with error: {stderr}")
return None
# Convert the timestamp to a datetime object
timestamp = datetime.fromtimestamp(int(timestamp_str), tz=UTC)
return timestamp.strftime("%Y%m%d%H%M%S")
except Exception:
logger.exception("An error occurred while retrieving the git changeset.")
return None