fused
Advanced tools
| from typing import Literal | ||
| from fused.models.request import WHITELISTED_INSTANCE_TYPES | ||
| ENGINE_LOCAL = "local" | ||
| ENGINE_REMOTE = "remote" | ||
| Engine = Literal[ENGINE_LOCAL, ENGINE_REMOTE, WHITELISTED_INSTANCE_TYPES] |
+3
-0
@@ -170,1 +170,4 @@ cdk.out | ||
| client/public/fasttortoise/fasttortoise-api.js | ||
| # Database files | ||
| server/dump |
@@ -174,3 +174,3 @@ from __future__ import annotations | ||
| <li>Run Udf (realtime): <pre>fused.run('{slug}')</pre></li> | ||
| <li>Run Udf (long-running): <pre>udf().run_remote()</pre></li> | ||
| <li>Run Udf (long-running): <pre>udf(engine='large')</pre></li> | ||
| <li>Load Udf: <pre>fused.load('{slug}')</pre></li> | ||
@@ -192,2 +192,3 @@ <li>Save Udf: <pre>udf.to_fused()</pre></li> | ||
| <li>Type: {metadata.get("fused:udfType")}</li> | ||
| <li>Collection: {udf.collection_name}</li> | ||
| <br> | ||
@@ -194,0 +195,0 @@ </ul> |
+16
-2
@@ -27,2 +27,3 @@ from pathlib import Path | ||
| import_globals: bool = True, | ||
| collection_name: str | None = None, | ||
| ) -> AnyBaseUdf: | ||
@@ -48,2 +49,4 @@ """ | ||
| This requires executing the code of the UDF. To globally configure this behavior, use `fused.options.never_import`. | ||
| collection_name: Collection name to load the UDF from. If not provided, will attempt to | ||
| use the collection of the currently executing UDF, or default to "default". | ||
@@ -113,2 +116,3 @@ Returns: | ||
| import_globals=import_globals, | ||
| collection_name=collection_name, | ||
| ) | ||
@@ -118,3 +122,6 @@ elif len(parts) == 1: | ||
| return load_udf_from_fused( | ||
| udf_name, cache_key=cache_key, import_globals=import_globals | ||
| udf_name, | ||
| cache_key=cache_key, | ||
| import_globals=import_globals, | ||
| collection_name=collection_name, | ||
| ) | ||
@@ -137,2 +144,3 @@ | ||
| import_globals: bool = True, | ||
| collection_name: str | None = None, | ||
| ) -> AnyBaseUdf: | ||
@@ -150,2 +158,4 @@ """ | ||
| import_globals: Expose the globals defined in the UDF's context as attributes on the UDF object. | ||
| collection_name: Collection name to load the UDF from. If not provided, will attempt to | ||
| use the collection of the currently executing UDF, or default to "default". | ||
@@ -214,2 +224,3 @@ Returns: | ||
| import_globals=import_globals, | ||
| collection_name=collection_name, | ||
| ) | ||
@@ -219,3 +230,6 @@ elif len(parts) == 1: | ||
| return await load_udf_from_fused_async( | ||
| udf_name, cache_key=cache_key, import_globals=import_globals | ||
| udf_name, | ||
| cache_key=cache_key, | ||
| import_globals=import_globals, | ||
| collection_name=collection_name, | ||
| ) | ||
@@ -222,0 +236,0 @@ |
+30
-5
| import os | ||
| import tempfile | ||
| import warnings | ||
| from contextlib import contextmanager | ||
| from pathlib import Path | ||
@@ -146,3 +147,4 @@ from typing import Dict, List, Literal, Optional, Tuple, Union | ||
| def get_writable_dir(storage: StorageStr) -> Path: | ||
| mount_path = Path("/mount") | ||
| # TODO: pass the system path from the server? | ||
| mount_path = Path("/mount/fused-system") | ||
| temp_dir = Path(tempfile.gettempdir()) | ||
@@ -222,6 +224,2 @@ | ||
| default_validate_imports: StrictBool = False | ||
| """Default for whether to validate imports in UDFs before `run_local`, | ||
| `run_batch`.""" | ||
| prompt_to_login: StrictBool = False | ||
@@ -303,2 +301,6 @@ """Automatically prompt the user to login when importing Fused.""" | ||
| cache_decorator_max_age: int = 60 * 60 * 48 # 48 hours | ||
| """Default max age for cache decorated functions. Accepts duration strings like | ||
| "12h", "30m", "7d", or integer seconds.""" | ||
| @model_validator(mode="after") | ||
@@ -512,1 +514,24 @@ def _set_default_directories(self): | ||
| options.request_timeout = 25 | ||
| @contextmanager | ||
| def with_options(new_options: Options): | ||
| import fused | ||
| global options | ||
| current_options: Options = options | ||
| options = new_options | ||
| # ALSO overwrite the fused.options import, since that is how it will be used mostly. | ||
| fused.options = new_options | ||
| try: | ||
| yield options | ||
| finally: | ||
| options = current_options | ||
| fused.options = current_options | ||
| @contextmanager | ||
| def reset_options(): | ||
| current_options: Options = options.model_copy(deep=True) | ||
| with with_options(current_options) as new_options: | ||
| yield new_options |
+9
-3
@@ -281,4 +281,4 @@ import asyncio | ||
| resolved_udf: ResolvedUdf, | ||
| disk_size_gb, | ||
| ) -> str: | ||
| disk_size_gb: int | None, | ||
| ) -> tuple[str, int | None]: | ||
| # if not specified in run(), use the default defined on the UDF | ||
@@ -723,3 +723,9 @@ if instance_type is None and resolved_udf.storage_type == "local_job_step": | ||
| if verbose: | ||
| sys.stdout.write("Cached UDF result returned.\n") | ||
| # TODO: Retrieve when it was cached | ||
| if udf_eval_result.udf and udf_eval_result.udf.name: | ||
| sys.stdout.write( | ||
| f"UDF '{udf_eval_result.udf.name}' returned cached result\n" | ||
| ) | ||
| else: | ||
| sys.stdout.write("UDF returned cached result\n") | ||
| logger.info(f"Cache source: {udf_eval_result.cache_source.value}") | ||
@@ -726,0 +732,0 @@ if udf_eval_result.error_message is not None: |
@@ -69,3 +69,3 @@ /* CSS stylesheet for displaying xarray objects in jupyterlab. | ||
| display: grid; | ||
| grid-template-columns: 150px auto auto 1fr 20px 20px; | ||
| grid-template-columns: 250px auto auto 1fr 20px 20px; | ||
| } | ||
@@ -72,0 +72,0 @@ |
+25
-0
@@ -0,1 +1,2 @@ | ||
| import datetime | ||
| import inspect | ||
@@ -87,2 +88,24 @@ import json | ||
| def _is_float_like_str(val: str) -> bool: | ||
| return all([c.isdigit() or c == "." for c in val]) | ||
| def _parse_to_datetime( | ||
| val: Any, annotation: type[datetime.datetime | datetime.date | datetime.time] | ||
| ) -> Any: | ||
| if (isinstance(val, str) and _is_float_like_str(val)) or isinstance( | ||
| val, (int, float) | ||
| ): | ||
| if annotation is datetime.datetime: | ||
| return datetime.datetime.fromtimestamp(float(val), tz=datetime.timezone.utc) | ||
| elif annotation is datetime.date: | ||
| return datetime.date.fromtimestamp(float(val)) | ||
| elif annotation is datetime.time: | ||
| # time does not support fromtimestamp | ||
| return datetime.datetime.fromtimestamp( | ||
| float(val), tz=datetime.timezone.utc | ||
| ).timetz() | ||
| return annotation.fromisoformat(val) | ||
| def convert_to_tile_gdf(val: Any, add_xyz: bool = True) -> "TileGDF": | ||
@@ -248,2 +271,4 @@ if not HAS_GEOPANDAS: | ||
| return uuid.UUID(val) | ||
| if annotation in [datetime.datetime, datetime.date, datetime.time]: | ||
| return _parse_to_datetime(val, annotation) | ||
| if annotation in [Tile, TileGDF, ViewportGDF]: | ||
@@ -250,0 +275,0 @@ return convert_to_tile_gdf(val, add_xyz=annotation in [Tile, TileGDF]) |
@@ -35,6 +35,8 @@ import ast | ||
| "arg_list", | ||
| "output_table", | ||
| "engine", | ||
| "cache_max_age", | ||
| "cache", | ||
| } | ||
| """Set of UDF parameter names that should not be used because they will cause conflicts | ||
| when instantiating the UDF to a job.""" | ||
| when calling the UDF.""" | ||
@@ -55,7 +57,3 @@ UDF_RUN_KWARGS = {"x", "y", "z"} | ||
| for param_idx, param in enumerate(signature.parameters.values()): | ||
| # Don't check certain | ||
| is_reserved_positional_param = (param_idx == 0 and param.name == "dataset") or ( | ||
| param_idx == 1 and param.name == "right" | ||
| ) | ||
| if param.name in RESERVED_UDF_PARAMETERS and not is_reserved_positional_param: | ||
| if param.name in RESERVED_UDF_PARAMETERS: | ||
| warnings.warn( | ||
@@ -62,0 +60,0 @@ FusedUdfWarning( |
@@ -1,2 +0,1 @@ | ||
| import ast | ||
| import datetime | ||
@@ -74,3 +73,2 @@ import io | ||
| input: list[Any], | ||
| validate_imports: bool | None = None, | ||
| cache_max_age: int | None = None, | ||
@@ -82,5 +80,2 @@ **kwargs, | ||
| # Validate import statements correspond to valid modules | ||
| validate_imports_whitelist(udf, validate_imports=validate_imports) | ||
| if not OPTIONS.local_engine_cache: | ||
@@ -106,3 +101,8 @@ cache_max_age = 0 | ||
| # Set up thread-local stream routing (reuses existing routers if already installed) | ||
| # Set up execution context with the UDF for local execution | ||
| from fused.core._context import set_nested_udf_context | ||
| # Set up thread-local stream routing and nested UDF | ||
| # Note: We use set_nested_udf instead of global_context to preserve | ||
| # the parent execution context (batch/realtime) while setting this_udf | ||
| with ( | ||
@@ -114,2 +114,3 @@ isolate_streams_per_thread(), | ||
| ), | ||
| set_nested_udf_context(udf), | ||
| ): | ||
@@ -203,52 +204,1 @@ original_exception = None | ||
| return _fn | ||
| def validate_imports_whitelist(udf: AnyBaseUdf, validate_imports: bool | None = None): | ||
| # Skip import validation if the option is set | ||
| if not fused.options.default_validate_imports and validate_imports is not True: | ||
| return | ||
| # Skip import validation if not logged in | ||
| if not fused.api.AUTHORIZATION.is_configured(): | ||
| return | ||
| from fused._global_api import get_api | ||
| # Get the dependency whitelist from the cached API endpoint | ||
| api = get_api() | ||
| package_dependencies = api.dependency_whitelist() | ||
| # Initialize a list to store the import statements | ||
| import_statements = [] | ||
| # Parse the source code into an AST | ||
| tree = ast.parse(udf.code) | ||
| # Traverse the AST to find import statements | ||
| for node in ast.walk(tree): | ||
| if isinstance(node, ast.Import): | ||
| for alias in node.names: | ||
| import_statements.append(alias.name) | ||
| elif isinstance(node, ast.ImportFrom): | ||
| module_name = node.module | ||
| import_statements.append(module_name) | ||
| # Check for unavailable modules | ||
| header_modules = [header.module_name for header in udf.headers] | ||
| fused_modules = ["fused"] # assume fused is always available | ||
| available_modules = ( | ||
| list(package_dependencies["dependency_whitelist"].keys()) | ||
| + header_modules | ||
| + fused_modules | ||
| ) | ||
| unavailable_modules = [] | ||
| for import_statement in import_statements: | ||
| if import_statement.split(".", 1)[0] not in available_modules: | ||
| unavailable_modules.append(import_statement) | ||
| if unavailable_modules: | ||
| raise ValueError( | ||
| f"The following imports in the UDF might not be available: {repr(unavailable_modules)}. Please check the UDF headers and imports and try again." | ||
| ) | ||
| # TODO: check major versions for some packages |
@@ -1,1 +0,1 @@ | ||
| __version__ = "1.30.1" | ||
| __version__ = "2.0.0" |
@@ -373,2 +373,3 @@ from __future__ import annotations | ||
| whose: Literal["self", "public", "community", "team"] = "self", | ||
| collection_name: str | None = None, | ||
| ) -> dict: | ||
@@ -384,2 +385,3 @@ """ | ||
| UDFs available publicly or "community" for all community UDFs. Defaults to "self". | ||
| collection_name: Filter UDFs by collection name. If not provided, defaults to "default". | ||
@@ -396,3 +398,5 @@ Returns: | ||
| api = get_api() | ||
| return api.get_udfs(by=by, whose=whose, n=n, skip=skip) | ||
| return api.get_udfs( | ||
| by=by, whose=whose, n=n, skip=skip, collection_name=collection_name | ||
| ) | ||
@@ -399,0 +403,0 @@ |
| from __future__ import annotations | ||
| import json | ||
| import shlex | ||
| import subprocess | ||
| from base64 import b64encode | ||
| from functools import lru_cache | ||
| from io import SEEK_SET, BytesIO | ||
@@ -338,9 +336,3 @@ from pathlib import Path | ||
| @lru_cache | ||
| def dependency_whitelist(self) -> dict[str, str]: | ||
| runnable = self._make_run_command("dependency-whitelist", []) | ||
| content = runnable.run_and_get_bytes() | ||
| return json.loads(content) | ||
| def ssh_command_wrapper(conn_string: str) -> Callable[[str], str]: | ||
@@ -347,0 +339,0 @@ """Creates a command wrapper that connects via SSH and sudo runs the command.""" |
+23
-4
@@ -39,5 +39,16 @@ from __future__ import annotations | ||
| DEFAULT_CACHE_MAX_AGE = "12h" | ||
| def _default_cache_max_age() -> str | int: | ||
| from fused.context import get_header # circular import | ||
| header_value = get_header("cache-decorator-max-age") | ||
| if header_value is not None: | ||
| try: | ||
| return int(header_value) | ||
| except ValueError: | ||
| return header_value | ||
| return OPTIONS.cache_decorator_max_age | ||
| class CacheResultError(Exception): | ||
@@ -771,3 +782,3 @@ pass | ||
| *args, | ||
| cache_max_age: str | int = DEFAULT_CACHE_MAX_AGE, | ||
| cache_max_age: str | int | None = None, | ||
| cache_folder_path: str = "tmp", | ||
@@ -808,2 +819,6 @@ concurrent_lock_timeout: str | int = 120, | ||
| """ | ||
| # Resolve default cache_max_age from options | ||
| if cache_max_age is None: | ||
| cache_max_age = _default_cache_max_age() | ||
| # Calculate expires_on | ||
@@ -940,3 +955,3 @@ if cache_storage is None: | ||
| *args, | ||
| cache_max_age: str | int = DEFAULT_CACHE_MAX_AGE, | ||
| cache_max_age: str | int | None = None, | ||
| cache_folder_path: str = "tmp", | ||
@@ -977,2 +992,6 @@ concurrent_lock_timeout: str | int = 120, | ||
| """ | ||
| # Resolve default cache_max_age from options | ||
| if cache_max_age is None: | ||
| cache_max_age = _default_cache_max_age() | ||
| # Calculate expires_on | ||
@@ -1200,3 +1219,3 @@ if cache_storage is None: | ||
| func: Callable[..., Any] | None = None, | ||
| cache_max_age: str | int = DEFAULT_CACHE_MAX_AGE, | ||
| cache_max_age: str | int | None = None, | ||
| cache_folder_path: str = "tmp", | ||
@@ -1203,0 +1222,0 @@ concurrent_lock_timeout: str | int = 120, |
| from __future__ import annotations | ||
| import threading | ||
| from contextlib import contextmanager | ||
@@ -14,2 +15,8 @@ from typing import TYPE_CHECKING, Dict, Protocol | ||
| # Thread-local storage for this_udf only | ||
| # This allows each thread to have its own UDF context during parallel local execution | ||
| # while still sharing the global context (auth, in_batch, in_realtime, etc.) | ||
| _thread_local = threading.local() | ||
| class ExecutionContextProtocol(Protocol): | ||
@@ -62,6 +69,16 @@ def __enter__(self): ... | ||
| """The UDF that is currently being run.""" | ||
| raise ValueError( | ||
| "No UDF is currently being run, or the UDF is being run locally." | ||
| ) | ||
| raise ValueError("No UDF is currently being run.") | ||
| def get_nested_udf(self) -> "Udf" | None: | ||
| """Get the thread-local nested UDF if one exists. | ||
| This is used when running a local UDF from within a pre-existing context. | ||
| The nested UDF is stored in thread-local storage so each thread can have | ||
| its own nested UDF while sharing the parent execution context. | ||
| Returns: | ||
| The nested UDF for the current thread, or None if not set. | ||
| """ | ||
| return getattr(_thread_local, "nested_udf", None) | ||
| def get_secret(self, key: str, client_id: str | None = None): | ||
@@ -107,2 +124,8 @@ api = get_api() | ||
| class LocalExecutionContext(ExecutionContextProtocol): | ||
| def __enter__(self): | ||
| return self | ||
| def __exit__(self, *exc_details): | ||
| pass | ||
| def auth_header(self, *, missing_ok: bool = False) -> Dict[str, str]: | ||
@@ -148,3 +171,30 @@ from fused._auth import AUTHORIZATION | ||
| @property | ||
| def this_udf(self) -> "Udf": | ||
| nested_udf = self.get_nested_udf() | ||
| if nested_udf is not None: | ||
| return nested_udf | ||
| raise ValueError("No UDF is currently being run.") | ||
| @contextmanager | ||
| def set_nested_udf_context(udf: "Udf"): | ||
| """ | ||
| Set a nested local UDF for the current thread without changing the global context. | ||
| """ | ||
| prev_udf = getattr(_thread_local, "nested_udf", None) | ||
| try: | ||
| _thread_local.nested_udf = udf | ||
| yield | ||
| finally: | ||
| if prev_udf is None: | ||
| # Clean up - no previous nested UDF existed | ||
| if hasattr(_thread_local, "nested_udf"): | ||
| delattr(_thread_local, "nested_udf") | ||
| else: | ||
| # Restore previous nested UDF | ||
| _thread_local.nested_udf = prev_udf | ||
| context: ExecutionContextProtocol = LocalExecutionContext() | ||
@@ -151,0 +201,0 @@ local_context = LocalExecutionContext() |
| from __future__ import annotations | ||
| import asyncio | ||
| import datetime | ||
| import json | ||
@@ -125,2 +126,4 @@ import sys | ||
| return obj.total_seconds() | ||
| elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): | ||
| return obj.isoformat() | ||
| else: | ||
@@ -130,28 +133,30 @@ raise TypeError(f"Cannot serialize type: {type(obj)}") | ||
| def serialize_realtime_param(param_value: Any): | ||
| if (HAS_GEOPANDAS and isinstance(param_value, GPD_GEODATAFRAME)) or (): | ||
| return param_value.to_json(cls=FusedJSONEncoder) | ||
| elif HAS_PANDAS and isinstance(param_value, PD_DATAFRAME): | ||
| return param_value.to_json( | ||
| orient="records", default_handler=fused_json_default_handler | ||
| ) | ||
| elif HAS_SHAPELY and isinstance(param_value, SHAPELY_GEOMETRY): | ||
| import shapely | ||
| return shapely.to_wkt(param_value) | ||
| # For dict and list types, serialize to JSON string | ||
| elif isinstance(param_value, (list, tuple, dict, bool)): | ||
| return json.dumps(param_value, default=fused_json_default_handler) | ||
| else: | ||
| # see if we succeed at serializing the param value (special types like GeoPandas, Pandas, Shapely) | ||
| try: | ||
| return fused_json_default_handler(param_value) | ||
| except TypeError: | ||
| # if not, use the raw value | ||
| return param_value | ||
| def serialize_realtime_params(params: dict[str, Any]): | ||
| result = {} | ||
| for param_name, param_value in params.items(): | ||
| if (HAS_GEOPANDAS and isinstance(param_value, GPD_GEODATAFRAME)) or (): | ||
| result[param_name] = param_value.to_json(cls=FusedJSONEncoder) | ||
| elif HAS_PANDAS and isinstance(param_value, PD_DATAFRAME): | ||
| result[param_name] = param_value.to_json( | ||
| orient="records", default_handler=fused_json_default_handler | ||
| ) | ||
| elif HAS_SHAPELY and isinstance(param_value, SHAPELY_GEOMETRY): | ||
| import shapely | ||
| result[param_name] = serialize_realtime_param(param_value) | ||
| result[param_name] = shapely.to_wkt(param_value) | ||
| # For dict and list types, serialize to JSON string | ||
| elif isinstance(param_value, (list, tuple, dict, bool)): | ||
| result[param_name] = json.dumps( | ||
| param_value, default=fused_json_default_handler | ||
| ) | ||
| else: | ||
| # see if we succeed at serializing the param value (special types like GeoPandas, Pandas, Shapely) | ||
| try: | ||
| result[param_name] = fused_json_default_handler(param_value) | ||
| except TypeError: | ||
| # if not, use the raw value | ||
| result[param_name] = param_value | ||
| return result | ||
@@ -158,0 +163,0 @@ |
@@ -36,2 +36,28 @@ from __future__ import annotations | ||
| def get_collection_name_with_fallback(collection_name: str | None = None) -> str: | ||
| """ | ||
| Get collection_name with fallback logic. | ||
| If collection_name is provided, return it. | ||
| If not provided, try to get it from the current UDF context. | ||
| If no UDF context exists, default to "default". | ||
| Args: | ||
| collection_name: Explicit collection name, or None to use fallback logic | ||
| Returns: | ||
| The resolved collection name | ||
| """ | ||
| if collection_name is not None: | ||
| return collection_name | ||
| try: | ||
| if this_udf := context.this_udf(): | ||
| return this_udf.collection_name or "default" | ||
| else: | ||
| return "default" | ||
| except ValueError: | ||
| return "default" | ||
| def get_step_config_from_server( | ||
@@ -43,2 +69,3 @@ email_or_handle: str | None, | ||
| import_udf_globals: bool = True, | ||
| collection_name: str | None = None, | ||
| ) -> UdfJobStepConfig: | ||
@@ -48,6 +75,17 @@ logger.info(f"Requesting {email_or_handle=} {slug=}") | ||
| api = get_api() | ||
| if email_or_handle in ["team", "public"]: | ||
| # TODO: handle/deprecate this when we have proper support for github collections | ||
| resolved_collection_name = "default" | ||
| else: | ||
| resolved_collection_name = get_collection_name_with_fallback(collection_name) | ||
| if _is_public: | ||
| obj = api._get_public_udf(slug) | ||
| # Note: this is mostly a deprecated code path | ||
| obj = api._get_public_udf(slug, collection_name=resolved_collection_name) | ||
| else: | ||
| obj = api._get_udf(email_or_handle, slug) | ||
| obj = api._get_udf( | ||
| email_or_handle, slug, collection_name=resolved_collection_name | ||
| ) | ||
| udf = load_udf_from_response_data( | ||
@@ -67,2 +105,3 @@ obj, context={"import_globals": import_udf_globals} | ||
| import_globals: bool | None = None, | ||
| collection_name: str | None = None, | ||
| ) -> AnyBaseUdf: | ||
@@ -78,2 +117,3 @@ """ | ||
| cache_key: Additional cache key for busting the UDF cache | ||
| collection_name: Collection name to load the UDF from. If not provided, uses fallback logic. | ||
| """ | ||
@@ -99,2 +139,3 @@ if id is None and not is_uuid(email_or_handle_or_id): | ||
| import_udf_globals=import_globals, | ||
| collection_name=collection_name, | ||
| ) | ||
@@ -111,9 +152,19 @@ | ||
| import_udf_globals: bool = True, | ||
| collection_name: str | None = None, | ||
| ) -> UdfJobStepConfig: | ||
| """Async version of get_step_config_from_server that uses async HTTP requests""" | ||
| api = get_api() | ||
| if email_or_handle in ["team", "public"]: | ||
| # TODO: handle/deprecate this when we have proper support for github collections | ||
| resolved_collection_name = "default" | ||
| else: | ||
| resolved_collection_name = get_collection_name_with_fallback(collection_name) | ||
| if _is_public: | ||
| obj = api._get_public_udf(slug) | ||
| obj = api._get_public_udf(slug, collection_name=resolved_collection_name) | ||
| else: | ||
| obj = await api._get_udf_async(email_or_handle, slug) | ||
| obj = await api._get_udf_async( | ||
| email_or_handle, slug, collection_name=resolved_collection_name | ||
| ) | ||
@@ -132,2 +183,3 @@ udf = load_udf_from_response_data( | ||
| import_globals: bool | None = None, | ||
| collection_name: str | None = None, | ||
| ) -> AnyBaseUdf: | ||
@@ -147,2 +199,3 @@ """ | ||
| import_globals: Whether to import globals from the UDF context | ||
| collection_name: Collection name to load the UDF from. If not provided, uses fallback logic. | ||
| """ | ||
@@ -165,2 +218,3 @@ if id is None and not is_uuid(email_or_handle_or_id): | ||
| import_udf_globals=import_globals, | ||
| collection_name=collection_name, | ||
| ) | ||
@@ -167,0 +221,0 @@ return step_config.udf |
@@ -5,4 +5,6 @@ """Thread-safe utilities for concurrent execution with isolated logging.""" | ||
| import io | ||
| import os | ||
| import sys | ||
| import threading | ||
| from pathlib import Path | ||
| from typing import Generator, TextIO, Type | ||
@@ -15,2 +17,38 @@ | ||
| class SyncingFileWrapper: | ||
| """ | ||
| File wrapper that flushes and fsyncs on newlines. | ||
| This ensures real-time visibility on network filesystems like EFS. | ||
| """ | ||
| def __init__(self, file: TextIO) -> None: | ||
| self._file = file | ||
| def __enter__(self): | ||
| return self | ||
| def __exit__(self, *args): | ||
| self.close() | ||
| def write(self, text: str) -> int: | ||
| res = self._file.write(text) | ||
| if "\n" in text: | ||
| self.flush() | ||
| return res | ||
| def flush(self) -> None: | ||
| self._file.flush() | ||
| try: | ||
| os.fsync(self._file.fileno()) | ||
| except (AttributeError, io.UnsupportedOperation, OSError): | ||
| pass | ||
| def close(self) -> None: | ||
| self._file.close() | ||
| def fileno(self) -> int: | ||
| return self._file.fileno() | ||
| class MultipleTextIO: | ||
@@ -176,3 +214,5 @@ _outputs: list[TextIO] | ||
| def capture_current_thread_output( | ||
| output_to_original_streams: bool = False, impl_class: Type[TextIO] = io.StringIO | ||
| output_to_original_streams: bool = False, | ||
| impl_class: Type[TextIO] = io.StringIO, | ||
| log_dir: Path | None = None, | ||
| ): | ||
@@ -185,2 +225,9 @@ """ | ||
| Args: | ||
| output_to_original_streams: If True, also write to the original stdout/stderr | ||
| impl_class: The class to use for the output buffers (default: io.StringIO) | ||
| log_dir: If provided, logs will be streamed to files in this directory | ||
| (stdout, stderr files will be created). The directory will be | ||
| created if it doesn't exist. | ||
| Usage: | ||
@@ -194,2 +241,6 @@ with capture_current_thread_output() as (out_buf, err_buf): | ||
| # With disk logging: | ||
| with capture_current_thread_output(log_dir=Path("/mount/logs/run123")) as (out_buf, err_buf): | ||
| print("This streams to both memory and disk") | ||
| Returns: | ||
@@ -212,7 +263,42 @@ Tuple of (stdout_buffer, stderr_buffer) as StringIO objects | ||
| with impl_class() as out_buf, impl_class() as err_buf: | ||
| # use MultipleTextIO to capture the output to the original streams | ||
| with contextlib.ExitStack() as stack: | ||
| # Open disk log files if log_dir is provided | ||
| stdout_file = None | ||
| stderr_file = None | ||
| if log_dir is not None: | ||
| # Directory should already exist and be verified writeable by caller, | ||
| # but ensure it exists as a safety measure | ||
| log_dir.mkdir(parents=True, exist_ok=True) | ||
| # Wrap with SyncingFileWrapper for real-time visibility on EFS | ||
| stdout_file = stack.enter_context( | ||
| SyncingFileWrapper(open(log_dir / "stdout", "w", buffering=1)) | ||
| ) | ||
| stderr_file = stack.enter_context( | ||
| SyncingFileWrapper(open(log_dir / "stderr", "w", buffering=1)) | ||
| ) | ||
| out_buf = stack.enter_context(impl_class()) | ||
| err_buf = stack.enter_context(impl_class()) | ||
| # Build list of outputs for MultipleTextIO | ||
| stdout_outputs = [out_buf] | ||
| stderr_outputs = [err_buf] | ||
| # Add disk file outputs if enabled | ||
| if stdout_file is not None: | ||
| stdout_outputs.append(stdout_file) | ||
| if stderr_file is not None: | ||
| stderr_outputs.append(stderr_file) | ||
| # Add original stream outputs if enabled | ||
| if output_to_original_streams: | ||
| out_buf = MultipleTextIO(out_buf, stdout_router.original_stream) | ||
| err_buf = MultipleTextIO(err_buf, stderr_router.original_stream) | ||
| stdout_outputs.append(stdout_router.original_stream) | ||
| stderr_outputs.append(stderr_router.original_stream) | ||
| # Wrap with MultipleTextIO if we have multiple outputs | ||
| if len(stdout_outputs) > 1: | ||
| out_buf = MultipleTextIO(*stdout_outputs) | ||
| if len(stderr_outputs) > 1: | ||
| err_buf = MultipleTextIO(*stderr_outputs) | ||
| # Register buffers for current thread (no thread_id needed!) | ||
@@ -219,0 +305,0 @@ stdout_router.register_buffer_for_current_thread(out_buf) |
@@ -85,3 +85,2 @@ from __future__ import annotations | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| ): | ||
@@ -102,9 +101,3 @@ """ | ||
| # Validate import stamements correspond to valid modules | ||
| if udf is not None: | ||
| from fused._udf.execute_v2 import validate_imports_whitelist | ||
| validate_imports_whitelist(udf, validate_imports=validate_imports) | ||
| def _assert_udf_has_parameters(udf: AnyBaseUdf): | ||
@@ -141,3 +134,2 @@ assert hasattr(udf, "set_parameters"), ( | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -149,3 +141,2 @@ ): | ||
| ignore_no_udf=ignore_no_udf if has_udf else True, | ||
| validate_imports=validate_imports, | ||
| ) | ||
@@ -162,3 +153,2 @@ | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -198,3 +188,2 @@ overwrite: bool | None = None, | ||
| ignore_no_udf=ignore_no_udf, | ||
| validate_imports=validate_imports, | ||
| validate_inputs=validate_inputs, | ||
@@ -217,3 +206,2 @@ name=name, | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -227,3 +215,2 @@ name: str | None = None, | ||
| ignore_no_udf=ignore_no_udf, | ||
| validate_imports=validate_imports, | ||
| validate_inputs=validate_inputs, | ||
@@ -289,3 +276,3 @@ ) | ||
| # String: Job instantiation | ||
| str_job_inst = f"job = {self.udf.entrypoint}({structure_params(self._generate_job_params())})" | ||
| str_job_inst = f"job = {self.udf.entrypoint}._to_job_step({structure_params(self._generate_job_params())})" | ||
| # String: Job execution | ||
@@ -346,3 +333,2 @@ str_job_exec = "job._run_local()" | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -352,3 +338,2 @@ ): | ||
| ignore_no_udf=ignore_no_udf, | ||
| validate_imports=validate_imports, | ||
| validate_inputs=validate_inputs, | ||
@@ -470,3 +455,2 @@ ) | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -498,3 +482,2 @@ overwrite: bool | None = None, | ||
| ignore_no_udf=ignore_no_udf, | ||
| validate_imports=validate_imports, | ||
| validate_inputs=validate_inputs, | ||
@@ -569,3 +552,2 @@ send_status_email=send_status_email, | ||
| sample: Any | None = ..., | ||
| validate_imports: bool | None = None, | ||
| cache_max_age: int | None = None, | ||
@@ -596,3 +578,2 @@ _return_response: bool = False, | ||
| sample_list, | ||
| validate_imports=validate_imports, | ||
| cache_max_age=cache_max_age, | ||
@@ -610,3 +591,2 @@ **kwargs, | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -898,3 +878,2 @@ ): | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -905,3 +884,2 @@ ): | ||
| ignore_no_udf=ignore_no_udf, | ||
| validate_imports=validate_imports, | ||
| validate_inputs=validate_inputs, | ||
@@ -919,3 +897,2 @@ ) | ||
| ignore_no_udf: bool = False, | ||
| validate_imports: bool | None = None, | ||
| validate_inputs: bool = True, | ||
@@ -943,3 +920,2 @@ send_status_email: bool | None = None, | ||
| ignore_no_udf=ignore_no_udf, | ||
| validate_imports=validate_imports, | ||
| validate_inputs=validate_inputs, | ||
@@ -1028,3 +1004,2 @@ ) | ||
| self, | ||
| validate_imports: bool | None = None, | ||
| *args, | ||
@@ -1039,5 +1014,3 @@ **kwargs, | ||
| } | ||
| _step = step._run_local( | ||
| validate_imports=validate_imports, *args, **filtered_kwargs | ||
| ) | ||
| _step = step._run_local(*args, **filtered_kwargs) | ||
| runs.append(_step) | ||
@@ -1044,0 +1017,0 @@ return MultiUdfEvaluationResult(udf_results=[run for run in runs]) |
@@ -125,2 +125,3 @@ from enum import Enum | ||
| allow_public_list: Optional[bool] = None | ||
| collection_id: Optional[str] = None | ||
@@ -127,0 +128,0 @@ |
@@ -65,2 +65,4 @@ from __future__ import annotations | ||
| METADATA_FUSED_EXPLORER_TAB = "fused:explorerTab" | ||
| METADATA_FUSED_IS_VISIBLE = "fused:isVisible" | ||
| METADATA_FUSED_POSITION = "fused:position" | ||
@@ -137,3 +139,7 @@ | ||
| headers: HeaderSequence | Sequence[Header] = Field(default_factory=list) | ||
| """Deprecated.""" | ||
| metadata: UserMetadataType = None | ||
| # Collection-related fields (used for runtime context) | ||
| collection_id: str | None = Field(default=None) | ||
| collection_name: str | None = Field(default=None) | ||
| _globals: CompiledAttrs | None = None | ||
@@ -234,2 +240,3 @@ import_globals: bool = Field(default=True, exclude=True) | ||
| yield "name", self.name | ||
| yield "collection_name", self.collection_name | ||
| yield "description", metadata.get("fused:description") | ||
@@ -366,2 +373,3 @@ catalog_url = self.catalog_url | ||
| overwrite: bool | None = None, | ||
| collection_name: str | None = None, | ||
| ): | ||
@@ -373,3 +381,7 @@ """ | ||
| overwrite: If True, overwrite existing remote UDF with the UDF object. | ||
| collection_name: The collection name to associate with this UDF. If not provided, | ||
| falls back to the collection of the currently executing UDF, or defaults to "default". | ||
| """ | ||
| from fused.core._udf_ops import get_collection_name_with_fallback | ||
| api = get_api() | ||
@@ -380,2 +392,8 @@ self_id = self._get_metadata_safe(METADATA_FUSED_ID) | ||
| old_metadata = None | ||
| # Resolve collection name with fallback logic | ||
| resolved_collection_name = get_collection_name_with_fallback(collection_name) | ||
| # Determine if we're creating a new UDF or updating an existing one | ||
| collection_id = None | ||
| if remote_udf: | ||
@@ -388,2 +406,17 @@ remote_id = remote_udf._get_metadata_safe(METADATA_FUSED_ID) | ||
| backend_id = remote_id | ||
| # Don't overwrite collection_id if remote UDF exists | ||
| collection_id = None | ||
| else: | ||
| # Not overwriting, so we need to create a new UDF with the specified collection | ||
| # Fetch collection ID from collection name | ||
| try: | ||
| collection_data = api._get_collection_by_name( | ||
| resolved_collection_name | ||
| ) | ||
| collection_id = collection_data.get("id") | ||
| except HTTPError as e: | ||
| # If collection doesn't exist, it will be created on the server with default | ||
| logger.debug( | ||
| f"Collection '{resolved_collection_name}' not found, will use server default: {e}" | ||
| ) | ||
| else: | ||
@@ -409,2 +442,12 @@ # If the UDF does not exist, we need to save it as new | ||
| # Fetch collection ID for new UDF | ||
| try: | ||
| collection_data = api._get_collection_by_name(resolved_collection_name) | ||
| collection_id = collection_data.get("id") | ||
| except HTTPError as e: | ||
| # If collection doesn't exist, it will be created on the server with default | ||
| logger.debug( | ||
| f"Collection '{resolved_collection_name}' not found, will use server default: {e}" | ||
| ) | ||
| # Ensures some metadata values are not serialized as it can lead to stale data upon loading. | ||
@@ -417,3 +460,5 @@ if self.metadata: | ||
| try: | ||
| result = api.save_udf(udf=self, slug=self.name, id=backend_id) | ||
| result = api.save_udf( | ||
| udf=self, slug=self.name, id=backend_id, collection_id=collection_id | ||
| ) | ||
| except HTTPError as e: | ||
@@ -456,3 +501,5 @@ if old_metadata: | ||
| self, | ||
| *, | ||
| overwrite: bool | None = None, | ||
| collection_name: str | None = None, | ||
| **kwargs: dict[str, Any], | ||
@@ -465,2 +512,4 @@ ): | ||
| overwrite: If True, overwrite existing remote UDF with the UDF object. | ||
| collection_name: The collection name to associate with this UDF. If not provided, | ||
| falls back to the collection of the currently executing UDF, or defaults to "default". | ||
| """ | ||
@@ -503,3 +552,3 @@ | ||
| if not any((slug is not Ellipsis, over_id, as_new, inplace is False)): | ||
| return self._to_fused_v2(overwrite) | ||
| return self._to_fused_v2(overwrite, collection_name=collection_name) | ||
@@ -547,4 +596,6 @@ backend_id = ( | ||
| api = get_api() | ||
| # Use collection_name from direct attribute if available | ||
| collection_name = self.collection_name or "default" | ||
| try: | ||
| remote_udf_data = api._get_udf(slug) | ||
| remote_udf_data = api._get_udf(slug, collection_name=collection_name) | ||
| except HTTPError: | ||
@@ -557,2 +608,8 @@ return None | ||
| def delete_saved(self, inplace: bool = True): | ||
| """Delete this UDF from the Fused service. | ||
| Args: | ||
| inplace: If True, modify the UDF metadata in place. (Default True) | ||
| If False, return a new UDF object with the metadata removed. | ||
| """ | ||
| from fused._global_api import get_api | ||
@@ -575,2 +632,3 @@ | ||
| def delete_cache(self): | ||
| """Delete the result cache for this UDF.""" | ||
| backend_id = self._get_metadata_safe(METADATA_FUSED_ID) | ||
@@ -595,2 +653,12 @@ if backend_id is None: | ||
| ) -> UdfAccessToken: | ||
| """Create a UDF access token (share token) for this UDF. | ||
| Args: | ||
| client_id: The client ID to use for the access token. (Default: detect automatically) | ||
| public_read: Whether the access token should have public read access. (Default: off) | ||
| access_scope: The access scope to use for the access token. (Default: world) | ||
| cache: Whether to enable caching on the access token. (Default True) | ||
| metadata_json: Additional metadata to serve as part of the tiles metadata.json. (Default None) | ||
| enabled: Whether the access token is enabled. (Default True) | ||
| """ | ||
| from fused._global_api import get_api | ||
@@ -639,2 +707,11 @@ | ||
| def shared_url(self, format: str | None = None) -> str | None: | ||
| """Get the shared URL for this UDF. | ||
| Args: | ||
| format: The result format (file type) for the URL. (Default None) | ||
| """ | ||
| access_token = self.get_access_token() | ||
| return access_token.get_file_url(format=format) if access_token else None | ||
| def schedule( | ||
@@ -694,3 +771,3 @@ self, | ||
| else: | ||
| return None | ||
| return default | ||
@@ -814,2 +891,3 @@ def _set_metadata_safe(self, key: str, value: Any): | ||
| def utils(self): | ||
| """Deprecated.""" | ||
| return self._cached_utils | ||
@@ -816,0 +894,0 @@ |
+217
-17
@@ -15,3 +15,2 @@ from __future__ import annotations | ||
| Sequence, | ||
| overload, | ||
| ) | ||
@@ -23,2 +22,3 @@ | ||
| from fused.models.udf._engine import ENGINE_LOCAL, ENGINE_REMOTE, Engine | ||
| from fused.models.udf.base_udf import ( | ||
@@ -35,2 +35,3 @@ METADATA_FUSED_ID, | ||
| if TYPE_CHECKING: | ||
| from fused._submit import JobPool | ||
| from fused.models.api.job import UdfJobStepConfig | ||
@@ -71,3 +72,3 @@ from fused.models.udf._eval_result import UdfEvaluationResult | ||
| class Udf(BaseUdf): | ||
| """A user-defined function that operates on [`geopandas.GeoDataFrame`s][geopandas.GeoDataFrame].""" | ||
| """A user-defined function.""" | ||
@@ -87,4 +88,4 @@ type: Literal[UdfType.GEOPANDAS_V2] = UdfType.GEOPANDAS_V2 | ||
| disk_size_gb: int | None = None | ||
| """The size of the disk in GB to use for remote execution | ||
| (only supported for a batch (non-realtime) instance type). | ||
| """The size of the disk in GB to use for remote execution. | ||
| Used in batch jobs. | ||
| """ | ||
@@ -100,2 +101,3 @@ region: str | None = None | ||
| original_headers: str | None = None | ||
| """Deprecated.""" | ||
@@ -233,2 +235,4 @@ _nested_callable = PrivateAttr(None) # TODO : Find out type | ||
| inplace: If True, modify this object. If False, return a new object. Defaults to True. | ||
| Deprecated: Set parameters when calling the UDF or using `UDF.map()` instead. | ||
| """ | ||
@@ -254,2 +258,4 @@ ret = _maybe_inplace(self, inplace) | ||
| inplace: If True, update this UDF object. Otherwise return a new UDF object (default). | ||
| Deprecated: Do not call this. | ||
| """ | ||
@@ -269,3 +275,2 @@ from fused._udf.execute_v2 import execute_for_decorator | ||
| inplace: bool = False, | ||
| validate_imports: bool | None = None, | ||
| **kwargs, | ||
@@ -277,2 +282,4 @@ ) -> UdfEvaluationResult: | ||
| inplace: If True, update this UDF object with schema information. (default) | ||
| Deprecated: Call the UDF instead. | ||
| """ | ||
@@ -285,3 +292,2 @@ from fused._udf.execute_v2 import execute_against_sample | ||
| input=[], | ||
| validate_imports=validate_imports, | ||
| **kwargs, | ||
@@ -302,3 +308,3 @@ ) | ||
| updated_udf = self._with_udf_entrypoint() | ||
| job = updated_udf() | ||
| job = updated_udf._to_job_step() | ||
| job.export(where, how="zip", overwrite=overwrite) | ||
@@ -316,15 +322,7 @@ | ||
| updated_udf = self._with_udf_entrypoint() | ||
| job = updated_udf() | ||
| job = updated_udf._to_job_step() | ||
| where = where or self.name | ||
| job.export(where, how="local", overwrite=overwrite) | ||
| # List of data input is passed - run that | ||
| @overload | ||
| def __call__(self, *, arg_list: Iterable[Any], **kwargs) -> UdfJobStepConfig: ... | ||
| # Nothing is passed - run the UDF once | ||
| @overload | ||
| def __call__(self, *, arg_list: None = None, **kwargs) -> UdfJobStepConfig: ... | ||
| def __call__( | ||
| def _to_job_step( | ||
| self, *, arg_list: Iterable[Any] | None = None, **kwargs | ||
@@ -359,3 +357,200 @@ ) -> UdfJobStepConfig: | ||
| def __call__( | ||
| self, | ||
| *args, | ||
| engine: Engine | None = None, | ||
| cache_max_age: str | None = None, | ||
| cache: bool = True, | ||
| **kwargs, | ||
| ): | ||
| """Call this UDF. | ||
| Args: | ||
| *args: Positional arguments to pass to the UDF. | ||
| engine: The engine to use for execution. | ||
| "remote": Run remotely on a realtime instance. (Default) | ||
| "local": Run locally. | ||
| "small", "medium", "large": Run on a batch instance. | ||
| Other values will be interpreted as a batch instance type. | ||
| cache_max_age: The maximum age when returning a result from the cache. | ||
| Supported units are seconds (s), minutes (m), hours (h), and days (d) (e.g. “48h”, “10s”, etc.). | ||
| Default is `None` so a UDF will follow `cache_max_age` defined in `@fused.udf()` unless this value is changed. | ||
| cache: Set to False as a shortcut for `cache_max_age='0s'` to disable caching. (Default True) | ||
| **kwargs: Keyword arguments to pass to the UDF. | ||
| Returns: | ||
| The result of the UDF execution. | ||
| """ | ||
| from fused import run | ||
| for arg_index, arg in enumerate(args): | ||
| if not self._parameter_list: | ||
| raise TypeError( | ||
| f"UDF {self.name} received {len(args)} positional arguments, but the UDF has no parameters" | ||
| ) | ||
| elif arg_index >= len(self._parameter_list): | ||
| raise TypeError( | ||
| f"UDF {self.name} has {len(self._parameter_list)} parameters, but got {len(args)} positional arguments" | ||
| ) | ||
| elif self._parameter_list[arg_index] in kwargs: | ||
| raise TypeError( | ||
| f"UDF {self.name} got multiple values for argument {self._parameter_list[arg_index]}" | ||
| ) | ||
| else: | ||
| kwargs[self._parameter_list[arg_index]] = arg | ||
| instance_type = None | ||
| if engine not in [ENGINE_LOCAL, ENGINE_REMOTE]: | ||
| instance_type = engine | ||
| engine = None | ||
| return run( | ||
| self, | ||
| engine=engine, | ||
| instance_type=instance_type, | ||
| cache_max_age=cache_max_age, | ||
| cache=cache, | ||
| disk_size_gb=self.disk_size_gb, | ||
| **kwargs, | ||
| ) | ||
| def map( | ||
| self, | ||
| arg_list, | ||
| *, | ||
| engine: Engine | None = None, | ||
| cache_max_age: str | None = None, | ||
| max_workers: int | None = None, | ||
| worker_concurrency: int | None = None, | ||
| cache: bool = True, | ||
| max_retry: int = 2, | ||
| ) -> "JobPool": | ||
| """Submit a job for each element in arg_list. | ||
| Args: | ||
| arg_list: A list of arguments to pass to the UDF. Each element | ||
| in arg_list will become a job and run. | ||
| engine: The engine to use for execution. | ||
| "remote": Run on a realtime instance. (Default) | ||
| "local": Run locally. | ||
| "small", "medium", "large": Run on a batch instance. | ||
| Other values will be interpreted as a batch instance type. | ||
| max_workers: The maximum number of workers to use. | ||
| For running on realtime instances, this is the number of | ||
| instances to use. (Default 32) | ||
| For running locally, this is the number of threads to use. | ||
| (Default 1) | ||
| For running on batch instances, this is the number of worker | ||
| machines to use. (Default 1) | ||
| worker_concurrency: The concurrency level for each worker. | ||
| For running on realtime instances, this is the number of | ||
| arguments to run in each instance at a time. (Default 1) | ||
| For running locally, this cannot be set. | ||
| For running on batch instances, this is the number of processes | ||
| to use in each worker machine. (Default based on the number of | ||
| cores in the machine.) | ||
| cache_max_age: The maximum age when returning a result from the cache. | ||
| Supported units are seconds (s), minutes (m), hours (h), and days (d) (e.g. “48h”, “10s”, etc.). | ||
| Default is `None` so a UDF will follow `cache_max_age` defined in `@fused.udf()` unless this value is changed. | ||
| cache: Set to False as a shortcut for `cache_max_age='0s'` to disable caching. (Default True) | ||
| max_retry: The maximum number of retries for failed jobs. (Default 2) | ||
| Note that retries will only be attempted if the object is waited on, | ||
| e.g. with `pool.wait()`, `pool.tail()`, or `pool.collect()`. | ||
| Returns: | ||
| A JobPool object. Call `.collect()` to get the results. | ||
| Example: | ||
| >>> @fused.udf() | ||
| ... def my_udf(x: int): | ||
| ... return x + 1 | ||
| ... | ||
| >>> pool = my_udf.map([1, 2, 3]) | ||
| >>> results = pool.collect() | ||
| >>> print(results) | ||
| [2, 3, 4] | ||
| """ | ||
| from fused import submit | ||
| instance_type = None | ||
| if engine not in [ENGINE_LOCAL, ENGINE_REMOTE]: | ||
| instance_type = engine | ||
| engine = None | ||
| return submit( | ||
| self, | ||
| arg_list, | ||
| engine=engine, | ||
| instance_type=instance_type, | ||
| cache_max_age=cache_max_age, | ||
| cache=cache, | ||
| max_workers=max_workers, | ||
| n_processes_per_worker=worker_concurrency, | ||
| max_retry=max_retry, | ||
| disk_size_gb=self.disk_size_gb, | ||
| collect=False, | ||
| ) | ||
| async def map_async( | ||
| self, | ||
| arg_list, | ||
| *, | ||
| engine: Engine | None = None, | ||
| max_workers: int | None = None, | ||
| cache_max_age: str | None = None, | ||
| cache: bool = True, | ||
| max_retry: int = 2, | ||
| ) -> "JobPool": | ||
| """Submit a job for each element in arg_list. | ||
| Args: | ||
| arg_list: A list of arguments to pass to the UDF. Each element | ||
| in arg_list will become a job and run. | ||
| engine: The engine to use for execution. | ||
| "remote": Run on a realtime instance. (Default) | ||
| "local": Run locally. | ||
| Note: batch instance types are not supported for async map. | ||
| max_workers: The maximum number of workers to use. | ||
| For running on realtime instances, this is the number of | ||
| instances to use. (Default 32) | ||
| For running locally, this is the number of threads to use. | ||
| (Default 1) | ||
| cache_max_age: The maximum age when returning a result from the cache. | ||
| Supported units are seconds (s), minutes (m), hours (h), and days (d) (e.g. “48h”, “10s”, etc.). | ||
| Default is `None` so a UDF will follow `cache_max_age` defined in `@fused.udf()` unless this value is changed. | ||
| cache: Set to False as a shortcut for `cache_max_age='0s'` to disable caching. (Default True) | ||
| max_retry: The maximum number of retries for failed jobs. (Default 2) | ||
| Note that retries will only be attempted if the object is waited on, | ||
| e.g. with `pool.wait()`, `pool.tail()`, or `pool.collect()`. | ||
| Note worker_concurrency is not supported for async map. | ||
| Returns: | ||
| An AsyncJobPool object. Call `.collect()` to get the results. | ||
| Example: | ||
| >>> @fused.udf() | ||
| ... def my_udf(x: int): | ||
| ... return x + 1 | ||
| ... | ||
| >>> pool = my_udf.map_async([1, 2, 3]) | ||
| >>> results = pool.collect() | ||
| >>> print(results) | ||
| [2, 3, 4] | ||
| """ | ||
| from fused import submit | ||
| return submit( | ||
| self, | ||
| arg_list, | ||
| engine=engine, | ||
| cache_max_age=cache_max_age, | ||
| cache=cache, | ||
| max_workers=max_workers, | ||
| max_retry=max_retry, | ||
| collect=False, | ||
| execution_type="async_loop", | ||
| ) | ||
| EMPTY_UDF = Udf(name="EMPTY_UDF", code="", entrypoint="") | ||
@@ -386,2 +581,7 @@ | ||
| udf._set_metadata_safe(METADATA_FUSED_SLUG, data["slug"]) | ||
| # Set collection-related fields as direct attributes (not in metadata) | ||
| if "collection_id" in data and data["collection_id"] is not None: | ||
| udf.collection_id = data["collection_id"] | ||
| if "collection_name" in data and data["collection_name"] is not None: | ||
| udf.collection_name = data["collection_name"] | ||
| return udf | ||
@@ -388,0 +588,0 @@ |
+1
-1
| Metadata-Version: 2.4 | ||
| Name: fused | ||
| Version: 1.30.1 | ||
| Version: 2.0.0 | ||
| Project-URL: Homepage, https://www.fused.io | ||
@@ -5,0 +5,0 @@ Project-URL: Documentation, https://docs.fused.io |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
938321
2.72%93
1.09%22473
2.53%