You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

auth0-python

Package Overview
Dependencies
Maintainers
1
Versions
69
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

auth0-python - pypi Package Compare versions

Comparing version
4.2.0
to
4.3.0
+5
auth0/types.py
from typing import Any, Dict, List, Tuple, Union
TimeoutType = Union[float, Tuple[float, float]]
RequestData = Union[Dict[str, Any], List[Any]]
+1
-1
Metadata-Version: 2.1
Name: auth0-python
Version: 4.2.0
Version: 4.3.0
Summary: Auth0 Python SDK

@@ -5,0 +5,0 @@ Home-page: https://github.com/auth0/auth0-python

@@ -9,2 +9,3 @@ LICENSE

auth0/rest_async.py
auth0/types.py
auth0/utils.py

@@ -11,0 +12,0 @@ auth0/authentication/__init__.py

@@ -1,2 +0,2 @@

__version__ = "4.2.0"
__version__ = "4.3.0"

@@ -3,0 +3,0 @@ from auth0.exceptions import Auth0Error, RateLimitError, TokenValidationError

import aiohttp
from auth0.authentication.base import AuthenticationBase
from auth0.rest import RestClientOptions
from auth0.rest_async import AsyncRestClient

@@ -22,3 +24,3 @@

class AsyncClient(cls):
class AsyncManagementClient(cls):
def __init__(

@@ -33,10 +35,3 @@ self,

):
if token is None:
# Wrap the auth client
super().__init__(domain, telemetry, timeout, protocol)
else:
# Wrap the mngtmt client
super().__init__(
domain, token, telemetry, timeout, protocol, rest_options
)
super().__init__(domain, token, telemetry, timeout, protocol, rest_options)
self.client = AsyncRestClient(

@@ -46,24 +41,38 @@ jwt=token, telemetry=telemetry, timeout=timeout, options=rest_options

class Wrapper(cls):
class AsyncAuthenticationClient(cls):
def __init__(
self,
domain,
token=None,
client_id,
client_secret=None,
client_assertion_signing_key=None,
client_assertion_signing_alg=None,
telemetry=True,
timeout=5.0,
protocol="https",
rest_options=None,
):
if token is None:
# Wrap the auth client
super().__init__(domain, telemetry, timeout, protocol)
super().__init__(
domain,
client_id,
client_secret,
client_assertion_signing_key,
client_assertion_signing_alg,
telemetry,
timeout,
protocol,
)
self.client = AsyncRestClient(
None,
options=RestClientOptions(
telemetry=telemetry, timeout=timeout, retries=0
),
)
class Wrapper(cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if AuthenticationBase in cls.__bases__:
self._async_client = AsyncAuthenticationClient(*args, **kwargs)
else:
# Wrap the mngtmt client
super().__init__(
domain, token, telemetry, timeout, protocol, rest_options
)
self._async_client = AsyncClient(
domain, token, telemetry, timeout, protocol, rest_options
)
self._async_client = AsyncManagementClient(*args, **kwargs)
for method in methods:

@@ -70,0 +79,0 @@ setattr(

"""Token Verifier module"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from .. import TokenValidationError

@@ -6,3 +10,7 @@ from ..rest_async import AsyncRestClient

if TYPE_CHECKING:
from aiohttp import ClientSession
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
class AsyncAsymmetricSignatureVerifier(AsymmetricSignatureVerifier):

@@ -16,7 +24,7 @@ """Async verifier for RSA signatures, which rely on public key certificates.

def __init__(self, jwks_url, algorithm="RS256"):
def __init__(self, jwks_url: str, algorithm: str = "RS256") -> None:
super().__init__(jwks_url, algorithm)
self._fetcher = AsyncJwksFetcher(jwks_url)
def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.

@@ -37,3 +45,3 @@

async def verify_signature(self, token):
async def verify_signature(self, token) -> dict[str, Any]:
"""Verifies the signature of the given JSON web token.

@@ -63,7 +71,7 @@

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_client = AsyncRestClient(None)
def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.

@@ -77,3 +85,3 @@

async def _fetch_jwks(self, force=False):
async def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]:
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.

@@ -98,3 +106,3 @@ When not, it will perform a network request to the jwks_url to obtain a fresh result

async def get_key(self, key_id):
async def get_key(self, key_id: str) -> RSAPublicKey:
"""Obtains the JWK associated with the given key id.

@@ -135,3 +143,9 @@

def __init__(self, signature_verifier, issuer, audience, leeway=0):
def __init__(
self,
signature_verifier: AsyncAsymmetricSignatureVerifier,
issuer: str,
audience: str,
leeway: int = 0,
) -> None:
if not signature_verifier or not isinstance(

@@ -150,3 +164,3 @@ signature_verifier, AsyncAsymmetricSignatureVerifier

def set_session(self, session):
def set_session(self, session: ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.

@@ -160,3 +174,9 @@

async def verify(self, token, nonce=None, max_age=None, organization=None):
async def verify(
self,
token: str,
nonce: str | None = None,
max_age: int | None = None,
organization: str | None = None,
) -> dict[str, Any]:
"""Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec.

@@ -163,0 +183,0 @@

@@ -0,2 +1,7 @@

from __future__ import annotations
from typing import Any
from auth0.rest import RestClient, RestClientOptions
from auth0.types import RequestData, TimeoutType

@@ -24,11 +29,11 @@ from .client_authentication import add_client_authentication

self,
domain,
client_id,
client_secret=None,
client_assertion_signing_key=None,
client_assertion_signing_alg=None,
telemetry=True,
timeout=5.0,
protocol="https",
):
domain: str,
client_id: str,
client_secret: str | None = None,
client_assertion_signing_key: str | None = None,
client_assertion_signing_alg: str | None = None,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
protocol: str = "https",
) -> None:
self.domain = domain

@@ -45,3 +50,3 @@ self.client_id = client_id

def _add_client_authentication(self, payload):
def _add_client_authentication(self, payload: dict[str, Any]) -> dict[str, Any]:
return add_client_authentication(

@@ -56,6 +61,16 @@ payload,

def post(self, url, data=None, headers=None):
def post(
self,
url: str,
data: RequestData | None = None,
headers: dict[str, str] | None = None,
) -> Any:
return self.client.post(url, data=data, headers=headers)
def authenticated_post(self, url, data=None, headers=None):
def authenticated_post(
self,
url: str,
data: dict[str, Any],
headers: dict[str, str] | None = None,
) -> Any:
return self.client.post(

@@ -65,3 +80,8 @@ url, data=self._add_client_authentication(data), headers=headers

def get(self, url, params=None, headers=None):
def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
return self.client.get(url, params, headers)

@@ -0,3 +1,6 @@

from __future__ import annotations
import datetime
import uuid
from typing import Any

@@ -8,4 +11,7 @@ import jwt

def create_client_assertion_jwt(
domain, client_id, client_assertion_signing_key, client_assertion_signing_alg
):
domain: str,
client_id: str,
client_assertion_signing_key: str,
client_assertion_signing_alg: str | None,
) -> str:
"""Creates a JWT for the client_assertion field.

@@ -16,3 +22,3 @@

client_id (str): Your application's client ID
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT
client_assertion_signing_key (str): Private key used to sign the client assertion JWT
client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')

@@ -40,9 +46,9 @@

def add_client_authentication(
payload,
domain,
client_id,
client_secret,
client_assertion_signing_key,
client_assertion_signing_alg,
):
payload: dict[str, Any],
domain: str,
client_id: str,
client_secret: str | None,
client_assertion_signing_key: str | None,
client_assertion_signing_alg: str | None,
) -> dict[str, Any]:
"""Adds the client_assertion or client_secret fields to authenticate a payload.

@@ -54,3 +60,3 @@

client_id (str): Your application's client ID
client_secret (str): Your application's client secret
client_secret (str, optional): Your application's client secret
client_assertion_signing_key (str, optional): Private key used to sign the client assertion JWT

@@ -57,0 +63,0 @@ client_assertion_signing_alg (str, optional): Algorithm used to sign the client assertion JWT (defaults to 'RS256')

@@ -1,3 +0,5 @@

import warnings
from __future__ import annotations
from typing import Any
from .base import AuthenticationBase

@@ -15,13 +17,13 @@

self,
email,
password,
connection,
username=None,
user_metadata=None,
given_name=None,
family_name=None,
name=None,
nickname=None,
picture=None,
):
email: str,
password: str,
connection: str,
username: str | None = None,
user_metadata: dict[str, Any] | None = None,
given_name: str | None = None,
family_name: str | None = None,
name: str | None = None,
nickname: str | None = None,
picture: str | None = None,
) -> dict[str, Any]:
"""Signup using email and password.

@@ -54,3 +56,3 @@

"""
body = {
body: dict[str, Any] = {
"client_id": self.client_id,

@@ -76,7 +78,10 @@ "email": email,

return self.post(
data: dict[str, Any] = self.post(
f"{self.protocol}://{self.domain}/dbconnections/signup", data=body
)
return data
def change_password(self, email, connection, password=None):
def change_password(
self, email: str, connection: str, password: str | None = None
) -> str:
"""Asks to change a password for a given user.

@@ -94,5 +99,6 @@

return self.post(
data: str = self.post(
f"{self.protocol}://{self.domain}/dbconnections/change_password",
data=body,
)
return data

@@ -0,1 +1,5 @@

from __future__ import annotations
from typing import Any
from .base import AuthenticationBase

@@ -13,9 +17,9 @@

self,
target,
api_type,
grant_type,
id_token=None,
refresh_token=None,
scope="openid",
):
target: str,
api_type: str,
grant_type: str,
id_token: str | None = None,
refresh_token: str | None = None,
scope: str = "openid",
) -> Any:
"""Obtain a delegation token."""

@@ -22,0 +26,0 @@

@@ -0,1 +1,3 @@

from typing import Any
from .base import AuthenticationBase

@@ -12,3 +14,3 @@

def saml_metadata(self):
def saml_metadata(self) -> Any:
"""Get SAML2.0 Metadata."""

@@ -22,3 +24,3 @@

def wsfed_metadata(self):
def wsfed_metadata(self) -> Any:
"""Returns the WS-Federation Metadata."""

@@ -25,0 +27,0 @@

@@ -0,3 +1,6 @@

from __future__ import annotations
from typing import Any
from .base import AuthenticationBase
from .client_authentication import add_client_authentication

@@ -15,6 +18,6 @@

self,
code,
redirect_uri,
grant_type="authorization_code",
):
code: str,
redirect_uri: str | None,
grant_type: str = "authorization_code",
) -> Any:
"""Authorization code grant

@@ -51,7 +54,7 @@

self,
code_verifier,
code,
redirect_uri,
grant_type="authorization_code",
):
code_verifier: str,
code: str,
redirect_uri: str | None,
grant_type: str = "authorization_code",
) -> Any:
"""Authorization code pkce grant

@@ -91,5 +94,5 @@

self,
audience,
grant_type="client_credentials",
):
audience: str,
grant_type: str = "client_credentials",
) -> Any:
"""Client credentials grant

@@ -122,9 +125,10 @@

self,
username,
password,
scope=None,
realm=None,
audience=None,
grant_type="http://auth0.com/oauth/grant-type/password-realm",
):
username: str,
password: str,
scope: str | None = None,
realm: str | None = None,
audience: str | None = None,
grant_type: str = "http://auth0.com/oauth/grant-type/password-realm",
forwarded_for: str | None = None,
) -> Any:
"""Calls /oauth/token endpoint with password-realm grant type

