amqtt
Advanced tools
| from enum import Enum | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
| _LOGGER = logging.getLogger(__name__) | ||
| if TYPE_CHECKING: | ||
| import asyncio | ||
| class BaseContext: | ||
| def __init__(self) -> None: | ||
| self.loop: asyncio.AbstractEventLoop | None = None | ||
| self.logger: logging.Logger = _LOGGER | ||
| self.config: dict[str, Any] | None = None | ||
| class Action(Enum): | ||
| """Actions issued by the broker.""" | ||
| SUBSCRIBE = "subscribe" | ||
| PUBLISH = "publish" |
+5
-0
@@ -37,1 +37,6 @@ #------- Package & Cache Files ------- | ||
| coverage.xml | ||
| #----- generated files ----- | ||
| *.log | ||
| *memray* | ||
| .coverage* |
| """INIT.""" | ||
| __version__ = "0.11.1" | ||
| __version__ = "0.11.2" |
+67
-35
@@ -6,3 +6,2 @@ import asyncio | ||
| import copy | ||
| from enum import Enum | ||
| from functools import partial | ||
@@ -27,2 +26,3 @@ import logging | ||
| ) | ||
| from amqtt.contexts import Action, BaseContext | ||
| from amqtt.errors import AMQTTError, BrokerError, MQTTError, NoDataError | ||
@@ -34,4 +34,5 @@ from amqtt.mqtt.protocol.broker_handler import BrokerProtocolHandler | ||
| from .events import BrokerEvents | ||
| from .mqtt.constants import QOS_0, QOS_1, QOS_2 | ||
| from .mqtt.disconnect import DisconnectPacket | ||
| from .plugins.manager import BaseContext, PluginManager | ||
| from .plugins.manager import PluginManager | ||
@@ -50,9 +51,2 @@ _CONFIG_LISTENER: TypeAlias = dict[str, int | bool | dict[str, Any]] | ||
| class Action(Enum): | ||
| """Actions issued by the broker.""" | ||
| SUBSCRIBE = "subscribe" | ||
| PUBLISH = "publish" | ||
| class RetainedApplicationMessage(ApplicationMessage): | ||
@@ -171,2 +165,6 @@ __slots__ = ("data", "qos", "source_session", "topic") | ||
| if config is not None: | ||
| # if 'plugins' isn't in the config but 'auth'/'topic-check' is included, assume this is a legacy config | ||
| if ("auth" in config or "topic-check" in config) and "plugins" not in config: | ||
| # set to None so that the config isn't updated with the new-style default plugin list | ||
| config["plugins"] = None # type: ignore[assignment] | ||
| self.config.update(config) | ||
@@ -182,2 +180,4 @@ self._build_listeners_config(self.config) | ||
| self._topic_filter_matchers: dict[str, re.Pattern[str]] = {} | ||
| # Broadcast queue for outgoing messages | ||
@@ -447,2 +447,3 @@ self._broadcast_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() | ||
| client_session.client_id = gen_client_id() | ||
| client_session.parent = 0 | ||
@@ -501,9 +502,20 @@ # Get session from cache | ||
| await handler.mqtt_connack_authorize(authenticated) | ||
| await self.plugins_manager.fire_event(BrokerEvents.CLIENT_CONNECTED, client_id=client_session.client_id) | ||
| await self.plugins_manager.fire_event(BrokerEvents.CLIENT_CONNECTED, | ||
| client_id=client_session.client_id, | ||
| client_session=client_session) | ||
| self.logger.debug(f"{client_session.client_id} Start messages handling") | ||
| await handler.start() | ||
| # publish messages that were retained because the client session was disconnected | ||
| self.logger.debug(f"Retained messages queue size: {client_session.retained_messages.qsize()}") | ||
| await self._publish_session_retained_messages(client_session) | ||
| # if this is not a new session, there are subscriptions associated with them; publish any topic retained messages | ||
| self.logger.debug("Publish retained messages to a pre-existing session's subscriptions.") | ||
| for topic in self._subscriptions: | ||
| await self._publish_retained_messages_for_subscription( (topic, QOS_0), client_session) | ||
| await self._client_message_loop(client_session, handler) | ||
@@ -605,3 +617,5 @@ | ||
| client_session.transitions.disconnect() | ||
| await self.plugins_manager.fire_event(BrokerEvents.CLIENT_DISCONNECTED, client_id=client_session.client_id) | ||
| await self.plugins_manager.fire_event(BrokerEvents.CLIENT_DISCONNECTED, | ||
| client_id=client_session.client_id, | ||
| client_session=client_session) | ||
@@ -671,2 +685,7 @@ | ||
| return False | ||
| if app_message.topic.startswith("$"): | ||
| self.logger.warning( | ||
| f"[MQTT-4.7.2-1] - {client_session.client_id} cannot use a topic with a leading $ character." | ||
| ) | ||
| return False | ||
@@ -712,14 +731,17 @@ permitted = await self._topic_filtering(client_session, topic=app_message.topic, action=Action.PUBLISH) | ||
| returns = await self.plugins_manager.map_plugin_auth(session=session) | ||
| auth_result = True | ||
| if returns: | ||
| for plugin in returns: | ||
| res = returns[plugin] | ||
| if res is False: | ||
| auth_result = False | ||
| self.logger.debug(f"Authentication failed due to '{plugin.__class__}' plugin result: {res}") | ||
| else: | ||
| self.logger.debug(f"'{plugin.__class__}' plugin result: {res}") | ||
| # If all plugins returned True, authentication is success | ||
| return auth_result | ||
| results = [ result for _, result in returns.items() if result is not None] if returns else [] | ||
| if len(results) < 1: | ||
| self.logger.debug("Authentication failed: no plugin responded with a boolean") | ||
| return False | ||
| if all(results): | ||
| self.logger.debug("Authentication succeeded") | ||
| return True | ||
| for plugin, result in returns.items(): | ||
| self.logger.debug(f"Authentication '{plugin.__class__.__name__}' result: {result}") | ||
| return False | ||
| def retain_message( | ||
@@ -780,9 +802,3 @@ self, | ||
| """ | ||
| topic_config = self.config.get("topic-check", {}) | ||
| enabled = False | ||
| if isinstance(topic_config, dict): | ||
| enabled = topic_config.get("enabled", False) | ||
| if not enabled: | ||
| if not self.plugins_manager.is_topic_filtering_enabled(): | ||
| return True | ||
@@ -882,5 +898,2 @@ | ||
| for k_filter, subscriptions in self._subscriptions.items(): | ||
| if broadcast["topic"].startswith("$") and (k_filter.startswith(("+", "#"))): | ||
| self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #") | ||
| continue | ||
@@ -896,3 +909,8 @@ # Skip all subscriptions which do not match the topic | ||
| # Retain all messages which cannot be broadcasted, due to the session not being connected | ||
| if target_session.transitions.state != "connected": | ||
| # but only when clean session is false and qos is 1 or 2 [MQTT 3.1.2.4] | ||
| # and, if a client used anonymous authentication, there is no expectation that messages should be retained | ||
| if (target_session.transitions.state != "connected" | ||
| and not target_session.clean_session | ||
| and qos in (QOS_1, QOS_2) | ||
| and not target_session.is_anonymous): | ||
| self.logger.debug(f"Session {target_session.client_id} is not connected, retaining message.") | ||
@@ -902,2 +920,6 @@ await self._retain_broadcast_message(broadcast, qos, target_session) | ||
| # Only broadcast the message to connected clients | ||
| if target_session.transitions.state != "connected": | ||
| continue | ||
| self.logger.debug( | ||
@@ -1003,7 +1025,17 @@ f"Broadcasting message from {format_client_message(session=broadcast['session'])}" | ||
| def _matches(self, topic: str, a_filter: str) -> bool: | ||
| if topic.startswith("$") and (a_filter.startswith(("+", "#"))): | ||
| self.logger.debug("[MQTT-4.7.2-1] - ignoring broadcasting $ topic to subscriptions starting with + or #") | ||
| return False | ||
| if "#" not in a_filter and "+" not in a_filter: | ||
| # if filter doesn't contain wildcard, return exact match | ||
| return a_filter == topic | ||
| # else use regex | ||
| match_pattern = re.compile(re.escape(a_filter).replace("\\#", "?.*").replace("\\+", "[^/]*").lstrip("?")) | ||
| # else use regex (re.compile is an expensive operation, store the matcher for future use) | ||
| if a_filter not in self._topic_filter_matchers: | ||
| self._topic_filter_matchers[a_filter] = re.compile(re.escape(a_filter) | ||
| .replace("\\#", "?.*") | ||
| .replace("\\+", "[^/]*") | ||
| .lstrip("?")) | ||
| match_pattern = self._topic_filter_matchers[a_filter] | ||
| return bool(match_pattern.fullmatch(topic)) | ||
@@ -1010,0 +1042,0 @@ |
+3
-2
@@ -22,2 +22,3 @@ import asyncio | ||
| ) | ||
| from amqtt.contexts import BaseContext | ||
| from amqtt.errors import ClientError, ConnectError, ProtocolHandlerError | ||
@@ -27,3 +28,3 @@ from amqtt.mqtt.connack import CONNECTION_ACCEPTED | ||
| from amqtt.mqtt.protocol.client_handler import ClientProtocolHandler | ||
| from amqtt.plugins.manager import BaseContext, PluginManager | ||
| from amqtt.plugins.manager import PluginManager | ||
| from amqtt.session import ApplicationMessage, OutgoingApplicationMessage, Session | ||
@@ -602,3 +603,3 @@ from amqtt.utils import gen_client_id, read_yaml_config | ||
| if cleansession is not None: | ||
| broker_conf["cleansession"] = cleansession | ||
| broker_conf["cleansession"] = cleansession # noop? | ||
| session.clean_session = cleansession | ||
@@ -605,0 +606,0 @@ else: |
+7
-2
@@ -32,7 +32,12 @@ from typing import Any | ||
| class PluginImportError(PluginError): | ||
| def __init__(self, plugin: Any) -> None: | ||
| super().__init__(f"Plugin import failed: {plugin!r}") | ||
| """Exceptions thrown when loading plugin.""" | ||
| class PluginCoroError(PluginError): | ||
| """Exceptions thrown when loading a plugin with a non-async call method.""" | ||
| class PluginInitError(PluginError): | ||
| """Exceptions thrown when initializing plugin.""" | ||
| def __init__(self, plugin: Any) -> None: | ||
@@ -39,0 +44,0 @@ super().__init__(f"Plugin init failed: {plugin!r}") |
@@ -195,3 +195,3 @@ from asyncio import StreamReader | ||
| payload.client_id = gen_client_id() | ||
| # indicator to trow exception in case CLEAN_SESSION_FLAG is set to False | ||
| # indicator to throw exception in case CLEAN_SESSION_FLAG is set to False | ||
| payload.client_id_is_random = True | ||
@@ -198,0 +198,0 @@ |
@@ -23,2 +23,3 @@ import asyncio | ||
| from amqtt.adapters import ReaderAdapter, WriterAdapter | ||
| from amqtt.contexts import BaseContext | ||
| from amqtt.errors import AMQTTError, MQTTError, NoDataError, ProtocolHandlerError | ||
@@ -61,3 +62,3 @@ from amqtt.events import MQTTEvents | ||
| from amqtt.mqtt.unsubscribe import UnsubscribePacket | ||
| from amqtt.plugins.manager import BaseContext, PluginManager | ||
| from amqtt.plugins.manager import PluginManager | ||
| from amqtt.session import INCOMING, OUTGOING, ApplicationMessage, IncomingApplicationMessage, OutgoingApplicationMessage, Session | ||
@@ -524,4 +525,5 @@ | ||
| elif packet.fixed_header.packet_type == CONNECT and isinstance(packet, ConnectPacket): | ||
| # TODO: why is this not like all other inside create_task? | ||
| await self.handle_connect(packet) # task = asyncio.create_task(self.handle_connect(packet)) | ||
| # q: why is this not like all other inside a create_task? | ||
| # a: the connection needs to be established before any other packet tasks for this new session are scheduled | ||
| await self.handle_connect(packet) | ||
| if task: | ||
@@ -528,0 +530,0 @@ running_tasks.append(task) |
@@ -15,3 +15,3 @@ import asyncio | ||
| super().__init__() | ||
| if "*" in topic_name: | ||
| if "#" in topic_name or "+" in topic_name: | ||
| msg = "[MQTT-3.3.2-2] Topic name in the PUBLISH Packet MUST NOT contain wildcard characters." | ||
@@ -18,0 +18,0 @@ raise MQTTError(msg) |
@@ -0,1 +1,2 @@ | ||
| from dataclasses import dataclass, field | ||
| from pathlib import Path | ||
@@ -6,2 +7,3 @@ | ||
| from amqtt.broker import BrokerContext | ||
| from amqtt.contexts import BaseContext | ||
| from amqtt.plugins.base import BaseAuthPlugin | ||
@@ -16,9 +18,15 @@ from amqtt.session import Session | ||
| def __init__(self, context: BaseContext) -> None: | ||
| super().__init__(context) | ||
| # Default to allowing anonymous | ||
| self._allow_anonymous = self._get_config_option("allow-anonymous", True) # noqa: FBT003 | ||
| async def authenticate(self, *, session: Session) -> bool: | ||
| authenticated = await super().authenticate(session=session) | ||
| if authenticated: | ||
| # Default to allowing anonymous | ||
| allow_anonymous = self.auth_config.get("allow-anonymous", True) if isinstance(self.auth_config, dict) else True | ||
| if allow_anonymous: | ||
| if self._allow_anonymous: | ||
| self.context.logger.debug("Authentication success: config allows anonymous") | ||
| session.is_anonymous = True | ||
| return True | ||
@@ -32,3 +40,9 @@ | ||
| @dataclass | ||
| class Config: | ||
| """Allow empty username.""" | ||
| allow_anonymous: bool = field(default=True) | ||
| class FileAuthPlugin(BaseAuthPlugin): | ||
@@ -44,3 +58,3 @@ """Authentication plugin based on a file-stored user database.""" | ||
| """Read the password file and populates the user dictionary.""" | ||
| password_file = self.auth_config.get("password-file") if isinstance(self.auth_config, dict) else None | ||
| password_file = self._get_config_option("password-file", None) | ||
| if not password_file: | ||
@@ -51,3 +65,6 @@ self.context.logger.warning("Configuration parameter 'password-file' not found") | ||
| try: | ||
| with Path(password_file).open(mode="r", encoding="utf-8") as file: | ||
| file = password_file | ||
| if isinstance(file, str): | ||
| file = Path(file) | ||
| with file.open(mode="r", encoding="utf-8") as file: | ||
| self.context.logger.debug(f"Reading user database from {password_file}") | ||
@@ -95,1 +112,7 @@ for _line in file: | ||
| return False | ||
| @dataclass | ||
| class Config: | ||
| """Path to the properly encoded password file.""" | ||
| password_file: str | Path | None = None |
+70
-17
@@ -1,5 +0,5 @@ | ||
| from typing import Any, Generic, TypeVar | ||
| from dataclasses import dataclass, is_dataclass | ||
| from typing import Any, Generic, TypeVar, cast | ||
| from amqtt.broker import Action | ||
| from amqtt.plugins.manager import BaseContext | ||
| from amqtt.contexts import Action, BaseContext | ||
| from amqtt.session import Session | ||
@@ -11,7 +11,28 @@ | ||
| class BasePlugin(Generic[C]): | ||
| """The base from which all plugins should inherit.""" | ||
| """The base from which all plugins should inherit. | ||
| Type Parameters | ||
| --------------- | ||
| C: | ||
| A BaseContext: either BrokerContext or ClientContext, depending on plugin usage | ||
| Attributes | ||
| ---------- | ||
| context (C): | ||
| Information about the environment in which this plugin is executed. Modifying | ||
| the broker or client state should happen through methods available here. | ||
| config (self.Config): | ||
| An instance of the Config dataclass defined by the plugin (or an empty dataclass, if not | ||
| defined). If using entrypoint- or mixed-style configuration, use `_get_config_option()` | ||
| to access the variable. | ||
| """ | ||
| def __init__(self, context: C) -> None: | ||
| self.context: C = context | ||
| # since the PluginManager will hydrate the config from a plugin's `Config` class, this is a safe cast | ||
| self.config = cast("self.Config", context.config) # type: ignore[name-defined] | ||
| # Deprecated: included to support entrypoint-style configs. Replaced by dataclass Config class. | ||
| def _get_config_section(self, name: str) -> dict[str, Any] | None: | ||
@@ -23,3 +44,3 @@ | ||
| section_config: int | dict[str, Any] | None = self.context.config.get(name, None) | ||
| # mypy has difficulty excluding int from `config`'s type, unless isinstance` is its own check | ||
| # mypy has difficulty excluding int from `config`'s type, unless there's an explicit check | ||
| if isinstance(section_config, int): | ||
@@ -29,2 +50,18 @@ return None | ||
| # Deprecated : supports entrypoint-style configs as well as dataclass configuration. | ||
| def _get_config_option(self, option_name: str, default: Any=None) -> Any: | ||
| if not self.context.config: | ||
| return default | ||
| if is_dataclass(self.context.config): | ||
| # overloaded context.config for BasePlugin `Config` class, so ignoring static type check | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] | ||
| if option_name in self.context.config: | ||
| return self.context.config[option_name] | ||
| return default | ||
| @dataclass | ||
| class Config: | ||
| """Override to define the configuration and defaults for plugin.""" | ||
| async def close(self) -> None: | ||
@@ -41,5 +78,16 @@ """Override if plugin needs to clean up resources upon shutdown.""" | ||
| self.topic_config: dict[str, Any] | None = self._get_config_section("topic-check") | ||
| if self.topic_config is None: | ||
| if not bool(self.topic_config) and not is_dataclass(self.context.config): | ||
| self.context.logger.warning("'topic-check' section not found in context configuration") | ||
| def _get_config_option(self, option_name: str, default: Any=None) -> Any: | ||
| if not self.context.config: | ||
| return default | ||
| if is_dataclass(self.context.config): | ||
| # overloaded context.config for BasePlugin `Config` class, so ignoring static type check | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] | ||
| if self.topic_config and option_name in self.topic_config: | ||
| return self.topic_config[option_name] | ||
| return default | ||
| async def topic_filtering( | ||
@@ -59,7 +107,3 @@ self, *, session: Session | None = None, topic: str | None = None, action: Action | None = None | ||
| """ | ||
| if not self.topic_config: | ||
| # auth config section not found | ||
| self.context.logger.warning("'topic-check' section not found in context configuration") | ||
| return False | ||
| return True | ||
| return bool(self.topic_config) or is_dataclass(self.context.config) | ||
@@ -70,2 +114,13 @@ | ||
| def _get_config_option(self, option_name: str, default: Any=None) -> Any: | ||
| if not self.context.config: | ||
| return default | ||
| if is_dataclass(self.context.config): | ||
| # overloaded context.config for BasePlugin `Config` class, so ignoring static type check | ||
| return getattr(self.context.config, option_name.replace("-", "_"), default) # type: ignore[unreachable] | ||
| if self.auth_config and option_name in self.auth_config: | ||
| return self.auth_config[option_name] | ||
| return default | ||
| def __init__(self, context: BaseContext) -> None: | ||
@@ -75,5 +130,7 @@ super().__init__(context) | ||
| self.auth_config: dict[str, Any] | None = self._get_config_section("auth") | ||
| if not self.auth_config: | ||
| if not bool(self.auth_config) and not is_dataclass(self.context.config): | ||
| # auth config section not found and Config dataclass not provided | ||
| self.context.logger.warning("'auth' section not found in context configuration") | ||
| async def authenticate(self, *, session: Session) -> bool | None: | ||
@@ -90,6 +147,2 @@ """Logic for session authentication. | ||
| """ | ||
| if not self.auth_config: | ||
| # auth config section not found | ||
| self.context.logger.warning("'auth' section not found in context configuration") | ||
| return False | ||
| return True | ||
| return bool(self.auth_config) or is_dataclass(self.context.config) |
@@ -6,2 +6,3 @@ from collections.abc import Callable, Coroutine | ||
| from amqtt.contexts import BaseContext | ||
| from amqtt.events import BrokerEvents | ||
@@ -11,3 +12,2 @@ from amqtt.mqtt import MQTTPacket | ||
| from amqtt.plugins.base import BasePlugin | ||
| from amqtt.plugins.manager import BaseContext | ||
| from amqtt.session import Session | ||
@@ -14,0 +14,0 @@ |
+159
-35
@@ -1,2 +0,2 @@ | ||
| __all__ = ["BaseContext", "PluginManager", "get_plugin_manager"] | ||
| __all__ = ["PluginManager", "get_plugin_manager"] | ||
@@ -11,15 +11,15 @@ import asyncio | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar | ||
| from typing import Any, Generic, NamedTuple, Optional, TypeAlias, TypeVar, cast | ||
| import warnings | ||
| from amqtt.errors import PluginImportError, PluginInitError | ||
| from dacite import Config as DaciteConfig, DaciteError, from_dict | ||
| from amqtt.contexts import Action, BaseContext | ||
| from amqtt.errors import PluginCoroError, PluginImportError, PluginInitError, PluginLoadError | ||
| from amqtt.events import BrokerEvents, Events, MQTTEvents | ||
| from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin | ||
| from amqtt.session import Session | ||
| from amqtt.utils import import_string | ||
| _LOGGER = logging.getLogger(__name__) | ||
| if TYPE_CHECKING: | ||
| from amqtt.broker import Action | ||
| from amqtt.plugins.base import BaseAuthPlugin, BasePlugin, BaseTopicPlugin | ||
| class Plugin(NamedTuple): | ||
@@ -43,7 +43,7 @@ name: str | ||
| class BaseContext: | ||
| def __init__(self) -> None: | ||
| self.loop: asyncio.AbstractEventLoop | None = None | ||
| self.logger: logging.Logger = _LOGGER | ||
| self.config: dict[str, Any] | None = None | ||
| def safe_issubclass(sub_class: Any, super_class: Any) -> bool: | ||
| try: | ||
| return issubclass(sub_class, super_class) | ||
| except TypeError: | ||
| return False | ||
@@ -75,2 +75,5 @@ | ||
| self._event_plugin_callbacks: dict[str, list[AsyncFunc]] = defaultdict(list) | ||
| self._is_topic_filtering_enabled = False | ||
| self._is_auth_filtering_enabled = False | ||
| self._load_plugins(namespace) | ||
@@ -84,12 +87,75 @@ self._fired_events: list[asyncio.Future[Any]] = [] | ||
| def _load_plugins(self, namespace: str) -> None: | ||
| def _load_plugins(self, namespace: str | None = None) -> None: | ||
| """Load plugins from entrypoint or config dictionary. | ||
| config style is now recommended; entrypoint has been deprecated | ||
| Example: | ||
| config = { | ||
| 'listeners':..., | ||
| 'plugins': { | ||
| 'myproject.myfile.MyPlugin': {} | ||
| } | ||
| """ | ||
| if self.app_context.config and self.app_context.config.get("plugins", None) is not None: | ||
| # plugins loaded directly from config dictionary | ||
| if "auth" in self.app_context.config: | ||
| self.logger.warning("Loading plugins from config will ignore 'auth' section of config") | ||
| if "topic-check" in self.app_context.config: | ||
| self.logger.warning("Loading plugins from config will ignore 'topic-check' section of config") | ||
| plugins_config: list[Any] | dict[str, Any] = self.app_context.config.get("plugins", []) | ||
| # if the config was generated from yaml, the plugins maybe a list instead of a dictionary; transform before loading | ||
| # | ||
| # plugins: | ||
| # - myproject.myfile.MyPlugin: | ||
| if isinstance(plugins_config, list): | ||
| plugins_info: dict[str, Any] = {} | ||
| for plugin_config in plugins_config: | ||
| if isinstance(plugin_config, str): | ||
| plugins_info.update({plugin_config: {}}) | ||
| elif not isinstance(plugin_config, dict): | ||
| msg = "malformed 'plugins' configuration" | ||
| raise PluginLoadError(msg) | ||
| else: | ||
| plugins_info.update(plugin_config) | ||
| self._load_str_plugins(plugins_info) | ||
| elif isinstance(plugins_config, dict): | ||
| self._load_str_plugins(plugins_config) | ||
| else: | ||
| if not namespace: | ||
| msg = "Namespace needs to be provided for EntryPoint plugin definitions" | ||
| raise PluginLoadError(msg) | ||
| warnings.warn( | ||
| "Loading plugins from EntryPoints is deprecated and will be removed in a future version." | ||
| " Use `plugins` section of config instead.", | ||
| DeprecationWarning, | ||
| stacklevel=2 | ||
| ) | ||
| self._load_ep_plugins(namespace) | ||
| # for all the loaded plugins, find all event callbacks | ||
| for plugin in self._plugins: | ||
| for event in list(BrokerEvents) + list(MQTTEvents): | ||
| if awaitable := getattr(plugin, f"on_{event}", None): | ||
| if not iscoroutinefunction(awaitable): | ||
| msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'" | ||
| raise PluginImportError(msg) | ||
| self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'") | ||
| self._event_plugin_callbacks[event].append(awaitable) | ||
| def _load_ep_plugins(self, namespace:str) -> None: | ||
| """Load plugins from `pyproject.toml` entrypoints. Deprecated.""" | ||
| self.logger.debug(f"Loading plugins for namespace {namespace}") | ||
| auth_filter_list = [] | ||
| topic_filter_list = [] | ||
| if self.app_context.config and "auth" in self.app_context.config: | ||
| auth_filter_list = self.app_context.config["auth"].get("plugins", []) | ||
| auth_filter_list = self.app_context.config["auth"].get("plugins", None) | ||
| if self.app_context.config and "topic-check" in self.app_context.config: | ||
| topic_filter_list = self.app_context.config["topic-check"].get("plugins", []) | ||
| topic_filter_list = self.app_context.config["topic-check"].get("plugins", None) | ||
@@ -106,6 +172,8 @@ ep: EntryPoints | list[EntryPoint] = [] | ||
| self._plugins.append(ep_plugin.object) | ||
| if ((not auth_filter_list or ep_plugin.name in auth_filter_list) | ||
| # maintain legacy behavior that if there is no list, use all auth plugins | ||
| if ((auth_filter_list is None or ep_plugin.name in auth_filter_list) | ||
| and hasattr(ep_plugin.object, "authenticate")): | ||
| self._auth_plugins.append(ep_plugin.object) | ||
| if ((not topic_filter_list or ep_plugin.name in topic_filter_list) | ||
| # maintain legacy behavior that if there is no list, use all topic plugins | ||
| if ((topic_filter_list is None or ep_plugin.name in topic_filter_list) | ||
| and hasattr(ep_plugin.object, "topic_filtering")): | ||
@@ -115,12 +183,4 @@ self._topic_plugins.append(ep_plugin.object) | ||
| for plugin in self._plugins: | ||
| for event in list(BrokerEvents) + list(MQTTEvents): | ||
| if awaitable := getattr(plugin, f"on_{event}", None): | ||
| if not iscoroutinefunction(awaitable): | ||
| msg = f"'on_{event}' for '{plugin.__class__.__name__}' is not a coroutine'" | ||
| raise PluginImportError(msg) | ||
| self.logger.debug(f"'{event}' handler found for '{plugin.__class__.__name__}'") | ||
| self._event_plugin_callbacks[event].append(awaitable) | ||
| def _load_ep_plugin(self, ep: EntryPoint) -> Plugin | None: | ||
| """Load plugins from `pyproject.toml` entrypoints. Deprecated.""" | ||
| try: | ||
@@ -145,5 +205,64 @@ self.logger.debug(f" Loading plugin {ep!s}") | ||
| def _load_str_plugins(self, plugins_info: dict[str, Any]) -> None: | ||
| self.logger.info("Loading plugins from config") | ||
| # legacy had a filtering 'enabled' flag, even if plugins were loaded/listed | ||
| self._is_topic_filtering_enabled = True | ||
| self._is_auth_filtering_enabled = True | ||
| for plugin_path, plugin_config in plugins_info.items(): | ||
| plugin = self._load_str_plugin(plugin_path, plugin_config) | ||
| self._plugins.append(plugin) | ||
| # make sure that authenticate and topic filtering plugins have the appropriate async signature | ||
| if isinstance(plugin, BaseAuthPlugin): | ||
| if not iscoroutinefunction(plugin.authenticate): | ||
| msg = f"Auth plugin {plugin_path} has non-async authenticate method." | ||
| raise PluginCoroError(msg) | ||
| self._auth_plugins.append(plugin) | ||
| if isinstance(plugin, BaseTopicPlugin): | ||
| if not iscoroutinefunction(plugin.topic_filtering): | ||
| msg = f"Topic plugin {plugin_path} has non-async topic_filtering method." | ||
| raise PluginCoroError(msg) | ||
| self._topic_plugins.append(plugin) | ||
| def _load_str_plugin(self, plugin_path: str, plugin_cfg: dict[str, Any] | None = None) -> "BasePlugin[C]": | ||
| """Load plugin from string dotted path: mymodule.myfile.MyPlugin.""" | ||
| try: | ||
| plugin_class: Any = import_string(plugin_path) | ||
| except ImportError as ep: | ||
| msg = f"Plugin import failed: {plugin_path}" | ||
| raise PluginImportError(msg) from ep | ||
| if not safe_issubclass(plugin_class, BasePlugin): | ||
| msg = f"Plugin {plugin_path} is not a subclass of 'BasePlugin'" | ||
| raise PluginLoadError(msg) | ||
| plugin_context = copy.copy(self.app_context) | ||
| plugin_context.logger = self.logger.getChild(plugin_class.__name__) | ||
| try: | ||
| # populate the config based on the inner dataclass called `Config` | ||
| # use `dacite` package to type check | ||
| plugin_context.config = from_dict(data_class=plugin_class.Config, | ||
| data=plugin_cfg or {}, | ||
| config=DaciteConfig(strict=True)) | ||
| except DaciteError as e: | ||
| raise PluginLoadError from e | ||
| except TypeError as e: | ||
| msg = f"Could not marshall 'Config' of {plugin_path}; should be a dataclass." | ||
| raise PluginLoadError(msg) from e | ||
| try: | ||
| pc = plugin_class(plugin_context) | ||
| self.logger.debug(f"Loading plugin {plugin_path}") | ||
| return cast("BasePlugin[C]", pc) | ||
| except Exception as e: | ||
| self.logger.debug(f"Plugin init failed: {plugin_class.__name__}", exc_info=True) | ||
| raise PluginInitError(plugin_class) from e | ||
| def get_plugin(self, name: str) -> Optional["BasePlugin[C]"]: | ||
| """Get a plugin by its name from the plugins loaded for the current namespace. | ||
| Only used for testing purposes to verify plugin loading correctly. | ||
| :param name: | ||
@@ -157,2 +276,8 @@ :return: | ||
| def is_topic_filtering_enabled(self) -> bool: | ||
| topic_config = self.app_context.config.get("topic-check", {}) if self.app_context.config else {} | ||
| if isinstance(topic_config, dict): | ||
| return topic_config.get("enabled", False) or self._is_topic_filtering_enabled | ||
| return False or self._is_topic_filtering_enabled | ||
| async def close(self) -> None: | ||
@@ -176,2 +301,6 @@ """Free PluginManager resources and cancel pending event methods.""" | ||
| def _clean_fired_events(self, future: asyncio.Future[Any]) -> None: | ||
| with contextlib.suppress(KeyError, ValueError): | ||
| self._fired_events.remove(future) | ||
| async def fire_event(self, event_name: Events, *, wait: bool = False, **method_kwargs: Any) -> None: | ||
@@ -202,9 +331,4 @@ """Fire an event to plugins. | ||
| tasks.append(asyncio.ensure_future(coro_instance)) | ||
| tasks[-1].add_done_callback(self._clean_fired_events) | ||
| def clean_fired_events(future: asyncio.Future[Any]) -> None: | ||
| with contextlib.suppress(KeyError, ValueError): | ||
| self._fired_events.remove(future) | ||
| tasks[-1].add_done_callback(clean_fired_events) | ||
| self._fired_events.extend(tasks) | ||
@@ -211,0 +335,0 @@ if wait and tasks: |
@@ -5,3 +5,3 @@ import json | ||
| from amqtt.plugins.manager import BaseContext | ||
| from amqtt.contexts import BaseContext | ||
| from amqtt.session import Session | ||
@@ -8,0 +8,0 @@ |
| import asyncio | ||
| from collections import deque # pylint: disable=C0412 | ||
| from dataclasses import dataclass | ||
| from typing import Any, SupportsIndex, SupportsInt, TypeAlias # pylint: disable=C0412 | ||
@@ -73,4 +74,7 @@ | ||
| self._sys_handle: asyncio.Handle | None = None | ||
| self._sys_interval: int = 0 | ||
| self._current_process = psutil.Process() | ||
| def _clear_stats(self) -> None: | ||
@@ -116,10 +120,10 @@ """Initialize broker statistics data structures.""" | ||
| try: | ||
| sys_interval: int = 0 | ||
| x = self.context.config.get("sys_interval") if self.context.config is not None else None | ||
| if isinstance(x, str | Buffer | SupportsInt | SupportsIndex): | ||
| sys_interval = int(x) | ||
| if sys_interval > 0: | ||
| self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds") | ||
| self._sys_interval = self._get_config_option("sys_interval", None) | ||
| if isinstance(self._sys_interval, str | Buffer | SupportsInt | SupportsIndex): | ||
| self._sys_interval = int(self._sys_interval) | ||
| if self._sys_interval > 0: | ||
| self.context.logger.debug(f"Setup $SYS broadcasting every {self._sys_interval} seconds") | ||
| self._sys_handle = ( | ||
| self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics) | ||
| self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics) | ||
| if self.context.loop is not None | ||
@@ -131,3 +135,3 @@ else None | ||
| except KeyError: | ||
| pass | ||
| self.context.logger.debug("could not find 'sys_interval' key: {e!r}") | ||
| # 'sys_interval' config parameter not found | ||
@@ -200,11 +204,5 @@ | ||
| # Reschedule | ||
| sys_interval: int = 0 | ||
| x = self.context.config.get("sys_interval") if self.context.config is not None else None | ||
| if isinstance(x, str | Buffer | SupportsInt | SupportsIndex): | ||
| sys_interval = int(x) | ||
| self.context.logger.debug("Broadcasting $SYS topics") | ||
| self.context.logger.debug(f"Setup $SYS broadcasting every {sys_interval} seconds") | ||
| self.context.logger.debug(f"Broadcast $SYS topics again in {self._sys_interval} seconds.") | ||
| self._sys_handle = ( | ||
| self.context.loop.call_later(sys_interval, self.broadcast_dollar_sys_topics) | ||
| self.context.loop.call_later(self._sys_interval, self.broadcast_dollar_sys_topics) | ||
| if self.context.loop is not None | ||
@@ -232,3 +230,3 @@ else None | ||
| async def on_broker_client_connected(self, client_id: str) -> None: | ||
| async def on_broker_client_connected(self, client_id: str, client_session: Session) -> None: | ||
| """Handle broker client connection.""" | ||
@@ -241,5 +239,11 @@ self._stats[STAT_CLIENTS_CONNECTED] += 1 | ||
| async def on_broker_client_disconnected(self, client_id: str) -> None: | ||
| async def on_broker_client_disconnected(self, client_id: str, client_session: Session) -> None: | ||
| """Handle broker client disconnection.""" | ||
| self._stats[STAT_CLIENTS_CONNECTED] -= 1 | ||
| self._stats[STAT_CLIENTS_DISCONNECTED] += 1 | ||
| @dataclass | ||
| class Config: | ||
| """Configuration struct for plugin.""" | ||
| sys_interval: int = 20 |
@@ -0,6 +1,6 @@ | ||
| from dataclasses import dataclass, field | ||
| from typing import Any | ||
| from amqtt.broker import Action | ||
| from amqtt.contexts import Action, BaseContext | ||
| from amqtt.plugins.base import BaseTopicPlugin | ||
| from amqtt.plugins.manager import BaseContext | ||
| from amqtt.session import Session | ||
@@ -55,3 +55,3 @@ | ||
| # hbmqtt and older amqtt do not support publish filtering | ||
| if action == Action.PUBLISH and self.topic_config is not None and "publish-acl" not in self.topic_config: | ||
| if action == Action.PUBLISH and not self._get_config_option("publish-acl", {}): | ||
| # maintain backward compatibility, assume permitted | ||
@@ -62,3 +62,3 @@ return True | ||
| if not req_topic: | ||
| return False | ||
| return False\ | ||
@@ -70,6 +70,7 @@ username = session.username if session else None | ||
| acl: dict[str, Any] = {} | ||
| if self.topic_config is not None and action == Action.PUBLISH: | ||
| acl = self.topic_config.get("publish-acl", {}) | ||
| elif self.topic_config is not None and action == Action.SUBSCRIBE: | ||
| acl = self.topic_config.get("acl", {}) | ||
| match action: | ||
| case Action.PUBLISH: | ||
| acl = self._get_config_option("publish-acl", {}) | ||
| case Action.SUBSCRIBE: | ||
| acl = self._get_config_option("acl", {}) | ||
@@ -81,1 +82,8 @@ allowed_topics = acl.get(username, []) | ||
| return any(self.topic_ac(req_topic, allowed_topic) for allowed_topic in allowed_topics) | ||
| @dataclass | ||
| class Config: | ||
| """Mappings of username and list of approved topics.""" | ||
| publish_acl: dict[str, list[str]] = field(default_factory=dict) | ||
| acl: dict[str, list[str]] = field(default_factory=dict) |
@@ -6,8 +6,8 @@ --- | ||
| bind: 0.0.0.0:1883 | ||
| sys_interval: 20 | ||
| auth: | ||
| plugins: | ||
| - auth_anonymous | ||
| allow-anonymous: true | ||
| topic-check: | ||
| enabled: False | ||
| plugins: | ||
| amqtt.plugins.logging_amqtt.EventLoggerPlugin: | ||
| amqtt.plugins.logging_amqtt.PacketLoggerPlugin: | ||
| amqtt.plugins.authentication.AnonymousAuthPlugin: | ||
| allow_anonymous: true | ||
| amqtt.plugins.sys.broker.BrokerSysPlugin: | ||
| sys_interval: 20 |
@@ -7,5 +7,8 @@ --- | ||
| auto_reconnect: true | ||
| cleansession: true | ||
| reconnect_max_interval: 10 | ||
| reconnect_retries: 2 | ||
| broker: | ||
| uri: "mqtt://127.0.0.1" | ||
| uri: "mqtt://127.0.0.1" | ||
| plugins: | ||
| amqtt.plugins.logging_amqtt.PacketLoggerPlugin: |
+4
-1
@@ -148,3 +148,3 @@ from asyncio import Queue | ||
| # Stores messages retained for this session | ||
| # Stores messages retained for this session (specifically when the client is disconnected) | ||
| self.retained_messages: Queue[ApplicationMessage] = Queue() | ||
@@ -155,2 +155,5 @@ | ||
| # identify anonymous client sessions or clients which didn't identify themselves | ||
| self.is_anonymous: bool = False | ||
| def _init_states(self) -> None: | ||
@@ -157,0 +160,0 @@ self.transitions = Machine(states=Session.states, initial="new") |
+38
-0
| from __future__ import annotations | ||
| from importlib import import_module | ||
| import logging | ||
@@ -7,2 +8,3 @@ from pathlib import Path | ||
| import string | ||
| import sys | ||
| import typing | ||
@@ -52,1 +54,37 @@ from typing import Any | ||
| return None | ||
| def cached_import(module_path: str, class_name: str | None = None) -> Any: | ||
| """Return cached import of a class from a module path (or retrieve, cache and then return).""" | ||
| # Check whether module is loaded and fully initialized. | ||
| if not ((module := sys.modules.get(module_path)) | ||
| and (spec := getattr(module, "__spec__", None)) | ||
| and getattr(spec, "_initializing", False) is False): | ||
| module = import_module(module_path) | ||
| if class_name: | ||
| return getattr(module, class_name) | ||
| return module | ||
| def import_string(dotted_path: str) -> Any: | ||
| """Import a dotted module path. | ||
| Returns: | ||
| attribute/class designated by the last name in the path | ||
| Raises: | ||
| ImportError (if the import failed) | ||
| """ | ||
| try: | ||
| module_path, class_name = dotted_path.rsplit(".", 1) | ||
| except ValueError as err: | ||
| msg = f"{dotted_path} doesn't look like a module path" | ||
| raise ImportError(msg) from err | ||
| try: | ||
| return cached_import(module_path, class_name) | ||
| except AttributeError as err: | ||
| msg = f'Module "{module_path}" does not define a "{class_name}" attribute/class' | ||
| raise ImportError(msg) from err |
+2
-1
| Metadata-Version: 2.4 | ||
| Name: amqtt | ||
| Version: 0.11.1 | ||
| Version: 0.11.2 | ||
| Summary: Python's asyncio-native MQTT broker and client. | ||
@@ -20,2 +20,3 @@ Author: aMQTT Contributors | ||
| Requires-Python: >=3.10.0 | ||
| Requires-Dist: dacite>=1.9.2 | ||
| Requires-Dist: passlib==1.7.4 | ||
@@ -22,0 +23,0 @@ Requires-Dist: psutil>=7.0.0 |
+2
-1
@@ -23,3 +23,3 @@ [build-system] | ||
| version = "0.11.1" | ||
| version = "0.11.2" | ||
| requires-python = ">=3.10.0" | ||
@@ -37,2 +37,3 @@ readme = "README.md" | ||
| "typer==0.15.4", | ||
| "dacite>=1.9.2", | ||
| "psutil>=7.0.0", | ||
@@ -39,0 +40,0 @@ ] |
| try: | ||
| from datetime import UTC, datetime | ||
| except ImportError: | ||
| from datetime import datetime, timezone | ||
| UTC = timezone.utc | ||
| import logging | ||
| from pathlib import Path | ||
| import shutil | ||
| import subprocess | ||
| import warnings | ||
| import amqtt | ||
| logger = logging.getLogger(__name__) | ||
| def get_version() -> str: | ||
| """Return the version of the amqtt package. | ||
| This function is deprecated. Use amqtt.__version__ instead. | ||
| """ | ||
| warnings.warn( | ||
| "amqtt.version.get_version() is deprecated, use amqtt.__version__ instead", | ||
| stacklevel=3, # Adjusted stack level to better reflect the caller | ||
| ) | ||
| return amqtt.__version__ | ||
| def get_git_changeset() -> str | None: | ||
| """Return a numeric identifier of the latest git changeset. | ||
| The result is the UTC timestamp of the changeset in YYYYMMDDHHMMSS format. | ||
| This value isn't guaranteed to be unique, but collisions are very unlikely, | ||
| so it's sufficient for generating the development version numbers. | ||
| """ | ||
| # Define the repository directory (two levels above the current script) | ||
| repo_dir = Path(__file__).resolve().parent.parent | ||
| # Ensure the directory exists and is valid | ||
| if not repo_dir.is_dir(): | ||
| logger.error(f"Invalid directory: {repo_dir} is not a valid directory") | ||
| return None | ||
| # Use the system's PATH to locate 'git', or define the full path if necessary | ||
| git_path = "git" # Assuming git is available in the system PATH | ||
| # Ensure 'git' is executable and available | ||
| if not shutil.which(git_path): | ||
| logger.error(f"{git_path} is not found in the system PATH.") | ||
| return None | ||
| # Call git log to get the latest changeset timestamp | ||
| try: | ||
| with subprocess.Popen( # noqa: S603 | ||
| [git_path, "log", "--pretty=format:%ct", "--quiet", "-1", "HEAD"], | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.PIPE, | ||
| cwd=repo_dir, | ||
| universal_newlines=True, | ||
| ) as git_log: | ||
| timestamp_str, stderr = git_log.communicate() | ||
| if git_log.returncode != 0: | ||
| logger.error(f"Git command failed with error: {stderr}") | ||
| return None | ||
| # Convert the timestamp to a datetime object | ||
| timestamp = datetime.fromtimestamp(int(timestamp_str), tz=UTC) | ||
| return timestamp.strftime("%Y%m%d%H%M%S") | ||
| except Exception: | ||
| logger.exception("An error occurred while retrieving the git changeset.") | ||
| return None |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
299568
4.05%5777
3.36%