@@ -156,5 +160,12 @@

forwarded_for (str, optional): End-user IP as a string value. Set this if you want
brute-force protection to work in server-side scenarios.
See https://auth0.com/docs/get-started/authentication-and-authorization-flow/avoid-common-issues-with-resource-owner-password-flow-and-attack-protection
Returns:
access_token, id_token
"""
headers = None
if forwarded_for:
headers = {"auth0-forwarded-for": forwarded_for}

@@ -172,2 +183,3 @@ return self.authenticated_post(

},
headers=headers,
)

@@ -177,6 +189,6 @@

self,
refresh_token,
scope="",
grant_type="refresh_token",
):
refresh_token: str,
scope: str = "",
grant_type: str = "refresh_token",
) -> Any:
"""Calls /oauth/token endpoint with refresh token grant type

@@ -209,3 +221,5 @@

def passwordless_login(self, username, otp, realm, scope, audience):
def passwordless_login(
self, username: str, otp: str, realm: str, scope: str, audience: str
) -> Any:
"""Calls /oauth/token endpoint with http://auth0.com/oauth/grant-type/passwordless/otp grant type

@@ -212,0 +226,0 @@

@@ -1,3 +0,5 @@

import warnings
from __future__ import annotations
from typing import Any
from .base import AuthenticationBase

@@ -14,3 +16,5 @@

def email(self, email, send="link", auth_params=None):
def email(
self, email: str, send: str = "link", auth_params: dict[str, str] | None = None
) -> Any:
"""Start flow sending an email.

@@ -39,3 +43,3 @@

data = {
data: dict[str, Any] = {
"client_id": self.client_id,

@@ -53,3 +57,3 @@ "connection": "email",

def sms(self, phone_number):
def sms(self, phone_number: str) -> Any:
"""Start flow sending an SMS message.

@@ -56,0 +60,0 @@

@@ -0,1 +1,3 @@

from typing import Any
from .base import AuthenticationBase

@@ -11,3 +13,3 @@

def revoke_refresh_token(self, token):
def revoke_refresh_token(self, token: str) -> Any:
"""Revokes a Refresh Token if it has been compromised

@@ -14,0 +16,0 @@

@@ -0,1 +1,3 @@

from typing import Any
from .base import AuthenticationBase

@@ -12,3 +14,3 @@

def login(self, access_token, connection, scope="openid"):
def login(self, access_token: str, connection: str, scope: str = "openid") -> Any:
"""Login using a social provider's access token

@@ -15,0 +17,0 @@

"""Token Verifier module"""
from __future__ import annotations
import json
import time
from typing import TYPE_CHECKING, Any, ClassVar

@@ -10,3 +13,6 @@ import jwt

if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
class SignatureVerifier:

@@ -20,3 +26,3 @@ """Abstract class that will verify a given JSON web token's signature

DISABLE_JWT_CHECKS = {
DISABLE_JWT_CHECKS: ClassVar[dict[str, bool]] = {
"verify_signature": True,

@@ -33,3 +39,3 @@ "verify_exp": False,

def __init__(self, algorithm):
def __init__(self, algorithm: str) -> None:
if not algorithm or type(algorithm) != str:

@@ -39,3 +45,3 @@ raise ValueError("algorithm must be specified.")

def _fetch_key(self, key_id=None):
def _fetch_key(self, key_id: str) -> str | RSAPublicKey:
"""Obtains the key associated to the given key id.

@@ -45,3 +51,3 @@ Must be implemented by subclasses.

Args:
key_id (str, optional): The id of the key to fetch.
key_id (str): The id of the key to fetch.

@@ -53,3 +59,3 @@ Returns:

def _get_kid(self, token):
def _get_kid(self, token: str) -> str | None:
"""Gets the key id from the kid claim of the header of the token

@@ -81,3 +87,3 @@

def _decode_jwt(self, token, secret_or_certificate):
def _decode_jwt(self, token: str, secret_or_certificate: str) -> dict[str, Any]:
"""Verifies and decodes the given JSON web token with the given public key or shared secret.

@@ -104,3 +110,3 @@

def verify_signature(self, token):
def verify_signature(self, token: str) -> dict[str, Any]:
"""Verifies the signature of the given JSON web token.

@@ -116,5 +122,7 @@

kid = self._get_kid(token)
if kid is None:
kid = ""
secret_or_certificate = self._fetch_key(key_id=kid)
return self._decode_jwt(token, secret_or_certificate)
return self._decode_jwt(token, secret_or_certificate) # type: ignore[arg-type]

@@ -130,7 +138,7 @@

def __init__(self, shared_secret, algorithm="HS256"):
def __init__(self, shared_secret: str, algorithm: str = "HS256") -> None:
super().__init__(algorithm)
self._shared_secret = shared_secret
def _fetch_key(self, key_id=None):
def _fetch_key(self, key_id: str = "") -> str:
return self._shared_secret

@@ -148,16 +156,15 @@

CACHE_TTL = 600 # 10 min cache lifetime
CACHE_TTL: ClassVar[int] = 600 # 10 min cache lifetime
def __init__(self, jwks_url, cache_ttl=CACHE_TTL):
def __init__(self, jwks_url: str, cache_ttl: int = CACHE_TTL) -> None:
self._jwks_url = jwks_url
self._init_cache(cache_ttl)
return
def _init_cache(self, cache_ttl):
self._cache_value = {}
self._cache_date = 0
def _init_cache(self, cache_ttl: int) -> None:
self._cache_value: dict[str, RSAPublicKey] = {}
self._cache_date = 0.0
self._cache_ttl = cache_ttl
self._cache_is_fresh = False
def _cache_expired(self):
def _cache_expired(self) -> bool:
"""Checks if the cache is expired

@@ -170,3 +177,3 @@

def _cache_jwks(self, jwks):
def _cache_jwks(self, jwks: dict[str, Any]) -> None:
"""Cache the response of the JWKS request

@@ -181,3 +188,3 @@

def _fetch_jwks(self, force=False):
def _fetch_jwks(self, force: bool = False) -> dict[str, RSAPublicKey]:
"""Attempts to obtain the JWK set from the cache, as long as it's still valid.

@@ -194,3 +201,3 @@ When not, it will perform a network request to the jwks_url to obtain a fresh result

if response.ok:
jwks = response.json()
jwks: dict[str, Any] = response.json()
self._cache_jwks(jwks)

@@ -203,7 +210,7 @@ return self._cache_value

@staticmethod
def _parse_jwks(jwks):
def _parse_jwks(jwks: dict[str, Any]) -> dict[str, RSAPublicKey]:
"""
Converts a JWK string representation into a binary certificate in PEM format.
"""
keys = {}
keys: dict[str, RSAPublicKey] = {}

@@ -213,7 +220,9 @@ for key in jwks["keys"]:

# requirement already includes cryptography -> pyjwt[crypto]
rsa_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key))
rsa_key: RSAPublicKey = jwt.algorithms.RSAAlgorithm.from_jwk(
json.dumps(key)
)
keys[key["kid"]] = rsa_key
return keys
def get_key(self, key_id):
def get_key(self, key_id: str) -> RSAPublicKey:
"""Obtains the JWK associated with the given key id.

@@ -251,7 +260,12 @@

def __init__(self, jwks_url, algorithm="RS256", cache_ttl=JwksFetcher.CACHE_TTL):
def __init__(
self,
jwks_url: str,
algorithm: str = "RS256",
cache_ttl: int = JwksFetcher.CACHE_TTL,
) -> None:
super().__init__(algorithm)
self._fetcher = JwksFetcher(jwks_url, cache_ttl)
def _fetch_key(self, key_id=None):
def _fetch_key(self, key_id: str) -> RSAPublicKey:
return self._fetcher.get_key(key_id)

@@ -272,3 +286,9 @@

def __init__(self, signature_verifier, issuer, audience, leeway=0):
def __init__(
self,
signature_verifier: SignatureVerifier,
issuer: str,
audience: str,
leeway: int = 0,
) -> None:
if not signature_verifier or not isinstance(

@@ -287,3 +307,9 @@ signature_verifier, SignatureVerifier

def verify(self, token, nonce=None, max_age=None, organization=None):
def verify(
self,
token: str,
nonce: str | None = None,
max_age: int | None = None,
organization: str | None = None,
) -> dict[str, Any]:
"""Attempts to verify the given ID token, following the steps defined in the OpenID Connect spec.

@@ -318,3 +344,9 @@

def _verify_payload(self, payload, nonce=None, max_age=None, organization=None):
def _verify_payload(
self,
payload: dict[str, Any],
nonce: str | None = None,
max_age: int | None = None,
organization: str | None = None,
) -> None:
# Issuer

@@ -321,0 +353,0 @@ if "iss" not in payload or not isinstance(payload["iss"], str):

@@ -0,2 +1,7 @@

from __future__ import annotations
from typing import Any
from auth0.rest import RestClient, RestClientOptions
from auth0.types import TimeoutType

@@ -16,7 +21,7 @@

self,
domain,
telemetry=True,
timeout=5.0,
protocol="https",
):
domain: str,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
protocol: str = "https",
) -> None:
self.domain = domain

@@ -35,3 +40,3 @@ self.protocol = protocol

def userinfo(self, access_token):
def userinfo(self, access_token: str) -> dict[str, Any]:
"""Returns the user information based on the Auth0 access token.

@@ -47,5 +52,6 @@ This endpoint will work only if openid was granted as a scope for the access_token.

return self.client.get(
data: dict[str, Any] = self.client.get(
url=f"{self.protocol}://{self.domain}/userinfo",
headers={"Authorization": f"Bearer {access_token}"},
)
return data

@@ -0,3 +1,14 @@

from __future__ import annotations
from typing import Any
class Auth0Error(Exception):
def __init__(self, status_code, error_code, message, content=None):
def __init__(
self,
status_code: int,
error_code: str,
message: str,
content: Any | None = None,
) -> None:
self.status_code = status_code

@@ -8,3 +19,3 @@ self.error_code = error_code

def __str__(self):
def __str__(self) -> str:
return f"{self.status_code}: {self.message}"

@@ -14,3 +25,3 @@

class RateLimitError(Auth0Error):
def __init__(self, error_code, message, reset_at):
def __init__(self, error_code: str, message: str, reset_at: int) -> None:
super().__init__(status_code=429, error_code=error_code, message=message)

@@ -17,0 +28,0 @@ self.reset_at = reset_at

@@ -94,3 +94,3 @@ from ..rest import RestClient

self._url("templates", "universal-login"),
body={"template": body},
data={"template": body},
)

@@ -97,0 +97,0 @@

@@ -55,2 +55,3 @@ from ..rest import RestClient

extra_params=None,
name=None,
):

@@ -80,2 +81,4 @@ """Retrieves all connections.

name (str): Provide the name of the connection to retrieve.
See: https://auth0.com/docs/api/management/v2#!/Connections/get_connections

@@ -93,2 +96,3 @@

params["per_page"] = per_page
params["name"] = name

@@ -95,0 +99,0 @@ return self.client.get(self._url(), params=params)

@@ -0,2 +1,5 @@

from __future__ import annotations
import asyncio
from typing import Any

@@ -6,7 +9,8 @@ import aiohttp

from auth0.exceptions import RateLimitError
from auth0.types import RequestData
from .rest import EmptyResponse, JsonResponse, PlainResponse, RestClient
from .rest import EmptyResponse, JsonResponse, PlainResponse, Response, RestClient
def _clean_params(params):
def _clean_params(params: dict[Any, Any] | None) -> dict[Any, Any] | None:
if params is None:

@@ -34,5 +38,5 @@ return params

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._session = None
self._session: aiohttp.ClientSession | None = None
sock_connect, sock_read = (

@@ -47,3 +51,3 @@ self.timeout

def set_session(self, session):
def set_session(self, session: aiohttp.ClientSession) -> None:
"""Set Client Session to improve performance by reusing session.

@@ -54,3 +58,3 @@ Session should be closed manually or within context manager.

async def _request(self, *args, **kwargs):
async def _request(self, *args: Any, **kwargs: Any) -> Any:
kwargs["headers"] = kwargs.get("headers", self.base_headers)

@@ -68,3 +72,8 @@ kwargs["timeout"] = self.timeout

async def get(self, url, params=None, headers=None):
async def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()

@@ -100,3 +109,8 @@ request_headers.update(headers or {})

async def post(self, url, data=None, headers=None):
async def post(
self,
url: str,
data: RequestData | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()

@@ -106,3 +120,8 @@ request_headers.update(headers or {})

async def file_post(self, url, data=None, files=None):
async def file_post(
self,
url: str,
data: dict[str, Any],
files: dict[str, Any],
) -> Any:
headers = self.base_headers.copy()

@@ -112,9 +131,14 @@ headers.pop("Content-Type", None)

async def patch(self, url, data=None):
async def patch(self, url: str, data: RequestData | None = None) -> Any:
return await self._request("patch", url, json=data)
async def put(self, url, data=None):
async def put(self, url: str, data: RequestData | None = None) -> Any:
return await self._request("put", url, json=data)
async def delete(self, url, params=None, data=None):
async def delete(
self,
url: str,
params: dict[str, Any] | None = None,
data: RequestData | None = None,
) -> Any:
return await self._request(

@@ -124,7 +148,7 @@ "delete", url, json=data, params=_clean_params(params) or {}

async def _process_response(self, response):
async def _process_response(self, response: aiohttp.ClientResponse) -> Any:
parsed_response = await self._parse(response)
return parsed_response.content()
async def _parse(self, response):
async def _parse(self, response: aiohttp.ClientResponse) -> Response:
text = await response.text()

@@ -141,5 +165,5 @@ requests_response = RequestsResponse(response, text)

class RequestsResponse:
def __init__(self, response, text):
def __init__(self, response: aiohttp.ClientResponse, text: str) -> None:
self.status_code = response.status
self.headers = response.headers
self.text = text

@@ -0,1 +1,3 @@

from __future__ import annotations
import base64

@@ -7,2 +9,3 @@ import json

from time import sleep
from typing import TYPE_CHECKING, Any, Mapping

@@ -12,3 +15,7 @@ import requests

from auth0.exceptions import Auth0Error, RateLimitError
from auth0.types import RequestData, TimeoutType
if TYPE_CHECKING:
from auth0.rest_async import RequestsResponse
UNKNOWN_ERROR = "a0.sdk.internal.unknown"

@@ -37,17 +44,13 @@

def __init__(self, telemetry=None, timeout=None, retries=None):
self.telemetry = True
self.timeout = 5.0
self.retries = 3
def __init__(
self,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
retries: int = 3,
) -> None:
self.telemetry = telemetry
self.timeout = timeout
self.retries = retries
if telemetry is not None:
self.telemetry = telemetry
if timeout is not None:
self.timeout = timeout
if retries is not None:
self.retries = retries
class RestClient:

@@ -57,2 +60,3 @@ """Provides simple methods for handling all RESTful api endpoints.

Args:
jwt (str, optional): The JWT to be used with the RestClient.
telemetry (bool, optional): Enable or disable Telemetry

@@ -71,3 +75,9 @@ (defaults to True)

def __init__(self, jwt, telemetry=True, timeout=5.0, options=None):
def __init__(
self,
jwt: str | None,
telemetry: bool = True,
timeout: TimeoutType = 5.0,
options: RestClientOptions | None = None,
) -> None:
if options is None:

@@ -119,18 +129,23 @@ options = RestClientOptions(telemetry=telemetry, timeout=timeout)

# Returns a hard cap for the maximum number of retries allowed (10)
def MAX_REQUEST_RETRIES(self):
def MAX_REQUEST_RETRIES(self) -> int:
return 10
# Returns the maximum amount of jitter to introduce in milliseconds (100ms)
def MAX_REQUEST_RETRY_JITTER(self):
def MAX_REQUEST_RETRY_JITTER(self) -> int:
return 100
# Returns the maximum delay window allowed (1000ms)
def MAX_REQUEST_RETRY_DELAY(self):
def MAX_REQUEST_RETRY_DELAY(self) -> int:
return 1000
# Returns the minimum delay window allowed (100ms)
def MIN_REQUEST_RETRY_DELAY(self):
def MIN_REQUEST_RETRY_DELAY(self) -> int:
return 100
def get(self, url, params=None, headers=None):
def get(
self,
url: str,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()

@@ -171,3 +186,8 @@ request_headers.update(headers or {})

def post(self, url, data=None, headers=None):
def post(
self,
url: str,
data: RequestData | None = None,
headers: dict[str, str] | None = None,
) -> Any:
request_headers = self.base_headers.copy()

@@ -181,3 +201,8 @@ request_headers.update(headers or {})

def file_post(self, url, data=None, files=None):
def file_post(
self,
url: str,
data: RequestData | None = None,
files: dict[str, Any] | None = None,
) -> Any:
headers = self.base_headers.copy()

@@ -191,3 +216,3 @@ headers.pop("Content-Type", None)

def patch(self, url, data=None):
def patch(self, url: str, data: RequestData | None = None) -> Any:
headers = self.base_headers.copy()

@@ -200,3 +225,3 @@

def put(self, url, data=None):
def put(self, url: str, data: RequestData | None = None) -> Any:
headers = self.base_headers.copy()

@@ -209,3 +234,8 @@

def delete(self, url, params=None, data=None):
def delete(
self,
url: str,
params: dict[str, Any] | None = None,
data: RequestData | None = None,
) -> Any:
headers = self.base_headers.copy()

@@ -222,3 +252,3 @@

def _calculate_wait(self, attempt):
def _calculate_wait(self, attempt: int) -> int:
# Retry the request. Apply a exponential backoff for subsequent attempts, using this formula:

@@ -240,10 +270,10 @@ # max(MIN_REQUEST_RETRY_DELAY, min(MAX_REQUEST_RETRY_DELAY, (100ms * (2 ** attempt - 1)) + random_between(1, MAX_REQUEST_RETRY_JITTER)))

self._metrics["retries"] = attempt
self._metrics["backoff"].append(wait)
self._metrics["backoff"].append(wait) # type: ignore[attr-defined]
return wait
def _process_response(self, response):
def _process_response(self, response: requests.Response) -> Any:
return self._parse(response).content()
def _parse(self, response):
def _parse(self, response: requests.Response) -> Response:
if not response.text:

@@ -258,3 +288,5 @@ return EmptyResponse(response.status_code)

class Response:
def __init__(self, status_code, content, headers):
def __init__(
self, status_code: int, content: Any, headers: Mapping[str, str]
) -> None:
self._status_code = status_code

@@ -264,3 +296,3 @@ self._content = content

def content(self):
def content(self) -> Any:
if self._is_error():

@@ -290,3 +322,3 @@ if self._status_code == 429:

def _is_error(self):
def _is_error(self) -> bool:
return self._status_code is None or self._status_code >= 400

@@ -303,7 +335,7 @@

class JsonResponse(Response):
def __init__(self, response):
def __init__(self, response: requests.Response | RequestsResponse) -> None:
content = json.loads(response.text)
super().__init__(response.status_code, content, response.headers)
def _error_code(self):
def _error_code(self) -> str:
if "errorCode" in self._content:

@@ -318,3 +350,3 @@ return self._content.get("errorCode")

def _error_message(self):
def _error_message(self) -> str:
if "error_description" in self._content:

@@ -329,9 +361,9 @@ return self._content.get("error_description")

class PlainResponse(Response):
def __init__(self, response):
def __init__(self, response: requests.Response | RequestsResponse) -> None:
super().__init__(response.status_code, response.text, response.headers)
def _error_code(self):
def _error_code(self) -> str:
return UNKNOWN_ERROR
def _error_message(self):
def _error_message(self) -> str:
return self._content

@@ -341,9 +373,9 @@

class EmptyResponse(Response):
def __init__(self, status_code):
def __init__(self, status_code: int) -> None:
super().__init__(status_code, "", {})
def _error_code(self):
def _error_code(self) -> str:
return UNKNOWN_ERROR
def _error_message(self):
def _error_message(self) -> str:
return ""

@@ -15,5 +15,7 @@ import base64

from auth0.asyncify import asyncify
from auth0.authentication import GetToken
from auth0.management import Clients, Guardian, Jobs
clients = re.compile(r"^https://example\.com/api/v2/clients.*")
token = re.compile(r"^https://example\.com/oauth/token.*")
factors = re.compile(r"^https://example\.com/api/v2/guardian/factors.*")

@@ -89,2 +91,27 @@ users_imports = re.compile(r"^https://example\.com/api/v2/jobs/users-imports.*")

@aioresponses()
async def test_post_auth(self, mocked):
callback, mock = get_callback()
mocked.post(token, callback=callback)
c = asyncify(GetToken)("example.com", "cid", client_secret="clsec")
self.assertEqual(
await c.login_async(username="usrnm", password="pswd"), payload
)
mock.assert_called_with(
Attrs(path="/oauth/token"),
allow_redirects=True,
json={
"client_id": "cid",
"username": "usrnm",
"password": "pswd",
"realm": None,
"scope": None,
"audience": None,
"grant_type": "http://auth0.com/oauth/grant-type/password-realm",
"client_secret": "clsec",
},
headers={i: headers[i] for i in headers if i != "Authorization"},
timeout=ANY,
)
@aioresponses()
async def test_file_post(self, mocked):

@@ -91,0 +118,0 @@ callback, mock = get_callback()

@@ -193,2 +193,18 @@ import unittest

@mock.patch("auth0.rest.RestClient.post")
def test_login_with_forwarded_for(self, mock_post):
g = GetToken("my.domain.com", "cid", client_secret="clsec")
g.login(username="usrnm", password="pswd", forwarded_for="192.168.0.1")
args, kwargs = mock_post.call_args
self.assertEqual(args[0], "https://my.domain.com/oauth/token")
self.assertEqual(
kwargs["headers"],
{
"auth0-forwarded-for": "192.168.0.1",
},
)
@mock.patch("auth0.rest.RestClient.post")
def test_refresh_token(self, mock_post):

@@ -195,0 +211,0 @@ g = GetToken("my.domain.com", "cid", client_secret="clsec")

@@ -62,6 +62,6 @@ import unittest

@mock.patch("auth0.management.branding.RestClient")
@mock.patch("auth0.rest.requests.put")
def test_update_template_universal_login(self, mock_rc):
api = mock_rc.return_value
api.put.return_value = {}
mock_rc.return_value.status_code = 200
mock_rc.return_value.text = "{}"

@@ -71,5 +71,7 @@ branding = Branding(domain="domain", token="jwttoken")

api.put.assert_called_with(
mock_rc.assert_called_with(
"https://domain/api/v2/branding/templates/universal-login",
body={"template": {"a": "b", "c": "d"}},
json={"template": {"a": "b", "c": "d"}},
headers=mock.ANY,
timeout=5.0,
)

@@ -76,0 +78,0 @@

@@ -36,2 +36,3 @@ import unittest

"include_fields": "true",
"name": None,
},

@@ -54,2 +55,3 @@ )

"include_fields": "false",
"name": None,
},

@@ -72,2 +74,3 @@ )

"include_fields": "true",
"name": None,
},

@@ -90,2 +93,3 @@ )

"include_fields": "true",
"name": None,
},

@@ -109,5 +113,24 @@ )

"some_key": "some_value",
"name": None,
},
)
# Name
c.all(name="foo")
args, kwargs = mock_instance.get.call_args
self.assertEqual("https://domain/api/v2/connections", args[0])
self.assertEqual(
kwargs["params"],
{
"fields": None,
"strategy": None,
"page": None,
"per_page": None,
"include_fields": "true",
"name": "foo",
},
)
@mock.patch("auth0.management.connections.RestClient")

@@ -114,0 +137,0 @@ def test_get(self, mock_rc):

@@ -1,2 +0,2 @@

def is_async_available():
def is_async_available() -> bool:
try:

@@ -3,0 +3,0 @@ import asyncio

Metadata-Version: 2.1
Name: auth0-python
Version: 4.2.0
Version: 4.3.0
Summary: Auth0 Python SDK

@@ -5,0 +5,0 @@ Home-page: https://github.com/auth0/auth0-python