Latest Threat Research:SANDWORM_MODE: Shai-Hulud-Style npm Worm Hijacks CI Workflows and Poisons AI Toolchains.Details
Socket
Book a DemoInstallSign in
Socket

neptune

Package Overview
Dependencies
Maintainers
2
Versions
73
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

neptune - npm Package Compare versions

Comparing version
2.0.0a8
to
1.12.0
+33
src/neptune/api/dtos.py
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["FileEntry"]
import datetime
from dataclasses import dataclass
from typing import Any
@dataclass
class FileEntry:
name: str
size: int
mtime: datetime.datetime
file_type: str
@classmethod
def from_dto(cls, file_dto: Any) -> "FileEntry":
return cls(name=file_dto.name, size=file_dto.size, mtime=file_dto.mtime, file_type=file_dto.fileType)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass
@dataclass(frozen=True)
class MultipartConfig:
min_chunk_size: int
max_chunk_size: int
max_chunk_count: int
max_single_part_size: int
@staticmethod
def get_default() -> "MultipartConfig":
return MultipartConfig(
min_chunk_size=5242880,
max_chunk_size=1073741824,
max_chunk_count=1000,
max_single_part_size=5242880,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["with_api_exceptions_handler", "get_retry_from_headers_or_default"]
import itertools
import os
import time
import requests
from bravado.exception import (
BravadoConnectionError,
BravadoTimeoutError,
HTTPBadGateway,
HTTPClientError,
HTTPForbidden,
HTTPGatewayTimeout,
HTTPInternalServerError,
HTTPRequestTimeout,
HTTPServiceUnavailable,
HTTPTooManyRequests,
HTTPUnauthorized,
)
from bravado_core.util import RecursiveCallException
from requests.exceptions import ChunkedEncodingError
from urllib3.exceptions import NewConnectionError
from neptune.common.envs import NEPTUNE_RETRIES_TIMEOUT_ENV
from neptune.common.exceptions import (
ClientHttpError,
Forbidden,
NeptuneAuthTokenExpired,
NeptuneConnectionLostException,
NeptuneInvalidApiTokenException,
NeptuneSSLVerificationError,
Unauthorized,
)
from neptune.common.utils import reset_internal_ssl_state
from neptune.internal.utils.logger import get_logger
_logger = get_logger()
MAX_RETRY_TIME = 30
MAX_RETRY_MULTIPLIER = 10
retries_timeout = int(os.getenv(NEPTUNE_RETRIES_TIMEOUT_ENV, "60"))
def get_retry_from_headers_or_default(headers, retry_count):
try:
return (
int(headers["retry-after"][0]) if "retry-after" in headers else 2 ** min(MAX_RETRY_MULTIPLIER, retry_count)
)
except Exception:
return min(2 ** min(MAX_RETRY_MULTIPLIER, retry_count), MAX_RETRY_TIME)
def with_api_exceptions_handler(func):
def wrapper(*args, **kwargs):
ssl_error_occurred = False
last_exception = None
start_time = time.monotonic()
for retry in itertools.count(0):
if time.monotonic() - start_time > retries_timeout:
break
try:
return func(*args, **kwargs)
except requests.exceptions.InvalidHeader as e:
if "X-Neptune-Api-Token" in e.args[0]:
raise NeptuneInvalidApiTokenException()
raise
except requests.exceptions.SSLError as e:
"""
OpenSSL's internal random number generator does not properly handle forked processes.
Applications must change the PRNG state of the parent process
if they use any SSL feature with os.fork().
Any successful call of RAND_add(), RAND_bytes() or RAND_pseudo_bytes() is sufficient.
https://docs.python.org/3/library/ssl.html#multi-processing
On Linux it looks like it does not help much but does not break anything either.
But single retry seems to solve the issue.
"""
if not ssl_error_occurred:
ssl_error_occurred = True
reset_internal_ssl_state()
continue
if "CertificateError" in str(e.__context__):
raise NeptuneSSLVerificationError() from e
else:
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
except (
BravadoConnectionError,
BravadoTimeoutError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
HTTPRequestTimeout,
HTTPServiceUnavailable,
HTTPGatewayTimeout,
HTTPBadGateway,
HTTPInternalServerError,
NewConnectionError,
ChunkedEncodingError,
RecursiveCallException,
) as e:
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
except HTTPTooManyRequests as e:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
last_exception = e
continue
except NeptuneAuthTokenExpired as e:
last_exception = e
continue
except HTTPUnauthorized:
raise Unauthorized()
except HTTPForbidden:
raise Forbidden()
except HTTPClientError as e:
raise ClientHttpError(e.status_code, e.response.text) from e
except requests.exceptions.RequestException as e:
if e.response is None:
raise
status_code = e.response.status_code
if status_code in (
HTTPRequestTimeout.status_code,
HTTPBadGateway.status_code,
HTTPServiceUnavailable.status_code,
HTTPGatewayTimeout.status_code,
HTTPInternalServerError.status_code,
):
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
elif status_code == HTTPTooManyRequests.status_code:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
last_exception = e
continue
elif status_code == HTTPUnauthorized.status_code:
raise Unauthorized()
elif status_code == HTTPForbidden.status_code:
raise Forbidden()
elif 400 <= status_code < 500:
raise ClientHttpError(status_code, e.response.text) from e
else:
raise
raise NeptuneConnectionLostException(last_exception) from last_exception
return wrapper
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
API_TOKEN_ENV_NAME = "NEPTUNE_API_TOKEN"
NEPTUNE_RETRIES_TIMEOUT_ENV = "NEPTUNE_RETRIES_TIMEOUT"
PROJECT_ENV_NAME = "NEPTUNE_PROJECT"
NOTEBOOK_ID_ENV_NAME = "NEPTUNE_NOTEBOOK_ID"
NOTEBOOK_PATH_ENV_NAME = "NEPTUNE_NOTEBOOK_PATH"
BACKEND = "NEPTUNE_BACKEND"
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import platform
from neptune.common.envs import (
API_TOKEN_ENV_NAME,
PROJECT_ENV_NAME,
)
UNIX_STYLES = {
"h1": "\033[95m",
"h2": "\033[94m",
"blue": "\033[94m",
"python": "\033[96m",
"bash": "\033[95m",
"warning": "\033[93m",
"correct": "\033[92m",
"fail": "\033[91m",
"bold": "\033[1m",
"underline": "\033[4m",
"end": "\033[0m",
}
WINDOWS_STYLES = {
"h1": "",
"h2": "",
"python": "",
"bash": "",
"warning": "",
"correct": "",
"fail": "",
"bold": "",
"underline": "",
"end": "",
}
EMPTY_STYLES = {
"h1": "",
"h2": "",
"python": "",
"bash": "",
"warning": "",
"correct": "",
"fail": "",
"bold": "",
"underline": "",
"end": "",
}
if platform.system() in ["Linux", "Darwin"]:
STYLES = UNIX_STYLES
elif platform.system() == "Windows":
STYLES = WINDOWS_STYLES
else:
STYLES = EMPTY_STYLES
class NeptuneException(Exception):
def __eq__(self, other):
if type(other) is type(self):
return super().__eq__(other) and str(self).__eq__(str(other))
else:
return False
def __hash__(self):
return hash((super().__hash__(), str(self)))
class NeptuneInvalidApiTokenException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneInvalidApiTokenException------------------------------------------------
{end}
The provided API token is invalid.
Make sure you copied and provided your API token correctly.
You can get it or check if it is correct here:
- https://app.neptune.ai/get_my_api_token
There are two options to add it:
- specify it in your code
- set it as an environment variable in your operating system.
{h2}CODE{end}
Pass the token to the {bold}init_run(){end} function via the {bold}api_token{end} argument:
{python}neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN'){end}
{h2}ENVIRONMENT VARIABLE{end} {correct}(Recommended option){end}
or export or set an environment variable depending on your operating system:
{correct}Linux/Unix{end}
In your terminal run:
{bash}export {env_api_token}="YOUR_API_TOKEN"{end}
{correct}Windows{end}
In your CMD run:
{bash}set {env_api_token}="YOUR_API_TOKEN"{end}
and skip the {bold}api_token{end} argument of the {bold}init_run(){end} function:
{python}neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME'){end}
You may also want to check the following docs page:
- https://docs.neptune.ai/setup/setting_api_token/
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(env_api_token=API_TOKEN_ENV_NAME, **STYLES))
class UploadedFileChanged(NeptuneException):
def __init__(self, filename: str):
super().__init__("File {} changed during upload, restarting upload.".format(filename))
class InternalClientError(NeptuneException):
def __init__(self, msg: str):
message = """
{h1}
----InternalClientError-----------------------------------------------------------------------
{end}
The Neptune client library encountered an unexpected internal error:
{msg}
Please contact Neptune support.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(msg=msg, **STYLES))
class ClientHttpError(NeptuneException):
def __init__(self, status, response):
self.status = status
self.response = response
message = """
{h1}
----ClientHttpError-----------------------------------------------------------------------
{end}
The Neptune server returned the status {fail}{status}{end}.
The server response was:
{fail}{response}{end}
Verify the correctness of your call or contact Neptune support.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(status=status, response=response, **STYLES))
class NeptuneApiException(NeptuneException):
pass
class Forbidden(NeptuneApiException):
def __init__(self):
message = """
{h1}
----Forbidden-----------------------------------------------------------------------
{end}
You don't have permission to access the given resource.
- Verify that your API token is correct.
See: https://app.neptune.ai/get_my_api_token
- Verify that the provided project name is correct.
The correct project name should look like this: {correct}WORKSPACE_NAME/PROJECT_NAME{end}
It has two parts:
- {correct}WORKSPACE_NAME{end}: can be your username or your organization name
- {correct}PROJECT_NAME{end}: the name specified for the project
- Ask your organization administrator to grant you the necessary privileges to the project.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(**STYLES))
class Unauthorized(NeptuneApiException):
def __init__(self, msg=None):
default_message = """
{h1}
----Unauthorized-----------------------------------------------------------------------
{end}
You don't have permission to access the given resource.
- Verify that your API token is correct.
See: https://app.neptune.ai/get_my_api_token
- Verify that the provided project name is correct.
The correct project name should look like this: {correct}WORKSPACE_NAME/PROJECT_NAME{end}
It has two parts:
- {correct}WORKSPACE_NAME{end}: can be your username or your organization name
- {correct}PROJECT_NAME{end}: the name specified for the project
- Ask your organization administrator to grant you the necessary privileges to the project.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
message = msg if msg is not None else default_message
super().__init__(message.format(**STYLES))
class NeptuneAuthTokenExpired(Unauthorized):
def __init__(self):
super().__init__("Authorization token expired")
class InternalServerError(NeptuneApiException):
def __init__(self, response):
message = """
{h1}
----InternalServerError-----------------------------------------------------------------------
{end}
The Neptune client library encountered an unexpected internal server error.
The server response was:
{fail}{response}{end}
Please try again later or contact Neptune support.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(response=response, **STYLES))
class NeptuneConnectionLostException(NeptuneException):
def __init__(self, cause: Exception):
self.cause = cause
message = """
{h1}
----NeptuneConnectionLostException---------------------------------------------------------
{end}
The connection to the Neptune server was lost.
If you are using the asynchronous (default) connection mode, Neptune continues to locally track your metadata and continuously tries to re-establish a connection to the Neptune servers.
If the connection is not re-established, you can upload your data later with the Neptune Command Line Interface tool:
{bash}neptune sync -p workspace_name/project_name{end}
What should I do?
- Check if your computer is connected to the internet.
- If your connection is unstable, consider working in offline mode:
{python}run = neptune.init_run(mode="offline"){end}
You can find detailed instructions on the following doc pages:
- https://docs.neptune.ai/api/connection_modes/#offline-mode
- https://docs.neptune.ai/api/neptune_sync/
You may also want to check the following docs page:
- https://docs.neptune.ai/api/connection_modes/#connectivity-issues
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
""" # noqa: E501
super().__init__(message.format(**STYLES))
class NeptuneSSLVerificationError(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneSSLVerificationError-----------------------------------------------------------------------
{end}
The Neptune client was unable to verify your SSL Certificate.
{bold}What could have gone wrong?{end}
- You are behind a proxy that inspects traffic to Neptune servers.
- Contact your network administrator
- The SSL/TLS certificate of your on-premises installation is not recognized due to a custom Certificate Authority (CA).
- To check, run the following command in your terminal:
{bash}curl https://<your_domain>/api/backend/echo {end}
- Where <your_domain> is the address that you use to access Neptune app, such as abc.com
- Contact your network administrator if you get the following output:
{fail}"curl: (60) server certificate verification failed..."{end}
- Your machine software is outdated.
- Minimal OS requirements:
- Windows >= XP SP3
- macOS >= 10.12.1
- Ubuntu >= 12.04
- Debian >= 8
{bold}What can I do?{end}
You can manually configure Neptune to skip all SSL checks. To do that,
set the NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE environment variable to 'TRUE'.
{bold}Note: This might mean that your connection is less secure{end}.
Linux/Unix
In your terminal run:
{bash}export NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
Windows
In your terminal run:
{bash}set NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
Jupyter notebook
In your code cell:
{bash}%env NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
You may also want to check the following docs page:
- https://docs.neptune.ai/api/environment_variables/#neptune_allow_self_signed_certificate
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
""" # noqa: E501
super().__init__(message.format(**STYLES))
class FileNotFound(NeptuneException):
def __init__(self, path):
super(FileNotFound, self).__init__("File {} doesn't exist.".format(path))
class InvalidNotebookPath(NeptuneException):
def __init__(self, path):
super(InvalidNotebookPath, self).__init__(
"File {} is not a valid notebook. Should end with .ipynb.".format(path)
)
class NeptuneIncorrectProjectQualifiedNameException(NeptuneException):
def __init__(self, project_qualified_name):
message = """
{h1}
----NeptuneIncorrectProjectQualifiedNameException-----------------------------------------------------------------------
{end}
Project qualified name {fail}"{project_qualified_name}"{end} you specified was incorrect.
The correct project qualified name should look like this {correct}WORKSPACE/PROJECT_NAME{end}.
It has two parts:
- {correct}WORKSPACE{end}: which can be your username or your organization name
- {correct}PROJECT_NAME{end}: which is the actual project name you chose
For example, a project {correct}neptune-ai/credit-default-prediction{end} parts are:
- {correct}neptune-ai{end}: {underline}WORKSPACE{end} our company organization name
- {correct}credit-default-prediction{end}: {underline}PROJECT_NAME{end} a project name
The URL to this project looks like this: https://app.neptune.ai/neptune-ai/credit-default-prediction
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/workspace-project-and-user-management/index.html
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(NeptuneIncorrectProjectQualifiedNameException, self).__init__(
message.format(project_qualified_name=project_qualified_name, **STYLES)
)
class NeptuneMissingProjectQualifiedNameException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneMissingProjectQualifiedNameException-------------------------------------------------------------------------
{end}
Neptune client couldn't find your project name.
There are two options two add it:
- specify it in your code
- set an environment variable in your operating system.
{h2}CODE{end}
Pass it to {bold}neptune.init(){end} via {bold}project_qualified_name{end} argument:
{python}neptune.init(project_qualified_name='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN'){end}
{h2}ENVIRONMENT VARIABLE{end}
or export or set an environment variable depending on your operating system:
{correct}Linux/Unix{end}
In your terminal run:
{bash}export {env_project}=WORKSPACE_NAME/PROJECT_NAME{end}
{correct}Windows{end}
In your CMD run:
{bash}set {env_project}=WORKSPACE_NAME/PROJECT_NAME{end}
and skip the {bold}project_qualified_name{end} argument of {bold}neptune.init(){end}:
{python}neptune.init(api_token='YOUR_API_TOKEN'){end}
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/workspace-project-and-user-management/index.html
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(NeptuneMissingProjectQualifiedNameException, self).__init__(
message.format(env_project=PROJECT_ENV_NAME, **STYLES)
)
class NotAFile(NeptuneException):
def __init__(self, path):
super(NotAFile, self).__init__("Path {} is not a file.".format(path))
class NotADirectory(NeptuneException):
def __init__(self, path):
super(NotADirectory, self).__init__("Path {} is not a directory.".format(path))
class WritingToArchivedProjectException(NeptuneException):
def __init__(self):
message = """
{h1}
----WritingToArchivedProjectException-----------------------------------------------------------------------
{end}
You're trying to write to a project that was archived.
Set the project as active again or use mode="read-only" at initialization to fetch metadata from it.
{correct}Need help?{end}-> https://docs.neptune.ai/help/error_writing_to_archived_project/
"""
super(WritingToArchivedProjectException, self).__init__(message.format(**STYLES))
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class LegacyExperiment:
pass
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class GitInfo(object):
"""Class that keeps information about a git repository in experiment.
When :meth:`~neptune.projects.Project.create_experiment` is invoked, instance of this class is created to
store information about git repository.
This information is later presented in the experiment details tab in the Neptune web application.
Args:
commit_id (:obj:`str`): commit id sha.
message (:obj:`str`, optional, default is ``""``): commit message.
author_name (:obj:`str`, optional, default is ``""``): commit author username.
author_email (:obj:`str`, optional, default is ``""``): commit author email.
commit_date (:obj:`datetime.datetime`, optional, default is ``""``): commit datetime.
repository_dirty (:obj:`bool`, optional, default is ``True``):
``True``, if the repository has uncommitted changes, ``False`` otherwise.
"""
def __init__(
self,
commit_id,
message="",
author_name="",
author_email="",
commit_date="",
repository_dirty=True,
active_branch="",
remote_urls=None,
):
if remote_urls is None:
remote_urls = []
if commit_id is None:
raise TypeError("commit_id must not be None")
self.commit_id = commit_id
self.message = message
self.author_name = author_name
self.author_email = author_email
self.commit_date = commit_date
self.repository_dirty = repository_dirty
self.active_branch = active_branch
self.remote_urls = remote_urls
def __eq__(self, o):
return o is not None and self.__dict__ == o.__dict__
def __ne__(self, o):
return not self.__eq__(o)
def __str__(self):
return "GitInfo({})".format(self.commit_id)
def __repr__(self):
return str(self)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
class CGroupFilesystemReader(object):
def __init__(self):
cgroup_memory_dir = self.__cgroup_mount_dir(subsystem="memory")
self.__memory_usage_file = os.path.join(cgroup_memory_dir, "memory.usage_in_bytes")
self.__memory_limit_file = os.path.join(cgroup_memory_dir, "memory.limit_in_bytes")
cgroup_cpu_dir = self.__cgroup_mount_dir(subsystem="cpu")
self.__cpu_period_file = os.path.join(cgroup_cpu_dir, "cpu.cfs_period_us")
self.__cpu_quota_file = os.path.join(cgroup_cpu_dir, "cpu.cfs_quota_us")
cgroup_cpuacct_dir = self.__cgroup_mount_dir(subsystem="cpuacct")
self.__cpuacct_usage_file = os.path.join(cgroup_cpuacct_dir, "cpuacct.usage")
def get_memory_usage_in_bytes(self):
return self.__read_int_file(self.__memory_usage_file)
def get_memory_limit_in_bytes(self):
return self.__read_int_file(self.__memory_limit_file)
def get_cpu_quota_micros(self):
return self.__read_int_file(self.__cpu_quota_file)
def get_cpu_period_micros(self):
return self.__read_int_file(self.__cpu_period_file)
def get_cpuacct_usage_nanos(self):
return self.__read_int_file(self.__cpuacct_usage_file)
def __read_int_file(self, filename):
with open(filename) as f:
return int(f.read())
def __cgroup_mount_dir(self, subsystem):
"""
:param subsystem: cgroup subsystem like memory, cpu
:return: directory where given subsystem is mounted
"""
with open("/proc/mounts", "r") as f:
for line in f.readlines():
split_line = re.split(r"\s+", line)
mount_dir = split_line[1]
if "cgroup" in mount_dir:
dirname = mount_dir.split("/")[-1]
subsystems = dirname.split(",")
if subsystem in subsystems:
return mount_dir
assert False, 'Mount directory for "{}" subsystem not found'.format(subsystem)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from neptune.common.hardware.cgroup.cgroup_filesystem_reader import CGroupFilesystemReader
from neptune.common.hardware.system.system_monitor import SystemMonitor
class CGroupMonitor(object):
def __init__(self, cgroup_filesystem_reader, system_monitor):
self.__cgroup_filesystem_reader = cgroup_filesystem_reader
self.__system_monitor = system_monitor
self.__last_cpu_usage_measurement_timestamp_nanos = None
self.__last_cpu_cumulative_usage_nanos = None
@staticmethod
def create():
return CGroupMonitor(CGroupFilesystemReader(), SystemMonitor())
def get_memory_usage_in_bytes(self):
return self.__cgroup_filesystem_reader.get_memory_usage_in_bytes()
def get_memory_limit_in_bytes(self):
cgroup_mem_limit = self.__cgroup_filesystem_reader.get_memory_limit_in_bytes()
total_virtual_memory = self.__system_monitor.virtual_memory().total
return min(cgroup_mem_limit, total_virtual_memory)
def get_cpu_usage_limit_in_cores(self):
cpu_quota_micros = self.__cgroup_filesystem_reader.get_cpu_quota_micros()
if cpu_quota_micros == -1:
return float(self.__system_monitor.cpu_count())
else:
cpu_period_micros = self.__cgroup_filesystem_reader.get_cpu_period_micros()
return float(cpu_quota_micros) / float(cpu_period_micros)
def get_cpu_usage_percentage(self):
current_timestamp_nanos = time.time() * 10**9
cpu_cumulative_usage_nanos = self.__cgroup_filesystem_reader.get_cpuacct_usage_nanos()
if self.__first_measurement():
current_usage = 0.0
else:
usage_diff = cpu_cumulative_usage_nanos - self.__last_cpu_cumulative_usage_nanos
time_diff = current_timestamp_nanos - self.__last_cpu_usage_measurement_timestamp_nanos
current_usage = float(usage_diff) / float(time_diff) / self.get_cpu_usage_limit_in_cores() * 100.0
self.__last_cpu_usage_measurement_timestamp_nanos = current_timestamp_nanos
self.__last_cpu_cumulative_usage_nanos = cpu_cumulative_usage_nanos
# cgroup cpu usage may slightly exceed the given limit, but we don't want to show it
return self.__clamp(current_usage, lower_limit=0.0, upper_limit=100.0)
def __first_measurement(self):
return (
self.__last_cpu_usage_measurement_timestamp_nanos is None or self.__last_cpu_cumulative_usage_nanos is None
)
@staticmethod
def __clamp(value, lower_limit, upper_limit):
return max(lower_limit, min(value, upper_limit))
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["BYTES_IN_ONE_MB", "BYTES_IN_ONE_GB", "MILLIWATTS_IN_ONE_WATT"]
BYTES_IN_ONE_MB = 2**20
BYTES_IN_ONE_GB = 2**30
MILLIWATTS_IN_ONE_WATT = 10**3
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.cgroup.cgroup_monitor import CGroupMonitor
from neptune.common.hardware.gauges.gauge import Gauge
from neptune.common.hardware.system.system_monitor import SystemMonitor
class SystemCpuUsageGauge(Gauge):
def __init__(self):
self.__system_monitor = SystemMonitor()
def name(self):
return "cpu"
def value(self):
return self.__system_monitor.cpu_percent()
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("SystemCpuUsageGauge")
class CGroupCpuUsageGauge(Gauge):
def __init__(self):
self.__cgroup_monitor = CGroupMonitor.create()
def name(self):
return "cpu"
def value(self):
return self.__cgroup_monitor.get_cpu_usage_percentage()
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("CGroupCpuUsageGauge")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.gauges.cpu import (
CGroupCpuUsageGauge,
SystemCpuUsageGauge,
)
from neptune.common.hardware.gauges.gauge_mode import GaugeMode
from neptune.common.hardware.gauges.gpu import (
GpuMemoryGauge,
GpuPowerGauge,
GpuUsageGauge,
)
from neptune.common.hardware.gauges.memory import (
CGroupMemoryUsageGauge,
SystemMemoryUsageGauge,
)
class GaugeFactory(object):
def __init__(self, gauge_mode):
self.__gauge_mode = gauge_mode
def create_cpu_usage_gauge(self):
if self.__gauge_mode == GaugeMode.SYSTEM:
return SystemCpuUsageGauge()
elif self.__gauge_mode == GaugeMode.CGROUP:
return CGroupCpuUsageGauge()
else:
raise self.__invalid_gauge_mode_exception()
def create_memory_usage_gauge(self):
if self.__gauge_mode == GaugeMode.SYSTEM:
return SystemMemoryUsageGauge()
elif self.__gauge_mode == GaugeMode.CGROUP:
return CGroupMemoryUsageGauge()
else:
raise self.__invalid_gauge_mode_exception()
@staticmethod
def create_gpu_usage_gauge(card_index):
return GpuUsageGauge(card_index=card_index)
@staticmethod
def create_gpu_memory_gauge(card_index):
return GpuMemoryGauge(card_index=card_index)
@staticmethod
def create_gpu_power_gauge(card_index):
return GpuPowerGauge(card_index=card_index)
def __invalid_gauge_mode_exception(self):
return ValueError(f"Invalid gauge mode: {self.__gauge_mode}")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class GaugeMode(object):
SYSTEM = "system"
CGROUP = "cgroup"
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import (
ABCMeta,
abstractmethod,
)
class Gauge(object):
__metaclass__ = ABCMeta
@abstractmethod
def name(self):
"""
:return: Gauge name (str).
"""
raise NotImplementedError()
@abstractmethod
def value(self):
"""
:return: Current value (float).
"""
raise NotImplementedError()
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.constants import (
BYTES_IN_ONE_GB,
MILLIWATTS_IN_ONE_WATT,
)
from neptune.common.hardware.gauges.gauge import Gauge
from neptune.common.hardware.gpu.gpu_monitor import GPUMonitor
class GpuUsageGauge(Gauge):
def __init__(self, card_index):
self.card_index = card_index
self.__gpu_monitor = GPUMonitor()
def name(self):
return str(self.card_index)
def value(self):
return self.__gpu_monitor.get_card_usage_percent(self.card_index)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.card_index == other.card_index
def __repr__(self):
return "GpuUsageGauge"
class GpuMemoryGauge(Gauge):
def __init__(self, card_index):
self.card_index = card_index
self.__gpu_monitor = GPUMonitor()
def name(self):
return str(self.card_index)
def value(self):
return self.__gpu_monitor.get_card_used_memory_in_bytes(self.card_index) / float(BYTES_IN_ONE_GB)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.card_index == other.card_index
def __repr__(self):
return "GpuMemoryGauge"
class GpuPowerGauge(Gauge):
def __init__(self, card_index):
self.card_index = card_index
self.__gpu_monitor = GPUMonitor()
def name(self):
return str(self.card_index)
def value(self):
power_usage = self.__gpu_monitor.get_card_power_usage(self.card_index)
return None if power_usage is None else power_usage / MILLIWATTS_IN_ONE_WATT
def __eq__(self, other):
return self.__class__ == other.__class__ and self.card_index == other.card_index
def __repr__(self):
return "GpuPowerGauge"
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.cgroup.cgroup_monitor import CGroupMonitor
from neptune.common.hardware.constants import BYTES_IN_ONE_GB
from neptune.common.hardware.gauges.gauge import Gauge
from neptune.common.hardware.system.system_monitor import SystemMonitor
class SystemMemoryUsageGauge(Gauge):
def __init__(self):
self.__system_monitor = SystemMonitor()
def name(self):
return "ram"
def value(self):
virtual_mem = self.__system_monitor.virtual_memory()
return (virtual_mem.total - virtual_mem.available) / float(BYTES_IN_ONE_GB)
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("SystemMemoryUsageGauge")
class CGroupMemoryUsageGauge(Gauge):
def __init__(self):
self.__cgroup_monitor = CGroupMonitor.create()
def name(self):
return "ram"
def value(self):
return self.__cgroup_monitor.get_memory_usage_in_bytes() / float(BYTES_IN_ONE_GB)
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("CGroupMemoryUsageGauge")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.constants import MILLIWATTS_IN_ONE_WATT
from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.vendor.pynvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetEnforcedPowerLimit,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetPowerUsage,
nvmlDeviceGetUtilizationRates,
nvmlInit,
)
class GPUMonitor(object):
nvml_error_printed = False
def get_card_count(self):
return self.__nvml_get_or_else(nvmlDeviceGetCount, default=0)
def get_card_usage_percent(self, card_index):
return self.__nvml_get_or_else(
lambda: float(nvmlDeviceGetUtilizationRates(nvmlDeviceGetHandleByIndex(card_index)).gpu)
)
def get_card_used_memory_in_bytes(self, card_index):
return self.__nvml_get_or_else(lambda: nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(card_index)).used)
def get_card_power_usage(self, card_index):
return self.__nvml_get_or_else(lambda: nvmlDeviceGetPowerUsage(nvmlDeviceGetHandleByIndex(card_index)))
def get_card_max_power_rating(self):
def read_max_power_rating():
return self.__nvml_get_or_else(
lambda: [
nvmlDeviceGetEnforcedPowerLimit(nvmlDeviceGetHandleByIndex(card_index)) // MILLIWATTS_IN_ONE_WATT
for card_index in range(nvmlDeviceGetCount())
],
default=0,
)
power_rating_per_card = read_max_power_rating()
return max(power_rating_per_card) if power_rating_per_card else 0
def get_top_card_memory_in_bytes(self):
def read_top_card_memory_in_bytes():
return self.__nvml_get_or_else(
lambda: [
nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(card_index)).total
for card_index in range(nvmlDeviceGetCount())
],
default=0,
)
memory_per_card = read_top_card_memory_in_bytes()
return max(memory_per_card) if memory_per_card else 0
def __nvml_get_or_else(self, getter, default=None):
try:
nvmlInit()
return getter()
except NVMLError as e:
if not GPUMonitor.nvml_error_printed:
warning = (
f"Info (NVML): {e}. GPU usage metrics may not be reported. For more information, "
"see https://docs.neptune.ai/help/nvml_error/"
)
warn_once(message=warning, exception=NeptuneWarning)
GPUMonitor.nvml_error_printed = True
return default
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class Metric(object):
def __init__(
self,
name,
description,
resource_type,
unit,
min_value,
max_value,
gauges,
internal_id=None,
):
self.__internal_id = internal_id
self.__name = name
self.__description = description
self.__resource_type = resource_type
self.__unit = unit
self.__min_value = min_value
self.__max_value = max_value
self.__gauges = gauges
@property
def internal_id(self):
return self.__internal_id
@internal_id.setter
def internal_id(self, value):
self.__internal_id = value
@property
def name(self):
return self.__name
@property
def description(self):
return self.__description
@property
def resource_type(self):
return self.__resource_type
@property
def unit(self):
return self.__unit
@property
def min_value(self):
return self.__min_value
@property
def max_value(self):
return self.__max_value
@property
def gauges(self):
return self.__gauges
def __repr__(self):
return (
"Metric(internal_id={}, name={}, description={}, resource_type={}, unit={}, min_value={}, "
"max_value={}, gauges={})"
).format(
self.internal_id,
self.name,
self.description,
self.resource_type,
self.unit,
self.min_value,
self.max_value,
self.gauges,
)
def __eq__(self, other):
return self.__class__ == other.__class__ and repr(self) == repr(other)
class MetricResourceType(object):
CPU = "CPU"
RAM = "MEMORY"
GPU = "GPU"
GPU_RAM = "GPU_MEMORY"
GPU_POWER = "GPU_POWER"
OTHER = "OTHER"
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class MetricsContainer(object):
def __init__(self, cpu_usage_metric, memory_metric, gpu_usage_metric, gpu_memory_metric, gpu_power_usage_metric):
self.cpu_usage_metric = cpu_usage_metric
self.memory_metric = memory_metric
self.gpu_usage_metric = gpu_usage_metric
self.gpu_memory_metric = gpu_memory_metric
self.gpu_power_usage_metric = gpu_power_usage_metric
def metrics(self):
return [
metric
for metric in [
self.cpu_usage_metric,
self.memory_metric,
self.gpu_usage_metric,
self.gpu_memory_metric,
self.gpu_power_usage_metric,
]
if metric is not None
]
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.constants import BYTES_IN_ONE_GB
from neptune.common.hardware.metrics.metric import (
Metric,
MetricResourceType,
)
from neptune.common.hardware.metrics.metrics_container import MetricsContainer
class MetricsFactory(object):
def __init__(self, gauge_factory, system_resource_info):
self.__gauge_factory = gauge_factory
self.__system_resource_info = system_resource_info
def create_metrics_container(self):
cpu_usage_metric = self.__create_cpu_usage_metric()
memory_metric = self.__create_memory_metric()
has_gpu = self.__system_resource_info.has_gpu()
gpu_usage_metric = self.__create_gpu_usage_metric() if has_gpu else None
gpu_memory_metric = self.__create_gpu_memory_metric() if has_gpu else None
gpu_power_usage_metric = self.__create_gpu_power_usage_metric() if has_gpu else None
return MetricsContainer(
cpu_usage_metric=cpu_usage_metric,
memory_metric=memory_metric,
gpu_usage_metric=gpu_usage_metric,
gpu_memory_metric=gpu_memory_metric,
gpu_power_usage_metric=gpu_power_usage_metric,
)
def __create_cpu_usage_metric(self):
return Metric(
name="CPU - usage",
description="average of all cores",
resource_type=MetricResourceType.CPU,
unit="%",
min_value=0.0,
max_value=100.0,
gauges=[self.__gauge_factory.create_cpu_usage_gauge()],
)
def __create_memory_metric(self):
return Metric(
name="RAM",
description="",
resource_type=MetricResourceType.RAM,
unit="GB",
min_value=0.0,
max_value=self.__system_resource_info.memory_amount_bytes / float(BYTES_IN_ONE_GB),
gauges=[self.__gauge_factory.create_memory_usage_gauge()],
)
def __create_gpu_usage_metric(self):
return Metric(
name="GPU - usage",
description="{} cards".format(self.__system_resource_info.gpu_card_count),
resource_type=MetricResourceType.GPU,
unit="%",
min_value=0.0,
max_value=100.0,
gauges=[
self.__gauge_factory.create_gpu_usage_gauge(card_index=card_index)
for card_index in self.__system_resource_info.gpu_card_indices
],
)
def __create_gpu_memory_metric(self):
return Metric(
name="GPU - memory",
description="{} cards".format(self.__system_resource_info.gpu_card_count),
resource_type=MetricResourceType.GPU_RAM,
unit="GB",
min_value=0.0,
max_value=self.__system_resource_info.gpu_memory_amount_bytes / float(BYTES_IN_ONE_GB),
gauges=[
self.__gauge_factory.create_gpu_memory_gauge(card_index=card_index)
for card_index in self.__system_resource_info.gpu_card_indices
],
)
def __create_gpu_power_usage_metric(self):
return Metric(
name="GPU - power usage",
description="{} cards".format(self.__system_resource_info.gpu_card_count),
resource_type=MetricResourceType.GPU_POWER,
unit="W",
min_value=0.0,
max_value=self.__system_resource_info.gpu_max_power_watts,
gauges=[
self.__gauge_factory.create_gpu_power_gauge(card_index=card_index)
for card_index in self.__system_resource_info.gpu_card_indices
],
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import namedtuple
MetricReport = namedtuple("MetricReport", ["metric", "values"])
MetricValue = namedtuple("MetricValue", ["timestamp", "running_time", "gauge_name", "value"])
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.metrics.reports.metric_reporter import MetricReporter
class MetricReporterFactory(object):
def __init__(self, reference_timestamp):
self.__reference_timestamp = reference_timestamp
def create(self, metrics):
return MetricReporter(metrics=metrics, reference_timestamp=self.__reference_timestamp)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.metrics.reports.metric_report import (
MetricReport,
MetricValue,
)
class MetricReporter(object):
def __init__(self, metrics, reference_timestamp):
self.__metrics = metrics
self.__reference_timestamp = reference_timestamp
def report(self, timestamp):
"""
:param timestamp: Time of measurement (float, seconds since Epoch).
:return: list[MetricReport]
"""
return [
MetricReport(
metric=metric,
values=[x for x in [self.__metric_value_for_gauge(gauge, timestamp) for gauge in metric.gauges] if x],
)
for metric in self.__metrics
]
def __metric_value_for_gauge(self, gauge, timestamp):
value = gauge.value()
return (
MetricValue(
timestamp=timestamp,
running_time=timestamp - self.__reference_timestamp,
gauge_name=gauge.name(),
value=value,
)
if value
else None
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.gauges.gauge_factory import GaugeFactory
from neptune.common.hardware.gpu.gpu_monitor import GPUMonitor
from neptune.common.hardware.metrics.metrics_factory import MetricsFactory
from neptune.common.hardware.metrics.reports.metric_reporter_factory import MetricReporterFactory
from neptune.common.hardware.metrics.service.metric_service import MetricService
from neptune.common.hardware.resources.system_resource_info_factory import SystemResourceInfoFactory
from neptune.common.hardware.system.system_monitor import SystemMonitor
class MetricServiceFactory(object):
def __init__(self, backend, os_environ):
self.__backend = backend
self.__os_environ = os_environ
def create(self, gauge_mode, experiment, reference_timestamp):
system_resource_info = SystemResourceInfoFactory(
system_monitor=SystemMonitor(),
gpu_monitor=GPUMonitor(),
os_environ=self.__os_environ,
).create(gauge_mode=gauge_mode)
gauge_factory = GaugeFactory(gauge_mode=gauge_mode)
metrics_factory = MetricsFactory(gauge_factory=gauge_factory, system_resource_info=system_resource_info)
metrics_container = metrics_factory.create_metrics_container()
for metric in metrics_container.metrics():
metric.internal_id = self.__backend.create_hardware_metric(experiment, metric)
metric_reporter = MetricReporterFactory(reference_timestamp).create(metrics=metrics_container.metrics())
return MetricService(
backend=self.__backend,
metric_reporter=metric_reporter,
experiment=experiment,
metrics_container=metrics_container,
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class MetricService(object):
def __init__(self, backend, metric_reporter, experiment, metrics_container):
self.__backend = backend
self.__metric_reporter = metric_reporter
self.experiment = experiment
self.metrics_container = metrics_container
def report_and_send(self, timestamp):
metric_reports = self.__metric_reporter.report(timestamp)
self.__backend.send_hardware_metric_reports(self.experiment, self.metrics_container.metrics(), metric_reports)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
class GPUCardIndicesProvider(object):
def __init__(self, cuda_visible_devices, gpu_card_count):
self.__cuda_visible_devices = cuda_visible_devices
self.__gpu_card_count = gpu_card_count
self.__cuda_visible_devices_regex = r"^-?\d+(,-?\d+)*$"
def get(self):
if self.__is_cuda_visible_devices_correct():
return self.__gpu_card_indices_from_cuda_visible_devices()
else:
return list(range(self.__gpu_card_count))
def __is_cuda_visible_devices_correct(self):
return self.__cuda_visible_devices is not None and re.match(
self.__cuda_visible_devices_regex, self.__cuda_visible_devices
)
def __gpu_card_indices_from_cuda_visible_devices(self):
correct_indices = []
# According to CUDA Toolkit specification.
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
for gpu_index_str in self.__cuda_visible_devices.split(","):
gpu_index = int(gpu_index_str)
if 0 <= gpu_index < self.__gpu_card_count:
correct_indices.append(gpu_index)
else:
break
return list(set(correct_indices))
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.hardware.cgroup.cgroup_monitor import CGroupMonitor
from neptune.common.hardware.gauges.gauge_mode import GaugeMode
from neptune.common.hardware.resources.gpu_card_indices_provider import GPUCardIndicesProvider
from neptune.common.hardware.resources.system_resource_info import SystemResourceInfo
class SystemResourceInfoFactory(object):
def __init__(self, system_monitor, gpu_monitor, os_environ):
self.__system_monitor = system_monitor
self.__gpu_monitor = gpu_monitor
self.__gpu_card_indices_provider = GPUCardIndicesProvider(
cuda_visible_devices=os_environ.get("CUDA_VISIBLE_DEVICES"),
gpu_card_count=self.__gpu_monitor.get_card_count(),
)
def create(self, gauge_mode):
if gauge_mode == GaugeMode.SYSTEM:
return self.__create_whole_system_resource_info()
elif gauge_mode == GaugeMode.CGROUP:
return self.__create_cgroup_resource_info()
else:
raise ValueError(f"Unknown gauge mode: {gauge_mode}")
def __create_whole_system_resource_info(self):
return SystemResourceInfo(
cpu_core_count=float(self.__system_monitor.cpu_count()),
memory_amount_bytes=self.__system_monitor.virtual_memory().total,
gpu_card_indices=self.__gpu_card_indices_provider.get(),
gpu_memory_amount_bytes=self.__gpu_monitor.get_top_card_memory_in_bytes(),
gpu_max_power_watts=self.__gpu_monitor.get_card_max_power_rating(),
)
def __create_cgroup_resource_info(self):
cgroup_monitor = CGroupMonitor.create()
return SystemResourceInfo(
cpu_core_count=cgroup_monitor.get_cpu_usage_limit_in_cores(),
memory_amount_bytes=cgroup_monitor.get_memory_limit_in_bytes(),
gpu_card_indices=self.__gpu_card_indices_provider.get(),
gpu_memory_amount_bytes=self.__gpu_monitor.get_top_card_memory_in_bytes(),
gpu_max_power_watts=self.__gpu_monitor.get_card_max_power_rating(),
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class SystemResourceInfo(object):
def __init__(
self,
cpu_core_count,
memory_amount_bytes,
gpu_card_indices,
gpu_memory_amount_bytes,
gpu_max_power_watts,
):
self.__cpu_core_count = cpu_core_count
self.__memory_amount_bytes = memory_amount_bytes
self.__gpu_card_indices = gpu_card_indices
self.__gpu_memory_amount_bytes = gpu_memory_amount_bytes
self.__gpu_max_power_watts = gpu_max_power_watts
@property
def cpu_core_count(self):
return self.__cpu_core_count
@property
def memory_amount_bytes(self):
return self.__memory_amount_bytes
@property
def gpu_card_count(self):
return len(self.__gpu_card_indices)
@property
def gpu_card_indices(self):
return self.__gpu_card_indices
@property
def gpu_memory_amount_bytes(self):
return self.__gpu_memory_amount_bytes
@property
def gpu_max_power_watts(self):
return self.__gpu_max_power_watts
def has_gpu(self):
return self.gpu_card_count > 0
def __repr__(self):
return str(self.__dict__)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
try:
import psutil
PSUTIL_INSTALLED = True
except ImportError:
PSUTIL_INSTALLED = False
class SystemMonitor(object):
@staticmethod
def cpu_count():
return psutil.cpu_count()
@staticmethod
def cpu_percent():
return psutil.cpu_percent()
@staticmethod
def virtual_memory():
return psutil.virtual_memory()
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import time
import jwt
from bravado.exception import HTTPUnauthorized
from bravado.requests_client import Authenticator
from oauthlib.oauth2 import (
OAuth2Error,
TokenExpiredError,
)
from requests.auth import AuthBase
from requests_oauthlib import OAuth2Session
from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.common.exceptions import NeptuneInvalidApiTokenException
from neptune.common.utils import update_session_proxies
_decoding_options = {
"verify_signature": False,
"verify_exp": False,
"verify_nbf": False,
"verify_iat": False,
"verify_aud": False,
"verify_iss": False,
}
class NeptuneAuth(AuthBase):
__LOCK = threading.RLock()
def __init__(self, session_factory):
self.session_factory = session_factory
self.session = session_factory()
self.token_expires_at = 0
def __call__(self, r):
try:
return self._add_token(r)
except TokenExpiredError:
self._refresh_token()
return self._add_token(r)
def _add_token(self, r):
r.url, r.headers, r.body = self.session._client.add_token(
r.url, http_method=r.method, body=r.body, headers=r.headers
)
return r
@with_api_exceptions_handler
def refresh_token_if_needed(self):
if self.token_expires_at - time.time() < 30:
self._refresh_token()
def _refresh_token(self):
with self.__LOCK:
try:
self._refresh_session_token()
except OAuth2Error:
# for some reason oauth session is no longer valid. Retry by creating new fresh session
# we can safely ignore this error, as it will be thrown again if it's persistent
try:
self.session.close()
except Exception:
pass
self.session = self.session_factory()
self._refresh_session_token()
def _refresh_session_token(self):
self.session.refresh_token(self.session.auto_refresh_url, verify=self.session.verify)
if self.session.token is not None and self.session.token.get("access_token") is not None:
decoded_json_token = jwt.decode(self.session.token.get("access_token"), options=_decoding_options)
self.token_expires_at = decoded_json_token.get("exp")
class NeptuneAuthenticator(Authenticator):
def __init__(self, api_token, backend_client, ssl_verify, proxies):
super(NeptuneAuthenticator, self).__init__(host="")
# We need to pass a lambda to be able to re-create fresh session at any time when needed
def session_factory():
try:
auth_tokens = backend_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result
except HTTPUnauthorized:
raise NeptuneInvalidApiTokenException()
decoded_json_token = jwt.decode(auth_tokens.accessToken, options=_decoding_options)
expires_at = decoded_json_token.get("exp")
client_name = decoded_json_token.get("azp")
refresh_url = "{realm_url}/protocol/openid-connect/token".format(realm_url=decoded_json_token.get("iss"))
token = {
"access_token": auth_tokens.accessToken,
"refresh_token": auth_tokens.refreshToken,
"expires_in": expires_at - time.time(),
}
session = OAuth2Session(
client_id=client_name,
token=token,
auto_refresh_url=refresh_url,
auto_refresh_kwargs={"client_id": client_name},
token_updater=_no_token_updater,
)
session.verify = ssl_verify
update_session_proxies(session, proxies)
return session
self.auth = NeptuneAuth(session_factory)
def matches(self, url):
return True
def apply(self, request):
self.auth.refresh_token_if_needed()
request.auth = self.auth
return request
def _no_token_updater():
# For unit tests.
return None
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["apply_patches"]
from neptune.common.patches.bravado import patch as bravado_patch
patches = [bravado_patch]
# Apply patches when importing a patching module
# Should be called before usages of patched objects
def apply_patches():
for patch in patches:
patch()
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import re
import bravado_core.model
from bravado_core.model import (
_bless_models,
_collect_models,
_get_unprocessed_uri,
_post_process_spec,
_tag_models,
)
def _run_post_processing(spec):
visited_models = {}
def _call_post_process_spec(spec_dict):
# Discover all the models in spec_dict
_post_process_spec(
spec_dict=spec_dict,
spec_resolver=spec.resolver,
on_container_callbacks=[
functools.partial(
_tag_models,
visited_models=visited_models,
swagger_spec=spec,
),
functools.partial(
_bless_models,
visited_models=visited_models,
swagger_spec=spec,
),
functools.partial(
_collect_models,
models=spec.definitions,
swagger_spec=spec,
),
],
)
# Post process specs to identify models
_call_post_process_spec(spec.spec_dict)
processed_uris = {
uri
for uri in spec.resolver.store
if uri == spec.origin_url or re.match(r"http(s)?://json-schema\.org/draft(/\d{4})?-\d+/(schema|meta/.*)", uri)
}
additional_uri = _get_unprocessed_uri(spec, processed_uris)
while additional_uri is not None:
# Post process each referenced specs to identify models in definitions of linked files
with spec.resolver.in_scope(additional_uri):
_call_post_process_spec(
spec.resolver.store[additional_uri],
)
processed_uris.add(additional_uri)
additional_uri = _get_unprocessed_uri(spec, processed_uris)
# Issue: https://github.com/Yelp/bravado-core/issues/388
# Bravado currently makes additional requests to `json-schema.org` in order to gather mission schemas
# This makes `neptune` unable to run without internet connection or with a many security policies
def patch():
bravado_core.model._run_post_processing = _run_post_processing
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
PROJECT_QUALIFIED_NAME_PATTERN = "^((?P<workspace>[^/]+)/){0,1}(?P<project>[^/]+)$"
__all__ = ["PROJECT_QUALIFIED_NAME_PATTERN"]
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
import io
import math
import os
import tarfile
from typing import (
Any,
Generator,
Optional,
)
from neptune.common.backends.api_model import MultipartConfig
from neptune.common.exceptions import (
InternalClientError,
UploadedFileChanged,
)
@dataclasses.dataclass
class FileChunk:
data: bytes
start: int
end: int
class FileChunker:
def __init__(self, filename: Optional[str], fobj, total_size, multipart_config: MultipartConfig):
self._filename: Optional[str] = filename
self._fobj = fobj
self._total_size = total_size
self._min_chunk_size = multipart_config.min_chunk_size
self._max_chunk_size = multipart_config.max_chunk_size
self._max_chunk_count = multipart_config.max_chunk_count
def _get_chunk_size(self) -> int:
if self._total_size > self._max_chunk_count * self._max_chunk_size:
# can't fit it
max_size = self._max_chunk_count * self._max_chunk_size
raise InternalClientError(
f"File {self._filename or 'stream'} is too big to upload:"
f" {self._total_size} bytes exceeds max size {max_size}"
)
if self._total_size <= self._max_chunk_count * self._min_chunk_size:
# can be done as minimal size chunks -- go for it!
return self._min_chunk_size
else:
# need larger chunks -- split more or less equally
return math.ceil(self._total_size / self._max_chunk_count)
def generate(self) -> Generator[FileChunk, Any, None]:
chunk_size = self._get_chunk_size()
last_offset = 0
last_change: Optional = os.stat(self._filename).st_mtime if self._filename else None
while last_offset < self._total_size:
chunk = self._fobj.read(chunk_size)
if chunk:
if last_change and last_change < os.stat(self._filename).st_mtime:
raise UploadedFileChanged(self._filename)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
new_offset = last_offset + len(chunk)
yield FileChunk(data=chunk, start=last_offset, end=new_offset)
last_offset = new_offset
def compress_to_tar_gz_in_memory(upload_entries) -> bytes:
f = io.BytesIO(b"")
with tarfile.TarFile.open(fileobj=f, mode="w|gz", dereference=True) as archive:
for entry in upload_entries:
archive.add(name=entry.source, arcname=entry.target_path, recursive=True)
f.seek(0)
data = f.read()
return data
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import os
import stat
import time
from abc import (
ABCMeta,
abstractmethod,
)
from dataclasses import dataclass
from io import BytesIO
from pprint import pformat
from typing import (
BinaryIO,
Generator,
List,
Set,
Union,
)
import six
from neptune.internal.utils.logger import get_logger
_logger = get_logger()
@dataclass
class AttributeUploadConfiguration:
chunk_size: int
class UploadEntry(object):
def __init__(self, source: Union[str, BytesIO], target_path: str):
self.source = source
self.target_path = target_path
def length(self) -> int:
if self.is_stream():
return self.source.getbuffer().nbytes
else:
return os.path.getsize(self.source)
def get_stream(self) -> Union[BinaryIO, io.BytesIO]:
if self.is_stream():
return self.source
else:
return io.open(self.source, "rb")
def get_permissions(self) -> str:
if self.is_stream():
return "----------"
else:
return self.permissions_to_unix_string(self.source)
@classmethod
def permissions_to_unix_string(cls, path):
st = 0
if os.path.exists(path):
st = os.lstat(path).st_mode
is_dir = "d" if stat.S_ISDIR(st) else "-"
dic = {
"7": "rwx",
"6": "rw-",
"5": "r-x",
"4": "r--",
"3": "-wx",
"2": "-w-",
"1": "--x",
"0": "---",
}
perm = ("%03o" % st)[-3:]
return is_dir + "".join(dic.get(x, x) for x in perm)
def __eq__(self, other):
"""
Returns true if both objects are equal
"""
return self.__dict__ == other.__dict__
def __ne__(self, other):
"""
Returns true if both objects are not equal
"""
return not self == other
def __hash__(self):
"""
Returns the hash of source and target path
"""
return hash((self.source, self.target_path))
def to_str(self):
"""
Returns the string representation of the model
"""
return pformat(self.__dict__)
def __repr__(self):
"""
For `print` and `pprint`
"""
return self.to_str()
def is_stream(self):
return hasattr(self.source, "read")
class UploadPackage(object):
def __init__(self):
self.items: List[UploadEntry] = []
self.size: int = 0
self.len: int = 0
def reset(self):
self.items = []
self.size = 0
self.len = 0
def update(self, entry: UploadEntry, size: int):
self.items.append(entry)
self.size += size
self.len += 1
def is_empty(self):
return self.len == 0
def __eq__(self, other):
"""
Returns true if both objects are equal
"""
return self.__dict__ == other.__dict__
def __ne__(self, other):
"""
Returns true if both objects are not equal
"""
return not self == other
def to_str(self):
"""
Returns the string representation of the model
"""
return pformat(self.__dict__)
def __repr__(self):
"""
For `print` and `pprint`
"""
return self.to_str()
@six.add_metaclass(ABCMeta)
class ProgressIndicator(object):
@abstractmethod
def progress(self, steps):
pass
@abstractmethod
def complete(self):
pass
class LoggingProgressIndicator(ProgressIndicator):
def __init__(self, total, frequency=10):
self.current = 0
self.total = total
self.last_warning = time.time()
self.frequency = frequency
_logger.warning(
"You are sending %dMB of source code to Neptune. "
"It is pretty uncommon - please make sure it's what you wanted.",
self.total / (1024 * 1024),
)
def progress(self, steps):
self.current += steps
if time.time() - self.last_warning > self.frequency:
_logger.warning(
"%d MB / %d MB (%d%%) of source code was sent to Neptune.",
self.current / (1024 * 1024),
self.total / (1024 * 1024),
100 * self.current / self.total,
)
self.last_warning = time.time()
def complete(self):
_logger.warning(
"%d MB (100%%) of source code was sent to Neptune.",
self.total / (1024 * 1024),
)
class SilentProgressIndicator(ProgressIndicator):
def __init__(self):
pass
def progress(self, steps):
pass
def complete(self):
pass
def scan_unique_upload_entries(upload_entries):
"""
Returns upload entries for all files that could be found for given upload entries.
In case of directory as upload entry, files we be taken from all subdirectories recursively.
Any duplicated entries are removed.
"""
walked_entries = set()
for entry in upload_entries:
if entry.is_stream() or not os.path.isdir(entry.source):
walked_entries.add(entry)
else:
for root, _, files in os.walk(entry.source):
path_relative_to_entry_source = os.path.relpath(root, entry.source)
target_root = os.path.normpath(os.path.join(entry.target_path, path_relative_to_entry_source))
for filename in files:
walked_entries.add(
UploadEntry(
os.path.join(root, filename),
os.path.join(target_root, filename),
)
)
return walked_entries
def split_upload_files(
upload_entries: Set[UploadEntry],
upload_configuration: AttributeUploadConfiguration,
max_files=500,
) -> Generator[UploadPackage, None, None]:
current_package = UploadPackage()
for entry in upload_entries:
if entry.is_stream():
if current_package.len > 0:
yield current_package
current_package.reset()
current_package.update(entry, 0)
yield current_package
current_package.reset()
else:
size = os.path.getsize(entry.source)
if (
size + current_package.size > upload_configuration.chunk_size or current_package.len > max_files
) and not current_package.is_empty():
yield current_package
current_package.reset()
current_package.update(entry, size)
yield current_package
def normalize_file_name(name):
return name.replace(os.sep, "/")
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import glob as globlib
import math
import os
import re
import ssl
import sys
import numpy as np
import pandas as pd
from neptune.common import envs
from neptune.common.exceptions import (
FileNotFound,
InvalidNotebookPath,
NeptuneIncorrectProjectQualifiedNameException,
NeptuneMissingProjectQualifiedNameException,
NotADirectory,
NotAFile,
)
from neptune.common.git_info import GitInfo
from neptune.common.patterns import PROJECT_QUALIFIED_NAME_PATTERN
from neptune.internal.utils.logger import get_logger
_logger = get_logger()
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
def reset_internal_ssl_state():
"""
OpenSSL's internal random number generator does not properly handle forked processes.
Applications must change the PRNG state of the parent process if they use any SSL feature with os.fork().
Any successful call of RAND_add(), RAND_bytes() or RAND_pseudo_bytes() is sufficient.
https://docs.python.org/3/library/ssl.html#multi-processing
"""
ssl.RAND_bytes(100)
def map_values(f_value, dictionary):
return dict((k, f_value(v)) for k, v in dictionary.items())
def map_keys(f_key, dictionary):
return dict((f_key(k), v) for k, v in dictionary.items())
def as_list(value):
if value is None or isinstance(value, list):
return value
else:
return [value]
def validate_notebook_path(path):
if not path.endswith(".ipynb"):
raise InvalidNotebookPath(path)
if not os.path.exists(path):
raise FileNotFound(path)
if not os.path.isfile(path):
raise NotAFile(path)
def assure_directory_exists(destination_dir):
"""Checks if `destination_dir` DIRECTORY exists, or creates one"""
if not destination_dir:
destination_dir = os.getcwd()
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)
elif not os.path.isdir(destination_dir):
raise NotADirectory(destination_dir)
return destination_dir
def align_channels_on_x(dataframe):
channel_dfs, common_x = _split_df_by_stems(dataframe)
return merge_dataframes([common_x] + channel_dfs, on="x", how="outer")
def get_channel_name_stems(columns):
return list(set([col[2:] for col in columns]))
def merge_dataframes(dataframes, on, how="outer"):
merged_df = functools.reduce(lambda left, right: pd.merge(left, right, on=on, how=how), dataframes)
return merged_df
def is_float(value):
try:
_ = float(value)
except ValueError:
return False
else:
return True
def is_nan_or_inf(value):
return math.isnan(value) or math.isinf(value)
def is_notebook():
try:
get_ipython # noqa: F821
return True
except Exception:
return False
def _split_df_by_stems(df):
channel_dfs, x_vals = [], []
for stem in get_channel_name_stems(df.columns):
channel_df = df[["x_{}".format(stem), "y_{}".format(stem)]]
channel_df.columns = ["x", stem]
channel_df = channel_df.dropna()
channel_dfs.append(channel_df)
x_vals.extend(channel_df["x"].tolist())
common_x = pd.DataFrame({"x": np.unique(x_vals)}, dtype=float)
return channel_dfs, common_x
def discover_git_repo_location():
import __main__
if hasattr(__main__, "__file__"):
return os.path.dirname(os.path.abspath(__main__.__file__))
return None
def update_session_proxies(session, proxies):
if proxies is not None:
try:
session.proxies.update(proxies)
except (TypeError, ValueError):
raise ValueError("Wrong proxies format: {}".format(proxies))
def get_git_info(repo_path=None):
"""Retrieve information about git repository.
If the attempt fails, ``None`` will be returned.
Args:
repo_path (:obj:`str`, optional, default is ``None``):
| Path to the repository from which extract information about git.
| If ``None`` is passed, calling ``get_git_info`` is equivalent to calling
``git.Repo(search_parent_directories=True)``.
Check `GitPython <https://gitpython.readthedocs.io/en/stable/reference.html#git.repo.base.Repo>`_
docs for more information.
Returns:
:class:`~neptune.git_info.GitInfo` - An object representing information about git repository.
Examples:
.. code:: python3
# Get git info from the current directory
git_info = get_git_info('.')
"""
try:
import git
repo = git.Repo(repo_path, search_parent_directories=True)
commit = repo.head.commit
active_branch = ""
try:
active_branch = repo.active_branch.name
except TypeError as e:
if str(e.args[0]).startswith("HEAD is a detached symbolic reference as it points to"):
active_branch = "Detached HEAD"
remote_urls = [remote.url for remote in repo.remotes]
return GitInfo(
commit_id=commit.hexsha,
message=commit.message,
author_name=commit.author.name,
author_email=commit.author.email,
commit_date=commit.committed_datetime,
repository_dirty=repo.is_dirty(index=False, untracked_files=True),
active_branch=active_branch,
remote_urls=remote_urls,
)
except: # noqa: E722
return None
def file_contains(filename, text):
with open(filename) as f:
for line in f:
if text in line:
return True
return False
def in_docker():
cgroup_file = "/proc/self/cgroup"
return os.path.exists("./dockerenv") or (os.path.exists(cgroup_file) and file_contains(cgroup_file, text="docker"))
def is_ipython():
try:
import IPython
ipython = IPython.core.getipython.get_ipython()
return ipython is not None
except ImportError:
return False
def glob(pathname):
if sys.version_info.major < 3 or (sys.version_info.major == 3 and sys.version_info.minor < 5):
return globlib.glob(pathname)
else:
return globlib.glob(pathname, recursive=True)
def assure_project_qualified_name(project_qualified_name):
project_qualified_name = project_qualified_name or os.getenv(envs.PROJECT_ENV_NAME)
if not project_qualified_name:
raise NeptuneMissingProjectQualifiedNameException()
if not re.match(PROJECT_QUALIFIED_NAME_PATTERN, project_qualified_name):
raise NeptuneIncorrectProjectQualifiedNameException(project_qualified_name)
return project_qualified_name
class NoopObject(object):
def __getattr__(self, name):
return self
def __getitem__(self, key):
return self
def __call__(self, *args, **kwargs):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"warn_once",
"warn_about_unsupported_type",
"NeptuneDeprecationWarning",
"NeptuneWarning",
"NeptuneUnsupportedType",
"NeptuneUnsupportedValue",
]
import os
import traceback
import warnings
import neptune
from neptune.internal.utils.logger import NEPTUNE_LOGGER_NAME
from neptune.internal.utils.runningmode import in_interactive
DEFAULT_FORMAT = "[%(name)s] [warning] %(filename)s:%(lineno)d: %(category)s: %(message)s\n"
INTERACTIVE_FORMAT = "[%(name)s] [warning] %(category)s: %(message)s\n"
class NeptuneDeprecationWarning(DeprecationWarning):
pass
class NeptuneUnsupportedValue(Warning):
pass
class NeptuneWarning(Warning):
pass
class NeptuneUnsupportedType(Warning):
pass
warnings.simplefilter("always", category=NeptuneDeprecationWarning)
MAX_WARNED_ONCE_CAPACITY = 1_000
warned_once = set()
path_to_root_module = os.path.dirname(os.path.realpath(neptune.__file__))
def get_user_code_stack_level():
call_stack = traceback.extract_stack()
for level, stack_frame in enumerate(reversed(call_stack)):
if path_to_root_module not in stack_frame.filename:
return level
return 2
def format_message(message, category, filename, lineno, line=None) -> str:
variables = {
"message": message,
"category": category.__name__,
"filename": filename,
"lineno": lineno,
"name": NEPTUNE_LOGGER_NAME,
}
message_format = INTERACTIVE_FORMAT if in_interactive() else DEFAULT_FORMAT
return message_format % variables
def warn_once(message: str, *, exception: type(Exception) = None):
if len(warned_once) < MAX_WARNED_ONCE_CAPACITY:
if exception is None:
exception = NeptuneDeprecationWarning
message_hash = hash(message)
if message_hash not in warned_once:
old_formatting = warnings.formatwarning
warnings.formatwarning = format_message
warnings.warn(
message=message,
category=exception,
stacklevel=get_user_code_stack_level(),
)
warnings.formatwarning = old_formatting
warned_once.add(message_hash)
def warn_about_unsupported_type(type_str: str):
warn_once(
message=f"""You're attempting to log a type that is not directly supported by Neptune ({type_str}).
Convert the value to a supported type, such as a string or float, or use stringify_unsupported(obj)
for dictionaries or collections that contain unsupported values.
For more, see https://docs.neptune.ai/help/value_of_unsupported_type""",
exception=NeptuneUnsupportedType,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from websocket import (
WebSocketConnectionClosedException,
WebSocketTimeoutException,
)
from neptune.common.websockets.websocket_client_adapter import (
WebsocketClientAdapter,
WebsocketNotConnectedException,
)
class ReconnectingWebsocket(object):
def __init__(self, url, oauth2_session, shutdown_event, proxies=None):
self.url = url
self.client = WebsocketClientAdapter()
self._shutdown_event = shutdown_event
self._oauth2_session = oauth2_session
self._reconnect_counter = ReconnectCounter()
self._token = oauth2_session.token
self._proxies = proxies
def shutdown(self):
self._shutdown_event.set()
self.client.close()
self.client.abort()
self.client.shutdown()
def recv(self):
if not self.client.connected:
self._try_to_establish_connection()
while self._is_active():
try:
data = self.client.recv()
self._on_successful_connect()
return data
except WebSocketTimeoutException:
raise
except WebSocketConnectionClosedException:
if self._is_active():
self._handle_lost_connection()
else:
raise
except WebsocketNotConnectedException:
if self._is_active():
self._handle_lost_connection()
except Exception:
if self._is_active():
self._handle_lost_connection()
def _is_active(self):
return not self._shutdown_event.is_set()
def _on_successful_connect(self):
self._reconnect_counter.clear()
def _try_to_establish_connection(self):
try:
self._request_token_refresh()
if self.client.connected:
self.client.shutdown()
self.client.connect(url=self.url, token=self._token, proxies=self._proxies)
except Exception:
self._shutdown_event.wait(self._reconnect_counter.calculate_delay())
def _handle_lost_connection(self):
self._reconnect_counter.increment()
self._try_to_establish_connection()
def _request_token_refresh(self):
self._token = self._oauth2_session.refresh_token(token_url=self._oauth2_session.auto_refresh_url)
class ReconnectCounter(object):
MAX_RETRY_DELAY = 128
def __init__(self):
self.retries = 0
def clear(self):
self.retries = 0
def increment(self):
self.retries += 1
def calculate_delay(self):
return self._compute_delay(self.retries, self.MAX_RETRY_DELAY)
@classmethod
def _compute_delay(cls, attempt, max_delay):
delay = cls._full_jitter_delay(attempt, max_delay)
return delay
@classmethod
def _full_jitter_delay(cls, attempt, cap):
exp = min(2 ** (attempt - 1), cap)
return random.uniform(0, exp)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import ssl
from future.utils import PY3
from six.moves import urllib
from websocket import (
ABNF,
create_connection,
)
class WebsocketClientAdapter(object):
def __init__(self):
self._ws_client = None
def connect(self, url, token, proxies=None):
sslopt = None
if os.getenv("NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE"):
sslopt = {"cert_reqs": ssl.CERT_NONE}
proto = url.split(":")[0].replace("ws", "http")
proxy = proxies[proto] if proxies and proto in proxies else os.getenv("{}_PROXY".format(proto.upper()))
if proxy:
proxy_split = urllib.parse.urlparse(proxy).netloc.split(":")
proxy_host = proxy_split[0]
proxy_port = proxy_split[1] if len(proxy_split) > 1 else "80" if proto == "http" else "443"
else:
proxy_host = None
proxy_port = None
self._ws_client = create_connection(
url,
header=self._auth_header(token),
sslopt=sslopt,
http_proxy_host=proxy_host,
http_proxy_port=proxy_port,
)
def recv(self):
if self._ws_client is None:
raise WebsocketNotConnectedException()
opcode, data = None, None
while opcode != ABNF.OPCODE_TEXT:
opcode, data = self._ws_client.recv_data()
return data.decode("utf-8") if PY3 else data
@property
def connected(self):
return self._ws_client and self._ws_client.connected
def close(self):
if self._ws_client:
return self._ws_client.close()
def abort(self):
if self._ws_client:
return self._ws_client.abort()
def shutdown(self):
if self._ws_client:
return self._ws_client.shutdown()
@classmethod
def _auth_header(cls, token):
return ["Authorization: Bearer " + token["access_token"]]
class WebsocketNotConnectedException(Exception):
def __init__(self):
super(WebsocketNotConnectedException, self).__init__("Websocket client is not connected!")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"Project",
"Session",
"session",
"project",
"ANONYMOUS",
"ANONYMOUS_API_TOKEN",
"CURRENT_KWARGS",
"init",
"set_project",
"create_experiment",
"get_experiment",
"append_tag",
"append_tags",
"remove_tag",
"set_property",
"remove_property",
"send_metric",
"log_metric",
"send_text",
"log_text",
"send_image",
"log_image",
"send_artifact",
"delete_artifacts",
"log_artifact",
"stop",
"InvalidNeptuneBackend",
"NeptuneIncorrectImportException",
"NeptuneUninitializedException",
"envs",
"constants",
]
import logging
import os
import threading
from neptune.common.utils import assure_project_qualified_name
from neptune.legacy import (
constants,
envs,
)
from neptune.legacy.exceptions import (
InvalidNeptuneBackend,
NeptuneIncorrectImportException,
NeptuneUninitializedException,
)
from neptune.legacy.internal.api_clients.backend_factory import backend_factory
from neptune.legacy.internal.utils.deprecation import legacy_client_deprecation
from neptune.legacy.projects import Project
from neptune.legacy.sessions import Session
session = None
project = None
__lock = threading.RLock()
_logger = logging.getLogger(__name__)
"""Access Neptune as an anonymous user.
You can pass this value as api_token during init() call, either by an environment variable or passing it directly
"""
ANONYMOUS = constants.ANONYMOUS
"""Anonymous user API token.
You can pass this value as api_token during init() call, either by an environment variable or passing it directly
"""
ANONYMOUS_API_TOKEN = constants.ANONYMOUS_API_TOKEN
CURRENT_KWARGS = (
"project",
"run",
"custom_run_id",
"mode",
"name",
"description",
"tags",
"source_files",
"capture_stdout",
"capture_stderr",
"capture_hardware_metrics",
"fail_on_exception",
"monitoring_namespace",
"flush_period",
)
def _check_for_extra_kwargs(caller_name, kwargs: dict):
for name in CURRENT_KWARGS:
if name in kwargs:
raise NeptuneIncorrectImportException()
if kwargs:
first_key = next(iter(kwargs.keys()))
raise TypeError(f"{caller_name}() got an unexpected keyword argument '{first_key}'")
@legacy_client_deprecation
def init(project_qualified_name=None, api_token=None, proxies=None, backend=None, **kwargs):
"""Initialize `Neptune client library <https://github.com/neptune-ai/neptune-client>`_ to work with
specific project.
Authorize user, sets value of global variable ``project`` to :class:`~neptune.projects.Project` object
that can be used to create or list experiments, notebooks, etc.
Args:
project_qualified_name (:obj:`str`, optional, default is ``None``):
Qualified name of a project in a form of ``namespace/project_name``.
If ``None``, the value of ``NEPTUNE_PROJECT`` environment variable will be taken.
api_token (:obj:`str`, optional, default is ``None``):
User's API token. If ``None``, the value of ``NEPTUNE_API_TOKEN`` environment variable will be taken.
.. note::
It is strongly recommended to use ``NEPTUNE_API_TOKEN`` environment variable rather than
placing your API token in plain text in your source code.
proxies (:obj:`dict`, optional, default is ``None``):
Argument passed to HTTP calls made via the `Requests <https://2.python-requests.org/en/master/>`_ library.
For more information see their proxies
`section <https://2.python-requests.org/en/master/user/advanced/#proxies>`_.
.. note::
Only `http` and `https` keys are supported by all features.
.. deprecated :: 0.4.4
Instead, use:
.. code :: python3
from neptune.legacy import HostedNeptuneBackendApiClient
neptune.init(backend=HostedNeptuneBackendApiClient(proxies=...))
backend (:class:`~neptune.ApiClient`, optional, default is ``None``):
By default, Neptune client library sends logs, metrics, images, etc to Neptune servers:
either publicly available SaaS, or an on-premises installation.
You can also pass the default backend instance explicitly to specify its parameters:
.. code :: python3
from neptune.legacy import HostedNeptuneBackendApiClient
neptune.init(backend=HostedNeptuneBackendApiClient(...))
Passing an instance of :class:`~neptune.OfflineApiClient` makes your code run without communicating
with Neptune servers.
.. code :: python3
from neptune.legacy import OfflineApiClient
neptune.init(backend=OfflineApiClient())
.. note::
Instead of passing a ``neptune.OfflineApiClient`` instance as ``backend``, you can set an
environment variable ``NEPTUNE_BACKEND=offline`` to override the default behaviour.
Returns:
:class:`~neptune.projects.Project` object that is used to create or list experiments, notebooks, etc.
Raises:
`NeptuneMissingApiTokenException`: When ``api_token`` is None
and ``NEPTUNE_API_TOKEN`` environment variable was not set.
`NeptuneMissingProjectQualifiedNameException`: When ``project_qualified_name`` is None
and ``NEPTUNE_PROJECT`` environment variable was not set.
`InvalidApiKey`: When given ``api_token`` is malformed.
`Unauthorized`: When given ``api_token`` is invalid.
Examples:
.. code:: python3
# minimal invoke
neptune.init()
# specifying project name
neptune.init('jack/sandbox')
# running offline
neptune.init(backend=neptune.OfflineApiClient())
"""
_check_for_extra_kwargs(init.__name__, kwargs)
project_qualified_name = assure_project_qualified_name(project_qualified_name)
with __lock:
global session, project
if backend is None:
backend_name = os.getenv(envs.BACKEND)
backend = backend_factory(
backend_name=backend_name,
api_token=api_token,
proxies=proxies,
)
session = Session(backend=backend)
project = session.get_project(project_qualified_name)
return project
@legacy_client_deprecation
def set_project(project_qualified_name):
"""Setups `Neptune client library <https://github.com/neptune-ai/neptune-client>`_ to work with specific project.
| Sets value of global variable ``project`` to :class:`~neptune.projects.Project` object
that can be used to create or list experiments, notebooks, etc.
| If Neptune client library was not previously initialized via :meth:`~neptune.init` call
it will be initialized with API token taken from ``NEPTUNE_API_TOKEN`` environment variable.
Args:
project_qualified_name (:obj:`str`):
Qualified name of a project in a form of ``namespace/project_name``.
Returns:
:class:`~neptune.projects.Project` object that is used to create or list experiments, notebooks, etc.
Raises:
`NeptuneMissingApiTokenException`: When library was not initialized previously by ``init`` call and
``NEPTUNE_API_TOKEN`` environment variable is not set.
Examples:
.. code:: python3
# minimal invoke
neptune.set_project('jack/sandbox')
"""
with __lock:
global session, project
if session is None:
init(project_qualified_name=project_qualified_name)
else:
project = session.get_project(project_qualified_name)
return project
@legacy_client_deprecation
def create_experiment(
name=None,
description=None,
params=None,
properties=None,
tags=None,
upload_source_files=None,
abort_callback=None,
logger=None,
upload_stdout=True,
upload_stderr=True,
send_hardware_metrics=True,
run_monitoring_thread=True,
handle_uncaught_exceptions=True,
git_info=None,
hostname=None,
notebook_id=None,
):
"""Create and start Neptune experiment.
Alias for: :meth:`~neptune.projects.Project.create_experiment`
"""
global project
if project is None:
raise NeptuneUninitializedException()
return project.create_experiment(
name=name,
description=description,
params=params,
properties=properties,
tags=tags,
upload_source_files=upload_source_files,
abort_callback=abort_callback,
logger=logger,
upload_stdout=upload_stdout,
upload_stderr=upload_stderr,
send_hardware_metrics=send_hardware_metrics,
run_monitoring_thread=run_monitoring_thread,
handle_uncaught_exceptions=handle_uncaught_exceptions,
git_info=git_info,
hostname=hostname,
notebook_id=notebook_id,
)
@legacy_client_deprecation
def get_experiment():
global project
if project is None:
raise NeptuneUninitializedException()
return project._get_current_experiment()
@legacy_client_deprecation
def append_tag(tag, *tags):
"""Append tag(s) to the experiment on the top of experiments view.
Alias for: :meth:`~neptune.experiments.Experiment.append_tag`
"""
get_experiment().append_tag(tag, *tags)
@legacy_client_deprecation
def append_tags(tag, *tags):
"""Append tag(s) to the experiment on the top of experiments view.
Alias for: :meth:`~neptune.experiments.Experiment.append_tags`
"""
get_experiment().append_tag(tag, *tags)
@legacy_client_deprecation
def remove_tag(tag):
"""Removes single tag from experiment.
Alias for: :meth:`~neptune.experiments.Experiment.remove_tag`
"""
get_experiment().remove_tag(tag)
@legacy_client_deprecation
def set_property(key, value):
"""Set `key-value` pair as an experiment property.
If property with given ``key`` does not exist, it adds a new one.
Alias for: :meth:`~neptune.experiments.Experiment.set_property`
"""
get_experiment().set_property(key, value)
@legacy_client_deprecation
def remove_property(key):
"""Removes a property with given key.
Alias for: :meth:`~neptune.experiments.Experiment.remove_property`
"""
get_experiment().remove_property(key)
@legacy_client_deprecation
def send_metric(channel_name, x, y=None, timestamp=None):
"""Log metrics (numeric values) in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_metric`
"""
return get_experiment().send_metric(channel_name, x, y, timestamp)
@legacy_client_deprecation
def log_metric(log_name, x, y=None, timestamp=None):
"""Log metrics (numeric values) in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_metric`
"""
return get_experiment().log_metric(log_name, x, y, timestamp)
@legacy_client_deprecation
def send_text(channel_name, x, y=None, timestamp=None):
"""Log text data in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_text`
"""
return get_experiment().send_text(channel_name, x, y, timestamp)
@legacy_client_deprecation
def log_text(log_name, x, y=None, timestamp=None):
"""Log text data in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_text`
"""
return get_experiment().send_text(log_name, x, y, timestamp)
@legacy_client_deprecation
def send_image(channel_name, x, y=None, name=None, description=None, timestamp=None):
"""Log image data in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_image`
"""
return get_experiment().send_image(channel_name, x, y, name, description, timestamp)
@legacy_client_deprecation
def log_image(log_name, x, y=None, image_name=None, description=None, timestamp=None):
"""Log image data in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_image`
"""
return get_experiment().send_image(log_name, x, y, image_name, description, timestamp)
@legacy_client_deprecation
def send_artifact(artifact, destination=None):
"""Save an artifact (file) in experiment storage.
Alias for :meth:`~neptune.experiments.Experiment.log_artifact`
"""
return get_experiment().log_artifact(artifact, destination)
@legacy_client_deprecation
def delete_artifacts(path):
"""Delete an artifact (file/directory) from experiment storage.
Alias for :meth:`~neptune.experiments.Experiment.delete_artifacts`
"""
return get_experiment().delete_artifacts(path)
@legacy_client_deprecation
def log_artifact(artifact, destination=None):
"""Save an artifact (file) in experiment storage.
Alias for :meth:`~neptune.experiments.Experiment.log_artifact`
"""
return get_experiment().log_artifact(artifact, destination)
@legacy_client_deprecation
def stop(traceback=None):
"""Marks experiment as finished (succeeded or failed).
Alias for :meth:`~neptune.experiments.Experiment.stop`
"""
get_experiment().stop(traceback)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.exceptions import STYLES
from neptune.legacy import envs
from neptune.legacy.exceptions import NeptuneException
class NeptuneApiException(NeptuneException):
pass
class NeptuneSSLVerificationError(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneSSLVerificationError-----------------------------------------------------------------------
{end}
Neptune client was unable to verify your SSL Certificate.
{bold}What could go wrong?{end}
- You are behind a proxy that inspects traffic to Neptune servers.
- Contact your network administrator
- Your on-prem installation's SSL/TLS certificate is not recognized due to a custom Certificate Authority (CA).
- To check run the following command in terminal:
{bash}curl https://<your_domain>/api/backend/echo {end}
- Where <your_domain> is the address that you use to access Neptune UI i.e. abc.com
- Contact your network administrator if you get the following output:
{fail}"curl: (60) server certificate verification failed..."{end}
- Your machine software is not up-to-date.
- Minimal OS requirements:
- Windows >= XP SP3
- macOS >= 10.12.1
- Ubuntu >= 12.04
- Debian >= 8
{bold}What can I do?{end}
You can manually configure Neptune to skip all SSL checks. To do that
set the NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE environment variable to 'TRUE'.
{bold}Note that might mean your connection is less secure{end}.
Linux/Unix
In your terminal run:
{bash}export NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
Windows
In your terminal run:
{bash}set NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
Jupyter notebook
In your code cell:
{bash}%env NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
You may also want to check the following docs pages:
- https://docs.neptune.ai/api/environment_variables/#neptune_allow_self_signed_certificate
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(**STYLES))
class ConnectionLost(NeptuneApiException):
def __init__(self):
super(ConnectionLost, self).__init__("Connection lost. Please try again.")
class ServerError(NeptuneApiException):
def __init__(self):
message = """
{h1}
----ServerError-----------------------------------------------------------------------
{end}
Neptune Client Library encountered an unexpected Server Error.
Please try again later or contact Neptune support.
"""
super(ServerError, self).__init__(message.format(**STYLES))
class Unauthorized(NeptuneApiException):
def __init__(self):
message = """
{h1}
----Unauthorized-----------------------------------------------------------------------
{end}
You have no permission to access given resource.
- Verify your API token is correct.
See: https://docs-legacy.neptune.ai/security-and-privacy/api-tokens/how-to-find-and-set-neptune-api-token.html
- Verify if you set your Project qualified name correctly
The correct project qualified name should look like this {correct}WORKSPACE/PROJECT_NAME{end}.
It has two parts:
- {correct}WORKSPACE{end}: which can be your username or your organization name
- {correct}PROJECT_NAME{end}: which is the actual project name you chose
- Ask your organization administrator to grant you necessary privileges to the project
"""
super(Unauthorized, self).__init__(message.format(**STYLES))
class Forbidden(NeptuneApiException):
def __init__(self):
message = """
{h1}
----Forbidden-----------------------------------------------------------------------
{end}
You have no permission to access given resource.
- Verify your API token is correct.
See: https://docs-legacy.neptune.ai/security-and-privacy/api-tokens/how-to-find-and-set-neptune-api-token.html
- Verify if you set your Project qualified name correctly
The correct project qualified name should look like this {correct}WORKSPACE/PROJECT_NAME{end}.
It has two parts:
- {correct}WORKSPACE{end}: which can be your username or your organization name
- {correct}PROJECT_NAME{end}: which is the actual project name you chose
- Ask your organization administrator to grant you necessary privileges to the project
"""
super(Forbidden, self).__init__(message.format(**STYLES))
class InvalidApiKey(NeptuneApiException):
def __init__(self):
message = """
{h1}
----InvalidApiKey-----------------------------------------------------------------------
{end}
Your API token is invalid.
Learn how to get it in this docs page:
https://docs-legacy.neptune.ai/security-and-privacy/api-tokens/how-to-find-and-set-neptune-api-token.html
There are two options to add it:
- specify it in your code
- set an environment variable in your operating system.
{h2}CODE{end}
Pass the token to {bold}neptune.init(){end} via {bold}api_token{end} argument:
{python}neptune.init(project_qualified_name='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN'){end}
{h2}ENVIRONMENT VARIABLE{end} {correct}(Recommended option){end}
or export or set an environment variable depending on your operating system:
{correct}Linux/Unix{end}
In your terminal run:
{bash}export {env_api_token}=YOUR_API_TOKEN{end}
{correct}Windows{end}
In your CMD run:
{bash}set {env_api_token}=YOUR_API_TOKEN{end}
and skip the {bold}api_token{end} argument of {bold}neptune.init(){end}:
{python}neptune.init(project_qualified_name='WORKSPACE_NAME/PROJECT_NAME'){end}
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/security-and-privacy/api-tokens/how-to-find-and-set-neptune-api-token.html
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(InvalidApiKey, self).__init__(message.format(env_api_token=envs.API_TOKEN_ENV_NAME, **STYLES))
class WorkspaceNotFound(NeptuneApiException):
def __init__(self, namespace_name):
message = """
{h1}
----WorkspaceNotFound-------------------------------------------------------------------------
{end}
Workspace {python}{workspace}{end} not found.
Workspace is your username or a name of your team organization.
"""
super(WorkspaceNotFound, self).__init__(message.format(workspace=namespace_name, **STYLES))
class ProjectNotFound(NeptuneApiException):
def __init__(self, project_identifier):
message = """
{h1}
----ProjectNotFound-------------------------------------------------------------------------
{end}
Project {python}{project}{end} not found.
Verify if your project's name was not misspelled. You can find proper name after logging into Neptune UI.
"""
super(ProjectNotFound, self).__init__(message.format(project=project_identifier, **STYLES))
class PathInProjectNotFound(NeptuneApiException):
def __init__(self, path, project_identifier):
super(PathInProjectNotFound, self).__init__(
"Path {} was not found in project {}.".format(path, project_identifier)
)
class PathInExperimentNotFound(NeptuneApiException):
def __init__(self, path, exp_identifier):
super().__init__(f"Path {path} was not found in experiment {exp_identifier}.")
class NotebookNotFound(NeptuneApiException):
def __init__(self, notebook_id, project=None):
if project:
super(NotebookNotFound, self).__init__(
"Notebook '{}' not found in project '{}'.".format(notebook_id, project)
)
else:
super(NotebookNotFound, self).__init__("Notebook '{}' not found.".format(notebook_id))
class ExperimentNotFound(NeptuneApiException):
def __init__(self, experiment_short_id, project_qualified_name):
super(ExperimentNotFound, self).__init__(
"Experiment '{exp}' not found in '{project}'.".format(
exp=experiment_short_id, project=project_qualified_name
)
)
class ChannelNotFound(NeptuneApiException):
def __init__(self, channel_id):
super(ChannelNotFound, self).__init__("Channel '{id}' not found.".format(id=channel_id))
class ExperimentAlreadyFinished(NeptuneApiException):
def __init__(self, experiment_short_id):
super(ExperimentAlreadyFinished, self).__init__(
"Experiment '{}' is already finished.".format(experiment_short_id)
)
class ExperimentLimitReached(NeptuneApiException):
def __init__(self):
super(ExperimentLimitReached, self).__init__("Experiment limit reached.")
class StorageLimitReached(NeptuneApiException):
def __init__(self):
super(StorageLimitReached, self).__init__("Storage limit reached.")
class ExperimentValidationError(NeptuneApiException):
pass
class ChannelAlreadyExists(NeptuneApiException):
def __init__(self, experiment_short_id, channel_name):
super(ChannelAlreadyExists, self).__init__(
"Channel with name '{}' already exists in experiment '{}'.".format(channel_name, experiment_short_id)
)
class ChannelDoesNotExist(NeptuneApiException):
def __init__(self, experiment_short_id, channel_name):
super(ChannelDoesNotExist, self).__init__(
"Channel with name '{}' does not exist in experiment '{}'.".format(channel_name, experiment_short_id)
)
class ChannelsValuesSendBatchError(NeptuneApiException):
@staticmethod
def _format_error(error):
return "{msg} (metricId: '{channelId}', x: {x})".format(msg=error.error, channelId=error.channelId, x=error.x)
def __init__(self, experiment_short_id, batch_errors):
super(ChannelsValuesSendBatchError, self).__init__(
"Received batch errors sending channels' values to experiment {}. "
"Cause: {} "
"Skipping {} values.".format(
experiment_short_id,
self._format_error(batch_errors[0]) if batch_errors else "No errors",
len(batch_errors),
)
)
class ExperimentOperationErrors(NeptuneApiException):
"""Handles minor errors returned by calling `client.executeOperations`"""
def __init__(self, errors):
super().__init__()
self.errors = errors
def __str__(self):
lines = ["Caused by:"]
for error in self.errors:
lines.append(f"\t* {error}")
return "\n".join(lines)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import (
ABC,
abstractmethod,
)
from typing import Dict
from neptune.legacy.model import ChannelWithLastValue
class ApiClient(ABC):
@property
@abstractmethod
def api_address(self):
pass
@property
@abstractmethod
def display_address(self):
pass
@property
@abstractmethod
def proxies(self):
pass
class BackendApiClient(ApiClient, ABC):
@abstractmethod
def get_project(self, project_qualified_name):
pass
@abstractmethod
def get_projects(self, namespace):
pass
@abstractmethod
def create_leaderboard_backend(self, project) -> "LeaderboardApiClient":
pass
class LeaderboardApiClient(ApiClient, ABC):
@abstractmethod
def get_project_members(self, project_identifier):
pass
@abstractmethod
def get_leaderboard_entries(
self,
project,
entry_types=None,
ids=None,
states=None,
owners=None,
tags=None,
min_running_time=None,
):
pass
def websockets_factory(self, project_id, experiment_id):
return None
@abstractmethod
def get_channel_points_csv(self, experiment, channel_internal_id, channel_name):
pass
@abstractmethod
def get_metrics_csv(self, experiment):
pass
@abstractmethod
def create_experiment(
self,
project,
name,
description,
params,
properties,
tags,
abortable,
monitored,
git_info,
hostname,
entrypoint,
notebook_id,
checkpoint_id,
):
pass
@abstractmethod
def upload_source_code(self, experiment, source_target_pairs):
pass
@abstractmethod
def get_notebook(self, project, notebook_id):
pass
@abstractmethod
def get_last_checkpoint(self, project, notebook_id):
pass
@abstractmethod
def create_notebook(self, project):
pass
@abstractmethod
def create_checkpoint(self, notebook_id, jupyter_path, _file=None):
pass
@abstractmethod
def get_experiment(self, experiment_id):
pass
@abstractmethod
def set_property(self, experiment, key, value):
pass
@abstractmethod
def remove_property(self, experiment, key):
pass
@abstractmethod
def update_tags(self, experiment, tags_to_add, tags_to_delete):
pass
@abstractmethod
def create_channel(self, experiment, name, channel_type) -> ChannelWithLastValue:
pass
@abstractmethod
def get_channels(self, experiment) -> Dict[str, object]:
pass
@abstractmethod
def reset_channel(self, experiment, channel_id, channel_name, channel_type):
pass
@abstractmethod
def create_system_channel(self, experiment, name, channel_type) -> ChannelWithLastValue:
pass
@abstractmethod
def get_system_channels(self, experiment) -> Dict[str, object]:
pass
@abstractmethod
def send_channels_values(self, experiment, channels_with_values):
pass
@abstractmethod
def mark_failed(self, experiment, traceback):
pass
@abstractmethod
def ping_experiment(self, experiment):
pass
@abstractmethod
def create_hardware_metric(self, experiment, metric):
pass
@abstractmethod
def send_hardware_metric_reports(self, experiment, metrics, metric_reports):
pass
@abstractmethod
def log_artifact(self, experiment, artifact, destination=None):
pass
@abstractmethod
def delete_artifacts(self, experiment, path):
pass
@abstractmethod
def download_data(self, experiment, path, destination):
pass
@abstractmethod
def download_sources(self, experiment, path=None, destination_dir=None):
pass
@abstractmethod
def download_artifacts(self, experiment, path=None, destination_dir=None):
pass
@abstractmethod
def download_artifact(self, experiment, path=None, destination_dir=None):
pass
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class Checkpoint(object):
def __init__(self, _id, name, path):
self.id = _id
self.name = name
self.path = path
#
# Copyright (c) 2020, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Constants used by Neptune"""
ANONYMOUS = "ANONYMOUS"
ANONYMOUS_API_TOKEN = (
"eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYW"
"kiLCJhcGlfa2V5IjoiYjcwNmJjOGYtNzZmOS00YzJlLTkzOWQtNGJhMDM2ZjkzMmU0In0="
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["PROJECT_ENV_NAME", "API_TOKEN_ENV_NAME", "NOTEBOOK_ID_ENV_NAME", "NOTEBOOK_PATH_ENV_NAME", "BACKEND"]
from neptune.common.envs import (
API_TOKEN_ENV_NAME,
BACKEND,
NOTEBOOK_ID_ENV_NAME,
NOTEBOOK_PATH_ENV_NAME,
PROJECT_ENV_NAME,
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"EMPTY_STYLES",
"STYLES",
"UNIX_STYLES",
"WINDOWS_STYLES",
"FileNotFound",
"InvalidNotebookPath",
"NeptuneIncorrectProjectQualifiedNameException",
"NeptuneMissingProjectQualifiedNameException",
"NotADirectory",
"NotAFile",
]
from neptune.common.exceptions import (
EMPTY_STYLES,
STYLES,
UNIX_STYLES,
WINDOWS_STYLES,
FileNotFound,
InvalidNotebookPath,
NeptuneIncorrectProjectQualifiedNameException,
NeptuneMissingProjectQualifiedNameException,
NotADirectory,
NotAFile,
)
from neptune.legacy import envs
class NeptuneException(Exception):
pass
class NeptuneUninitializedException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneUninitializedException---------------------------------------------------------------------------------------
{end}
You must initialize neptune-client before you create an experiment.
Looks like you forgot to add:
{python}neptune.init(project_qualified_name='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN'){end}
before you ran:
{python}neptune.create_experiment(){end}
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
""".format(
**STYLES
)
super(NeptuneUninitializedException, self).__init__(message)
class NoChannelValue(NeptuneException):
def __init__(self):
super(NoChannelValue, self).__init__("No channel value provided.")
class NeptuneLibraryNotInstalledException(NeptuneException):
def __init__(self, library):
message = """
{h1}
----NeptuneLibraryNotInstalledException---------------------------------------------------------------------------------
{end}
Looks like library {library} wasn't installed.
To install run:
{bash}pip install {library}{end}
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/getting-started/installation/index.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(NeptuneLibraryNotInstalledException, self).__init__(message.format(library=library, **STYLES))
class InvalidChannelValue(NeptuneException):
def __init__(self, expected_type, actual_type):
super(InvalidChannelValue, self).__init__(
"Invalid channel value type. Expected: {expected}, actual: {actual}.".format(
expected=expected_type, actual=actual_type
)
)
class NeptuneNoExperimentContextException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneNoExperimentContextException---------------------------------------------------------------------------------
{end}
Neptune couldn't find an active experiment.
Looks like you forgot to run:
{python}neptune.create_experiment(){end}
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(NeptuneNoExperimentContextException, self).__init__(message.format(**STYLES))
class NeptuneMissingApiTokenException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneMissingApiTokenException-------------------------------------------------------------------------------------
{end}
Neptune client couldn't find your API token.
Learn how to get it in this docs page:
https://docs-legacy.neptune.ai/security-and-privacy/api-tokens/how-to-find-and-set-neptune-api-token.html
There are two options to add it:
- specify it in your code
- set an environment variable in your operating system.
{h2}CODE{end}
Pass the token to {bold}neptune.init(){end} via {bold}api_token{end} argument:
{python}neptune.init(project_qualified_name='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN'){end}
{h2}ENVIRONMENT VARIABLE{end} {correct}(Recommended option){end}
or export or set an environment variable depending on your operating system:
{correct}Linux/Unix{end}
In your terminal run:
{bash}export {env_api_token}=YOUR_API_TOKEN{end}
{correct}Windows{end}
In your CMD run:
{bash}set {env_api_token}=YOUR_API_TOKEN{end}
and skip the {bold}api_token{end} argument of {bold}neptune.init(){end}:
{python}neptune.init(project_qualified_name='WORKSPACE_NAME/PROJECT_NAME'){end}
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/security-and-privacy/api-tokens/how-to-find-and-set-neptune-api-token.html
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(NeptuneMissingApiTokenException, self).__init__(
message.format(env_api_token=envs.API_TOKEN_ENV_NAME, **STYLES)
)
class InvalidNeptuneBackend(NeptuneException):
def __init__(self, provided_backend_name):
super(InvalidNeptuneBackend, self).__init__(
'Unknown {} "{}". '
"Use this environment variable to modify neptune-client behaviour at runtime, "
"e.g. using {}=offline allows you to run your code without logging anything to Neptune"
"".format(envs.BACKEND, provided_backend_name, envs.BACKEND)
)
class DeprecatedApiToken(NeptuneException):
def __init__(self, app_url):
super(DeprecatedApiToken, self).__init__(
"Your API token is deprecated. Please visit {} to get a new one.".format(app_url)
)
class CannotResolveHostname(NeptuneException):
def __init__(self, host):
message = """
{h1}
----CannotResolveHostname-----------------------------------------------------------------------
{end}
Neptune Client Library was not able to resolve hostname {host}.
What should I do?
- Check if your computer is connected to the internet.
- Check if your computer should use any proxy to access internet.
If so, you may want to use {python}proxies{end} parameter of {python}neptune.init(){end} function.
See https://docs-legacy.neptune.ai/api-reference/neptune/index.html#neptune.init
and https://requests.readthedocs.io/en/master/user/advanced/#proxies
"""
super(CannotResolveHostname, self).__init__(message.format(host=host, **STYLES))
class UnsupportedClientVersion(NeptuneException):
def __init__(self, version, minVersion, maxVersion):
super(UnsupportedClientVersion, self).__init__(
"This client version ({}) is not supported. Please install neptune-client{}".format(
version,
"==" + str(maxVersion) if maxVersion else ">=" + str(minVersion),
)
)
class UnsupportedInAlphaException(NeptuneException):
"""Raised for operations which was available in old client,
but aren't supported in alpha version"""
class DownloadSourcesException(UnsupportedInAlphaException):
message = """
{h1}
----DownloadSourcesException-----------------------------------------------------------------------
{end}
Neptune Client Library was not able to download single file from sources.
Why am I seeing this?
Your project "{project}" has been migrated to new structure.
Old version of `neptune-api` is not supporting downloading particular source files.
We recommend you to use new version of api: `neptune.new`.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
If you don't want to adapt your code to new api yet,
you can use `download_sources` with `path` parameter set to None.
"""
def __init__(self, experiment):
assert self.message is not None
super().__init__(
self.message.format(
project=experiment._project.internal_id,
**STYLES,
)
)
class DownloadArtifactsUnsupportedException(UnsupportedInAlphaException):
message = """
{h1}
----DownloadArtifactsUnsupportedException-----------------------------------------------------------------------
{end}
Neptune Client Library was not able to download artifacts.
Function `download_artifacts` is deprecated.
Why am I seeing this?
Your project "{project}" has been migrated to new structure.
Old version of `neptune-api` is not supporting downloading artifact directories.
We recommend you to use new version of api: `neptune.new`.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
If you don't want to adapt your code to new api yet,
you can use `download_artifact` and download files one by one.
"""
def __init__(self, experiment):
assert self.message is not None
super().__init__(
self.message.format(
project=experiment._project.internal_id,
**STYLES,
)
)
class DownloadArtifactUnsupportedException(UnsupportedInAlphaException):
message = """
{h1}
----DownloadArtifactUnsupportedException-----------------------------------------------------------------------
{end}
Neptune Client Library was not able to download attribute: "{artifact_path}".
It's not present in experiment {experiment} or is a directory.
Why am I seeing this?
Your project "{project}" has been migrated to new structure.
Old version of `neptune-api` is not supporting downloading whole artifact directories.
We recommend you to use new version of api: `neptune.new`.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
If you don't want to adapt your code to new api yet:
- Make sure that artifact "{artifact_path}" is present in experiment "{experiment}".
- Make sure that you're addressing artifact which is file, not whole directory.
"""
def __init__(self, artifact_path, experiment):
assert self.message is not None
super().__init__(
self.message.format(
artifact_path=artifact_path,
experiment=experiment.id,
project=experiment._project.internal_id,
**STYLES,
)
)
class DeleteArtifactUnsupportedInAlphaException(UnsupportedInAlphaException):
message = """
{h1}
----DeleteArtifactUnsupportedInAlphaException-----------------------------------------------------------------------
{end}
Neptune Client Library was not able to delete attribute: "{artifact_path}".
It's not present in experiment {experiment} or is a directory.
Why am I seeing this?
Your project "{project}" has been migrated to new structure.
Old version of `neptune-api` is not supporting deleting whole artifact directories.
We recommend you to use new version of api: `neptune.new`.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
If you don't want to adapt your code to new api yet:
- Make sure that artifact "{artifact_path}" is present in experiment "{experiment}".
- Make sure that you're addressing artifact which is file, not whole directory.
"""
def __init__(self, artifact_path, experiment):
assert self.message is not None
super().__init__(
self.message.format(
artifact_path=artifact_path,
experiment=experiment.id,
project=experiment._project.internal_id,
**STYLES,
)
)
class NeptuneIncorrectImportException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneIncorrectImportException----------------------------------------------------------------
{end}
It seems you are trying to use the new Python API, but imported the legacy API.
Simply update your import statement to:
{python}import neptune{end}
You may also want to check the following docs pages:
- https://docs.neptune.ai/about/legacy/#migrating-to-neptunenew
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(**STYLES))
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import logging
import traceback
import pandas as pd
import six
from pandas.errors import EmptyDataError
from neptune.common.experiments import LegacyExperiment
from neptune.common.utils import (
align_channels_on_x,
is_float,
is_nan_or_inf,
)
from neptune.legacy.api_exceptions import (
ChannelDoesNotExist,
ExperimentAlreadyFinished,
)
from neptune.legacy.exceptions import (
InvalidChannelValue,
NeptuneIncorrectImportException,
NoChannelValue,
)
from neptune.legacy.internal.channels.channels import (
ChannelNamespace,
ChannelType,
ChannelValue,
)
from neptune.legacy.internal.channels.channels_values_sender import ChannelsValuesSender
from neptune.legacy.internal.execution.execution_context import ExecutionContext
from neptune.legacy.internal.utils.deprecation import legacy_client_deprecation
from neptune.legacy.internal.utils.image import get_image_content
_logger = logging.getLogger(__name__)
class Experiment(LegacyExperiment):
"""A class for managing Neptune experiment.
Each time User creates new experiment instance of this class is created.
It lets you manage experiment, :meth:`~neptune.experiments.Experiment.log_metric`,
:meth:`~neptune.experiments.Experiment.log_text`,
:meth:`~neptune.experiments.Experiment.log_image`,
:meth:`~neptune.experiments.Experiment.set_property`,
and much more.
Args:
backend (:obj:`neptune.ApiClient`): A ApiClient object
project (:obj:`neptune.Project`): The project this experiment belongs to
_id (:obj:`str`): Experiment id
internal_id (:obj:`str`): internal ID
Example:
Assuming that `project` is an instance of :class:`~neptune.projects.Project`.
.. code:: python3
experiment = project.create_experiment()
Warning:
User should never create instances of this class manually.
Always use: :meth:`~neptune.projects.Project.create_experiment`.
"""
IMAGE_SIZE_LIMIT_MB = 15
@legacy_client_deprecation
def __init__(self, backend, project, _id, internal_id):
self._backend = backend
self._project = project
self._id = _id
self._internal_id = internal_id
self._channels_values_sender = ChannelsValuesSender(self)
self._execution_context = ExecutionContext(backend, self)
def _raise_new_client_expected(self):
raise NeptuneIncorrectImportException()
def __getitem__(self, item):
self._raise_new_client_expected()
def __setitem__(self, key, value):
self._raise_new_client_expected()
def __delitem__(self, item):
self._raise_new_client_expected()
@property
def id(self):
"""Experiment short id
| Combination of project key and unique experiment number.
| Format is ``<project_key>-<experiment_number>``, for example: ``MPI-142``.
Returns:
:obj:`str` - experiment short id
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
exp_id = experiment.id
"""
return self._id
@property
def name(self):
"""Experiment name
Returns:
:obj:`str` experiment name
Examples:
Assuming that `project` is an instance of :class:`~neptune.projects.Project`.
.. code:: python3
experiment = project.create_experiment('exp_name')
exp_name = experiment.name
"""
return self._backend.get_experiment(self._internal_id).name
@property
def state(self):
"""Current experiment state
Possible values: `'running'`, `'succeeded'`, `'failed'`, `'aborted'`.
Returns:
:obj:`str` - current experiment state
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
state_str = experiment.state
"""
return self._backend.get_experiment(self._internal_id).state
@property
def internal_id(self):
return self._internal_id
@property
def limits(self):
return {"channels": {"numeric": 1000, "text": 100, "image": 100}}
def get_system_properties(self):
"""Retrieve experiment properties.
| Experiment properties are for example: `owner`, `created`, `name`, `hostname`.
| List of experiment properties may change over time.
Returns:
:obj:`dict` - dictionary mapping a property name to value.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
sys_properties = experiment.get_system_properties
"""
experiment = self._backend.get_experiment(self._internal_id)
return {
"id": experiment.shortId,
"name": experiment.name,
"created": experiment.timeOfCreation,
"finished": experiment.timeOfCompletion,
"running_time": experiment.runningTime,
"owner": experiment.owner,
"storage_size": experiment.storageSize,
"channels_size": experiment.channelsSize,
"size": experiment.storageSize + experiment.channelsSize,
"tags": experiment.tags,
"notes": experiment.description,
"description": experiment.description,
"hostname": experiment.hostname,
}
def get_tags(self):
"""Get tags associated with experiment.
Returns:
:obj:`list` of :obj:`str` with all tags for this experiment.
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
experiment.get_tags()
"""
return self._backend.get_experiment(self._internal_id).tags
def append_tag(self, tag, *tags):
"""Append tag(s) to the current experiment.
Alias: :meth:`~neptune.experiments.Experiment.append_tags`.
Only ``[a-zA-Z0-9]`` and ``-`` (dash) characters are allowed in tags.
Args:
tag (single :obj:`str` or multiple :obj:`str` or :obj:`list` of :obj:`str`):
Tag(s) to add to the current experiment.
* If :obj:`str` is passed, singe tag is added.
* If multiple - comma separated - :obj:`str` are passed, all of them are added as tags.
* If :obj:`list` of :obj:`str` is passed, all elements of the :obj:`list` are added as tags.
Examples:
.. code:: python3
neptune.append_tag('new-tag') # single tag
neptune.append_tag('first-tag', 'second-tag', 'third-tag') # few str
neptune.append_tag(['first-tag', 'second-tag', 'third-tag']) # list of str
"""
if isinstance(tag, list):
tags_list = tag
else:
tags_list = [tag] + list(tags)
self._backend.update_tags(experiment=self, tags_to_add=tags_list, tags_to_delete=[])
def append_tags(self, tag, *tags):
"""Append tag(s) to the current experiment.
Alias for: :meth:`~neptune.experiments.Experiment.append_tag`
"""
self.append_tag(tag, *tags)
def remove_tag(self, tag):
"""Removes single tag from the experiment.
Args:
tag (:obj:`str`): Tag to be removed
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
# assuming experiment has tags: `['tag-1', 'tag-2']`.
experiment.remove_tag('tag-1')
Note:
Removing a tag that is not assigned to this experiment is silently ignored.
"""
self._backend.update_tags(experiment=self, tags_to_add=[], tags_to_delete=[tag])
def get_channels(self):
"""Alias for :meth:`~neptune.experiments.Experiment.get_logs`"""
return self.get_logs()
def get_logs(self):
"""Retrieve all log names along with their last values for this experiment.
Returns:
:obj:`dict` - A dictionary mapping a log names to the log's last value.
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
exp_logs = experiment.get_logs()
"""
def get_channel_value(ch):
return float(ch.y) if ch.y is not None and ch.channelType == "numeric" else ch.y
return {key: get_channel_value(ch) for key, ch in self._backend.get_channels(self).items()}
def _get_system_channels(self):
return self._backend.get_system_channels(self)
def send_metric(self, channel_name, x, y=None, timestamp=None):
"""Log metrics (numeric values) in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_metric`
"""
return self.log_metric(channel_name, x, y, timestamp)
def log_metric(self, log_name, x, y=None, timestamp=None):
"""Log metrics (numeric values) in Neptune
| If a log with provided ``log_name`` does not exist, it is created automatically.
| If log exists (determined by ``log_name``), then new value is appended to it.
Args:
log_name (:obj:`str`): The name of log, i.e. `mse`, `loss`, `accuracy`.
x (:obj:`double`): Depending, whether ``y`` parameter is passed:
* ``y`` not passed: The value of the log (data-point).
* ``y`` passed: Index of log entry being appended. Must be strictly increasing.
y (:obj:`double`, optional, default is ``None``): The value of the log (data-point).
timestamp (:obj:`time`, optional, default is ``None``):
Timestamp to be associated with log entry. Must be Unix time.
If ``None`` is passed, `time.time() <https://docs.python.org/3.6/library/time.html#time.time>`_
(Python 3.6 example) is invoked to obtain timestamp.
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment` and
'accuracy' log does not exists:
.. code:: python3
# Both calls below have the same effect
# Common invocation, providing log name and value
experiment.log_metric('accuracy', 0.5)
experiment.log_metric('accuracy', 0.65)
experiment.log_metric('accuracy', 0.8)
# Providing both x and y params
experiment.log_metric('accuracy', 0, 0.5)
experiment.log_metric('accuracy', 1, 0.65)
experiment.log_metric('accuracy', 2, 0.8)
# Common invocation, logging loss tensor in PyTorch
loss = torch.Tensor([0.89])
experiment.log_metric('log-loss', loss)
# Common invocation, logging metric tensor in Tensorflow
acc = tf.constant([0.93])
experiment.log_metric('accuracy', acc)
f1_score = tf.constant(0.78)
experiment.log_metric('f1_score', f1_score)
Note:
For efficiency, logs are uploaded in batches via a queue.
Hence, if you log a lot of data, you may experience slight delays in Neptune web application.
Note:
Passing either ``x`` or ``y`` coordinate as NaN or +/-inf causes this log entry to be ignored.
Warning is printed to ``stdout``.
"""
x, y = self._get_valid_x_y(x, y)
if not is_float(y):
raise InvalidChannelValue(expected_type="float", actual_type=type(y).__name__)
if is_nan_or_inf(y):
_logger.warning(
"Invalid metric value: %s for channel %s. "
"Metrics with nan or +/-inf values will not be sent to server",
y,
log_name,
)
elif x is not None and is_nan_or_inf(x):
_logger.warning(
"Invalid metric x-coordinate: %s for channel %s. "
"Metrics with nan or +/-inf x-coordinates will not be sent to server",
x,
log_name,
)
else:
value = ChannelValue(x, dict(numeric_value=y), timestamp)
self._channels_values_sender.send(log_name, ChannelType.NUMERIC.value, value)
def send_text(self, channel_name, x, y=None, timestamp=None):
"""Log text data in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_text`
"""
return self.log_text(channel_name, x, y, timestamp)
def log_text(self, log_name, x, y=None, timestamp=None):
"""Log text data in Neptune
| If a log with provided ``log_name`` does not exist, it is created automatically.
| If log exists (determined by ``log_name``), then new value is appended to it.
Args:
log_name (:obj:`str`): The name of log, i.e. `mse`, `my_text_data`, `timing_info`.
x (:obj:`double` or :obj:`str`): Depending, whether ``y`` parameter is passed:
* ``y`` not passed: The value of the log (data-point). Must be ``str``.
* ``y`` passed: Index of log entry being appended. Must be strictly increasing.
y (:obj:`str`, optional, default is ``None``): The value of the log (data-point).
timestamp (:obj:`time`, optional, default is ``None``):
Timestamp to be associated with log entry. Must be Unix time.
If ``None`` is passed, `time.time() <https://docs.python.org/3.6/library/time.html#time.time>`_
(Python 3.6 example) is invoked to obtain timestamp.
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
# common case, where log name and data are passed
neptune.log_text('my_text_data', str(data_item))
# log_name, x and timestamp are passed
neptune.log_text(log_name='logging_losses_as_text',
x=str(val_loss),
timestamp=1560430912)
Note:
For efficiency, logs are uploaded in batches via a queue.
Hence, if you log a lot of data, you may experience slight delays in Neptune web application.
Note:
Passing ``x`` coordinate as NaN or +/-inf causes this log entry to be ignored.
Warning is printed to ``stdout``.
"""
x, y = self._get_valid_x_y(x, y)
if x is not None and is_nan_or_inf(x):
x = None
if not isinstance(y, six.string_types):
raise InvalidChannelValue(expected_type="str", actual_type=type(y).__name__)
if x is not None and is_nan_or_inf(x):
_logger.warning(
"Invalid metric x-coordinate: %s for channel %s. "
"Metrics with nan or +/-inf x-coordinates will not be sent to server",
x,
log_name,
)
else:
value = ChannelValue(x, dict(text_value=y), timestamp)
self._channels_values_sender.send(log_name, ChannelType.TEXT.value, value)
def send_image(self, channel_name, x, y=None, name=None, description=None, timestamp=None):
"""Log image data in Neptune.
Alias for :meth:`~neptune.experiments.Experiment.log_image`
"""
return self.log_image(channel_name, x, y, name, description, timestamp)
def log_image(self, log_name, x, y=None, image_name=None, description=None, timestamp=None):
"""Log image data in Neptune
| If a log with provided ``log_name`` does not exist, it is created automatically.
| If log exists (determined by ``log_name``), then new value is appended to it.
Args:
log_name (:obj:`str`): The name of log, i.e. `bboxes`, `visualisations`, `sample_images`.
x (:obj:`double`): Depending, whether ``y`` parameter is passed:
* ``y`` not passed: The value of the log (data-point). See ``y`` parameter.
* ``y`` passed: Index of log entry being appended. Must be strictly increasing.
y (multiple types supported, optional, default is ``None``):
The value of the log (data-point). Can be one of the following types:
* :obj:`PIL image`
`Pillow docs <https://pillow.readthedocs.io/en/latest/reference/Image.html#image-module>`_
* :obj:`matplotlib.figure.Figure`
`Matplotlib 3.1.1 docs <https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.figure.Figure.html>`_
* :obj:`str` - path to image file
* 2-dimensional :obj:`numpy.array` with values in the [0, 1] range - interpreted as grayscale image
* 3-dimensional :obj:`numpy.array` with values in the [0, 1] range - behavior depends on last dimension
* if last dimension is 1 - interpreted as grayscale image
* if last dimension is 3 - interpreted as RGB image
* if last dimension is 4 - interpreted as RGBA image
* :obj:`torch.tensor` with values in the [0, 1] range.
:obj:`torch.tensor` is converted to :obj:`numpy.array` via `.numpy()` method and logged.
* :obj:`tensorflow.tensor` with values in [0, 1] range.
:obj:`tensorflow.tensor` is converted to :obj:`numpy.array` via `.numpy()` method and logged.
image_name (:obj:`str`, optional, default is ``None``): Image name
description (:obj:`str`, optional, default is ``None``): Image description
timestamp (:obj:`time`, optional, default is ``None``):
Timestamp to be associated with log entry. Must be Unix time.
If ``None`` is passed, `time.time() <https://docs.python.org/3.6/library/time.html#time.time>`_
(Python 3.6 example) is invoked to obtain timestamp.
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
# path to image file
experiment.log_image('bbox_images', 'pictures/image.png')
experiment.log_image('bbox_images', x=5, 'pictures/image.png')
experiment.log_image('bbox_images', 'pictures/image.png', image_name='difficult_case')
# PIL image
img = PIL.Image.new('RGB', (60, 30), color = 'red')
experiment.log_image('fig', img)
# 2d numpy array
array = numpy.random.rand(300, 200)*255
experiment.log_image('fig', array)
# 3d grayscale array
array = numpy.random.rand(300, 200, 1)*255
experiment.log_image('fig', array)
# 3d RGB array
array = numpy.random.rand(300, 200, 3)*255
experiment.log_image('fig', array)
# 3d RGBA array
array = numpy.random.rand(300, 200, 4)*255
experiment.log_image('fig', array)
# torch tensor
tensor = torch.rand(10, 20)
experiment.log_image('fig', tensor)
# tensorflow tensor
tensor = tensorflow.random.uniform(shape=[10, 20])
experiment.log_image('fig', tensor)
# matplotlib figure example 1
from matplotlib import pyplot
pyplot.plot([1, 2, 3, 4])
pyplot.ylabel('some numbers')
experiment.log_image('plots', plt.gcf())
# matplotlib figure example 2
from matplotlib import pyplot
import numpy
numpy.random.seed(19680801)
data = numpy.random.randn(2, 100)
figure, axs = pyplot.subplots(2, 2, figsize=(5, 5))
axs[0, 0].hist(data[0])
axs[1, 0].scatter(data[0], data[1])
axs[0, 1].plot(data[0], data[1])
axs[1, 1].hist2d(data[0], data[1])
experiment.log_image('diagrams', figure)
Note:
For efficiency, logs are uploaded in batches via a queue.
Hence, if you log a lot of data, you may experience slight delays in Neptune web application.
Note:
Passing ``x`` coordinate as NaN or +/-inf causes this log entry to be ignored.
Warning is printed to ``stdout``.
Warning:
Only images up to 15MB are supported. Larger files will not be logged to Neptune.
"""
x, y = self._get_valid_x_y(x, y)
if x is not None and is_nan_or_inf(x):
x = None
image_content = get_image_content(y)
if len(image_content) > self.IMAGE_SIZE_LIMIT_MB * 1024 * 1024:
_logger.warning(
"Your image is larger than %dMB. Neptune supports logging images smaller than %dMB. "
"Resize or increase compression of this image",
self.IMAGE_SIZE_LIMIT_MB,
self.IMAGE_SIZE_LIMIT_MB,
)
image_content = None
input_image = dict(name=image_name, description=description)
if image_content:
input_image["data"] = base64.b64encode(image_content).decode("utf-8")
if x is not None and is_nan_or_inf(x):
_logger.warning(
"Invalid metric x-coordinate: %s for channel %s. "
"Metrics with nan or +/-inf x-coordinates will not be sent to server",
x,
log_name,
)
else:
value = ChannelValue(x, dict(image_value=input_image), timestamp)
self._channels_values_sender.send(log_name, ChannelType.IMAGE.value, value)
def send_artifact(self, artifact, destination=None):
"""Save an artifact (file) in experiment storage.
Alias for :meth:`~neptune.experiments.Experiment.log_artifact`
"""
return self.log_artifact(artifact, destination)
def log_artifact(self, artifact, destination=None):
"""Save an artifact (file) in experiment storage.
Args:
artifact (:obj:`str` or :obj:`IO object`):
A path to the file in local filesystem or IO object. It can be open
file descriptor or in-memory buffer like `io.StringIO` or `io.BytesIO`.
destination (:obj:`str`, optional, default is ``None``):
A destination path.
If ``None`` is passed, an artifact file name will be used.
Note:
If you use in-memory buffers like `io.StringIO` or `io.BytesIO`, remember that in typical case when you
write to such a buffer, it's current position is set to the end of the stream, so in order to read it's
content, you need to move back it's position to the beginning.
We recommend to call seek(0) on the in-memory buffers before passing it to Neptune.
Additionally, if you provide `io.StringIO`, it will be encoded in 'utf-8' before sent to Neptune.
Raises:
`FileNotFound`: When ``artifact`` file was not found.
`StorageLimitReached`: When storage limit in the project has been reached.
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
# simple use
experiment.log_artifact('images/wrong_prediction_1.png')
# save file in other directory
experiment.log_artifact('images/wrong_prediction_1.png', 'validation/images/wrong_prediction_1.png')
# save file under different name
experiment.log_artifact('images/wrong_prediction_1.png', 'images/my_image_1.png')
"""
self._backend.log_artifact(self, artifact, destination)
def delete_artifacts(self, path):
"""Removes an artifact(s) (file/directory) from the experiment storage.
Args:
path (:obj:`list` or :obj:`str`): Path or list of paths to remove from the experiment's output
Raises:
`FileNotFound`: If a path in experiment artifacts does not exist.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
experiment.delete_artifacts('forest_results.pkl')
experiment.delete_artifacts(['forest_results.pkl', 'directory'])
experiment.delete_artifacts('')
"""
self._backend.delete_artifacts(self, path)
def download_artifact(self, path, destination_dir=None):
"""Download an artifact (file) from the experiment storage.
Download a file indicated by ``path`` from the experiment artifacts and save it in ``destination_dir``.
Args:
path (:obj:`str`): Path to the file to be downloaded.
destination_dir (:obj:`str`):
The directory where the file will be downloaded.
If ``None`` is passed, the file will be downloaded to the current working directory.
Raises:
`NotADirectory`: When ``destination_dir`` is not a directory.
`FileNotFound`: If a path in experiment artifacts does not exist.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
experiment.download_artifact('forest_results.pkl', '/home/user/files/')
"""
return self._backend.download_artifact(self, path, destination_dir)
def download_sources(self, path=None, destination_dir=None):
"""Download a directory or a single file from experiment's sources as a ZIP archive.
Download a subdirectory (or file) ``path`` from the experiment sources and save it in ``destination_dir``
as a ZIP archive. The name of an archive will be a name of downloaded directory (or file) with '.zip' extension.
Args:
path (:obj:`str`):
Path of a directory or file in experiment sources to be downloaded.
If ``None`` is passed, all source files will be downloaded.
destination_dir (:obj:`str`): The directory where the archive will be downloaded.
If ``None`` is passed, the archive will be downloaded to the current working directory.
Raises:
`NotADirectory`: When ``destination_dir`` is not a directory.
`FileNotFound`: If a path in experiment sources does not exist.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
# Download all experiment sources to current working directory
experiment.download_sources()
# Download a single directory
experiment.download_sources('src/my-module')
# Download all experiment sources to user-defined directory
experiment.download_sources(destination_dir='/tmp/sources/')
# Download a single directory to user-defined directory
experiment.download_sources('src/my-module', 'sources/')
"""
return self._backend.download_sources(self, path, destination_dir)
def download_artifacts(self, path=None, destination_dir=None):
"""Download a directory or a single file from experiment's artifacts as a ZIP archive.
Download a subdirectory (or file) ``path`` from the experiment artifacts and save it in ``destination_dir``
as a ZIP archive. The name of an archive will be a name of downloaded directory (or file) with '.zip' extension.
Args:
path (:obj:`str`):
Path of a directory or file in experiment artifacts to be downloaded.
If ``None`` is passed, all artifacts will be downloaded.
destination_dir (:obj:`str`): The directory where the archive will be downloaded.
If ``None`` is passed, the archive will be downloaded to the current working directory.
Raises:
`NotADirectory`: When ``destination_dir`` is not a directory.
`FileNotFound`: If a path in experiment artifacts does not exist.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
# Download all experiment artifacts to current working directory
experiment.download_artifacts()
# Download a single directory
experiment.download_artifacts('data/images')
# Download all experiment artifacts to user-defined directory
experiment.download_artifacts(destination_dir='/tmp/artifacts/')
# Download a single directory to user-defined directory
experiment.download_artifacts('data/images', 'artifacts/')
"""
return self._backend.download_artifacts(self, path, destination_dir)
def reset_log(self, log_name):
"""Resets the log.
Removes all data from the log and enables it to be reused from scratch.
Args:
log_name (:obj:`str`): The name of log to reset.
Raises:
`ChannelDoesNotExist`: When the log with name ``log_name`` does not exist on the server.
Example:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
experiment.reset_log('my_metric')
Note:
Check Neptune web application to see that reset charts have no data.
"""
channel = self._find_channel(log_name, ChannelNamespace.USER)
if channel is None:
raise ChannelDoesNotExist(self.id, log_name)
self._backend.reset_channel(self, channel.id, log_name, channel.channelType)
def get_parameters(self):
"""Retrieve parameters for this experiment.
Returns:
:obj:`dict` - dictionary mapping a parameter name to value.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
exp_params = experiment.get_parameters()
"""
experiment = self._backend.get_experiment(self.internal_id)
return dict((p.name, self._convert_parameter_value(p.value, p.parameterType)) for p in experiment.parameters)
def get_properties(self):
"""Retrieve User-defined properties for this experiment.
Returns:
:obj:`dict` - dictionary mapping a property key to value.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`.
.. code:: python3
exp_properties = experiment.get_properties()
"""
experiment = self._backend.get_experiment(self.internal_id)
return dict((p.key, p.value) for p in experiment.properties)
def set_property(self, key, value):
"""Set `key-value` pair as an experiment property.
If property with given ``key`` does not exist, it adds a new one.
Args:
key (:obj:`str`): Property key.
value (:obj:`obj`): New value of a property.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
experiment.set_property('model', 'LightGBM')
experiment.set_property('magic-number', 7)
"""
return self._backend.set_property(
experiment=self,
key=key,
value=value,
)
def remove_property(self, key):
"""Removes a property with given key.
Args:
key (single :obj:`str`):
Key of property to remove.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
experiment.remove_property('host')
"""
return self._backend.remove_property(
experiment=self,
key=key,
)
def get_hardware_utilization(self):
"""Retrieve GPU, CPU and memory utilization data.
Get hardware utilization metrics for entire experiment as a single
`pandas.DataFrame <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html>`_
object. Returned DataFrame has following columns (assuming single GPU with 0 index):
* `x_ram` - time (in milliseconds) from the experiment start,
* `y_ram` - memory usage in GB,
* `x_cpu` - time (in milliseconds) from the experiment start,
* `y_cpu` - CPU utilization percentage (0-100),
* `x_gpu_util_0` - time (in milliseconds) from the experiment start,
* `y_gpu_util_0` - GPU utilization percentage (0-100),
* `x_gpu_mem_0` - time (in milliseconds) from the experiment start,
* `y_gpu_mem_0` - GPU memory usage in GB.
| If more GPUs are available they have their separate columns with appropriate indices (0, 1, 2, ...),
for example: `x_gpu_util_1`, `y_gpu_util_1`.
| The returned DataFrame may contain ``NaN`` s if one of the metrics has more values than others.
Returns:
:obj:`pandas.DataFrame` - DataFrame containing the hardware utilization metrics.
Examples:
The following values denote that after 3 seconds, the experiment used 16.7 GB of RAM
* `x_ram` = 3000
* `y_ram` = 16.7
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
hardware_df = experiment.get_hardware_utilization()
"""
metrics_csv = self._backend.get_metrics_csv(self)
try:
return pd.read_csv(metrics_csv)
except EmptyDataError:
return pd.DataFrame()
def get_numeric_channels_values(self, *channel_names):
"""Retrieve values of specified metrics (numeric logs).
The returned
`pandas.DataFrame <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html>`_
contains 1 additional column `x` along with the requested metrics.
Args:
*channel_names (one or more :obj:`str`): comma-separated metric names.
Returns:
:obj:`pandas.DataFrame` - DataFrame containing values for the requested metrics.
| The returned DataFrame may contain ``NaN`` s if one of the metrics has more values than others.
Example:
Invoking ``get_numeric_channels_values('loss', 'auc')`` returns DataFrame with columns
`x`, `loss`, `auc`.
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
batch_channels = experiment.get_numeric_channels_values('batch-1-loss', 'batch-2-metric')
epoch_channels = experiment.get_numeric_channels_values('epoch-1-loss', 'epoch-2-metric')
Note:
It's good idea to get metrics with common temporal pattern (like iteration or batch/epoch number).
Thanks to this each row of returned DataFrame has metrics from the same moment in experiment.
For example, combine epoch metrics to one DataFrame and batch metrics to the other.
"""
channels_data = {}
channels_by_name = self._backend.get_channels(self)
for channel_name in channel_names:
channel_id = channels_by_name[channel_name].id
try:
channels_data[channel_name] = pd.read_csv(
self._backend.get_channel_points_csv(self, channel_id, channel_name),
header=None,
names=["x_{}".format(channel_name), "y_{}".format(channel_name)],
dtype=float,
)
except EmptyDataError:
channels_data[channel_name] = pd.DataFrame(
columns=["x_{}".format(channel_name), "y_{}".format(channel_name)],
dtype=float,
)
return align_channels_on_x(pd.concat(channels_data.values(), axis=1, sort=False))
def _start(
self,
abort_callback=None,
logger=None,
upload_stdout=True,
upload_stderr=True,
send_hardware_metrics=True,
run_monitoring_thread=True,
handle_uncaught_exceptions=True,
):
self._execution_context.start(
abort_callback=abort_callback,
logger=logger,
upload_stdout=upload_stdout,
upload_stderr=upload_stderr,
send_hardware_metrics=send_hardware_metrics,
run_monitoring_thread=run_monitoring_thread,
handle_uncaught_exceptions=handle_uncaught_exceptions,
)
def stop(self, exc_tb=None):
"""Marks experiment as finished (succeeded or failed).
Args:
exc_tb (:obj:`str`, optional, default is ``None``): Additional traceback information
to be stored in experiment details in case of failure (stacktrace, etc).
If this argument is ``None`` the experiment will be marked as succeeded.
Otherwise, experiment will be marked as failed.
Examples:
Assuming that `experiment` is an instance of :class:`~neptune.experiments.Experiment`:
.. code:: python3
# Marks experiment as succeeded
experiment.stop()
# Assuming 'ex' is some exception,
# it marks experiment as failed with exception info in experiment details.
experiment.stop(str(ex))
"""
self._channels_values_sender.join()
try:
if exc_tb is not None:
self._backend.mark_failed(self, exc_tb)
except ExperimentAlreadyFinished:
pass
self._execution_context.stop()
self._project._remove_stopped_experiment(self)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_tb is None:
self.stop()
else:
self.stop("\n".join(traceback.format_tb(exc_tb)) + "\n" + repr(exc_val))
def __str__(self):
return "Experiment({})".format(self.id)
def __repr__(self):
return str(self)
def __eq__(self, o):
return self._id == o._id and self._internal_id == o._internal_id and self._project == o._project
def __ne__(self, o):
return not self.__eq__(o)
@staticmethod
def _convert_parameter_value(value, parameter_type):
if parameter_type == "double":
return float(value)
else:
return value
@staticmethod
def _get_valid_x_y(x, y):
"""
The goal of this function is to allow user to call experiment.log_* with any of:
- single parameter treated as y value
- both parameters (named/unnamed)
- single named y parameter
If intended X-coordinate is provided, it is validated to be a float value
"""
if x is None and y is None:
raise NoChannelValue()
if x is None and y is not None:
return None, y
if x is not None and y is None:
return None, x
if x is not None and y is not None:
if not is_float(x):
raise InvalidChannelValue(expected_type="float", actual_type=type(x).__name__)
return x, y
def _send_channels_values(self, channels_with_values):
self._backend.send_channels_values(self, channels_with_values)
def _get_channels(self, channels_names_with_types):
existing_channels = self._backend.get_channels(self)
channels_by_name = {}
for channel_name, channel_type in channels_names_with_types:
channel = existing_channels.get(channel_name, None)
if channel is None:
channel = self._create_channel(channel_name, channel_type)
channels_by_name[channel.name] = channel
return channels_by_name
def _get_channel(self, channel_name, channel_type, channel_namespace=ChannelNamespace.USER):
channel = self._find_channel(channel_name, channel_namespace)
if channel is None:
channel = self._create_channel(channel_name, channel_type, channel_namespace)
return channel
def _find_channel(self, channel_name, channel_namespace):
if channel_namespace == ChannelNamespace.USER:
return self._backend.get_channels(self).get(channel_name, None)
elif channel_namespace == ChannelNamespace.SYSTEM:
return self._get_system_channels().get(channel_name, None)
else:
raise RuntimeError("Unknown channel namespace {}".format(channel_namespace))
def _create_channel(self, channel_name, channel_type, channel_namespace=ChannelNamespace.USER):
if channel_namespace == ChannelNamespace.USER:
return self._backend.create_channel(self, channel_name, channel_type)
elif channel_namespace == ChannelNamespace.SYSTEM:
return self._backend.create_system_channel(self, channel_name, channel_type)
else:
raise RuntimeError("Unknown channel namespace {}".format(channel_namespace))
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["GitInfo"]
from neptune.common.git_info import GitInfo
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
try:
import psutil
PSUTIL_INSTALLED = True
except ImportError:
PSUTIL_INSTALLED = False
class CustomAbortImpl(object):
def __init__(self, runnable):
self.__runnable = runnable
def abort(self):
self.__runnable()
class DefaultAbortImpl(object):
KILL_TIMEOUT = 5
def __init__(self, pid):
self._pid = pid
@staticmethod
def requirements_installed():
return PSUTIL_INSTALLED
def abort(self):
try:
processes = self._get_process_with_children(psutil.Process(self._pid))
except psutil.NoSuchProcess:
processes = []
for p in processes:
self._abort(p)
_, alive = psutil.wait_procs(processes, timeout=self.KILL_TIMEOUT)
for p in alive:
self._kill(p)
@staticmethod
def _get_process_with_children(process):
try:
return [process] + process.children(recursive=True)
except psutil.NoSuchProcess:
return []
@staticmethod
def _abort(process):
try:
process.terminate()
except psutil.NoSuchProcess:
pass
def _kill(self, process):
for process in self._get_process_with_children(process):
try:
if process.is_running():
process.kill()
except psutil.NoSuchProcess:
pass
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["HostedNeptuneBackendApiClient", "OfflineBackendApiClient"]
from neptune.legacy.internal.api_clients.hosted_api_clients.hosted_backend_api_client import (
HostedNeptuneBackendApiClient,
)
from neptune.legacy.internal.api_clients.offline_backend import OfflineBackendApiClient
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.legacy.backend import BackendApiClient
from neptune.legacy.exceptions import InvalidNeptuneBackend
from neptune.legacy.internal.api_clients import (
HostedNeptuneBackendApiClient,
OfflineBackendApiClient,
)
def backend_factory(*, backend_name, api_token=None, proxies=None) -> BackendApiClient:
if backend_name == "offline":
return OfflineBackendApiClient()
elif backend_name is None:
return HostedNeptuneBackendApiClient(api_token, proxies)
else:
raise InvalidNeptuneBackend(backend_name)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
@dataclasses.dataclass
class MultipartConfig:
min_chunk_size: int
max_chunk_size: int
max_chunk_count: int
max_single_part_size: int
class ClientConfig(object):
def __init__(
self,
api_url,
display_url,
min_recommended_version,
min_compatible_version,
max_compatible_version,
multipart_config,
):
self._api_url = api_url
self._display_url = display_url
self._min_recommended_version = min_recommended_version
self._min_compatible_version = min_compatible_version
self._max_compatible_version = max_compatible_version
self._multipart_config = multipart_config
@property
def api_url(self):
return self._api_url
@property
def display_url(self):
return self._display_url
@property
def min_recommended_version(self):
return self._min_recommended_version
@property
def min_compatible_version(self):
return self._min_compatible_version
@property
def max_compatible_version(self):
return self._max_compatible_version
@property
def multipart_config(self):
return self._multipart_config
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import json
import logging
import os
from neptune.legacy import envs
from neptune.legacy.api_exceptions import InvalidApiKey
from neptune.legacy.constants import (
ANONYMOUS,
ANONYMOUS_API_TOKEN,
)
from neptune.legacy.exceptions import NeptuneMissingApiTokenException
_logger = logging.getLogger(__name__)
class Credentials(object):
"""It formats your Neptune api token to the format that can be understood by the Neptune Client.
A constructor allowing you to pass the Neptune API token.
Args:
api_token(str): This is a secret API key that you can retrieve by running
`$ neptune account api-token get`
Attributes:
api_token: This is a secret API key that was passed at instantiation.
Examples:
>>> from neptune.internal.backends.credentials import Credentials
>>> credentials=Credentials('YOUR_NEPTUNE_API_KEY')
Alternatively you can create an environment variable by running:
$ export NEPTUNE_API_TOKEN=YOUR_API_TOKEN
which will allow you to use the same method without `api_token` parameter provided.
>>> credentials=Credentials()
Note:
For security reasons it is recommended to provide api_token through environment variable `NEPTUNE_API_TOKEN`.
You can do that by going to your console and running:
$ export NEPTUNE_API_TOKEN=YOUR_API_TOKEN`
Token provided through environment variable takes precedence over `api_token` parameter.
"""
def __init__(self, api_token=None):
if api_token is None:
api_token = os.getenv(envs.API_TOKEN_ENV_NAME)
if api_token == ANONYMOUS:
api_token = ANONYMOUS_API_TOKEN
self._api_token = api_token
if self.api_token is None:
raise NeptuneMissingApiTokenException()
token_dict = self._api_token_to_dict(self.api_token)
self._token_origin_address = token_dict["api_address"]
self._api_url = token_dict["api_url"] if "api_url" in token_dict else None
@property
def api_token(self):
return self._api_token
@property
def token_origin_address(self):
return self._token_origin_address
@property
def api_url_opt(self):
return self._api_url
@staticmethod
def _api_token_to_dict(api_token):
try:
return json.loads(base64.b64decode(api_token.encode()).decode("utf-8"))
except Exception:
raise InvalidApiKey()
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import math
import os
import re
import sys
import time
from collections import namedtuple
from http.client import NOT_FOUND
from io import StringIO
from itertools import groupby
from pathlib import Path
from typing import (
TYPE_CHECKING,
Dict,
List,
)
import requests
import six
from bravado.exception import HTTPNotFound
import neptune.exceptions as alpha_exceptions
from neptune.attributes import constants as alpha_consts
from neptune.attributes.constants import (
MONITORING_TRACEBACK_ATTRIBUTE_PATH,
SYSTEM_FAILED_ATTRIBUTE_PATH,
)
from neptune.common import exceptions as common_exceptions
from neptune.common.exceptions import ClientHttpError
from neptune.common.storage.storage_utils import normalize_file_name
from neptune.common.utils import (
NoopObject,
assure_directory_exists,
)
from neptune.internal import operation as alpha_operation
from neptune.internal.backends import hosted_file_operations as alpha_hosted_file_operations
from neptune.internal.backends.api_model import AttributeType
from neptune.internal.backends.operation_api_name_visitor import OperationApiNameVisitor as AlphaOperationApiNameVisitor
from neptune.internal.backends.operation_api_object_converter import (
OperationApiObjectConverter as AlphaOperationApiObjectConverter,
)
from neptune.internal.backends.utils import handle_server_raw_response_messages
from neptune.internal.operation import (
AssignBool,
AssignString,
ConfigFloatSeries,
LogFloats,
LogStrings,
)
from neptune.internal.utils import (
base64_decode,
base64_encode,
)
from neptune.internal.utils import paths as alpha_path_utils
from neptune.internal.utils.paths import parse_path
from neptune.legacy.api_exceptions import (
ExperimentNotFound,
ExperimentOperationErrors,
NotebookNotFound,
PathInExperimentNotFound,
ProjectNotFound,
)
from neptune.legacy.backend import LeaderboardApiClient
from neptune.legacy.checkpoint import Checkpoint
from neptune.legacy.exceptions import (
DeleteArtifactUnsupportedInAlphaException,
DownloadArtifactsUnsupportedException,
DownloadArtifactUnsupportedException,
DownloadSourcesException,
FileNotFound,
NeptuneException,
)
from neptune.legacy.experiments import Experiment
from neptune.legacy.internal.api_clients.hosted_api_clients.mixins import HostedNeptuneMixin
from neptune.legacy.internal.api_clients.hosted_api_clients.utils import legacy_with_api_exceptions_handler
from neptune.legacy.internal.channels.channels import (
ChannelNamespace,
ChannelType,
ChannelValueType,
)
from neptune.legacy.internal.utils.alpha_integration import (
AlphaChannelDTO,
AlphaChannelWithValueDTO,
AlphaParameterDTO,
AlphaPropertyDTO,
channel_type_to_clear_operation,
channel_type_to_operation,
channel_value_type_to_operation,
deprecated_img_to_alpha_image,
)
from neptune.legacy.internal.websockets.reconnecting_websocket_factory import ReconnectingWebsocketFactory
from neptune.legacy.model import (
ChannelWithLastValue,
LeaderboardEntry,
)
from neptune.legacy.notebook import Notebook
_logger = logging.getLogger(__name__)
LegacyExperiment = namedtuple(
"LegacyExperiment",
"shortId "
"name "
"timeOfCreation "
"timeOfCompletion "
"runningTime "
"owner "
"storageSize "
"channelsSize "
"tags "
"description "
"hostname "
"state "
"properties "
"parameters",
)
LegacyLeaderboardEntry = namedtuple(
"LegacyExperiment",
"id "
"organizationName "
"projectName "
"shortId "
"name "
"state "
"timeOfCreation "
"timeOfCompletion "
"runningTime "
"owner "
"size "
"tags "
"description "
"channelsLastValues "
"parameters "
"properties",
)
if TYPE_CHECKING:
from neptune.legacy.internal.api_clients import HostedNeptuneBackendApiClient
class HostedAlphaLeaderboardApiClient(HostedNeptuneMixin, LeaderboardApiClient):
@legacy_with_api_exceptions_handler
def __init__(self, backend_api_client: "HostedNeptuneBackendApiClient"):
self._backend_api_client = backend_api_client
self._client_config = self._create_client_config(
api_token=self.credentials.api_token, backend_client=self.backend_client
)
self.leaderboard_swagger_client = self._get_swagger_client(
"{}/api/leaderboard/swagger.json".format(self._client_config.api_url),
self._backend_api_client.http_client,
)
if sys.version_info >= (3, 7):
try:
os.register_at_fork(after_in_child=self._handle_fork_in_child)
except AttributeError:
pass
def _handle_fork_in_child(self):
self.leaderboard_swagger_client = NoopObject()
@property
def http_client(self):
return self._backend_api_client.http_client
@property
def backend_client(self):
return self._backend_api_client.backend_client
@property
def authenticator(self):
return self._backend_api_client.authenticator
@property
def credentials(self):
return self._backend_api_client.credentials
@property
def backend_swagger_client(self):
return self._backend_api_client.backend_swagger_client
@property
def client_lib_version(self):
return self._backend_api_client.client_lib_version
@property
def api_address(self):
return self._client_config.api_url
@property
def display_address(self):
return self._backend_api_client.display_address
@property
def proxies(self):
return self._backend_api_client.proxies
@legacy_with_api_exceptions_handler
def get_project_members(self, project_identifier):
try:
r = self.backend_swagger_client.api.listProjectMembers(projectIdentifier=project_identifier).response()
return r.result
except HTTPNotFound:
raise ProjectNotFound(project_identifier)
@legacy_with_api_exceptions_handler
def create_experiment(
self,
project,
name,
description,
params,
properties,
tags,
abortable, # deprecated in alpha
monitored, # deprecated in alpha
git_info,
hostname,
entrypoint,
notebook_id,
checkpoint_id,
):
if not isinstance(name, six.string_types):
raise ValueError("Invalid name {}, should be a string.".format(name))
if not isinstance(description, six.string_types):
raise ValueError("Invalid description {}, should be a string.".format(description))
if not isinstance(params, dict):
raise ValueError("Invalid params {}, should be a dict.".format(params))
if not isinstance(properties, dict):
raise ValueError("Invalid properties {}, should be a dict.".format(properties))
if hostname is not None and not isinstance(hostname, six.string_types):
raise ValueError("Invalid hostname {}, should be a string.".format(hostname))
if entrypoint is not None and not isinstance(entrypoint, six.string_types):
raise ValueError("Invalid entrypoint {}, should be a string.".format(entrypoint))
git_info = (
{
"commit": {
"commitId": git_info.commit_id,
"message": git_info.message,
"authorName": git_info.author_name,
"authorEmail": git_info.author_email,
"commitDate": git_info.commit_date,
},
"repositoryDirty": git_info.repository_dirty,
"currentBranch": git_info.active_branch,
"remotes": git_info.remote_urls,
}
if git_info
else None
)
api_params = {
"notebookId": notebook_id,
"checkpointId": checkpoint_id,
"projectIdentifier": str(project.internal_id),
"cliVersion": self.client_lib_version,
"gitInfo": git_info,
"customId": None,
}
kwargs = {
"experimentCreationParams": api_params,
"X-Neptune-CliVersion": self.client_lib_version,
"_request_options": {"headers": {"X-Neptune-LegacyClient": "true"}},
}
try:
api_experiment = self.leaderboard_swagger_client.api.createExperiment(**kwargs).response().result
except HTTPNotFound:
raise ProjectNotFound(project_identifier=project.full_id)
experiment = self._convert_to_experiment(api_experiment, project)
# Initialize new experiment
init_experiment_operations = self._get_init_experiment_operations(
name, description, params, properties, tags, hostname, entrypoint
)
self._execute_operations(
experiment=experiment,
operations=init_experiment_operations,
)
return experiment
def upload_source_code(self, experiment, source_target_pairs):
dest_path = alpha_path_utils.parse_path(alpha_consts.SOURCE_CODE_FILES_ATTRIBUTE_PATH)
file_globs = [source_path for source_path, target_path in source_target_pairs]
upload_files_operation = alpha_operation.UploadFileSet(
path=dest_path,
file_globs=file_globs,
reset=True,
)
self._execute_upload_operations_with_400_retry(experiment, upload_files_operation)
@legacy_with_api_exceptions_handler
def get_notebook(self, project, notebook_id):
try:
api_notebook_list = (
self.leaderboard_swagger_client.api.listNotebooks(
projectIdentifier=project.internal_id, id=[notebook_id]
)
.response()
.result
)
if not api_notebook_list.entries:
raise NotebookNotFound(notebook_id=notebook_id, project=project.full_id)
api_notebook = api_notebook_list.entries[0]
return Notebook(
backend=self,
project=project,
_id=api_notebook.id,
owner=api_notebook.owner,
)
except HTTPNotFound:
raise NotebookNotFound(notebook_id=notebook_id, project=project.full_id)
@legacy_with_api_exceptions_handler
def get_last_checkpoint(self, project, notebook_id):
try:
api_checkpoint_list = (
self.leaderboard_swagger_client.api.listCheckpoints(notebookId=notebook_id, offset=0, limit=1)
.response()
.result
)
if not api_checkpoint_list.entries:
raise NotebookNotFound(notebook_id=notebook_id, project=project.full_id)
checkpoint = api_checkpoint_list.entries[0]
return Checkpoint(checkpoint.id, checkpoint.name, checkpoint.path)
except HTTPNotFound:
raise NotebookNotFound(notebook_id=notebook_id, project=project.full_id)
@legacy_with_api_exceptions_handler
def create_notebook(self, project):
try:
api_notebook = (
self.leaderboard_swagger_client.api.createNotebook(projectIdentifier=project.internal_id)
.response()
.result
)
return Notebook(
backend=self,
project=project,
_id=api_notebook.id,
owner=api_notebook.owner,
)
except HTTPNotFound:
raise ProjectNotFound(project_identifier=project.full_id)
@legacy_with_api_exceptions_handler
def create_checkpoint(self, notebook_id, jupyter_path, _file=None):
if _file is not None:
with self._upload_raw_data(
api_method=self.leaderboard_swagger_client.api.createCheckpoint,
data=_file,
headers={"Content-Type": "application/octet-stream"},
path_params={"notebookId": notebook_id},
query_params={"jupyterPath": jupyter_path},
) as response:
if response.status_code == NOT_FOUND:
raise NotebookNotFound(notebook_id=notebook_id)
else:
response.raise_for_status()
CheckpointDTO = self.leaderboard_swagger_client.get_model("CheckpointDTO")
return CheckpointDTO.unmarshal(response.json())
else:
try:
NewCheckpointDTO = self.leaderboard_swagger_client.get_model("NewCheckpointDTO")
return (
self.leaderboard_swagger_client.api.createEmptyCheckpoint(
notebookId=notebook_id,
checkpoint=NewCheckpointDTO(path=jupyter_path),
)
.response()
.result
)
except HTTPNotFound:
return None
@legacy_with_api_exceptions_handler
def get_experiment(self, experiment_id):
api_attributes = self._get_api_experiment_attributes(experiment_id)
attributes = api_attributes.attributes
system_attributes = api_attributes.systemAttributes
return LegacyExperiment(
shortId=system_attributes.shortId.value,
name=system_attributes.name.value,
timeOfCreation=system_attributes.creationTime.value,
timeOfCompletion=None,
runningTime=system_attributes.runningTime.value,
owner=system_attributes.owner.value,
storageSize=system_attributes.size.value,
channelsSize=0,
tags=system_attributes.tags.values,
description=system_attributes.description.value,
hostname=system_attributes.hostname.value if system_attributes.hostname else None,
state="running" if system_attributes.state.value == "running" else "succeeded",
properties=[AlphaPropertyDTO(attr) for attr in attributes if AlphaPropertyDTO.is_valid_attribute(attr)],
parameters=[AlphaParameterDTO(attr) for attr in attributes if AlphaParameterDTO.is_valid_attribute(attr)],
)
@legacy_with_api_exceptions_handler
def set_property(self, experiment, key, value):
"""Save attribute casted to string under `alpha_consts.PROPERTIES_ATTRIBUTE_SPACE` namespace"""
self._execute_operations(
experiment=experiment,
operations=[
alpha_operation.AssignString(
path=alpha_path_utils.parse_path(f"{alpha_consts.PROPERTIES_ATTRIBUTE_SPACE}{key}"),
value=str(value),
)
],
)
@legacy_with_api_exceptions_handler
def remove_property(self, experiment, key):
self._remove_attribute(experiment, str_path=f"{alpha_consts.PROPERTIES_ATTRIBUTE_SPACE}{key}")
@legacy_with_api_exceptions_handler
def update_tags(self, experiment, tags_to_add, tags_to_delete):
operations = [
alpha_operation.AddStrings(
path=alpha_path_utils.parse_path(alpha_consts.SYSTEM_TAGS_ATTRIBUTE_PATH),
values=tags_to_add,
),
alpha_operation.RemoveStrings(
path=alpha_path_utils.parse_path(alpha_consts.SYSTEM_TAGS_ATTRIBUTE_PATH),
values=tags_to_delete,
),
]
self._execute_operations(
experiment=experiment,
operations=operations,
)
@staticmethod
def _get_channel_attribute_path(channel_name: str, channel_namespace: ChannelNamespace) -> str:
if channel_namespace == ChannelNamespace.USER:
prefix = alpha_consts.LOG_ATTRIBUTE_SPACE
else:
prefix = alpha_consts.MONITORING_ATTRIBUTE_SPACE
return f"{prefix}{channel_name}"
def _create_channel(
self,
experiment: Experiment,
channel_id: str,
channel_name: str,
channel_type: ChannelType,
channel_namespace: ChannelNamespace,
):
"""This function is responsible for creating 'fake' channels in alpha projects.
Since channels are abandoned in alpha api, we're mocking them using empty logging operation."""
operation = channel_type_to_operation(channel_type)
log_empty_operation = operation(
path=alpha_path_utils.parse_path(self._get_channel_attribute_path(channel_name, channel_namespace)),
values=[],
) # this operation is used to create empty attribute
self._execute_operations(
experiment=experiment,
operations=[log_empty_operation],
)
return ChannelWithLastValue(
AlphaChannelWithValueDTO(
channelId=channel_id,
channelName=channel_name,
channelType=channel_type.value,
x=None,
y=None,
)
)
@legacy_with_api_exceptions_handler
def create_channel(self, experiment, name, channel_type) -> ChannelWithLastValue:
channel_id = f"{alpha_consts.LOG_ATTRIBUTE_SPACE}{name}"
return self._create_channel(
experiment,
channel_id,
channel_name=name,
channel_type=ChannelType(channel_type),
channel_namespace=ChannelNamespace.USER,
)
def _get_channels(self, experiment) -> List[AlphaChannelDTO]:
try:
return [
AlphaChannelDTO(attr)
for attr in self._get_attributes(experiment.internal_id)
if AlphaChannelDTO.is_valid_attribute(attr)
]
except HTTPNotFound:
raise ExperimentNotFound(
experiment_short_id=experiment.id,
project_qualified_name=experiment._project.full_id,
)
@legacy_with_api_exceptions_handler
def get_channels(self, experiment) -> Dict[str, AlphaChannelDTO]:
api_channels = [
channel
for channel in self._get_channels(experiment)
# return channels from LOG_ATTRIBUTE_SPACE namespace only
if channel.id.startswith(alpha_consts.LOG_ATTRIBUTE_SPACE)
]
return {ch.name: ch for ch in api_channels}
@legacy_with_api_exceptions_handler
def create_system_channel(self, experiment, name, channel_type) -> ChannelWithLastValue:
channel_id = f"{alpha_consts.MONITORING_ATTRIBUTE_SPACE}{name}"
return self._create_channel(
experiment,
channel_id,
channel_name=name,
channel_type=ChannelType(channel_type),
channel_namespace=ChannelNamespace.SYSTEM,
)
@legacy_with_api_exceptions_handler
def get_system_channels(self, experiment) -> Dict[str, AlphaChannelDTO]:
return {
channel.name: channel
for channel in self._get_channels(experiment)
if (
channel.channelType == ChannelType.TEXT.value
and channel.id.startswith(alpha_consts.MONITORING_ATTRIBUTE_SPACE)
)
}
@legacy_with_api_exceptions_handler
def send_channels_values(self, experiment, channels_with_values):
send_operations = []
for channel_with_values in channels_with_values:
channel_value_type = channel_with_values.channel_type
operation = channel_value_type_to_operation(channel_value_type)
if channel_value_type == ChannelValueType.IMAGE_VALUE:
# IMAGE_VALUE requires minor data modification before it's sent
data_transformer = deprecated_img_to_alpha_image
else:
# otherwise use identity function as transformer
data_transformer = lambda e: e # noqa: E731
ch_values = [
alpha_operation.LogSeriesValue(
value=data_transformer(ch_value.value),
step=ch_value.x,
ts=ch_value.ts,
)
for ch_value in channel_with_values.channel_values
]
send_operations.append(
operation(
path=alpha_path_utils.parse_path(
self._get_channel_attribute_path(
channel_with_values.channel_name,
channel_with_values.channel_namespace,
)
),
values=ch_values,
)
)
self._execute_operations(experiment, send_operations)
def mark_failed(self, experiment, traceback):
operations = []
path = parse_path(SYSTEM_FAILED_ATTRIBUTE_PATH)
traceback_values = [LogStrings.ValueType(val, step=None, ts=time.time()) for val in traceback.split("\n")]
operations.append(AssignBool(path=path, value=True))
operations.append(
LogStrings(
values=traceback_values,
path=parse_path(MONITORING_TRACEBACK_ATTRIBUTE_PATH),
)
)
self._execute_operations(experiment, operations)
def ping_experiment(self, experiment):
try:
self.leaderboard_swagger_client.api.ping(experimentId=str(experiment.internal_id)).response().result
except HTTPNotFound:
raise ExperimentNotFound(
experiment_short_id=experiment.id,
project_qualified_name=experiment._project.full_id,
)
@staticmethod
def _get_attribute_name_for_metric(resource_type, gauge_name, gauges_count) -> str:
if gauges_count > 1:
return "monitoring/{}_{}".format(resource_type, gauge_name).lower()
return "monitoring/{}".format(resource_type).lower()
@legacy_with_api_exceptions_handler
def create_hardware_metric(self, experiment, metric):
operations = []
gauges_count = len(metric.gauges)
for gauge in metric.gauges:
path = parse_path(self._get_attribute_name_for_metric(metric.resource_type, gauge.name(), gauges_count))
operations.append(ConfigFloatSeries(path, min=metric.min_value, max=metric.max_value, unit=metric.unit))
self._execute_operations(experiment, operations)
@legacy_with_api_exceptions_handler
def send_hardware_metric_reports(self, experiment, metrics, metric_reports):
operations = []
metrics_by_name = {metric.name: metric for metric in metrics}
for report in metric_reports:
metric = metrics_by_name.get(report.metric.name)
gauges_count = len(metric.gauges)
for gauge_name, metric_values in groupby(report.values, lambda value: value.gauge_name):
metric_values = list(metric_values)
path = parse_path(self._get_attribute_name_for_metric(metric.resource_type, gauge_name, gauges_count))
operations.append(
LogFloats(
path,
[LogFloats.ValueType(value.value, step=None, ts=value.timestamp) for value in metric_values],
)
)
self._execute_operations(experiment, operations)
def log_artifact(self, experiment, artifact, destination=None):
if isinstance(artifact, str):
if os.path.isfile(artifact):
target_name = os.path.basename(artifact) if destination is None else destination
dest_path = self._get_dest_and_ext(target_name)
operation = alpha_operation.UploadFile(
path=dest_path,
ext="",
file_path=os.path.abspath(artifact),
)
elif os.path.isdir(artifact):
for path, file_destination in self._log_dir_artifacts(artifact, destination):
self.log_artifact(experiment, path, file_destination)
return
else:
raise FileNotFound(artifact)
elif hasattr(artifact, "read"):
if not destination:
raise ValueError("destination is required for IO streams")
dest_path = self._get_dest_and_ext(destination)
data = artifact.read()
content = data.encode("utf-8") if isinstance(data, str) else data
operation = alpha_operation.UploadFileContent(path=dest_path, ext="", file_content=base64_encode(content))
else:
raise ValueError("Artifact must be a local path or an IO object")
self._execute_upload_operations_with_400_retry(experiment, operation)
@staticmethod
def _get_dest_and_ext(target_name):
qualified_target_name = f"{alpha_consts.ARTIFACT_ATTRIBUTE_SPACE}{target_name}"
return alpha_path_utils.parse_path(normalize_file_name(qualified_target_name))
def _log_dir_artifacts(self, directory_path, destination):
directory_path = Path(directory_path)
prefix = directory_path.name if destination is None else destination
for path in directory_path.glob("**/*"):
if path.is_file():
relative_path = path.relative_to(directory_path)
file_destination = prefix + "/" + str(relative_path)
yield str(path), file_destination
def delete_artifacts(self, experiment, path):
try:
self._remove_attribute(experiment, str_path=f"{alpha_consts.ARTIFACT_ATTRIBUTE_SPACE}{path}")
except ExperimentOperationErrors as e:
if all(isinstance(err, alpha_exceptions.MetadataInconsistency) for err in e.errors):
raise DeleteArtifactUnsupportedInAlphaException(path, experiment) from None
raise
@legacy_with_api_exceptions_handler
def download_data(self, experiment: Experiment, path: str, destination):
project_storage_path = f"artifacts/{path}"
with self._download_raw_data(
api_method=self.leaderboard_swagger_client.api.downloadAttribute,
headers={"Accept": "application/octet-stream"},
path_params={},
query_params={
"experimentId": experiment.internal_id,
"attribute": project_storage_path,
},
) as response:
if response.status_code == NOT_FOUND:
raise PathInExperimentNotFound(path=path, exp_identifier=experiment.internal_id)
else:
response.raise_for_status()
with open(destination, "wb") as f:
for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
if chunk:
f.write(chunk)
def download_sources(self, experiment: Experiment, path=None, destination_dir=None):
if path is not None:
# in alpha all source files stored as single FileSet must be downloaded at once
raise DownloadSourcesException(experiment)
path = alpha_consts.SOURCE_CODE_FILES_ATTRIBUTE_PATH
destination_dir = assure_directory_exists(destination_dir)
download_request = self._get_file_set_download_request(experiment.internal_id, path)
alpha_hosted_file_operations.download_file_set_attribute(
swagger_client=self.leaderboard_swagger_client,
download_id=download_request.id,
destination=destination_dir,
)
@legacy_with_api_exceptions_handler
def _get_file_set_download_request(self, run_id: str, path: str):
params = {
"experimentId": run_id,
"attribute": path,
}
return self.leaderboard_swagger_client.api.prepareForDownloadFileSetAttributeZip(**params).response().result
def download_artifacts(self, experiment: Experiment, path=None, destination_dir=None):
raise DownloadArtifactsUnsupportedException(experiment)
def download_artifact(self, experiment: Experiment, path=None, destination_dir=None):
destination_dir = assure_directory_exists(destination_dir)
destination_path = os.path.join(destination_dir, os.path.basename(path))
try:
self.download_data(experiment, path, destination_path)
except PathInExperimentNotFound:
raise DownloadArtifactUnsupportedException(path, experiment) from None
def _get_attributes(self, experiment_id) -> list:
return self._get_api_experiment_attributes(experiment_id).attributes
def _get_api_experiment_attributes(self, experiment_id):
params = {
"experimentId": experiment_id,
}
return self.leaderboard_swagger_client.api.getExperimentAttributes(**params).response().result
def _remove_attribute(self, experiment, str_path: str):
"""Removes given attribute"""
self._execute_operations(
experiment=experiment,
operations=[
alpha_operation.DeleteAttribute(
path=alpha_path_utils.parse_path(str_path),
)
],
)
@staticmethod
def _get_client_config_args(api_token):
return dict(
X_Neptune_Api_Token=api_token,
alpha="true",
)
def _execute_upload_operation(self, experiment: Experiment, upload_operation: alpha_operation.Operation):
experiment_id = experiment.internal_id
try:
if isinstance(upload_operation, alpha_operation.UploadFile):
alpha_hosted_file_operations.upload_file_attribute(
swagger_client=self.leaderboard_swagger_client,
container_id=experiment_id,
attribute=alpha_path_utils.path_to_str(upload_operation.path),
source=upload_operation.file_path,
ext=upload_operation.ext,
multipart_config=self._client_config.multipart_config,
)
elif isinstance(upload_operation, alpha_operation.UploadFileContent):
alpha_hosted_file_operations.upload_file_attribute(
swagger_client=self.leaderboard_swagger_client,
container_id=experiment_id,
attribute=alpha_path_utils.path_to_str(upload_operation.path),
source=base64_decode(upload_operation.file_content),
ext=upload_operation.ext,
multipart_config=self._client_config.multipart_config,
)
elif isinstance(upload_operation, alpha_operation.UploadFileSet):
alpha_hosted_file_operations.upload_file_set_attribute(
swagger_client=self.leaderboard_swagger_client,
container_id=experiment_id,
attribute=alpha_path_utils.path_to_str(upload_operation.path),
file_globs=upload_operation.file_globs,
reset=upload_operation.reset,
multipart_config=self._client_config.multipart_config,
)
else:
raise NeptuneException("Upload operation in neither File or FileSet")
except common_exceptions.NeptuneException as e:
raise NeptuneException(e) from e
return None
def _execute_upload_operations_with_400_retry(
self, experiment: Experiment, upload_operation: alpha_operation.Operation
):
while True:
try:
return self._execute_upload_operation(experiment, upload_operation)
except ClientHttpError as ex:
if "Length of stream does not match given range" not in ex.response:
raise ex
@legacy_with_api_exceptions_handler
def _execute_operations(self, experiment: Experiment, operations: List[alpha_operation.Operation]):
experiment_id = experiment.internal_id
file_operations = (
alpha_operation.UploadFile,
alpha_operation.UploadFileContent,
alpha_operation.UploadFileSet,
)
if any(isinstance(op, file_operations) for op in operations):
raise NeptuneException(
"File operations must be handled directly by `_execute_upload_operation`,"
" not by `_execute_operations` function call."
)
kwargs = {
"experimentId": experiment_id,
"operations": [
{
"path": alpha_path_utils.path_to_str(op.path),
AlphaOperationApiNameVisitor().visit(op): AlphaOperationApiObjectConverter().convert(op),
}
for op in operations
],
}
try:
result = self.leaderboard_swagger_client.api.executeOperations(**kwargs).response().result
errors = [alpha_exceptions.MetadataInconsistency(err.errorDescription) for err in result]
if errors:
raise ExperimentOperationErrors(errors=errors)
return None
except HTTPNotFound as e:
raise ExperimentNotFound(
experiment_short_id=experiment.id,
project_qualified_name=experiment._project.full_id,
) from e
def _get_init_experiment_operations(
self, name, description, params, properties, tags, hostname, entrypoint
) -> List[alpha_operation.Operation]:
"""Returns operations required to initialize newly created experiment"""
init_operations = list()
# Assign experiment name
init_operations.append(
alpha_operation.AssignString(
path=alpha_path_utils.parse_path(alpha_consts.SYSTEM_NAME_ATTRIBUTE_PATH),
value=name,
)
)
# Assign experiment description
init_operations.append(
alpha_operation.AssignString(
path=alpha_path_utils.parse_path(alpha_consts.SYSTEM_DESCRIPTION_ATTRIBUTE_PATH),
value=description,
)
)
# Assign experiment parameters
for p_name, p_val in params.items():
parameter_type, string_value = self._get_parameter_with_type(p_val)
operation_cls = alpha_operation.AssignFloat if parameter_type == "double" else alpha_operation.AssignString
init_operations.append(
operation_cls(
path=alpha_path_utils.parse_path(f"{alpha_consts.PARAMETERS_ATTRIBUTE_SPACE}{p_name}"),
value=string_value,
)
)
# Assign experiment properties
for p_key, p_val in properties.items():
init_operations.append(
AssignString(
path=alpha_path_utils.parse_path(f"{alpha_consts.PROPERTIES_ATTRIBUTE_SPACE}{p_key}"),
value=str(p_val),
)
)
# Assign tags
if tags:
init_operations.append(
alpha_operation.AddStrings(
path=alpha_path_utils.parse_path(alpha_consts.SYSTEM_TAGS_ATTRIBUTE_PATH),
values=set(tags),
)
)
# Assign source hostname
if hostname:
init_operations.append(
alpha_operation.AssignString(
path=alpha_path_utils.parse_path(alpha_consts.SYSTEM_HOSTNAME_ATTRIBUTE_PATH),
value=hostname,
)
)
# Assign source entrypoint
if entrypoint:
init_operations.append(
alpha_operation.AssignString(
path=alpha_path_utils.parse_path(alpha_consts.SOURCE_CODE_ENTRYPOINT_ATTRIBUTE_PATH),
value=entrypoint,
)
)
return init_operations
@legacy_with_api_exceptions_handler
def reset_channel(self, experiment, channel_id, channel_name, channel_type):
op = channel_type_to_clear_operation(ChannelType(channel_type))
attr_path = self._get_channel_attribute_path(channel_name, ChannelNamespace.USER)
self._execute_operations(
experiment=experiment,
operations=[op(path=alpha_path_utils.parse_path(attr_path))],
)
@legacy_with_api_exceptions_handler
def _get_channel_tuples_from_csv(self, experiment, channel_attribute_path):
try:
csv = (
self.leaderboard_swagger_client.api.getFloatSeriesValuesCSV(
experimentId=experiment.internal_id,
attribute=channel_attribute_path,
)
.response()
.incoming_response.text
)
lines = csv.split("\n")[:-1]
return [line.split(",") for line in lines]
except HTTPNotFound:
raise ExperimentNotFound(
experiment_short_id=experiment.id,
project_qualified_name=experiment._project.full_id,
)
@legacy_with_api_exceptions_handler
def get_channel_points_csv(self, experiment, channel_internal_id, channel_name):
try:
channel_attr_path = self._get_channel_attribute_path(channel_name, ChannelNamespace.USER)
values = self._get_channel_tuples_from_csv(experiment, channel_attr_path)
step_and_value = [val[0] + "," + val[2] for val in values]
csv = StringIO()
for line in step_and_value:
csv.write(line + "\n")
csv.seek(0)
return csv
except HTTPNotFound:
raise ExperimentNotFound(
experiment_short_id=experiment.id,
project_qualified_name=experiment._project.full_id,
)
@legacy_with_api_exceptions_handler
def get_metrics_csv(self, experiment):
metric_channels = [
channel
for channel in self._get_channels(experiment)
if (
channel.channelType == ChannelType.NUMERIC.value
and channel.id.startswith(alpha_consts.MONITORING_ATTRIBUTE_SPACE)
)
]
data = {
# val[1] + ',' + val[2] is timestamp,value
ch.name: [val[1] + "," + val[2] for val in self._get_channel_tuples_from_csv(experiment, ch.id)]
for ch in metric_channels
}
values_count = max(len(values) for values in data.values())
csv = StringIO()
csv.write(",".join(["x_{name},y_{name}".format(name=ch.name) for ch in metric_channels]))
csv.write("\n")
for i in range(0, values_count):
csv.write(",".join([data[ch.name][i] if i < len(data[ch.name]) else "," for ch in metric_channels]))
csv.write("\n")
csv.seek(0)
return csv
@legacy_with_api_exceptions_handler
def get_leaderboard_entries(
self,
project,
entry_types=None, # deprecated
ids=None,
states=None,
owners=None,
tags=None,
min_running_time=None,
):
if states is not None:
states = [state if state == "running" else "idle" for state in states]
try:
def get_portion(limit, offset):
return (
self.leaderboard_swagger_client.api.getLeaderboard(
projectIdentifier=project.full_id,
shortId=ids,
state=states,
owner=owners,
tags=tags,
tagsMode="and",
minRunningTimeSeconds=min_running_time,
sortBy=["sys/id"],
sortFieldType=["string"],
sortDirection=["ascending"],
limit=limit,
offset=offset,
)
.response()
.result.entries
)
return [
LeaderboardEntry(self._to_leaderboard_entry_dto(e)) for e in self._get_all_items(get_portion, step=100)
]
except HTTPNotFound:
raise ProjectNotFound(project_identifier=project.full_id)
def websockets_factory(self, project_id, experiment_id):
base_url = re.sub(r"^http", "ws", self.api_address) + "/api/notifications/v1"
return ReconnectingWebsocketFactory(backend=self, url=base_url + f"/runs/{project_id}/{experiment_id}/signal")
@staticmethod
def _to_leaderboard_entry_dto(experiment_attributes):
attributes = experiment_attributes.attributes
system_attributes = experiment_attributes.systemAttributes
def is_channel_namespace(name):
return name.startswith(alpha_consts.LOG_ATTRIBUTE_SPACE) or name.startswith(
alpha_consts.MONITORING_ATTRIBUTE_SPACE
)
numeric_channels = [
HostedAlphaLeaderboardApiClient._float_series_to_channel_last_value_dto(attr)
for attr in attributes
if (
attr.type == AttributeType.FLOAT_SERIES.value
and is_channel_namespace(attr.name)
and attr.floatSeriesProperties.last is not None
)
]
text_channels = [
HostedAlphaLeaderboardApiClient._string_series_to_channel_last_value_dto(attr)
for attr in attributes
if (
attr.type == AttributeType.STRING_SERIES.value
and is_channel_namespace(attr.name)
and attr.stringSeriesProperties.last is not None
)
]
return LegacyLeaderboardEntry(
id=experiment_attributes.id,
organizationName=experiment_attributes.organizationName,
projectName=experiment_attributes.projectName,
shortId=system_attributes.shortId.value,
name=system_attributes.name.value,
state="running" if system_attributes.state.value == "running" else "succeeded",
timeOfCreation=system_attributes.creationTime.value,
timeOfCompletion=None,
runningTime=system_attributes.runningTime.value,
owner=system_attributes.owner.value,
size=system_attributes.size.value,
tags=system_attributes.tags.values,
description=system_attributes.description.value,
channelsLastValues=numeric_channels + text_channels,
parameters=[AlphaParameterDTO(attr) for attr in attributes if AlphaParameterDTO.is_valid_attribute(attr)],
properties=[AlphaPropertyDTO(attr) for attr in attributes if AlphaPropertyDTO.is_valid_attribute(attr)],
)
@staticmethod
def _float_series_to_channel_last_value_dto(attribute):
return AlphaChannelWithValueDTO(
channelId=attribute.name,
channelName=attribute.name.split("/", 1)[-1],
channelType="numeric",
x=attribute.floatSeriesProperties.lastStep,
y=attribute.floatSeriesProperties.last,
)
@staticmethod
def _string_series_to_channel_last_value_dto(attribute):
return AlphaChannelWithValueDTO(
channelId=attribute.name,
channelName=attribute.name.split("/", 1)[-1],
channelType="text",
x=attribute.stringSeriesProperties.lastStep,
y=attribute.stringSeriesProperties.last,
)
@staticmethod
def _get_all_items(get_portion, step):
items = []
previous_items = None
while previous_items is None or len(previous_items) >= step:
previous_items = get_portion(limit=step, offset=len(items))
items += previous_items
return items
def _upload_raw_data(self, api_method, data, headers, path_params, query_params):
url = self.api_address + api_method.operation.path_name + "?"
for key, val in path_params.items():
url = url.replace("{" + key + "}", val)
for key, val in query_params.items():
url = url + key + "=" + val + "&"
session = self.http_client.session
request = self.authenticator.apply(requests.Request(method="POST", url=url, data=data, headers=headers))
return handle_server_raw_response_messages(session.send(session.prepare_request(request)))
def _get_parameter_with_type(self, parameter):
string_type = "string"
double_type = "double"
if isinstance(parameter, bool):
return (string_type, str(parameter))
elif isinstance(parameter, float) or isinstance(parameter, int):
if math.isinf(parameter) or math.isnan(parameter):
return (string_type, json.dumps(parameter))
else:
return (double_type, str(parameter))
else:
return (string_type, str(parameter))
def _convert_to_experiment(self, api_experiment, project):
return Experiment(
backend=project._backend,
project=project,
_id=api_experiment.shortId,
internal_id=api_experiment.id,
)
def _download_raw_data(self, api_method, headers, path_params, query_params):
url = self.api_address + api_method.operation.path_name + "?"
for key, val in path_params.items():
url = url.replace("{" + key + "}", val)
for key, val in query_params.items():
url = url + key + "=" + val + "&"
session = self.http_client.session
request = self.authenticator.apply(requests.Request(method="GET", url=url, headers=headers))
return handle_server_raw_response_messages(session.send(session.prepare_request(request), stream=True))
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import platform
import sys
import click
import urllib3
from bravado.exception import HTTPNotFound
from bravado.requests_client import RequestsClient
from packaging import version
from neptune.common.exceptions import STYLES
from neptune.common.oauth import NeptuneAuthenticator
from neptune.common.utils import (
NoopObject,
update_session_proxies,
)
from neptune.internal.backends.hosted_client import NeptuneResponseAdapter
from neptune.legacy.api_exceptions import (
ProjectNotFound,
WorkspaceNotFound,
)
from neptune.legacy.backend import (
BackendApiClient,
LeaderboardApiClient,
)
from neptune.legacy.exceptions import UnsupportedClientVersion
from neptune.legacy.internal.api_clients.credentials import Credentials
from neptune.legacy.internal.api_clients.hosted_api_clients.hosted_alpha_leaderboard_api_client import (
HostedAlphaLeaderboardApiClient,
)
from neptune.legacy.internal.api_clients.hosted_api_clients.mixins import HostedNeptuneMixin
from neptune.legacy.internal.api_clients.hosted_api_clients.utils import legacy_with_api_exceptions_handler
from neptune.legacy.projects import Project
_logger = logging.getLogger(__name__)
class HostedNeptuneBackendApiClient(HostedNeptuneMixin, BackendApiClient):
@legacy_with_api_exceptions_handler
def __init__(self, api_token=None, proxies=None):
self._old_leaderboard_client = None
self._new_leaderboard_client = None
self._api_token = api_token
self._proxies = proxies
# This is not a top-level import because of circular dependencies
from neptune import __version__
self.client_lib_version = __version__
self.credentials = Credentials(api_token)
ssl_verify = True
if os.getenv("NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE"):
urllib3.disable_warnings()
ssl_verify = False
self._http_client = RequestsClient(ssl_verify=ssl_verify, response_adapter_class=NeptuneResponseAdapter)
# for session re-creation we need to keep an authenticator-free version of http client
self._http_client_for_token = RequestsClient(
ssl_verify=ssl_verify, response_adapter_class=NeptuneResponseAdapter
)
user_agent = "neptune-client/{lib_version} ({system}, python {python_version})".format(
lib_version=self.client_lib_version,
system=platform.platform(),
python_version=platform.python_version(),
)
self.http_client.session.headers.update({"User-Agent": user_agent})
self._http_client_for_token.session.headers.update({"User-Agent": user_agent})
update_session_proxies(self.http_client.session, proxies)
update_session_proxies(self._http_client_for_token.session, proxies)
config_api_url = self.credentials.api_url_opt or self.credentials.token_origin_address
# We don't need to be able to resolve Neptune host if we use proxy
if proxies is None:
self._verify_host_resolution(config_api_url, self.credentials.token_origin_address)
# this backend client is used only for initial configuration and session re-creation
self.backend_client = self._get_swagger_client(
"{}/api/backend/swagger.json".format(config_api_url),
self._http_client_for_token,
)
self._client_config = self._create_client_config(
api_token=self.credentials.api_token, backend_client=self.backend_client
)
self._verify_version()
self.backend_swagger_client = self._get_swagger_client(
"{}/api/backend/swagger.json".format(self._client_config.api_url),
self.http_client,
)
self.authenticator = self._create_authenticator(
api_token=self.credentials.api_token,
ssl_verify=ssl_verify,
proxies=proxies,
backend_client=self.backend_client,
)
self.http_client.authenticator = self.authenticator
if sys.version_info >= (3, 7):
try:
os.register_at_fork(after_in_child=self._handle_fork_in_child)
except AttributeError:
pass
def _handle_fork_in_child(self):
self.backend_swagger_client = NoopObject()
@property
def api_address(self):
return self._client_config.api_url
@property
def http_client(self):
return self._http_client
@property
def display_address(self):
return self._client_config.display_url
@property
def proxies(self):
return self._proxies
@legacy_with_api_exceptions_handler
def get_project(self, project_qualified_name):
try:
response = self.backend_swagger_client.api.getProject(projectIdentifier=project_qualified_name).response()
warning = response.metadata.headers.get("X-Server-Warning")
if warning:
click.echo("{warning}{content}{end}".format(content=warning, **STYLES))
project = response.result
return Project(
backend=self.create_leaderboard_backend(project=project),
internal_id=project.id,
namespace=project.organizationName,
name=project.name,
)
except HTTPNotFound:
raise ProjectNotFound(project_qualified_name)
@legacy_with_api_exceptions_handler
def get_projects(self, namespace):
try:
r = self.backend_swagger_client.api.listProjects(organizationIdentifier=namespace).response()
return r.result.entries
except HTTPNotFound:
raise WorkspaceNotFound(namespace_name=namespace)
def create_leaderboard_backend(self, project) -> LeaderboardApiClient:
return self.get_new_leaderboard_client()
def get_new_leaderboard_client(self) -> HostedAlphaLeaderboardApiClient:
if self._new_leaderboard_client is None:
self._new_leaderboard_client = HostedAlphaLeaderboardApiClient(backend_api_client=self)
return self._new_leaderboard_client
@legacy_with_api_exceptions_handler
def _create_authenticator(self, api_token, ssl_verify, proxies, backend_client):
return NeptuneAuthenticator(api_token, backend_client, ssl_verify, proxies)
def _verify_version(self):
parsed_version = version.parse(self.client_lib_version)
if self._client_config.min_compatible_version and self._client_config.min_compatible_version > parsed_version:
click.echo(
"ERROR: Minimal supported client version is {} (installed: {}). Please upgrade neptune-client".format(
self._client_config.min_compatible_version, self.client_lib_version
),
sys.stderr,
)
raise UnsupportedClientVersion(
self.client_lib_version,
self._client_config.min_compatible_version,
self._client_config.max_compatible_version,
)
if self._client_config.max_compatible_version and self._client_config.max_compatible_version < parsed_version:
click.echo(
"ERROR: Maximal supported client version is {} (installed: {}). Please downgrade neptune-client".format(
self._client_config.max_compatible_version, self.client_lib_version
),
sys.stderr,
)
raise UnsupportedClientVersion(
self.client_lib_version,
self._client_config.min_compatible_version,
self._client_config.max_compatible_version,
)
if self._client_config.min_recommended_version and self._client_config.min_recommended_version > parsed_version:
click.echo(
"WARNING: We recommend an upgrade to a new version of neptune-client - {} (installed - {}).".format(
self._client_config.min_recommended_version, self.client_lib_version
),
sys.stderr,
)
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import socket
import sys
import click
from bravado.client import SwaggerClient
from bravado_core.formatter import SwaggerFormat
from packaging import version
from six.moves import urllib
from neptune.legacy.exceptions import (
CannotResolveHostname,
DeprecatedApiToken,
UnsupportedClientVersion,
)
from neptune.legacy.internal.api_clients.client_config import (
ClientConfig,
MultipartConfig,
)
from neptune.legacy.internal.api_clients.hosted_api_clients.utils import legacy_with_api_exceptions_handler
_logger = logging.getLogger(__name__)
uuid_format = SwaggerFormat(
format="uuid",
to_python=lambda x: x,
to_wire=lambda x: x,
validate=lambda x: None,
description="",
)
class HostedNeptuneMixin:
"""Mixin containing operation common for both backend and leaderboard api clients"""
@legacy_with_api_exceptions_handler
def _get_swagger_client(self, url, http_client):
return SwaggerClient.from_url(
url,
config=dict(
validate_swagger_spec=False,
validate_requests=False,
validate_responses=False,
formats=[uuid_format],
),
http_client=http_client,
)
@staticmethod
def _get_client_config_args(api_token):
return dict(X_Neptune_Api_Token=api_token)
@legacy_with_api_exceptions_handler
def _create_client_config(self, api_token, backend_client):
client_config_args = self._get_client_config_args(api_token)
config = backend_client.api.getClientConfig(**client_config_args).response().result
if hasattr(config, "pyLibVersions"):
min_recommended = getattr(config.pyLibVersions, "minRecommendedVersion", None)
min_compatible = getattr(config.pyLibVersions, "minCompatibleVersion", None)
max_compatible = getattr(config.pyLibVersions, "maxCompatibleVersion", None)
else:
click.echo(
"ERROR: This client version is not supported by your Neptune instance. Please contant Neptune support.",
sys.stderr,
)
raise UnsupportedClientVersion(self.client_lib_version, None, "0.4.111")
multipart_upload_config_obj = getattr(config, "multiPartUpload", None)
has_multipart_upload = getattr(multipart_upload_config_obj, "enabled", False)
if not has_multipart_upload:
multipart_upload_config = None
else:
min_chunk_size = getattr(multipart_upload_config_obj, "minChunkSize")
max_chunk_size = getattr(multipart_upload_config_obj, "maxChunkSize")
max_chunk_count = getattr(multipart_upload_config_obj, "maxChunkCount")
max_single_part_size = getattr(multipart_upload_config_obj, "maxSinglePartSize")
multipart_upload_config = MultipartConfig(
min_chunk_size, max_chunk_size, max_chunk_count, max_single_part_size
)
return ClientConfig(
api_url=config.apiUrl,
display_url=config.applicationUrl,
min_recommended_version=version.parse(min_recommended) if min_recommended else None,
min_compatible_version=version.parse(min_compatible) if min_compatible else None,
max_compatible_version=version.parse(max_compatible) if max_compatible else None,
multipart_config=multipart_upload_config,
)
def _verify_host_resolution(self, api_url, app_url):
host = urllib.parse.urlparse(api_url).netloc.split(":")[0]
try:
socket.gethostbyname(host)
except socket.gaierror:
if self.credentials.api_url_opt is None:
raise DeprecatedApiToken(urllib.parse.urlparse(app_url).netloc)
else:
raise CannotResolveHostname(host)
#
# Copyright (c) 2023, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
import requests
from bravado.exception import (
BravadoConnectionError,
BravadoTimeoutError,
HTTPBadGateway,
HTTPForbidden,
HTTPGatewayTimeout,
HTTPInternalServerError,
HTTPRequestTimeout,
HTTPServiceUnavailable,
HTTPTooManyRequests,
HTTPUnauthorized,
)
from urllib3.exceptions import NewConnectionError
from neptune.common.backends.utils import get_retry_from_headers_or_default
from neptune.legacy.api_exceptions import (
ConnectionLost,
Forbidden,
NeptuneSSLVerificationError,
ServerError,
Unauthorized,
)
_logger = logging.getLogger(__name__)
def legacy_with_api_exceptions_handler(func):
def wrapper(*args, **kwargs):
retries = 11
retry = 0
while retry < retries:
try:
return func(*args, **kwargs)
except requests.exceptions.SSLError:
raise NeptuneSSLVerificationError()
except HTTPServiceUnavailable:
if retry >= 6:
_logger.warning("Experiencing connection interruptions. Reestablishing communication with Neptune.")
time.sleep(2**retry)
retry += 1
continue
except (
BravadoConnectionError,
BravadoTimeoutError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
HTTPRequestTimeout,
HTTPGatewayTimeout,
HTTPBadGateway,
HTTPInternalServerError,
NewConnectionError,
):
if retry >= 6:
_logger.warning("Experiencing connection interruptions. Reestablishing communication with Neptune.")
time.sleep(2**retry)
retry += 1
continue
except HTTPTooManyRequests as e:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
retry += 1
continue
except HTTPUnauthorized:
raise Unauthorized()
except HTTPForbidden:
raise Forbidden()
except requests.exceptions.RequestException as e:
if e.response is None:
raise
status_code = e.response.status_code
if status_code in (
HTTPRequestTimeout.status_code,
HTTPBadGateway.status_code,
HTTPServiceUnavailable.status_code,
HTTPGatewayTimeout.status_code,
HTTPInternalServerError.status_code,
):
if retry >= 6:
_logger.warning(
"Experiencing connection interruptions. Reestablishing communication with Neptune."
)
time.sleep(2**retry)
retry += 1
continue
elif status_code == HTTPTooManyRequests.status_code:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
retry += 1
continue
elif status_code >= HTTPInternalServerError.status_code:
raise ServerError()
elif status_code == HTTPUnauthorized.status_code:
raise Unauthorized()
elif status_code == HTTPForbidden.status_code:
raise Forbidden()
else:
raise
raise ConnectionLost()
return wrapper
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from io import StringIO
from neptune.common.utils import NoopObject
from neptune.legacy.backend import (
BackendApiClient,
LeaderboardApiClient,
)
_logger = logging.getLogger(__name__)
class OfflineBackendApiClient(BackendApiClient):
def __init__(self):
_logger.warning("Neptune is running in offline mode. No data is being logged to Neptune.")
_logger.warning("Disable offline mode to log your experiments.")
@property
def api_address(self):
return "OFFLINE"
@property
def display_address(self):
return "OFFLINE"
@property
def proxies(self):
return None
def get_project(self, project_qualified_name):
return NoopObject()
def get_projects(self, namespace):
return []
def create_leaderboard_backend(self, project) -> "OfflineLeaderboardApiClient":
return OfflineLeaderboardApiClient()
class OfflineLeaderboardApiClient(LeaderboardApiClient):
@property
def api_address(self):
return "OFFLINE"
@property
def display_address(self):
return "OFFLINE"
@property
def proxies(self):
return None
def get_project_members(self, project_identifier):
return []
def get_leaderboard_entries(
self,
project,
entry_types=None,
ids=None,
states=None,
owners=None,
tags=None,
min_running_time=None,
):
return []
def get_channel_points_csv(self, experiment, channel_internal_id, channel_name):
return StringIO()
def get_metrics_csv(self, experiment):
return StringIO()
def create_experiment(
self,
project,
name,
description,
params,
properties,
tags,
abortable,
monitored,
git_info,
hostname,
entrypoint,
notebook_id,
checkpoint_id,
):
return NoopObject()
def upload_source_code(self, experiment, source_target_pairs):
pass
def get_notebook(self, project, notebook_id):
return NoopObject()
def get_last_checkpoint(self, project, notebook_id):
return NoopObject()
def create_notebook(self, project):
return NoopObject()
def create_checkpoint(self, notebook_id, jupyter_path, _file=None):
pass
def get_experiment(self, experiment_id):
return NoopObject()
def set_property(self, experiment, key, value):
pass
def remove_property(self, experiment, key):
pass
def update_tags(self, experiment, tags_to_add, tags_to_delete):
pass
def create_channel(self, experiment, name, channel_type):
return NoopObject()
def reset_channel(self, experiment, channel_id, channel_name, channel_type):
pass
def get_channels(self, experiment):
return {}
def create_system_channel(self, experiment, name, channel_type):
return NoopObject()
def get_system_channels(self, experiment):
return {}
def send_channels_values(self, experiment, channels_with_values):
pass
def mark_failed(self, experiment, traceback):
pass
def ping_experiment(self, experiment):
pass
def create_hardware_metric(self, experiment, metric):
return NoopObject()
def send_hardware_metric_reports(self, experiment, metrics, metric_reports):
pass
def log_artifact(self, experiment, artifact, destination=None):
pass
def delete_artifacts(self, experiment, path):
pass
def download_data(self, experiment, path, destination):
pass
def download_sources(self, experiment, path=None, destination_dir=None):
pass
def download_artifacts(self, experiment, path=None, destination_dir=None):
pass
def download_artifact(self, experiment, path=None, destination_dir=None):
pass
# define deprecated OfflineBackend class
OfflineBackend = OfflineBackendApiClient
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.legacy.internal.api_clients import HostedNeptuneBackendApiClient
# define deprecated HostedNeptuneBackend class
HostedNeptuneBackend = HostedNeptuneBackendApiClient
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import threading
import time
from collections import namedtuple
from itertools import groupby
from queue import (
Empty,
Queue,
)
from bravado.exception import HTTPUnprocessableEntity
from neptune.legacy.exceptions import NeptuneException
from neptune.legacy.internal.channels.channels import (
ChannelIdWithValues,
ChannelNamespace,
ChannelType,
ChannelValue,
)
from neptune.legacy.internal.threads.neptune_thread import NeptuneThread
_logger = logging.getLogger(__name__)
class ChannelsValuesSender(object):
_QUEUED_CHANNEL_VALUE = namedtuple(
"QueuedChannelValue",
[
"channel_id",
"channel_name",
"channel_type",
"channel_value",
"channel_namespace",
],
)
__LOCK = threading.RLock()
def __init__(self, experiment):
self._experiment = experiment
self._values_queue = None
self._sending_thread = None
self._user_channel_name_to_id_map = dict()
self._system_channel_name_to_id_map = dict()
def send(
self,
channel_name,
channel_type,
channel_value,
channel_namespace=ChannelNamespace.USER,
):
# Before taking the lock here, we need to check if the sending thread is not running yet.
# Otherwise, the sending thread could call send() while being join()-ed, which would result
# in a deadlock.
if not self._is_running():
with self.__LOCK:
if not self._is_running():
self._start()
if channel_namespace == ChannelNamespace.USER:
namespaced_channel_map = self._user_channel_name_to_id_map
else:
namespaced_channel_map = self._system_channel_name_to_id_map
if channel_name in namespaced_channel_map:
channel_id = namespaced_channel_map[channel_name]
else:
response = self._experiment._create_channel(channel_name, channel_type, channel_namespace)
channel_id = response.id
namespaced_channel_map[channel_name] = channel_id
self._values_queue.put(
self._QUEUED_CHANNEL_VALUE(
channel_id=channel_id,
channel_name=channel_name,
channel_type=channel_type,
channel_value=channel_value,
channel_namespace=channel_namespace,
)
)
def join(self):
with self.__LOCK:
if self._is_running():
self._sending_thread.interrupt()
self._sending_thread.join()
self._sending_thread = None
self._values_queue = None
def _is_running(self):
return self._values_queue is not None and self._sending_thread is not None and self._sending_thread.is_alive()
def _start(self):
self._values_queue = Queue()
self._sending_thread = ChannelsValuesSendingThread(self._experiment, self._values_queue)
self._sending_thread.start()
class ChannelsValuesSendingThread(NeptuneThread):
_SLEEP_TIME = 5
_MAX_VALUES_BATCH_LENGTH = 100
_MAX_IMAGE_VALUES_BATCH_SIZE = 10485760 # 10 MB
def __init__(self, experiment, values_queue):
super(ChannelsValuesSendingThread, self).__init__(is_daemon=False)
self._values_queue = values_queue
self._experiment = experiment
self._sleep_time = self._SLEEP_TIME
self._values_batch = []
def run(self):
while self.should_continue_running() or not self._values_queue.empty():
try:
sleep_start = time.time()
self._values_batch.append(self._values_queue.get(timeout=max(self._sleep_time, 0)))
self._values_queue.task_done()
self._sleep_time -= time.time() - sleep_start
except Empty:
self._sleep_time = 0
image_values_batch_size = sum(
[
len(v.channel_value.y["image_value"]["data"] or [])
for v in self._values_batch
if v.channel_type == ChannelType.IMAGE.value
]
)
if (
self._sleep_time <= 0
or len(self._values_batch) >= self._MAX_VALUES_BATCH_LENGTH
or image_values_batch_size >= self._MAX_IMAGE_VALUES_BATCH_SIZE
):
self._process_batch()
self._process_batch()
def _process_batch(self):
send_start = time.time()
if self._values_batch:
try:
self._send_values(self._values_batch)
self._values_batch = []
except (NeptuneException, IOError):
_logger.exception("Failed to send channel value.")
self._sleep_time = self._SLEEP_TIME - (time.time() - send_start)
def _send_values(self, queued_channels_values):
def get_channel_metadata(value):
return (
value.channel_id,
value.channel_name,
value.channel_type,
value.channel_namespace,
)
queued_grouped_by_channel = {
channel_metadata: list(values)
for channel_metadata, values in groupby(
sorted(queued_channels_values, key=get_channel_metadata),
get_channel_metadata,
)
}
channels_with_values = []
for channel_metadata in queued_grouped_by_channel:
channel_values = []
for queued_value in queued_grouped_by_channel[channel_metadata]:
channel_values.append(
ChannelValue(
ts=queued_value.channel_value.ts,
x=queued_value.channel_value.x,
y=queued_value.channel_value.y,
)
)
channels_with_values.append(ChannelIdWithValues(*channel_metadata, channel_values))
try:
self._experiment._send_channels_values(channels_with_values)
except HTTPUnprocessableEntity as e:
message = "Maximum storage limit reached"
try:
message = e.response.json()["message"]
finally:
_logger.warning("Failed to send channel value: %s", message)
except (NeptuneException, IOError):
_logger.exception("Failed to send channel value.")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from collections import namedtuple
from enum import Enum
from typing import List
from neptune.legacy.exceptions import NeptuneException
ChannelNameWithTypeAndNamespace = namedtuple(
"ChannelNameWithType",
["channel_id", "channel_name", "channel_type", "channel_namespace"],
)
class ChannelType(Enum):
TEXT = "text"
NUMERIC = "numeric"
IMAGE = "image"
class ChannelValueType(Enum):
TEXT_VALUE = "text_value"
NUMERIC_VALUE = "numeric_value"
IMAGE_VALUE = "image_value"
class ChannelNamespace(Enum):
USER = "user"
SYSTEM = "system"
class ChannelValue(object):
def __init__(self, x, y, ts):
self._x = x
self._y = y
if ts is None:
ts = time.time()
self._ts = ts
@property
def ts(self):
return self._ts
@property
def x(self):
return self._x
@property
def y(self):
return self._y
@property
def value(self):
return self.y.get(self.value_type.value)
@property
def value_type(self) -> ChannelValueType:
"""We expect that exactly one of `y` values is not None, and according to that we try to determine value type"""
unique_channel_value_types = set(
[ch_value_type for ch_value_type in ChannelValueType if self.y.get(ch_value_type.value) is not None]
)
if len(unique_channel_value_types) > 1:
raise NeptuneException(f"There are mixed value types in {self}")
if not unique_channel_value_types:
raise NeptuneException(f"Can't determine type of {self}")
return next(iter(unique_channel_value_types))
def __str__(self):
return "ChannelValue(x={},y={},ts={})".format(self.x, self.y, self.ts)
def __repr__(self):
return str(self)
def __eq__(self, o):
return self.__dict__ == o.__dict__
class ChannelIdWithValues:
def __init__(self, channel_id, channel_name, channel_type, channel_namespace, channel_values):
self._channel_id = channel_id
self._channel_name = channel_name
self._channel_type = channel_type
self._channel_namespace = channel_namespace
self._channel_values = channel_values
@property
def channel_id(self) -> str:
return self._channel_id
@property
def channel_name(self) -> str:
return self._channel_name
@property
def channel_values(self) -> List[ChannelValue]:
return self._channel_values
@property
def channel_type(self) -> ChannelValueType:
if self._channel_type == ChannelType.NUMERIC.value:
return ChannelValueType.NUMERIC_VALUE
elif self._channel_type == ChannelType.TEXT.value:
return ChannelValueType.TEXT_VALUE
elif self._channel_type == ChannelType.IMAGE.value:
return ChannelValueType.IMAGE_VALUE
else:
raise NeptuneException(f"Unknown channel type: {self._channel_type}")
@property
def channel_namespace(self) -> ChannelNamespace:
return self._channel_namespace
def __eq__(self, other):
return self.channel_id == other.channel_id and self.channel_values == other.channel_values
def __gt__(self, other):
return hash(self.channel_id) < hash(other.channel_id)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import sys
import time
import traceback
from logging import StreamHandler
from neptune.common.hardware.gauges.gauge_mode import GaugeMode
from neptune.common.hardware.metrics.service.metric_service_factory import MetricServiceFactory
from neptune.common.utils import (
in_docker,
is_ipython,
is_notebook,
)
from neptune.legacy.internal.abort import (
CustomAbortImpl,
DefaultAbortImpl,
)
from neptune.legacy.internal.channels.channels import ChannelNamespace
from neptune.legacy.internal.streams.channel_writer import ChannelWriter
from neptune.legacy.internal.streams.stdstream_uploader import (
StdErrWithUpload,
StdOutWithUpload,
)
from neptune.legacy.internal.threads.aborting_thread import AbortingThread
from neptune.legacy.internal.threads.hardware_metric_reporting_thread import HardwareMetricReportingThread
from neptune.legacy.internal.threads.ping_thread import PingThread
_logger = logging.getLogger(__name__)
class ExecutionContext(object):
def __init__(self, backend, experiment):
self._backend = backend
self._experiment = experiment
self._ping_thread = None
self._hardware_metric_thread = None
self._aborting_thread = None
self._logger = None
self._logger_handler = None
self._stdout_uploader = None
self._stderr_uploader = None
self._uncaught_exception_handler = sys.__excepthook__
self._previous_uncaught_exception_handler = None
def start(
self,
abort_callback=None,
logger=None,
upload_stdout=True,
upload_stderr=True,
send_hardware_metrics=True,
run_monitoring_thread=True,
handle_uncaught_exceptions=True,
):
if handle_uncaught_exceptions:
self._set_uncaught_exception_handler()
if logger:
channel = self._experiment._get_channel("logger", "text", ChannelNamespace.SYSTEM)
channel_writer = ChannelWriter(self._experiment, channel.name, ChannelNamespace.SYSTEM)
self._logger_handler = StreamHandler(channel_writer)
self._logger = logger
logger.addHandler(self._logger_handler)
if upload_stdout and not is_notebook():
self._stdout_uploader = StdOutWithUpload(self._experiment)
if upload_stderr and not is_notebook():
self._stderr_uploader = StdErrWithUpload(self._experiment)
abortable = abort_callback is not None or DefaultAbortImpl.requirements_installed()
if abortable:
self._run_aborting_thread(abort_callback)
else:
_logger.warning("psutil is not installed. You will not be able to abort this experiment from the UI.")
if run_monitoring_thread:
self._run_monitoring_thread()
if send_hardware_metrics:
self._run_hardware_metrics_reporting_thread()
def stop(self):
if self._ping_thread:
self._ping_thread.interrupt()
self._ping_thread = None
if self._hardware_metric_thread:
self._hardware_metric_thread.interrupt()
self._hardware_metric_thread = None
if self._aborting_thread:
self._aborting_thread.shutdown()
self._aborting_thread = None
if self._stdout_uploader:
self._stdout_uploader.close()
if self._stderr_uploader:
self._stderr_uploader.close()
if self._logger and self._logger_handler:
self._logger.removeHandler(self._logger_handler)
sys.excepthook = self._previous_uncaught_exception_handler
def _set_uncaught_exception_handler(self):
def exception_handler(exc_type, exc_val, exc_tb):
self._experiment.stop("\n".join(traceback.format_tb(exc_tb)) + "\n" + repr(exc_val))
sys.__excepthook__(exc_type, exc_val, exc_tb)
self._uncaught_exception_handler = exception_handler
self._previous_uncaught_exception_handler = sys.excepthook
sys.excepthook = exception_handler
def _run_aborting_thread(self, abort_callback):
if abort_callback is not None:
abort_impl = CustomAbortImpl(abort_callback)
elif not is_ipython():
abort_impl = DefaultAbortImpl(pid=os.getpid())
else:
return
websocket_factory = self._backend.websockets_factory(
project_id=self._experiment._project.internal_id,
experiment_id=self._experiment.internal_id,
)
if not websocket_factory:
return
self._aborting_thread = AbortingThread(
websocket_factory=websocket_factory,
abort_impl=abort_impl,
experiment=self._experiment,
)
self._aborting_thread.start()
def _run_monitoring_thread(self):
self._ping_thread = PingThread(backend=self._backend, experiment=self._experiment)
self._ping_thread.start()
def _run_hardware_metrics_reporting_thread(self):
gauge_mode = GaugeMode.CGROUP if in_docker() else GaugeMode.SYSTEM
metric_service = MetricServiceFactory(self._backend, os.environ).create(
gauge_mode=gauge_mode,
experiment=self._experiment,
reference_timestamp=time.time(),
)
self._hardware_metric_thread = HardwareMetricReportingThread(
metric_service=metric_service, metric_sending_interval_seconds=10
)
self._hardware_metric_thread.start()
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
_logger = logging.getLogger(__name__)
class MessageType(object):
CHECKPOINT_CREATED = "CHECKPOINT_CREATED"
def send_checkpoint_created(notebook_id, notebook_path, checkpoint_id):
"""Send checkpoint created message.
Args:
notebook_id (:obj:`str`): The notebook's id.
notebook_path (:obj:`str`): The notebook's path.
checkpoint_id (:obj:`str`): The checkpoint's path.
Raises:
`ImportError`: If ipykernel is not available.
"""
neptune_comm = _get_comm()
neptune_comm.send(
data=dict(
message_type=MessageType.CHECKPOINT_CREATED,
data=dict(
checkpoint_id=checkpoint_id,
notebook_id=notebook_id,
notebook_path=notebook_path,
),
)
)
def _get_comm():
from ipykernel.comm import Comm
return Comm(target_name="neptune_comm")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import threading
from neptune.common.utils import is_ipython
from neptune.legacy.internal.notebooks.comm import send_checkpoint_created
_logger = logging.getLogger(__name__)
_checkpoints_lock = threading.Lock()
_checkpoints = dict()
def create_checkpoint(backend, notebook_id, notebook_path):
if is_ipython():
import IPython
ipython = IPython.core.getipython.get_ipython()
execution_count = -1
if ipython.kernel is not None:
execution_count = ipython.kernel.execution_count
with _checkpoints_lock:
if execution_count in _checkpoints:
return _checkpoints[execution_count]
checkpoint = backend.create_checkpoint(notebook_id, notebook_path)
if ipython is not None and ipython.kernel is not None:
send_checkpoint_created(
notebook_id=notebook_id,
notebook_path=notebook_path,
checkpoint_id=checkpoint.id,
)
_checkpoints[execution_count] = checkpoint
return checkpoint
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import unicode_literals
import re
from datetime import datetime
from neptune.legacy.internal.channels.channels import (
ChannelNamespace,
ChannelType,
ChannelValue,
)
class ChannelWriter(object):
__SPLIT_PATTERN = re.compile(r"[\n\r]{1,2}")
def __init__(self, experiment, channel_name, channel_namespace=ChannelNamespace.USER):
self._experiment = experiment
self._channel_name = channel_name
self._channel_namespace = channel_namespace
self._data = None
self._x_offset = TimeOffsetGenerator(self._experiment.get_system_properties()["created"])
def write(self, data):
if self._data is None:
self._data = data
else:
self._data += data
lines = self.__SPLIT_PATTERN.split(self._data)
for line in lines[:-1]:
value = ChannelValue(x=self._x_offset.next(), y=dict(text_value=str(line)), ts=None)
self._experiment._channels_values_sender.send(
channel_name=self._channel_name,
channel_type=ChannelType.TEXT.value,
channel_value=value,
channel_namespace=self._channel_namespace,
)
self._data = lines[-1]
class TimeOffsetGenerator(object):
def __init__(self, start):
self._start = start
self._previous_millis_from_start = None
def next(self):
"""
This method returns the number of milliseconds from start.
It returns a float, with microsecond granularity.
Since on Windows, datetime.now() has actually a millisecond granularity,
we remember the last returned value and in case of a collision, we add a microsecond.
"""
millis_from_start = (datetime.now(tz=self._start.tzinfo) - self._start).total_seconds() * 1000
if self._previous_millis_from_start is not None and self._previous_millis_from_start >= millis_from_start:
microsecond = 0.001
self._previous_millis_from_start = self._previous_millis_from_start + microsecond
else:
self._previous_millis_from_start = millis_from_start
return self._previous_millis_from_start
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from neptune.legacy.internal.channels.channels import ChannelNamespace
from neptune.legacy.internal.streams.channel_writer import ChannelWriter
class StdStreamWithUpload(object):
def __init__(self, experiment, channel_name, stream):
self._channel = experiment._get_channel(channel_name, "text", ChannelNamespace.SYSTEM)
self._channel_writer = ChannelWriter(experiment, channel_name, ChannelNamespace.SYSTEM)
self._stream = stream
def write(self, data):
self._stream.write(data)
try:
self._channel_writer.write(data)
except: # noqa: E722
pass
def isatty(self):
return hasattr(self._stream, "isatty") and self._stream.isatty()
def flush(self):
self._stream.flush()
def fileno(self):
return self._stream.fileno()
class StdOutWithUpload(StdStreamWithUpload):
def __init__(self, experiment):
super(StdOutWithUpload, self).__init__(experiment, "stdout", sys.__stdout__)
sys.stdout = self
def close(self):
sys.stdout = sys.__stdout__
class StdErrWithUpload(StdStreamWithUpload):
def __init__(self, experiment):
super(StdErrWithUpload, self).__init__(experiment, "stderr", sys.__stderr__)
sys.stderr = self
def close(self):
sys.stderr = sys.__stderr__
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
from websocket import WebSocketConnectionClosedException
from neptune.legacy.internal.threads.neptune_thread import NeptuneThread
from neptune.legacy.internal.websockets.message import MessageType
from neptune.legacy.internal.websockets.websocket_message_processor import WebsocketMessageProcessor
class AbortingThread(NeptuneThread):
def __init__(self, websocket_factory, abort_impl, experiment):
super(AbortingThread, self).__init__(is_daemon=True)
self._abort_message_processor = AbortMessageProcessor(abort_impl, experiment)
self._ws_client = websocket_factory.create(shutdown_condition=threading.Event())
def run(self):
try:
while self.should_continue_running():
raw_message = self._ws_client.recv()
self._abort_message_processor.run(raw_message)
except WebSocketConnectionClosedException:
pass
def shutdown(self):
self.interrupt()
self._ws_client.shutdown()
@staticmethod
def _is_heartbeat(message):
return message.strip() == ""
class AbortMessageProcessor(WebsocketMessageProcessor):
def __init__(self, abort_impl, experiment):
super(AbortMessageProcessor, self).__init__()
self._abort_impl = abort_impl
self._experiment = experiment
self.received_abort_message = False
def _process_message(self, message):
if message.get_type() == MessageType.STOP:
self._experiment.stop()
self._abort()
elif message.get_type() == MessageType.ABORT:
self._experiment.stop("Remotely aborted")
self._abort()
def _abort(self):
self.received_abort_message = True
self._abort_impl.abort()
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
from bravado.exception import HTTPError
from neptune.legacy.exceptions import NeptuneException
from neptune.legacy.internal.threads.neptune_thread import NeptuneThread
_logger = logging.getLogger(__name__)
class HardwareMetricReportingThread(NeptuneThread):
def __init__(self, metric_service, metric_sending_interval_seconds):
super(HardwareMetricReportingThread, self).__init__(is_daemon=True)
self.__metric_service = metric_service
self.__metric_sending_interval_seconds = metric_sending_interval_seconds
def run(self):
try:
while self.should_continue_running():
before = time.time()
try:
self.__metric_service.report_and_send(timestamp=time.time())
except (NeptuneException, HTTPError):
_logger.exception("Unexpected HTTP error in hardware metric reporting thread.")
reporting_duration = time.time() - before
time.sleep(max(0, self.__metric_sending_interval_seconds - reporting_duration))
except Exception as e:
_logger.debug("Unexpected error in hardware metric reporting thread: %s", e)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import six
class NeptuneThread(threading.Thread):
def __init__(self, is_daemon):
super(NeptuneThread, self).__init__(target=self.run)
self.daemon = is_daemon
self._interrupted = threading.Event()
def should_continue_running(self):
if six.PY2:
all_threads = threading.enumerate()
main_thread_is_alive = any(t.__class__ is threading._MainThread and t.is_alive() for t in all_threads)
else:
main_thread_is_alive = threading.main_thread().is_alive()
return not self._interrupted.is_set() and main_thread_is_alive
def interrupt(self):
self._interrupted.set()
def run(self):
raise NotImplementedError()
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from bravado.exception import HTTPUnprocessableEntity
from neptune.legacy.internal.threads.neptune_thread import NeptuneThread
_logger = logging.getLogger(__name__)
class PingThread(NeptuneThread):
PING_INTERVAL_SECS = 5
def __init__(self, backend, experiment):
super(PingThread, self).__init__(is_daemon=True)
self.__backend = backend
self.__experiment = experiment
def run(self):
while self.should_continue_running():
try:
self.__backend.ping_experiment(self.__experiment)
except HTTPUnprocessableEntity:
# A 422 error means that we tried to ping the job after marking it as completed.
# In this case, this thread is not needed anymore.
break
except Exception:
_logger.exception("Unexpected error in ping thread.")
self._interrupted.wait(self.PING_INTERVAL_SECS)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from collections import namedtuple
from neptune.attributes import constants as alpha_consts
from neptune.internal import operation as alpha_operation
from neptune.internal.backends.api_model import AttributeType as AlphaAttributeType
# Alpha equivalent of old api's `KeyValueProperty` used in `Experiment.properties`
from neptune.internal.operation import ImageValue
from neptune.legacy.exceptions import NeptuneException
from neptune.legacy.internal.channels.channels import (
ChannelType,
ChannelValueType,
)
AlphaKeyValueProperty = namedtuple("AlphaKeyValueProperty", ["key", "value"])
class AlphaAttributeWrapper(abc.ABC):
"""It's simple wrapper for `AttributeDTO`."""
_allowed_atribute_types = list()
def __init__(self, attribute):
"""Expects `AttributeDTO`"""
assert self._allowed_atribute_types is not None
if not self.is_valid_attribute(attribute):
raise NeptuneException(f"Invalid channel attribute type: {attribute.type}")
self._attribute = attribute
@classmethod
def is_valid_attribute(cls, attribute):
"""Checks if attribute can be wrapped by particular descendant of this class"""
return attribute.type in cls._allowed_atribute_types
@property
def _properties(self):
"""Returns proper attribute property according to type"""
return getattr(self._attribute, f"{self._attribute.type}Properties")
class AlphaPropertyDTO(AlphaAttributeWrapper):
"""It's simple wrapper for `AttributeDTO` objects which uses alpha variables attributes to fake properties.
Alpha leaderboard doesn't have `KeyValueProperty` since it doesn't support properties at all,
so we do need fake `KeyValueProperty` class for backward compatibility with old client's code."""
_allowed_atribute_types = [
AlphaAttributeType.STRING.value,
]
@classmethod
def is_valid_attribute(cls, attribute):
"""Checks if attribute can be used as property"""
has_valid_type = super().is_valid_attribute(attribute)
is_in_properties_space = attribute.name.startswith(alpha_consts.PROPERTIES_ATTRIBUTE_SPACE)
return has_valid_type and is_in_properties_space
@property
def key(self):
return self._properties.attributeName.split("/", 1)[-1]
@property
def value(self):
return self._properties.value
class AlphaParameterDTO(AlphaAttributeWrapper):
"""It's simple wrapper for `AttributeDTO` objects which uses alpha variables attributes to fake properties.
Alpha leaderboard doesn't have `KeyValueProperty` since it doesn't support properties at all,
so we do need fake `KeyValueProperty` class for backward compatibility with old client's code."""
_allowed_atribute_types = [
AlphaAttributeType.FLOAT.value,
AlphaAttributeType.STRING.value,
AlphaAttributeType.DATETIME.value,
]
@classmethod
def is_valid_attribute(cls, attribute):
"""Checks if attribute can be used as property"""
has_valid_type = super().is_valid_attribute(attribute)
is_in_parameters_space = attribute.name.startswith(alpha_consts.PARAMETERS_ATTRIBUTE_SPACE)
return has_valid_type and is_in_parameters_space
@property
def name(self):
return self._properties.attributeName.split("/", 1)[-1]
@property
def value(self):
return self._properties.value
@property
def parameterType(self):
return "double" if self._properties.attributeType == AlphaAttributeType.FLOAT.value else "string"
class AlphaChannelDTO(AlphaAttributeWrapper):
"""It's simple wrapper for `AttributeDTO` objects which uses alpha series attributes to fake channels.
Alpha leaderboard doesn't have `ChannelDTO` since it doesn't support channels at all,
so we do need fake `ChannelDTO` class for backward compatibility with old client's code."""
_allowed_atribute_types = [
AlphaAttributeType.FLOAT_SERIES.value,
AlphaAttributeType.STRING_SERIES.value,
AlphaAttributeType.IMAGE_SERIES.value,
]
@property
def id(self):
return self._properties.attributeName
@property
def name(self):
return self._properties.attributeName.split("/", 1)[-1]
@property
def channelType(self):
attr_type = self._properties.attributeType
if attr_type == AlphaAttributeType.FLOAT_SERIES.value:
return ChannelType.NUMERIC.value
elif attr_type == AlphaAttributeType.STRING_SERIES.value:
return ChannelType.TEXT.value
elif attr_type == AlphaAttributeType.IMAGE_SERIES.value:
return ChannelType.IMAGE.value
@property
def x(self):
return self._properties.lastStep
@property
def y(self):
if self.channelType == ChannelType.IMAGE.value:
# We do not store last value for image series
return None
return self._properties.last
class AlphaChannelWithValueDTO:
"""Alpha leaderboard doesn't have `ChannelWithValueDTO` since it doesn't support channels at all,
so we do need fake `ChannelWithValueDTO` class for backward compatibility with old client's code"""
def __init__(self, channelId: str, channelName: str, channelType: str, x, y):
self._ch_id = channelId
self._ch_name = channelName
self._ch_type = channelType
self._x = x
self._y = y
@property
def channelId(self):
return self._ch_id
@property
def channelName(self):
return self._ch_name
@property
def channelType(self):
return self._ch_type
@property
def x(self):
return self._x
@x.setter
def x(self, x):
self._x = x
@property
def y(self):
return self._y
@y.setter
def y(self, y):
self._y = y
def _map_using_dict(el, el_name, source_dict) -> alpha_operation.Operation:
try:
return source_dict[el]
except KeyError as e:
raise NeptuneException(f"We're not supporting {el} {el_name}.") from e
def channel_type_to_operation(channel_type: ChannelType) -> alpha_operation.Operation:
_channel_type_to_operation = {
ChannelType.TEXT: alpha_operation.LogStrings,
ChannelType.NUMERIC: alpha_operation.LogFloats,
ChannelType.IMAGE: alpha_operation.LogImages,
}
return _map_using_dict(channel_type, "channel type", _channel_type_to_operation)
def channel_type_to_clear_operation(
channel_type: ChannelType,
) -> alpha_operation.Operation:
_channel_type_to_operation = {
ChannelType.TEXT: alpha_operation.ClearStringLog,
ChannelType.NUMERIC: alpha_operation.ClearFloatLog,
ChannelType.IMAGE: alpha_operation.ClearImageLog,
}
return _map_using_dict(channel_type, "channel type", _channel_type_to_operation)
def channel_value_type_to_operation(
channel_value_type: ChannelValueType,
) -> alpha_operation.Operation:
_channel_value_type_to_operation = {
ChannelValueType.TEXT_VALUE: alpha_operation.LogStrings,
ChannelValueType.NUMERIC_VALUE: alpha_operation.LogFloats,
ChannelValueType.IMAGE_VALUE: alpha_operation.LogImages,
}
return _map_using_dict(channel_value_type, "channel value type", _channel_value_type_to_operation)
def deprecated_img_to_alpha_image(img: dict) -> ImageValue:
return ImageValue(data=img["data"], name=img["name"], description=img["description"])
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from functools import wraps
from neptune.common.warnings import warn_once
def legacy_client_deprecation(func):
@wraps(func)
def inner(*args, **kwargs):
warn_once(
message="You're using a legacy version of Neptune client."
" It will be moved to `neptune.legacy` as of `neptune-client==1.0.0`."
)
return func(*args, **kwargs)
return inner
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from functools import wraps
from http.client import (
NOT_FOUND,
UNPROCESSABLE_ENTITY,
)
from requests.exceptions import HTTPError
from neptune.legacy.api_exceptions import (
ExperimentNotFound,
StorageLimitReached,
)
from neptune.legacy.exceptions import NeptuneException
_logger = logging.getLogger(__name__)
def extract_response_field(response, field_name):
if response is None:
return None
try:
response_json = response.json()
if isinstance(response_json, dict):
return response_json.get(field_name)
else:
_logger.debug("HTTP response is not a dict: %s", str(response_json))
return None
except ValueError as e:
_logger.debug("Failed to parse HTTP response: %s", e)
return None
def handle_quota_limits(f):
"""Wrapper for functions which may request for non existing experiment or cause quota limit breach
Limitations:
Decorated function must be called with experiment argument like this fun(..., experiment=<experiment>, ...)"""
@wraps(f)
def handler(*args, **kwargs):
experiment = kwargs.get("experiment")
if experiment is None:
raise NeptuneException(
"This function must be called with experiment passed by name,"
" like this fun(..., experiment=<experiment>, ...)"
)
try:
return f(*args, **kwargs)
except HTTPError as e:
if e.response.status_code == NOT_FOUND:
raise ExperimentNotFound(
experiment_short_id=experiment.id,
project_qualified_name=experiment._project.full_id,
)
if e.response.status_code == UNPROCESSABLE_ENTITY and extract_response_field(
e.response, "title"
).startswith("Storage limit reached in organization: "):
raise StorageLimitReached()
raise
return handler
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from functools import wraps
from http.client import (
NOT_FOUND,
UNPROCESSABLE_ENTITY,
)
from requests.exceptions import HTTPError
from neptune.legacy.api_exceptions import (
ExperimentNotFound,
StorageLimitReached,
)
from neptune.legacy.exceptions import NeptuneException
_logger = logging.getLogger(__name__)
def extract_response_field(response, field_name):
if response is None:
return None
try:
response_json = response.json()
if isinstance(response_json, dict):
return response_json.get(field_name)
else:
_logger.debug("HTTP response is not a dict: %s", str(response_json))
return None
except ValueError as e:
_logger.debug("Failed to parse HTTP response: %s", e)
return None
def handle_quota_limits(f):
"""Wrapper for functions which may request for non existing experiment or cause quota limit breach
Limitations:
Decorated function must be called with experiment argument like this fun(..., experiment=<experiment>, ...)"""
@wraps(f)
def handler(*args, **kwargs):
experiment = kwargs.get("experiment")
if experiment is None:
raise NeptuneException(
"This function must be called with experiment passed by name,"
" like this fun(..., experiment=<experiment>, ...)"
)
try:
return f(*args, **kwargs)
except HTTPError as e:
if e.response.status_code == NOT_FOUND:
raise ExperimentNotFound(
experiment_short_id=experiment.id,
project_qualified_name=experiment._project.full_id,
)
if e.response.status_code == UNPROCESSABLE_ENTITY and extract_response_field(
e.response, "title"
).startswith("Storage limit reached in organization: "):
raise StorageLimitReached()
raise
return handler
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import os
import numpy
import six
from PIL import Image
from neptune.legacy.exceptions import (
FileNotFound,
InvalidChannelValue,
)
def get_image_content(image):
if isinstance(image, six.string_types):
if not os.path.exists(image):
raise FileNotFound(image)
with open(image, "rb") as image_file:
return image_file.read()
elif isinstance(image, numpy.ndarray):
return _get_numpy_as_image(image)
elif isinstance(image, Image.Image):
return _get_pil_image_data(image)
else:
try:
from matplotlib import figure
if isinstance(image, figure.Figure):
return _get_figure_as_image(image)
except ImportError:
pass
try:
from torch import Tensor as TorchTensor
if isinstance(image, TorchTensor):
return _get_numpy_as_image(image.detach().numpy())
except ImportError:
pass
try:
from tensorflow import Tensor as TensorflowTensor
if isinstance(image, TensorflowTensor):
return _get_numpy_as_image(image.numpy())
except ImportError:
pass
raise InvalidChannelValue(expected_type="image", actual_type=type(image).__name__)
def _get_figure_as_image(figure):
if figure.__class__.__name__ == "Axes":
figure = figure.figure
with io.BytesIO() as image_buffer:
figure.savefig(image_buffer, format="png", bbox_inches="tight")
return image_buffer.getvalue()
def _get_pil_image_data(image):
with io.BytesIO() as image_buffer:
image.save(image_buffer, format="PNG")
return image_buffer.getvalue()
def _get_numpy_as_image(array):
array = array.copy() # prevent original array from modifying
array *= 255
shape = array.shape
if len(shape) == 2:
return _get_pil_image_data(Image.fromarray(array.astype(numpy.uint8)))
if len(shape) == 3:
if shape[2] == 1:
array2d = numpy.array([[col[0] for col in row] for row in array])
return _get_pil_image_data(Image.fromarray(array2d.astype(numpy.uint8)))
if shape[2] in (3, 4):
return _get_pil_image_data(Image.fromarray(array.astype(numpy.uint8)))
raise ValueError(
"Incorrect size of numpy.ndarray. Should be 2-dimensional or"
" 3-dimensional with 3rd dimension of size 1, 3 or 4."
)
#
# Copyright (c) 2021, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import os.path
import sys
from typing import (
List,
Optional,
Tuple,
)
from neptune.common.storage.storage_utils import normalize_file_name
from neptune.common.utils import (
glob,
is_ipython,
)
def get_source_code_to_upload(
upload_source_files: Optional[List[str]],
) -> Tuple[str, List[Tuple[str, str]]]:
source_target_pairs = []
if is_ipython():
main_file = None
entrypoint = None
else:
main_file = sys.argv[0]
entrypoint = main_file or None
if upload_source_files is None:
if main_file is not None and os.path.isfile(main_file):
entrypoint = normalize_file_name(os.path.basename(main_file))
source_target_pairs = [
(
os.path.abspath(main_file),
normalize_file_name(os.path.basename(main_file)),
)
]
else:
expanded_source_files = set()
for filepath in upload_source_files:
expanded_source_files |= set(glob(filepath))
if sys.version_info.major < 3 or (sys.version_info.major == 3 and sys.version_info.minor < 5):
for filepath in expanded_source_files:
if filepath.startswith(".."):
raise ValueError("You need to have Python 3.5 or later to use paths outside current directory.")
source_target_pairs.append((os.path.abspath(filepath), normalize_file_name(filepath)))
else:
absolute_paths = []
for filepath in expanded_source_files:
absolute_paths.append(os.path.abspath(filepath))
try:
common_source_root = os.path.commonpath(absolute_paths)
except ValueError:
for absolute_path in absolute_paths:
source_target_pairs.append((absolute_path, normalize_file_name(absolute_path)))
else:
if os.path.isfile(common_source_root):
common_source_root = os.path.dirname(common_source_root)
if common_source_root.startswith(os.getcwd() + os.sep):
common_source_root = os.getcwd()
for absolute_path in absolute_paths:
source_target_pairs.append(
(
absolute_path,
normalize_file_name(os.path.relpath(absolute_path, common_source_root)),
)
)
return entrypoint, source_target_pairs
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.attributes.constants import (
SIGNAL_TYPE_ABORT,
SIGNAL_TYPE_STOP,
)
class Message(object):
def __init__(self):
pass
MESSAGE_TYPE = "messageType"
MESSAGE_NEW_TYPE = "type"
MESSAGE_BODY = "messageBody"
MESSAGE_NEW_BODY = "body"
@classmethod
def from_json(cls, json_value):
message_type = json_value.get(Message.MESSAGE_TYPE) or json_value.get(Message.MESSAGE_NEW_TYPE)
message_body = json_value.get(Message.MESSAGE_BODY) or json_value.get(Message.MESSAGE_NEW_BODY)
if message_type == SIGNAL_TYPE_STOP:
message_type = MessageType.STOP
elif message_type == SIGNAL_TYPE_ABORT:
message_type = MessageType.ABORT
if message_type in MessageClassRegistry.MESSAGE_CLASSES:
return MessageClassRegistry.MESSAGE_CLASSES[message_type].from_json(message_body)
else:
raise ValueError("Unknown message type '{}'!".format(message_type))
@classmethod
def get_type(cls):
raise NotImplementedError()
def body_to_json(self):
raise NotImplementedError()
class AbortMessage(Message):
@classmethod
def get_type(cls):
return MessageType.ABORT
@classmethod
def from_json(cls, json_value):
return AbortMessage()
def body_to_json(self):
return None
class StopMessage(Message):
@classmethod
def get_type(cls):
return MessageType.STOP
@classmethod
def from_json(cls, json_value):
return StopMessage()
def body_to_json(self):
return None
class ActionInvocationMessage(Message):
_ACTION_ID_JSON_KEY = "actionId"
_ACTION_INVOCATION_ID_JSON_KEY = "actionInvocationId"
_ARGUMENT_JSON_KEY = "argument"
def __init__(self, action_id, action_invocation_id, argument):
super(ActionInvocationMessage, self).__init__()
self.action_id = action_id
self.action_invocation_id = action_invocation_id
self.argument = argument
@classmethod
def get_type(cls):
return MessageType.ACTION_INVOCATION
@classmethod
def from_json(cls, json_value):
field_names = [
cls._ACTION_ID_JSON_KEY,
cls._ACTION_INVOCATION_ID_JSON_KEY,
cls._ARGUMENT_JSON_KEY,
]
return ActionInvocationMessage(*[json_value[field] for field in field_names])
def body_to_json(self):
return {
self._ACTION_ID_JSON_KEY: self.action_id,
self._ACTION_INVOCATION_ID_JSON_KEY: self.action_invocation_id,
self._ARGUMENT_JSON_KEY: self.argument,
}
class MessageType(object):
NEW_CHANNEL_VALUES = "NewChannelValues"
ABORT = "Abort"
STOP = "Stop"
ACTION_INVOCATION = "InvokeAction"
class MessageClassRegistry(object):
def __init__(self):
pass
MESSAGE_CLASSES = dict([(cls.get_type(), cls) for cls in Message.__subclasses__()])
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.common.websockets.reconnecting_websocket import ReconnectingWebsocket
class ReconnectingWebsocketFactory(object):
def __init__(self, backend, url):
self._backend = backend
self._url = url
def create(self, shutdown_condition):
return ReconnectingWebsocket(
url=self._url,
oauth2_session=self._backend.authenticator.auth.session,
shutdown_event=shutdown_condition,
proxies=self._backend.proxies,
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from neptune.legacy.internal.websockets.message import Message
class WebsocketMessageProcessor(object):
def __init__(self):
pass
def run(self, raw_message):
# Atmosphere framework sends heartbeat messages every minute, we have to ignore them
if raw_message is not None and not self._is_heartbeat(raw_message):
message = Message.from_json(json.loads(raw_message))
self._process_message(message)
def _process_message(self, message):
raise NotImplementedError()
@staticmethod
def _is_heartbeat(message):
return message.strip() == ""
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class ChannelWithLastValue:
def __init__(self, channel_with_value_dto):
self.channel_with_value_dto = channel_with_value_dto
@property
def id(self):
return self.channel_with_value_dto.channelId
@property
def name(self):
return self.channel_with_value_dto.channelName
@property
def type(self):
return self.channel_with_value_dto.channelType
@property
def x(self):
return self.channel_with_value_dto.x
@x.setter
def x(self, x):
self.channel_with_value_dto.x = x
@property
def trimmed_y(self):
return self.y[:255] if self.type == "text" else self.y
@property
def y(self):
return self.channel_with_value_dto.y
@y.setter
def y(self, y):
self.channel_with_value_dto.y = y
class LeaderboardEntry(object):
def __init__(self, project_leaderboard_entry_dto):
self.project_leaderboard_entry_dto = project_leaderboard_entry_dto
@property
def id(self):
return self.project_leaderboard_entry_dto.shortId
@property
def name(self):
return self.project_leaderboard_entry_dto.name
@property
def state(self):
return self.project_leaderboard_entry_dto.state
@property
def internal_id(self):
return self.project_leaderboard_entry_dto.id
@property
def project_full_id(self):
return "{org_name}/{project_name}".format(
org_name=self.project_leaderboard_entry_dto.organizationName,
project_name=self.project_leaderboard_entry_dto.projectName,
)
@property
def system_properties(self):
entry = self.project_leaderboard_entry_dto
return {
"id": entry.shortId,
"name": entry.name,
"created": entry.timeOfCreation,
"finished": entry.timeOfCompletion,
"running_time": entry.runningTime,
"owner": entry.owner,
"size": entry.size,
"tags": entry.tags,
"notes": entry.description,
}
@property
def channels(self):
return [ChannelWithLastValue(ch) for ch in self.project_leaderboard_entry_dto.channelsLastValues]
def add_channel(self, channel):
self.project_leaderboard_entry_dto.channelsLastValues.append(channel.channel_with_value_dto)
@property
def channels_dict_by_name(self):
return dict((ch.name, ch) for ch in self.channels)
@property
def parameters(self):
return dict((p.name, p.value) for p in self.project_leaderboard_entry_dto.parameters)
@property
def properties(self):
return dict((p.key, p.value) for p in self.project_leaderboard_entry_dto.properties)
@property
def tags(self):
return self.project_leaderboard_entry_dto.tags
class Point(object):
def __init__(self, point_dto):
self.point_dto = point_dto
@property
def x(self):
return self.point_dto.x
@property
def numeric_y(self):
return self.point_dto.y.numericValue
class Points(object):
def __init__(self, point_dtos):
self.point_dtos = point_dtos
@property
def xs(self):
return [p.x for p in self.point_dtos]
@property
def numeric_ys(self):
return [p.y.numericValue for p in self.point_dtos]
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from neptune.common.utils import validate_notebook_path
from neptune.legacy.internal.utils.deprecation import legacy_client_deprecation
class Notebook(object):
"""It contains all the information about a Neptune Notebook
Args:
backend (:class:`~neptune.ApiClient`): A ApiClient object
project (:class:`~neptune.projects.Project`): Project object
_id (:obj:`str`): Notebook id
owner (:obj:`str`): Creator of the notebook is the Notebook owner
Examples:
.. code:: python3
# Create a notebook in Neptune.
notebook = project.create_notebook('data_exploration.ipynb')
"""
@legacy_client_deprecation
def __init__(self, backend, project, _id, owner):
self._backend = backend
self._project = project
self._id = _id
self._owner = owner
@property
def id(self):
return self._id
@property
def owner(self):
return self._owner
def add_checkpoint(self, file_path):
"""Uploads new checkpoint of the notebook to Neptune
Args:
file_path (:obj:`str`): File path containing notebook contents
Example:
.. code:: python3
# Create a notebook.
notebook = project.create_notebook('file.ipynb')
# Change content in your notebook & save it
# Upload new checkpoint
notebook.add_checkpoint('file.ipynb')
"""
validate_notebook_path(file_path)
with open(file_path) as f:
return self._backend.create_checkpoint(self.id, os.path.abspath(file_path), f)
def get_path(self):
"""Returns the path used to upload the current checkpoint of this notebook
Returns:
:obj:`str`: path of the current checkpoint
"""
return self._backend.get_last_checkpoint(self._project, self._id).path
def get_name(self):
"""Returns the name used to upload the current checkpoint of this notebook
Returns:
:obj:`str`: the name of current checkpoint
"""
return self._backend.get_last_checkpoint(self._project, self._id).name
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# flake8: noqa
from neptune.common.oauth import (
NeptuneAuth,
NeptuneAuthenticator,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# flake8: noqa
from neptune.common.patterns import PROJECT_QUALIFIED_NAME_PATTERN
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import atexit
import logging
import os
import os.path
import threading
from platform import node as get_hostname
import click
import pandas as pd
import six
from neptune.common.utils import (
as_list,
discover_git_repo_location,
get_git_info,
map_keys,
)
from neptune.legacy.envs import (
NOTEBOOK_ID_ENV_NAME,
NOTEBOOK_PATH_ENV_NAME,
)
from neptune.legacy.exceptions import NeptuneNoExperimentContextException
from neptune.legacy.experiments import Experiment
from neptune.legacy.internal.abort import DefaultAbortImpl
from neptune.legacy.internal.notebooks.notebooks import create_checkpoint
from neptune.legacy.internal.utils.deprecation import legacy_client_deprecation
from neptune.legacy.internal.utils.source_code import get_source_code_to_upload
_logger = logging.getLogger(__name__)
class Project(object):
"""A class for storing information and managing Neptune project.
Args:
backend (:class:`~neptune.ApiClient`, required): A ApiClient object.
internal_id (:obj:`str`, required): ID of the project.
namespace (:obj:`str`, required): It can either be your workspace or user name.
name (:obj:`str`, required): project name.
Note:
``namespace`` and ``name`` joined together with ``/`` form ``project_qualified_name``.
"""
@legacy_client_deprecation
def __init__(self, backend, internal_id, namespace, name):
self._backend = backend
self.internal_id = internal_id
self.namespace = namespace
self.name = name
self._experiments_stack = []
self.__lock = threading.RLock()
atexit.register(self._shutdown_hook)
def get_members(self):
"""Retrieve a list of project members.
Returns:
:obj:`list` of :obj:`str` - A list of usernames of project members.
Examples:
.. code:: python3
project = session.get_projects('neptune-ai')['neptune-ai/Salt-Detection']
project.get_members()
"""
project_members = self._backend.get_project_members(self.internal_id)
return [member.registeredMemberInfo.username for member in project_members if member.registeredMemberInfo]
def get_experiments(self, id=None, state=None, owner=None, tag=None, min_running_time=None):
"""Retrieve list of experiments matching the specified criteria.
All parameters are optional, each of them specifies a single criterion.
Only experiments matching all of the criteria will be returned.
Args:
id (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| An experiment id like ``'SAN-1'`` or list of ids like ``['SAN-1', 'SAN-2']``.
| Matching any element of the list is sufficient to pass criterion.
state (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| An experiment state like ``'succeeded'`` or list of states like ``['succeeded', 'running']``.
| Possible values: ``'running'``, ``'succeeded'``, ``'failed'``, ``'aborted'``.
| Matching any element of the list is sufficient to pass criterion.
owner (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| *Username* of the experiment owner (User who created experiment is an owner) like ``'josh'``
or list of owners like ``['frederic', 'josh']``.
| Matching any element of the list is sufficient to pass criterion.
tag (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| An experiment tag like ``'lightGBM'`` or list of tags like ``['pytorch', 'cycleLR']``.
| Only experiments that have all specified tags will match this criterion.
min_running_time (:obj:`int`, optional, default is ``None``):
Minimum running time of an experiment in seconds, like ``2000``.
Returns:
:obj:`list` of :class:`~neptune.experiments.Experiment` objects.
Examples:
.. code:: python3
# Fetch a project
project = session.get_projects('neptune-ai')['neptune-ai/Salt-Detection']
# Get list of experiments
project.get_experiments(state=['aborted'], owner=['neyo'], min_running_time=100000)
# Example output:
# [Experiment(SAL-1609),
# Experiment(SAL-1765),
# Experiment(SAL-1941),
# Experiment(SAL-1960),
# Experiment(SAL-2025)]
"""
leaderboard_entries = self._fetch_leaderboard(id, state, owner, tag, min_running_time)
return [Experiment(self._backend, self, entry.id, entry.internal_id) for entry in leaderboard_entries]
def get_leaderboard(self, id=None, state=None, owner=None, tag=None, min_running_time=None):
"""Fetch Neptune experiments view as pandas ``DataFrame``.
**returned DataFrame**
| In the returned ``DataFrame`` each *row* is an experiment and *columns* represent all system properties,
numeric and text logs, parameters and properties in these experiments.
| Note that, returned ``DataFrame`` does not contain all columns across the entire project.
| Some columns may be empty, since experiments may define various logs, properties, etc.
| For each log at most one (the last one) value is returned per experiment.
| Text values are trimmed to 255 characters.
**about parameters**
All parameters are optional, each of them specifies a single criterion.
Only experiments matching all of the criteria will be returned.
Args:
id (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| An experiment id like ``'SAN-1'`` or list of ids like ``['SAN-1', 'SAN-2']``.
| Matching any element of the list is sufficient to pass criterion.
state (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| An experiment state like ``'succeeded'`` or list of states like ``['succeeded', 'running']``.
| Possible values: ``'running'``, ``'succeeded'``, ``'failed'``, ``'aborted'``.
| Matching any element of the list is sufficient to pass criterion.
owner (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| *Username* of the experiment owner (User who created experiment is an owner) like ``'josh'``
or list of owners like ``['frederic', 'josh']``.
| Matching any element of the list is sufficient to pass criterion.
tag (:obj:`str` or :obj:`list` of :obj:`str`, optional, default is ``None``):
| An experiment tag like ``'lightGBM'`` or list of tags like ``['pytorch', 'cycleLR']``.
| Only experiments that have all specified tags will match this criterion.
min_running_time (:obj:`int`, optional, default is ``None``):
Minimum running time of an experiment in seconds, like ``2000``.
Returns:
:obj:`pandas.DataFrame` - Fetched Neptune experiments view.
Examples:
.. code:: python3
# Fetch a project.
project = session.get_projects('neptune-ai')['neptune-ai/Salt-Detection']
# Get DataFrame that resembles experiment view.
project.get_leaderboard(state=['aborted'], owner=['neyo'], min_running_time=100000)
"""
leaderboard_entries = self._fetch_leaderboard(id, state, owner, tag, min_running_time)
def make_row(entry):
channels = dict(("channel_{}".format(ch.name), ch.trimmed_y) for ch in entry.channels)
parameters = map_keys("parameter_{}".format, entry.parameters)
properties = map_keys("property_{}".format, entry.properties)
r = {}
r.update(entry.system_properties)
r.update(channels)
r.update(parameters)
r.update(properties)
return r
rows = ((n, make_row(e)) for (n, e) in enumerate(leaderboard_entries))
df = pd.DataFrame.from_dict(data=dict(rows), orient="index")
df = df.reindex(self._sort_leaderboard_columns(df.columns), axis="columns")
return df
def create_experiment(
self,
name=None,
description=None,
params=None,
properties=None,
tags=None,
upload_source_files=None,
abort_callback=None,
logger=None,
upload_stdout=True,
upload_stderr=True,
send_hardware_metrics=True,
run_monitoring_thread=True,
handle_uncaught_exceptions=True,
git_info=None,
hostname=None,
notebook_id=None,
notebook_path=None,
):
"""Create and start Neptune experiment.
Create experiment, set its status to `running` and append it to the top of the experiments view.
All parameters are optional, hence minimal invocation: ``neptune.create_experiment()``.
Args:
name (:obj:`str`, optional, default is ``'Untitled'``):
Editable name of the experiment.
Name is displayed in the experiment's `Details` (`Metadata` section)
and in `experiments view` as a column.
description (:obj:`str`, optional, default is ``''``):
Editable description of the experiment.
Description is displayed in the experiment's `Details` (`Metadata` section)
and can be displayed in the `experiments view` as a column.
params (:obj:`dict`, optional, default is ``{}``):
Parameters of the experiment.
After experiment creation ``params`` are read-only
(see: :meth:`~neptune.experiments.Experiment.get_parameters`).
Parameters are displayed in the experiment's `Details` (`Parameters` section)
and each key-value pair can be viewed in `experiments view` as a column.
properties (:obj:`dict`, optional, default is ``{}``):
Properties of the experiment.
They are editable after experiment is created.
Properties are displayed in the experiment's `Details` (`Properties` section)
and each key-value pair can be viewed in `experiments view` as a column.
tags (:obj:`list`, optional, default is ``[]``):
Must be list of :obj:`str`. Tags of the experiment.
They are editable after experiment is created
(see: :meth:`~neptune.experiments.Experiment.append_tag`
and :meth:`~neptune.experiments.Experiment.remove_tag`).
Tags are displayed in the experiment's `Details` (`Metadata` section)
and can be viewed in `experiments view` as a column.
upload_source_files (:obj:`list` or :obj:`str`, optional, default is ``None``):
List of source files to be uploaded. Must be list of :obj:`str` or single :obj:`str`.
Uploaded sources are displayed in the experiment's `Source code` tab.
| If ``None`` is passed, Python file from which experiment was created will be uploaded.
| Pass empty list (``[]``) to upload no files.
| Unix style pathname pattern expansion is supported. For example, you can pass ``'*.py'`` to upload
all python source files from the current directory.
For Python 3.5 or later, paths of uploaded files on server are resolved as relative to the
| calculated common root of all uploaded source files. For older Python versions, paths on server are
| resolved always as relative to the current directory.
For recursion lookup use ``'**/*.py'`` (for Python 3.5 and later).
For more information see `glob library <https://docs.python.org/3/library/glob.html>`_.
abort_callback (:obj:`callable`, optional, default is ``None``):
Callback that defines how `abort experiment` action in the Web application should work.
Actual behavior depends on your setup:
* (default) If ``abort_callback=None`` and `psutil <https://psutil.readthedocs.io/en/latest/>`_
is installed, then current process and it's children are aborted by sending `SIGTERM`.
If, after grace period, processes are not terminated, `SIGKILL` is sent.
* If ``abort_callback=None`` and `psutil <https://psutil.readthedocs.io/en/latest/>`_
is **not** installed, then `abort experiment` action just marks experiment as *aborted*
in the Web application. No action is performed on the current process.
* If ``abort_callback=callable``, then ``callable`` is executed when `abort experiment` action
in the Web application is triggered.
logger (:obj:`logging.Logger` or `None`, optional, default is ``None``):
If Python's `Logger <https://docs.python.org/3/library/logging.html#logging.Logger>`_
is passed, new experiment's `text log`
(see: :meth:`~neptune.experiments.Experiment.log_text`) with name `"logger"` is created.
Each time `Python logger` logs new data, it is automatically sent to the `"logger"` in experiment.
As a results all data from `Python logger` are in the `Logs` tab in the experiment.
upload_stdout (:obj:`Boolean`, optional, default is ``True``):
Whether to send stdout to experiment's *Monitoring*.
upload_stderr (:obj:`Boolean`, optional, default is ``True``):
Whether to send stderr to experiment's *Monitoring*.
send_hardware_metrics (:obj:`Boolean`, optional, default is ``True``):
Whether to send hardware monitoring logs (CPU, GPU, Memory utilization) to experiment's *Monitoring*.
run_monitoring_thread (:obj:`Boolean`, optional, default is ``True``):
Whether to run thread that pings Neptune server in order to determine if experiment is responsive.
handle_uncaught_exceptions (:obj:`Boolean`, optional, default is ``True``):
Two options ``True`` and ``False`` are possible:
* If set to ``True`` and uncaught exception occurs, then Neptune automatically place
`Traceback` in the experiment's `Details` and change experiment status to `Failed`.
* If set to ``False`` and uncaught exception occurs, then no action is performed
in the Web application. As a consequence, experiment's status is `running` or `not responding`.
git_info (:class:`~neptune.git_info.GitInfo`, optional, default is ``None``):
| Instance of the class :class:`~neptune.git_info.GitInfo` that provides information about
the git repository from which experiment was started.
| If ``None`` is passed,
system attempts to automatically extract information about git repository in the following way:
* System looks for `.git` file in the current directory and, if not found,
goes up recursively until `.git` file will be found
(see: :meth:`~neptune.utils.get_git_info`).
* If there is no git repository,
then no information about git is displayed in experiment details in Neptune web application.
hostname (:obj:`str`, optional, default is ``None``):
If ``None``, neptune.legacy automatically get `hostname` information.
User can also set `hostname` directly by passing :obj:`str`.
Returns:
:class:`~neptune.experiments.Experiment` object that is used to manage experiment and log data to it.
Raises:
`ExperimentValidationError`: When provided arguments are invalid.
`ExperimentLimitReached`: When experiment limit in the project has been reached.
Examples:
.. code:: python3
# minimal invoke
neptune.create_experiment()
# explicitly return experiment object
experiment = neptune.create_experiment()
# create experiment with name and two parameters
neptune.create_experiment(name='first-pytorch-ever',
params={'lr': 0.0005,
'dropout': 0.2})
# create experiment with name and description, and no sources files uploaded
neptune.create_experiment(name='neural-net-mnist',
description='neural net trained on MNIST',
upload_source_files=[])
# Send all py files in cwd (excluding hidden files with names beginning with a dot)
neptune.create_experiment(upload_source_files='*.py')
# Send all py files from all subdirectories (excluding hidden files with names beginning with a dot)
# Supported on Python 3.5 and later.
neptune.create_experiment(upload_source_files='**/*.py')
# Send all files and directories in cwd (excluding hidden files with names beginning with a dot)
neptune.create_experiment(upload_source_files='*')
# Send all files and directories in cwd including hidden files
neptune.create_experiment(upload_source_files=['*', '.*'])
# Send files with names being a single character followed by '.py' extension.
neptune.create_experiment(upload_source_files='?.py')
# larger example
neptune.create_experiment(name='first-pytorch-ever',
params={'lr': 0.0005,
'dropout': 0.2},
properties={'key1': 'value1',
'key2': 17,
'key3': 'other-value'},
description='write longer description here',
tags=['list-of', 'tags', 'goes-here', 'as-list-of-strings'],
upload_source_files=['training_with_pytorch.py', 'net.py'])
"""
if name is None:
name = "Untitled"
if description is None:
description = ""
if params is None:
params = {}
if properties is None:
properties = {}
if tags is None:
tags = []
if git_info is None:
git_info = get_git_info(discover_git_repo_location())
if hostname is None:
hostname = get_hostname()
if notebook_id is None and os.getenv(NOTEBOOK_ID_ENV_NAME, None) is not None:
notebook_id = os.environ[NOTEBOOK_ID_ENV_NAME]
if isinstance(upload_source_files, six.string_types):
upload_source_files = [upload_source_files]
entrypoint, source_target_pairs = get_source_code_to_upload(upload_source_files=upload_source_files)
if notebook_path is None and os.getenv(NOTEBOOK_PATH_ENV_NAME, None) is not None:
notebook_path = os.environ[NOTEBOOK_PATH_ENV_NAME]
abortable = abort_callback is not None or DefaultAbortImpl.requirements_installed()
checkpoint_id = None
if notebook_id is not None and notebook_path is not None:
checkpoint = create_checkpoint(
backend=self._backend,
notebook_id=notebook_id,
notebook_path=notebook_path,
)
if checkpoint is not None:
checkpoint_id = checkpoint.id
experiment = self._backend.create_experiment(
project=self,
name=name,
description=description,
params=params,
properties=properties,
tags=tags,
abortable=abortable,
monitored=run_monitoring_thread,
git_info=git_info,
hostname=hostname,
entrypoint=entrypoint,
notebook_id=notebook_id,
checkpoint_id=checkpoint_id,
)
self._backend.upload_source_code(experiment, source_target_pairs)
experiment._start(
abort_callback=abort_callback,
logger=logger,
upload_stdout=upload_stdout,
upload_stderr=upload_stderr,
send_hardware_metrics=send_hardware_metrics,
run_monitoring_thread=run_monitoring_thread,
handle_uncaught_exceptions=handle_uncaught_exceptions,
)
self._push_new_experiment(experiment)
click.echo(self._get_experiment_link(experiment))
return experiment
def _get_experiment_link(self, experiment):
return "{base_url}/{namespace}/{project}/e/{exp_id}".format(
base_url=self._backend.display_address,
namespace=self.namespace,
project=self.name,
exp_id=experiment.id,
)
def create_notebook(self):
"""Create a new notebook object and return corresponding :class:`~neptune.notebook.Notebook` instance.
Returns:
:class:`~neptune.notebook.Notebook` object.
Examples:
.. code:: python3
# Instantiate a session and fetch a project
project = neptune.init()
# Create a notebook in Neptune
notebook = project.create_notebook()
"""
return self._backend.create_notebook(self)
def get_notebook(self, notebook_id):
"""Get a :class:`~neptune.notebook.Notebook` object with given ``notebook_id``.
Returns:
:class:`~neptune.notebook.Notebook` object.
Examples:
.. code:: python3
# Instantiate a session and fetch a project
project = neptune.init()
# Get a notebook object
notebook = project.get_notebook('d1c1b494-0620-4e54-93d5-29f4e848a51a')
"""
return self._backend.get_notebook(project=self, notebook_id=notebook_id)
@property
def full_id(self):
"""Project qualified name as :obj:`str`, for example `john/sandbox`."""
return "{}/{}".format(self.namespace, self.name)
def __str__(self):
return "Project({})".format(self.full_id)
def __repr__(self):
return str(self)
def __eq__(self, o):
return self.__dict__ == o.__dict__
def __ne__(self, o):
return not self.__eq__(o)
def _fetch_leaderboard(self, id, state, owner, tag, min_running_time):
return self._backend.get_leaderboard_entries(
project=self,
ids=as_list(id),
states=as_list(state),
owners=as_list(owner),
tags=as_list(tag),
min_running_time=min_running_time,
)
@staticmethod
def _sort_leaderboard_columns(column_names):
user_defined_weights = {"channel": 1, "parameter": 2, "property": 3}
system_properties_weights = {
"id": 0,
"name": 1,
"created": 2,
"finished": 3,
"owner": 4,
"worker_type": 5,
"environment": 6,
}
def key(c):
"""A sorting key for a column name.
Sorts by the system properties first, then channels, parameters, user-defined properties.
Within each group columns are sorted alphabetically, except for system properties,
where order is custom.
"""
parts = c.split("_", 1)
if parts[0] in user_defined_weights.keys():
name = parts[1]
weight = user_defined_weights.get(parts[0], 99)
system_property_weight = None
else:
name = c
weight = 0
system_property_weight = system_properties_weights.get(name, 99)
return weight, system_property_weight, name
return sorted(column_names, key=key)
def _get_current_experiment(self):
with self.__lock:
if self._experiments_stack:
return self._experiments_stack[-1]
else:
raise NeptuneNoExperimentContextException()
def _push_new_experiment(self, new_experiment):
with self.__lock:
self._experiments_stack.append(new_experiment)
return new_experiment
def _remove_stopped_experiment(self, experiment):
with self.__lock:
if self._experiments_stack:
self._experiments_stack = [exp for exp in self._experiments_stack if exp != experiment]
def _shutdown_hook(self):
if self._experiments_stack:
# stopping experiment removes it from list, co we copy it
copied_experiment_list = [exp for exp in self._experiments_stack]
for exp in copied_experiment_list:
exp.stop()
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from collections import OrderedDict
from neptune.common.utils import assure_project_qualified_name
from neptune.legacy.internal.api_clients import HostedNeptuneBackendApiClient
from neptune.legacy.internal.utils.deprecation import legacy_client_deprecation
from neptune.legacy.projects import Project
_logger = logging.getLogger(__name__)
class Session(object):
"""A class for running communication with Neptune.
In order to query Neptune experiments you need to instantiate this object first.
Args:
backend (:class:`~neptune.backend.ApiClient`, optional, default is ``None``):
By default, Neptune client library sends logs, metrics, images, etc to Neptune servers:
either publicly available SaaS, or an on-premises installation.
You can pass the default backend instance explicitly to specify its parameters:
.. code :: python3
from neptune.legacy import Session, HostedNeptuneBackendApiClient
session = Session(backend=HostedNeptuneBackendApiClient(...))
Passing an instance of :class:`~neptune.OfflineApiClient` makes your code run without communicating
with Neptune servers.
.. code :: python3
from neptune.legacy import Session, OfflineApiClient
session = Session(backend=OfflineApiClient())
api_token (:obj:`str`, optional, default is ``None``):
User's API token. If ``None``, the value of ``NEPTUNE_API_TOKEN`` environment variable will be taken.
Parameter is ignored if ``backend`` is passed.
.. deprecated :: 0.4.4
Instead, use:
.. code :: python3
from neptune.legacy import Session
session = Session.with_default_backend(api_token='...')
proxies (:obj:`str`, optional, default is ``None``):
Argument passed to HTTP calls made via the `Requests <https://2.python-requests.org/en/master/>`_ library.
For more information see their proxies
`section <https://2.python-requests.org/en/master/user/advanced/#proxies>`_.
Parameter is ignored if ``backend`` is passed.
.. deprecated :: 0.4.4
Instead, use:
.. code :: python3
from neptune.legacy import Session, HostedNeptuneBackendApiClient
session = Session(backend=HostedNeptuneBackendApiClient(proxies=...))
Examples:
Create session, assuming you have created an environment variable ``NEPTUNE_API_TOKEN``
.. code:: python3
from neptune.legacy import Session
session = Session.with_default_backend()
Create session and pass ``api_token``
.. code:: python3
from neptune.legacy import Session
session = Session.with_default_backend(api_token='...')
Create an offline session
.. code:: python3
from neptune.legacy import Session, OfflineApiClient
session = Session(backend=OfflineApiClient())
"""
@legacy_client_deprecation
def __init__(self, api_token=None, proxies=None, backend=None):
self._backend = backend
if self._backend is None:
_logger.warning(
"WARNING: Instantiating Session without specifying a backend is deprecated "
"and will be removed in future versions. For current behaviour "
"use `neptune.init(...)` or `Session.with_default_backend(...)"
)
self._backend = HostedNeptuneBackendApiClient(api_token, proxies)
@classmethod
def with_default_backend(cls, api_token=None, proxies=None):
"""The simplest way to instantiate a ``Session``.
Args:
api_token (:obj:`str`):
User's API token.
If ``None``, the value of ``NEPTUNE_API_TOKEN`` environment variable will be taken.
proxies (:obj:`str`, optional, default is ``None``):
Argument passed to HTTP calls made via the `Requests <https://2.python-requests.org/en/master/>`_
library.
For more information see their proxies
`section <https://2.python-requests.org/en/master/user/advanced/#proxies>`_.
Examples:
.. code :: python3
from neptune.legacy import Session
session = Session.with_default_backend()
"""
return cls(backend=HostedNeptuneBackendApiClient(api_token=api_token, proxies=proxies))
def get_project(self, project_qualified_name):
"""Get a project with given ``project_qualified_name``.
In order to access experiments data one needs to get a :class:`~neptune.projects.Project` object first.
This method gives you the ability to do that.
Args:
project_qualified_name (:obj:`str`):
Qualified name of a project in a form of ``namespace/project_name``.
If ``None``, the value of ``NEPTUNE_PROJECT`` environment variable will be taken.
Returns:
:class:`~neptune.projects.Project` object.
Raise:
:class:`~neptune.api_exceptions.ProjectNotFound`: When a project with given name does not exist.
Examples:
.. code:: python3
# Create a Session instance
from neptune.sessions import Session
session = Session()
# Get a project by it's ``project_qualified_name``:
my_project = session.get_project('namespace/project_name')
"""
project_qualified_name = assure_project_qualified_name(project_qualified_name)
return self._backend.get_project(project_qualified_name)
def get_projects(self, namespace):
"""Get all projects that you have permissions to see in given workspace.
| This method gets you all available projects names and their
corresponding :class:`~neptune.projects.Project` objects.
| Both private and public projects may be returned for the workspace.
If you have role in private project, it is included.
| You can retrieve all the public projects that belong to any user or workspace,
as long as you know their username or workspace name.
Args:
namespace (:obj:`str`): It can either be name of the workspace or username.
Returns:
:obj:`OrderedDict`
| **keys** are ``project_qualified_name`` that is: *'workspace/project_name'*
| **values** are corresponding :class:`~neptune.projects.Project` objects.
Raises:
`WorkspaceNotFound`: When the given workspace does not exist.
Examples:
.. code:: python3
# create Session
from neptune.sessions import Session
session = Session()
# Now, you can list all the projects available for a selected namespace.
# You can use `YOUR_NAMESPACE` which is your workspace or user name.
# You can also list public projects created in other workspaces.
# For example you can use the `neptune-ai` namespace.
session.get_projects('neptune-ai')
# Example output:
# OrderedDict([('neptune-ai/credit-default-prediction',
# Project(neptune-ai/credit-default-prediction)),
# ('neptune-ai/GStore-Customer-Revenue-Prediction',
# Project(neptune-ai/GStore-Customer-Revenue-Prediction)),
# ('neptune-ai/human-protein-atlas',
# Project(neptune-ai/human-protein-atlas)),
# ('neptune-ai/Ships',
# Project(neptune-ai/Ships)),
# ('neptune-ai/Mapping-Challenge',
# Project(neptune-ai/Mapping-Challenge))
# ])
"""
projects = [
Project(self._backend.create_leaderboard_backend(p), p.id, namespace, p.name)
for p in self._backend.get_projects(namespace)
]
return OrderedDict((p.full_id, p) for p in projects)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# flake8: noqa
from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.common.utils import (
IS_MACOS,
IS_WINDOWS,
NoopObject,
align_channels_on_x,
as_list,
assure_directory_exists,
assure_project_qualified_name,
discover_git_repo_location,
file_contains,
get_channel_name_stems,
get_git_info,
glob,
in_docker,
is_float,
is_ipython,
is_nan_or_inf,
is_notebook,
map_keys,
map_values,
merge_dataframes,
update_session_proxies,
validate_notebook_path,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Logger"]
from neptune.logging.logger import Logger
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# backwards compatibility
__all__ = [
"Logger",
]
from neptune.metadata_containers import MetadataContainer
class Logger(object):
def __init__(self, container: MetadataContainer, attribute_name: str):
self._container = container
self._attribute_name = attribute_name
def log(self, msg: str):
self._container[self._attribute_name].log(msg)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"MetadataContainer",
"Model",
"ModelVersion",
"Project",
"Run",
]
from neptune.metadata_containers.metadata_container import MetadataContainer
from neptune.metadata_containers.model import Model
from neptune.metadata_containers.model_version import ModelVersion
from neptune.metadata_containers.project import Project
from neptune.metadata_containers.run import Run
#
# Copyright (c) 2023, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["SupportsNamespaces", "NeptuneObject", "NeptuneObjectCallback"]
from abc import (
ABC,
abstractmethod,
)
from typing import (
TYPE_CHECKING,
Callable,
Optional,
Union,
)
if TYPE_CHECKING:
from neptune.handler import Handler
class SupportsNamespaces(ABC):
"""
Interface for Neptune objects that supports subscripting (selecting namespaces)
It could be a Run, Model, ModelVersion, Project or already selected namespace (Handler).
Example:
>>> from neptune import init_run
>>> from neptune.typing import SupportsNamespaces
>>> class NeptuneCallback:
... # Proper type hinting of `start_from` parameter.
... def __init__(self, start_from: SupportsNamespaces):
... self._start_from = start_from
...
... def log_accuracy(self, accuracy: float) -> None:
... self._start_from["train/acc"] = accuracy
...
>>> run = init_run()
>>> callback = NeptuneCallback(start_from=run)
>>> callback.log_accuracy(0.8)
>>> # or
... callback = NeptuneCallback(start_from=run["some/random/path"])
>>> callback.log_accuracy(0.8)
"""
@abstractmethod
def __getitem__(self, path: str) -> "Handler": ...
@abstractmethod
def __setitem__(self, key: str, value) -> None: ...
@abstractmethod
def __delitem__(self, path) -> None: ...
@abstractmethod
def get_root_object(self) -> "SupportsNamespaces": ...
class NeptuneObject(SupportsNamespaces, ABC):
@abstractmethod
def stop(self, *, seconds: Optional[Union[float, int]] = None) -> None: ...
NeptuneObjectCallback = Callable[[NeptuneObject], None]
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["MetadataContainer"]
import abc
import atexit
import itertools
import logging
import os
import threading
import time
import traceback
from contextlib import AbstractContextManager
from functools import (
partial,
wraps,
)
from queue import Queue
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Union,
)
from neptune.attributes import create_attribute_from_type
from neptune.attributes.attribute import Attribute
from neptune.attributes.namespace import Namespace as NamespaceAttr
from neptune.attributes.namespace import NamespaceBuilder
from neptune.common.exceptions import UNIX_STYLES
from neptune.common.utils import reset_internal_ssl_state
from neptune.common.warnings import warn_about_unsupported_type
from neptune.envs import (
NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK,
NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK,
)
from neptune.exceptions import (
MetadataInconsistency,
NeptunePossibleLegacyUsageException,
)
from neptune.handler import Handler
from neptune.internal.backends.api_model import (
ApiExperiment,
AttributeType,
Project,
)
from neptune.internal.backends.factory import get_backend
from neptune.internal.backends.neptune_backend import NeptuneBackend
from neptune.internal.backends.nql import NQLQuery
from neptune.internal.backends.project_name_lookup import project_name_lookup
from neptune.internal.backgroud_job_list import BackgroundJobList
from neptune.internal.background_job import BackgroundJob
from neptune.internal.container_structure import ContainerStructure
from neptune.internal.container_type import ContainerType
from neptune.internal.id_formats import (
QualifiedName,
SysId,
UniqueId,
conform_optional,
)
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
)
from neptune.internal.operation import DeleteAttribute
from neptune.internal.operation_processors.factory import get_operation_processor
from neptune.internal.operation_processors.lazy_operation_processor_wrapper import LazyOperationProcessorWrapper
from neptune.internal.operation_processors.operation_processor import OperationProcessor
from neptune.internal.signals_processing.background_job import CallbacksMonitor
from neptune.internal.state import ContainerState
from neptune.internal.utils import (
verify_optional_callable,
verify_type,
)
from neptune.internal.utils.logger import (
get_disabled_logger,
get_logger,
)
from neptune.internal.utils.paths import parse_path
from neptune.internal.utils.uncaught_exception_handler import instance as uncaught_exception_handler
from neptune.internal.value_to_attribute_visitor import ValueToAttributeVisitor
from neptune.metadata_containers.abstract import (
NeptuneObject,
NeptuneObjectCallback,
)
from neptune.metadata_containers.utils import parse_dates
from neptune.table import Table
from neptune.types.mode import Mode
from neptune.types.type_casting import cast_value
from neptune.typing import ProgressBarType
from neptune.utils import stop_synchronization_callback
if TYPE_CHECKING:
from neptune.internal.signals_processing.signals import Signal
def ensure_not_stopped(fun):
@wraps(fun)
def inner_fun(self: "MetadataContainer", *args, **kwargs):
self._raise_if_stopped()
return fun(self, *args, **kwargs)
return inner_fun
class MetadataContainer(AbstractContextManager, NeptuneObject):
container_type: ContainerType
LEGACY_METHODS = set()
def __init__(
self,
*,
project: Optional[str] = None,
api_token: Optional[str] = None,
mode: Mode = Mode.ASYNC,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
):
verify_type("project", project, (str, type(None)))
verify_type("api_token", api_token, (str, type(None)))
verify_type("mode", mode, Mode)
verify_type("flush_period", flush_period, (int, float))
verify_type("proxies", proxies, (dict, type(None)))
verify_type("async_lag_threshold", async_lag_threshold, (int, float))
verify_optional_callable("async_lag_callback", async_lag_callback)
verify_type("async_no_progress_threshold", async_no_progress_threshold, (int, float))
verify_optional_callable("async_no_progress_callback", async_no_progress_callback)
self._mode: Mode = mode
self._flush_period = flush_period
self._lock: threading.RLock = threading.RLock()
self._forking_cond: threading.Condition = threading.Condition()
self._forking_state: bool = False
self._state: ContainerState = ContainerState.CREATED
self._signals_queue: "Queue[Signal]" = Queue()
self._logger: logging.Logger = get_logger()
self._backend: NeptuneBackend = get_backend(mode=mode, api_token=api_token, proxies=proxies)
self._project_qualified_name: Optional[str] = conform_optional(project, QualifiedName)
self._project_api_object: Project = project_name_lookup(
backend=self._backend, name=self._project_qualified_name
)
self._project_id: UniqueId = self._project_api_object.id
self._api_object: ApiExperiment = self._get_or_create_api_object()
self._id: UniqueId = self._api_object.id
self._sys_id: SysId = self._api_object.sys_id
self._workspace: str = self._api_object.workspace
self._project_name: str = self._api_object.project_name
self._async_lag_threshold = async_lag_threshold
self._async_lag_callback = MetadataContainer._get_callback(
provided=async_lag_callback,
env_name=NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK,
)
self._async_no_progress_threshold = async_no_progress_threshold
self._async_no_progress_callback = MetadataContainer._get_callback(
provided=async_no_progress_callback,
env_name=NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK,
)
self._op_processor: OperationProcessor = get_operation_processor(
mode=mode,
container_id=self._id,
container_type=self.container_type,
backend=self._backend,
lock=self._lock,
flush_period=flush_period,
queue=self._signals_queue,
)
self._bg_job: BackgroundJobList = self._prepare_background_jobs_if_non_read_only()
self._structure: ContainerStructure[Attribute, NamespaceAttr] = ContainerStructure(NamespaceBuilder(self))
if self._mode != Mode.OFFLINE:
self.sync(wait=False)
if self._mode != Mode.READ_ONLY:
self._write_initial_attributes()
self._startup(debug_mode=mode == Mode.DEBUG)
try:
os.register_at_fork(
before=self._before_fork,
after_in_child=self._handle_fork_in_child,
after_in_parent=self._handle_fork_in_parent,
)
except AttributeError:
pass
"""
OpenSSL's internal random number generator does not properly handle forked processes.
Applications must change the PRNG state of the parent process if they use any SSL feature with os.fork().
Any successful call of RAND_add(), RAND_bytes() or RAND_pseudo_bytes() is sufficient.
https://docs.python.org/3/library/ssl.html#multi-processing
On Linux it looks like it does not help much but does not break anything either.
"""
@staticmethod
def _get_callback(provided: Optional[NeptuneObjectCallback], env_name: str) -> Optional[NeptuneObjectCallback]:
if provided is not None:
return provided
if os.getenv(env_name, "") == "TRUE":
return stop_synchronization_callback
return None
def _handle_fork_in_parent(self):
reset_internal_ssl_state()
if self._state == ContainerState.STARTED:
self._op_processor.resume()
self._bg_job.resume()
with self._forking_cond:
self._forking_state = False
self._forking_cond.notify_all()
def _handle_fork_in_child(self):
reset_internal_ssl_state()
self._logger = get_disabled_logger()
if self._state == ContainerState.STARTED:
self._op_processor.close()
self._signals_queue = Queue()
self._op_processor = LazyOperationProcessorWrapper(
operation_processor_getter=partial(
get_operation_processor,
mode=self._mode,
container_id=self._id,
container_type=self.container_type,
backend=self._backend,
lock=self._lock,
flush_period=self._flush_period,
queue=self._signals_queue,
),
)
# TODO: Every implementation of background job should handle fork by itself.
jobs = []
if self._mode == Mode.ASYNC:
jobs.append(
CallbacksMonitor(
queue=self._signals_queue,
async_lag_threshold=self._async_lag_threshold,
async_no_progress_threshold=self._async_no_progress_threshold,
async_lag_callback=self._async_lag_callback,
async_no_progress_callback=self._async_no_progress_callback,
)
)
self._bg_job = BackgroundJobList(jobs)
with self._forking_cond:
self._forking_state = False
self._forking_cond.notify_all()
def _before_fork(self):
with self._forking_cond:
self._forking_cond.wait_for(lambda: self._state != ContainerState.STOPPING)
self._forking_state = True
if self._state == ContainerState.STARTED:
self._bg_job.pause()
self._op_processor.pause()
def _prepare_background_jobs_if_non_read_only(self) -> BackgroundJobList:
jobs = []
if self._mode != Mode.READ_ONLY:
jobs.extend(self._get_background_jobs())
if self._mode == Mode.ASYNC:
jobs.append(
CallbacksMonitor(
queue=self._signals_queue,
async_lag_threshold=self._async_lag_threshold,
async_no_progress_threshold=self._async_no_progress_threshold,
async_lag_callback=self._async_lag_callback,
async_no_progress_callback=self._async_no_progress_callback,
)
)
return BackgroundJobList(jobs)
@abc.abstractmethod
def _get_or_create_api_object(self) -> ApiExperiment:
raise NotImplementedError
def _get_background_jobs(self) -> List["BackgroundJob"]:
return []
def _write_initial_attributes(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_tb is not None:
traceback.print_exception(exc_type, exc_val, exc_tb)
uncaught_exception_handler.trigger(exc_type, exc_val, exc_tb)
self.stop()
def __getattr__(self, item):
if item in self.LEGACY_METHODS:
raise NeptunePossibleLegacyUsageException()
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
@abc.abstractmethod
def _raise_if_stopped(self):
raise NotImplementedError
def _get_subpath_suggestions(self, path_prefix: str = None, limit: int = 1000) -> List[str]:
parsed_path = parse_path(path_prefix or "")
return list(itertools.islice(self._structure.iterate_subpaths(parsed_path), limit))
def _ipython_key_completions_(self):
return self._get_subpath_suggestions()
@ensure_not_stopped
def __getitem__(self, path: str) -> "Handler":
return Handler(self, path)
@ensure_not_stopped
def __setitem__(self, key: str, value) -> None:
self.__getitem__(key).assign(value)
@ensure_not_stopped
def __delitem__(self, path) -> None:
self.pop(path)
@ensure_not_stopped
def assign(self, value, *, wait: bool = False) -> None:
"""Assigns values to multiple fields from a dictionary.
You can use this method to quickly log all parameters at once.
Args:
value (dict): A dictionary with values to assign, where keys become paths of the fields.
The dictionary can be nested, in which case the path will be a combination of all the keys.
wait: If `True`, Neptune waits to send all tracked metadata to the server before executing the call.
Examples:
>>> import neptune
>>> run = neptune.init_run()
>>> # Assign a single value with the Python "=" operator
>>> run["parameters/learning_rate"] = 0.8
>>> # or the assign() method
>>> run["parameters/learning_rate"].assign(0.8)
>>> # Assign a dictionary with the Python "=" operator
>>> run["parameters"] = {"max_epochs": 10, "optimizer": "Adam", "learning_rate": 0.8}
>>> # or the assign() method
>>> run.assign({"parameters": {"max_epochs": 10, "optimizer": "Adam", "learning_rate": 0.8}})
When operating on a handler object, you can use assign() to circumvent normal Python variable assignment.
>>> params = run["params"]
>>> params.assign({"max_epochs": 10, "optimizer": "Adam", "learning_rate": 0.8})
See also the API reference:
https://docs.neptune.ai/api/universal/#assign
"""
self._get_root_handler().assign(value, wait=wait)
@ensure_not_stopped
def fetch(self) -> dict:
"""Fetch values of all non-File Atom fields as a dictionary.
You can use this method to retrieve metadata from a started or resumed run.
The result preserves the hierarchical structure of the run's metadata, but only contains Atom fields.
This means fields that contain single values, as opposed to series, files, or sets.
Returns:
`dict` containing the values of all non-File Atom fields.
Examples:
Resuming an existing run and fetching metadata from it:
>>> import neptune
>>> resumed_run = neptune.init_run(with_id="CLS-3")
>>> params = resumed_run["model/parameters"].fetch()
>>> run_data = resumed_run.fetch()
>>> print(run_data)
>>> # prints all Atom attributes stored in run as a dict
Fetching metadata from an existing model version:
>>> model_version = neptune.init_model_version(with_id="CLS-TREE-45")
>>> optimizer = model["parameters/optimizer"].fetch()
See also the API reference:
https://docs.neptune.ai/api/universal#fetch
"""
return self._get_root_handler().fetch()
def ping(self):
self._backend.ping(self._id, self.container_type)
def start(self):
atexit.register(self._shutdown_hook)
self._op_processor.start()
self._bg_job.start(self)
self._state = ContainerState.STARTED
def stop(self, *, seconds: Optional[Union[float, int]] = None) -> None:
"""Stops the connection and ends the synchronization thread.
You should stop any initialized runs or other objects when the connection to them is no longer needed.
This method is automatically called:
- when the script that created the run or other object finishes execution.
- if using a context manager, on destruction of the Neptune context.
Note: In interactive sessions, such as Jupyter Notebook, objects are stopped automatically only when
the Python kernel stops. However, background monitoring of system metrics and standard streams is disabled
unless explicitly enabled when initializing Neptune.
Args:
seconds: Seconds to wait for all metadata tracking calls to finish before stopping the object.
If `None`, waits for all tracking calls to finish.
Example:
>>> import neptune
>>> run = neptune.init_run()
>>> # Your training or monitoring code
>>> run.stop()
See also the docs:
Best practices - Stopping objects
https://docs.neptune.ai/usage/best_practices/#stopping-runs-and-other-objects
API reference:
https://docs.neptune.ai/api/universal/#stop
"""
verify_type("seconds", seconds, (float, int, type(None)))
if self._state != ContainerState.STARTED:
return
with self._forking_cond:
self._forking_cond.wait_for(lambda: not self._forking_state)
self._state = ContainerState.STOPPING
ts = time.time()
self._logger.info("Shutting down background jobs, please wait a moment...")
self._bg_job.stop()
self._bg_job.join(seconds)
self._logger.info("Done!")
sec_left = None if seconds is None else seconds - (time.time() - ts)
self._op_processor.stop(sec_left)
if self._mode not in {Mode.OFFLINE, Mode.DEBUG}:
metadata_url = self.get_url().rstrip("/") + "/metadata"
self._logger.info(f"Explore the metadata in the Neptune app: {metadata_url}")
self._backend.close()
with self._forking_cond:
self._state = ContainerState.STOPPED
self._forking_cond.notify_all()
def get_state(self) -> str:
"""Returns the current state of the container as a string.
Examples:
>>> from neptune import init_run
>>> run = init_run()
>>> run.get_state()
'started'
>>> run.stop()
>>> run.get_state()
'stopped'
"""
return self._state.value
def get_structure(self) -> Dict[str, Any]:
"""Returns the object's metadata structure as a dictionary.
This method can be used to programmatically traverse the metadata structure of a run, model,
or project object when using Neptune in automated workflows.
Note: The returned object is a deep copy of the structure of the internal object.
See also the API reference:
https://docs.neptune.ai/api/universal/#get_structure
"""
return self._structure.get_structure().to_dict()
def print_structure(self) -> None:
"""Pretty-prints the structure of the object's metadata.
Paths are ordered lexicographically and the whole structure is neatly colored.
See also: https://docs.neptune.ai/api/universal/#print_structure
"""
self._print_structure_impl(self.get_structure(), indent=0)
def _print_structure_impl(self, struct: dict, indent: int) -> None:
for key in sorted(struct.keys()):
print(" " * indent, end="")
if isinstance(struct[key], dict):
print("{blue}'{key}'{end}:".format(blue=UNIX_STYLES["blue"], key=key, end=UNIX_STYLES["end"]))
self._print_structure_impl(struct[key], indent=indent + 1)
else:
print(
"{blue}'{key}'{end}: {type}".format(
blue=UNIX_STYLES["blue"],
key=key,
end=UNIX_STYLES["end"],
type=type(struct[key]).__name__,
)
)
def define(
self,
path: str,
value: Any,
*,
wait: bool = False,
) -> Optional[Attribute]:
with self._lock:
old_attr = self.get_attribute(path)
if old_attr is not None:
raise MetadataInconsistency("Attribute or namespace {} is already defined".format(path))
neptune_value = cast_value(value)
if neptune_value is None:
warn_about_unsupported_type(type_str=str(type(value)))
return None
attr = ValueToAttributeVisitor(self, parse_path(path)).visit(neptune_value)
self.set_attribute(path, attr)
attr.process_assignment(neptune_value, wait=wait)
return attr
def get_attribute(self, path: str) -> Optional[Attribute]:
with self._lock:
return self._structure.get(parse_path(path))
def set_attribute(self, path: str, attribute: Attribute) -> Optional[Attribute]:
with self._lock:
return self._structure.set(parse_path(path), attribute)
def exists(self, path: str) -> bool:
"""Checks if there is a field or namespace under the specified path."""
verify_type("path", path, str)
return self.get_attribute(path) is not None
@ensure_not_stopped
def pop(self, path: str, *, wait: bool = False) -> None:
"""Removes the field stored under the path and all data associated with it.
Args:
path: Path of the field to be removed.
wait: If `True`, Neptune waits to send all tracked metadata to the server before executing the call.
Examples:
>>> import neptune
>>> run = neptune.init_run()
>>> run["parameters/learninggg_rata"] = 0.3
>>> # Let's delete that misspelled field along with its data
... run.pop("parameters/learninggg_rata")
>>> run["parameters/learning_rate"] = 0.3
>>> # Training finished
... run["trained_model"].upload("model.pt")
>>> # "model_checkpoint" is a File field
... run.pop("model_checkpoint")
See also the API reference:
https://docs.neptune.ai/api/universal/#pop
"""
verify_type("path", path, str)
self._get_root_handler().pop(path, wait=wait)
def _pop_impl(self, parsed_path: List[str], *, wait: bool):
self._structure.pop(parsed_path)
self._op_processor.enqueue_operation(DeleteAttribute(parsed_path), wait=wait)
def lock(self) -> threading.RLock:
return self._lock
def wait(self, *, disk_only=False) -> None:
"""Wait for all the queued metadata tracking calls to reach the Neptune servers.
Args:
disk_only: If `True`, the process will only wait for data to be saved
locally from memory, but will not wait for them to reach Neptune servers.
See also the API reference:
https://docs.neptune.ai/api/universal/#wait
"""
with self._lock:
if disk_only:
self._op_processor.flush()
else:
self._op_processor.wait()
def sync(self, *, wait: bool = True) -> None:
"""Synchronizes the local representation of the object with the representation on the Neptune servers.
Args:
wait: If `True`, the process will only wait for data to be saved
locally from memory, but will not wait for them to reach Neptune servers.
Example:
>>> import neptune
>>> # Connect to a run from Worker #3
... worker_id = 3
>>> run = neptune.init_run(with_id="DIST-43", monitoring_namespace=f"monitoring/{worker_id}")
>>> # Try to access logs that were created in the meantime by Worker #2
... worker_2_status = run["status/2"].fetch()
... # Error if this field was created after this script starts
>>> run.sync() # Synchronizes local representation with Neptune servers
>>> worker_2_status = run["status/2"].fetch()
... # No error
See also the API reference:
https://docs.neptune.ai/api/universal/#sync
"""
with self._lock:
if wait:
self._op_processor.wait()
attributes = self._backend.get_attributes(self._id, self.container_type)
self._structure.clear()
for attribute in attributes:
self._define_attribute(parse_path(attribute.path), attribute.type)
def _define_attribute(self, _path: List[str], _type: AttributeType):
attr = create_attribute_from_type(_type, self, _path)
self._structure.set(_path, attr)
def _get_root_handler(self):
return Handler(self, "")
@abc.abstractmethod
def get_url(self) -> str:
"""Returns a link to the object in the Neptune app.
The same link is printed in the console once the object has been initialized.
API reference: https://docs.neptune.ai/api/universal/#get_url
"""
...
def _startup(self, debug_mode):
if not debug_mode:
self._logger.info(f"Neptune initialized. Open in the app: {self.get_url()}")
self.start()
uncaught_exception_handler.activate()
def _shutdown_hook(self):
self.stop()
def _fetch_entries(
self,
child_type: ContainerType,
query: NQLQuery,
columns: Optional[Iterable[str]],
limit: Optional[int],
sort_by: str,
ascending: bool,
progress_bar: Optional[ProgressBarType],
) -> Table:
if columns is not None:
# always return entries with 'sys/id' and the column chosen for sorting when filter applied
columns = set(columns)
columns.add("sys/id")
columns.add(sort_by)
leaderboard_entries = self._backend.search_leaderboard_entries(
project_id=self._project_id,
types=[child_type],
query=query,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
leaderboard_entries = parse_dates(leaderboard_entries)
return Table(
backend=self._backend,
container_type=child_type,
entries=leaderboard_entries,
)
def get_root_object(self) -> "MetadataContainer":
"""Returns the same Neptune object."""
return self
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["ModelVersion"]
import os
from typing import (
TYPE_CHECKING,
List,
Optional,
)
from typing_extensions import Literal
from neptune.attributes.constants import (
SYSTEM_NAME_ATTRIBUTE_PATH,
SYSTEM_STAGE_ATTRIBUTE_PATH,
)
from neptune.common.exceptions import NeptuneException
from neptune.envs import CONNECTION_MODE
from neptune.exceptions import (
InactiveModelVersionException,
NeedExistingModelVersionForReadOnlyMode,
NeptuneMissingRequiredInitParameter,
NeptuneOfflineModeChangeStageException,
)
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.container_type import ContainerType
from neptune.internal.id_formats import QualifiedName
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
DEFAULT_NAME,
OFFLINE_PROJECT_QUALIFIED_NAME,
)
from neptune.internal.operation_processors.offline_operation_processor import OfflineOperationProcessor
from neptune.internal.state import ContainerState
from neptune.internal.utils import verify_type
from neptune.internal.utils.deprecation import model_registry_deprecation
from neptune.internal.utils.ping_background_job import PingBackgroundJob
from neptune.metadata_containers import MetadataContainer
from neptune.metadata_containers.abstract import NeptuneObjectCallback
from neptune.types.mode import Mode
from neptune.types.model_version_stage import ModelVersionStage
if TYPE_CHECKING:
from neptune.internal.background_job import BackgroundJob
class ModelVersion(MetadataContainer):
"""Initializes a ModelVersion object from an existing or new model version.
Before creating model versions, you must first register a model by creating a Model object.
A ModelVersion object is suitable for storing model metadata that is version-specific. It does not track
background metrics or logs automatically, but you can assign metadata to the model version just like you can
for runs. You can use the parent Model object to store metadata that is common to all versions of the model.
To learn more about model registry, see the docs: https://docs.neptune.ai/model_registry/overview/
To manage the stage of a model version, use its `change_stage()` method or use the menu in the web app.
You can also use the ModelVersion object as a context manager (see examples).
Args:
with_id: The Neptune identifier of an existing model version to resume, such as "CLS-PRE-3".
The identifier is stored in the model version's "sys/id" field.
If left empty, a new model version is created.
name: Custom name for the model version. You can add it as a column in the model versions table
("sys/name"). You can also edit the name in the app, in the information view.
model: Identifier of the model for which the new version should be created.
Required when creating a new model version.
You can find the model ID in the leftmost column of the models table, or in a model's "sys/id" field.
project: Name of a project in the form `workspace-name/project-name`.
If None, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
mode: Connection mode in which the tracking will work.
If None (default), the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `read-only`, and `debug`.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered
(in seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a ModelVersion object as the argument and can contain any custom code, such as calling
`stop()` on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback should take a
ModelVersion object as the argument and can contain any custom code, such as calling `stop()` on the
object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
ModelVersion object that is used to manage the model version and log metadata to it.
Examples:
>>> import neptune
Creating a new model version:
>>> # Create a new model version for a model with identifier "CLS-PRE"
... model_version = neptune.init_model_version(model="CLS-PRE")
>>> model_version["your/structure"] = some_metadata
>>> # You can provide the project parameter as an environment variable
... # or directly in the init_model_version() function:
... model_version = neptune.init_model_version(
... model="CLS-PRE",
... project="ml-team/classification",
... )
>>> # Or initialize with the constructor:
... model_version = ModelVersion(model="CLS-PRE")
Connecting to an existing model version:
>>> # Initialize an existing model version with identifier "CLS-PRE-12"
... model_version = neptune.init_model_version(with_id="CLS-PRE-12")
>>> # To prevent modifications when connecting to an existing model version,
... # you can connect in read-only mode:
... model_version = neptune.init_model(with_id="CLS-PRE-12", mode="read-only")
Using the ModelVersion object as context manager:
>>> with ModelVersion(model="CLS-PRE") as model_version:
... model_version["metadata"] = some_metadata
For more, see the docs:
Initializing a model version:
https://docs.neptune.ai/api/neptune#init_model_version
ModelVersion class reference:
https://docs.neptune.ai/api/model_version/
"""
container_type = ContainerType.MODEL_VERSION
@model_registry_deprecation
def __init__(
self,
with_id: Optional[str] = None,
*,
name: Optional[str] = None,
model: Optional[str] = None,
project: Optional[str] = None,
api_token: Optional[str] = None,
mode: Optional[Literal["async", "sync", "read-only", "debug"]] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
) -> None:
verify_type("with_id", with_id, (str, type(None)))
verify_type("name", name, (str, type(None)))
verify_type("model", model, (str, type(None)))
verify_type("project", project, (str, type(None)))
verify_type("mode", mode, (str, type(None)))
self._model: Optional[str] = model
self._with_id: Optional[str] = with_id
self._name: Optional[str] = DEFAULT_NAME if model is None and name is None else name
# make mode proper Enum instead of string
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
if mode == Mode.OFFLINE:
raise NeptuneException("ModelVersion can't be initialized in OFFLINE mode")
if mode == Mode.DEBUG:
project = OFFLINE_PROJECT_QUALIFIED_NAME
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
project_workspace = self._project_api_object.workspace
project_name = self._project_api_object.name
project_qualified_name = f"{project_workspace}/{project_name}"
if self._with_id is not None:
# with_id (resume existing model_version) has priority over model (creating a new model_version)
return self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._with_id),
expected_container_type=self.container_type,
)
elif self._model is not None:
if self._mode == Mode.READ_ONLY:
raise NeedExistingModelVersionForReadOnlyMode()
api_model = self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._model),
expected_container_type=ContainerType.MODEL,
)
return self._backend.create_model_version(project_id=self._project_api_object.id, model_id=api_model.id)
else:
raise NeptuneMissingRequiredInitParameter(
parameter_name="model",
called_function="init_model_version",
)
def _get_background_jobs(self) -> List["BackgroundJob"]:
return [PingBackgroundJob()]
def _write_initial_attributes(self):
if self._name is not None:
self[SYSTEM_NAME_ATTRIBUTE_PATH] = self._name
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveModelVersionException(label=self._sys_id)
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_model_version_url(
model_version_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
sys_id=self._sys_id,
model_id=self["sys/model_id"].fetch(),
)
def change_stage(self, stage: str) -> None:
"""Changes the stage of the model version.
This method is always synchronous, which means that Neptune will wait for all other calls to reach the Neptune
servers before executing it.
Args:
stage: The new stage of the model version.
Possible values are `none`, `staging`, `production`, and `archived`.
Examples:
>>> import neptune
>>> model_version = neptune.init_model_version(with_id="CLS-TREE-3")
>>> # If the model is good enough, promote it to the staging
... val_acc = model_version["validation/metrics/acc"].fetch()
>>> if val_acc >= ACC_THRESHOLD:
... model_version.change_stage("staging")
Learn more about stage management in the docs:
https://docs.neptune.ai/model_registry/managing_stage/
API reference:
https://docs.neptune.ai/api/model_version/#change_stage
"""
mapped_stage = ModelVersionStage(stage)
if isinstance(self._op_processor, OfflineOperationProcessor):
raise NeptuneOfflineModeChangeStageException()
self.wait()
with self.lock():
attr = self.get_attribute(SYSTEM_STAGE_ATTRIBUTE_PATH)
# We are sure that such attribute exists, because
# SYSTEM_STAGE_ATTRIBUTE_PATH is set by default on ModelVersion creation
assert attr is not None, f"No {SYSTEM_STAGE_ATTRIBUTE_PATH} found in model version"
attr.process_assignment(
value=mapped_stage.value,
wait=True,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Model"]
import os
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
)
from typing_extensions import Literal
from neptune.attributes.constants import SYSTEM_NAME_ATTRIBUTE_PATH
from neptune.common.exceptions import NeptuneException
from neptune.envs import CONNECTION_MODE
from neptune.exceptions import (
InactiveModelException,
NeedExistingModelForReadOnlyMode,
NeptuneMissingRequiredInitParameter,
NeptuneModelKeyAlreadyExistsError,
NeptuneObjectCreationConflict,
)
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.backends.nql import (
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLQueryAggregate,
NQLQueryAttribute,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.id_formats import QualifiedName
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
DEFAULT_NAME,
OFFLINE_PROJECT_QUALIFIED_NAME,
)
from neptune.internal.state import ContainerState
from neptune.internal.utils import verify_type
from neptune.internal.utils.deprecation import model_registry_deprecation
from neptune.internal.utils.ping_background_job import PingBackgroundJob
from neptune.metadata_containers import MetadataContainer
from neptune.metadata_containers.abstract import NeptuneObjectCallback
from neptune.metadata_containers.utils import build_raw_query
from neptune.table import Table
from neptune.types.mode import Mode
from neptune.typing import (
ProgressBarCallback,
ProgressBarType,
)
if TYPE_CHECKING:
from neptune.internal.background_job import BackgroundJob
class Model(MetadataContainer):
"""Initializes a Model object from an existing or new model.
You can use this to create a new model from code or to perform actions on existing models.
A Model object is suitable for storing model metadata that is common to all versions (you can use ModelVersion
objects to track version-specific metadata). It does not track background metrics or logs automatically,
but you can assign metadata to the Model object just like you can for runs.
To learn more about model registry, see the docs: https://docs.neptune.ai/model_registry/overview/
You can also use the Model object as a context manager (see examples).
Args:
with_id: The Neptune identifier of an existing model to resume, such as "CLS-PRE".
The identifier is stored in the model's "sys/id" field.
If left empty, a new model is created.
name: Custom name for the model. You can add it as a column in the models table ("sys/name").
You can also edit the name in the app, in the information view.
key: Key for the model. Required when creating a new model.
Used together with the project key to form the model identifier.
Must be uppercase and unique within the project.
project: Name of a project in the form `workspace-name/project-name`.
If None, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
mode: Connection mode in which the tracking will work.
If `None` (default), the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `read-only`, and `debug`.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered
(in seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a Model object as the argument and can contain any custom code, such as calling `stop()` on
the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback should take a Model
object as the argument and can contain any custom code, such as calling `stop()` on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
Model object that is used to manage the model and log metadata to it.
Examples:
>>> import neptune
Creating a new model:
>>> model = neptune.init_model(key="PRE")
>>> model["metadata"] = some_metadata
>>> # Or initialize with the constructor
... model = Model(key="PRE")
>>> # You can provide the project parameter as an environment variable
... # or as an argument to the init_model() function:
... model = neptune.init_model(key="PRE", project="workspace-name/project-name")
>>> # When creating a model, you can give it a name:
... model = neptune.init_model(key="PRE", name="Pre-trained model")
Connecting to an existing model:
>>> # Initialize existing model with identifier "CLS-PRE"
... model = neptune.init_model(with_id="CLS-PRE")
>>> # To prevent modifications when connecting to an existing model, you can connect in read-only mode
... model = neptune.init_model(with_id="CLS-PRE", mode="read-only")
Using the Model object as context manager:
>>> with Model(key="PRE") as model:
... model["metadata"] = some_metadata
For details, see the docs:
Initializing a model:
https://docs.neptune.ai/api/neptune#init_model
Model class reference:
https://docs.neptune.ai/api/model
"""
container_type = ContainerType.MODEL
@model_registry_deprecation
def __init__(
self,
with_id: Optional[str] = None,
*,
name: Optional[str] = None,
key: Optional[str] = None,
project: Optional[str] = None,
api_token: Optional[str] = None,
mode: Optional[Literal["async", "sync", "read-only", "debug"]] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
):
verify_type("with_id", with_id, (str, type(None)))
verify_type("name", name, (str, type(None)))
verify_type("key", key, (str, type(None)))
verify_type("project", project, (str, type(None)))
verify_type("mode", mode, (str, type(None)))
self._key: Optional[str] = key
self._with_id: Optional[str] = with_id
self._name: Optional[str] = DEFAULT_NAME if with_id is None and name is None else name
# make mode proper Enum instead of string
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
if mode == Mode.OFFLINE:
raise NeptuneException("Model can't be initialized in OFFLINE mode")
if mode == Mode.DEBUG:
project = OFFLINE_PROJECT_QUALIFIED_NAME
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
project_workspace = self._project_api_object.workspace
project_name = self._project_api_object.name
project_qualified_name = f"{project_workspace}/{project_name}"
if self._with_id is not None:
# with_id (resume existing model) has priority over key (creating a new model)
# additional creation parameters (e.g. name) are simply ignored in this scenario
return self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._with_id),
expected_container_type=self.container_type,
)
elif self._key is not None:
if self._mode == Mode.READ_ONLY:
raise NeedExistingModelForReadOnlyMode()
try:
return self._backend.create_model(project_id=self._project_api_object.id, key=self._key)
except NeptuneObjectCreationConflict as e:
base_url = self._backend.get_display_address()
raise NeptuneModelKeyAlreadyExistsError(
model_key=self._key,
models_tab_url=f"{base_url}/{project_workspace}/{project_name}/models",
) from e
else:
raise NeptuneMissingRequiredInitParameter(
parameter_name="key",
called_function="init_model",
)
def _get_background_jobs(self) -> List["BackgroundJob"]:
return [PingBackgroundJob()]
def _write_initial_attributes(self):
if self._name is not None:
self[SYSTEM_NAME_ATTRIBUTE_PATH] = self._name
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveModelException(label=self._sys_id)
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_model_url(
model_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
sys_id=self._sys_id,
)
@model_registry_deprecation
def fetch_model_versions_table(
self,
*,
query: Optional[str] = None,
columns: Optional[Iterable[str]] = None,
limit: Optional[int] = None,
sort_by: str = "sys/creation_time",
ascending: bool = False,
progress_bar: Optional[ProgressBarType] = None,
) -> Table:
"""Retrieve all versions of the given model.
Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(model_size: float > 100) AND (backbone: string = VGG)"`.
columns: Names of columns to include in the table, as a list of field names.
The Neptune ID ("sys/id") is included automatically.
If `None` (default), all the columns of the model versions table are included,
up to a maximum of 10 000 columns.
limit: How many entries to return at most. If `None`, all entries are returned.
sort_by: Name of the field to sort the results by.
The field must represent a simple type (string, float, datetime, integer, or Boolean).
ascending: Whether to sort the entries in ascending order of the sorting column values.
progress_bar: Set to `False` to disable the download progress bar,
or pass a `ProgressBarCallback` class to use your own progress bar callback.
Returns:
`Table` object containing `ModelVersion` objects that match the specified criteria.
Use `to_pandas()` to convert it to a pandas DataFrame.
Examples:
>>> import neptune
... # Initialize model with the ID "CLS-FOREST"
... model = neptune.init_model(with_id="CLS-FOREST")
... # Fetch the metadata of all the model's versions as a pandas DataFrame
... model_versions_df = model.fetch_model_versions_table().to_pandas()
>>> # Include only the fields "params/lr" and "val/loss" as columns:
... model_versions_df = model.fetch_model_versions_table(columns=["params/lr", "val/loss"]).to_pandas()
>>> # Sort model versions by size (space they take up in Neptune)
... model_versions_df = model.fetch_model_versions_table(sort_by="sys/size").to_pandas()
... # Extract the ID of the largest model version object
... largest_model_version_id = model_versions_df["sys/id"].values[0]
>>> # Fetch model versions with VGG backbone
... models_table_df = project.fetch_model_versions_table(
... query="(backbone: string = VGG)"
... ).to_pandas()
See also the API referene:
https://docs.neptune.ai/api/model/#fetch_model_versions_table
"""
verify_type("query", query, (str, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback)))
if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")
query = query if query is not None else ""
nql = build_raw_query(query=query, trashed=False)
nql = NQLQueryAggregate(
items=[
nql,
NQLQueryAttribute(
name="sys/model_id",
value=self._sys_id,
operator=NQLAttributeOperator.EQUALS,
type=NQLAttributeType.STRING,
),
],
aggregator=NQLAggregator.AND,
)
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.MODEL_VERSION,
query=nql,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Project"]
import os
from typing import (
Iterable,
Optional,
Union,
)
from typing_extensions import Literal
from neptune.common.exceptions import NeptuneException
from neptune.envs import CONNECTION_MODE
from neptune.exceptions import InactiveProjectException
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.container_type import ContainerType
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
)
from neptune.internal.state import ContainerState
from neptune.internal.utils import (
as_list,
verify_collection_type,
verify_type,
verify_value,
)
from neptune.internal.utils.deprecation import model_registry_deprecation
from neptune.metadata_containers import MetadataContainer
from neptune.metadata_containers.abstract import NeptuneObjectCallback
from neptune.metadata_containers.utils import (
build_raw_query,
prepare_nql_query,
)
from neptune.table import Table
from neptune.types.mode import Mode
from neptune.typing import (
ProgressBarCallback,
ProgressBarType,
)
class Project(MetadataContainer):
"""Starts a connection to an existing Neptune project.
You can use the Project object to retrieve information about runs, models, and model versions
within the project.
You can also log (and fetch) metadata common to the whole project, such as information about datasets,
links to documents, or key project metrics.
Note: If you want to instead create a project, use the
[`management.create_project()`](https://docs.neptune.ai/api/management/#create_project) function.
You can also use the Project object as a context manager (see examples).
Args:
project: Name of a project in the form `workspace-name/project-name`.
If left empty, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
mode: Connection mode in which the tracking will work.
If left empty, the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `read-only`, and `debug`.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered.
Defaults to 5 (every 5 seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a Project object as the argument and can contain any custom code, such as calling `stop()`
on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback
should take a Project object as the argument and can contain any custom code, such as calling `stop()`
on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
Project object that can be used to interact with the project as a whole,
like logging or fetching project-level metadata.
Examples:
>>> import neptune
>>> # Connect to the project "classification" in the workspace "ml-team":
... project = neptune.init_project(project="ml-team/classification")
>>> # Or initialize with the constructor
... project = Project(project="ml-team/classification")
>>> # Connect to a project in read-only mode:
... project = neptune.init_project(
... project="ml-team/classification",
... mode="read-only",
... )
Using the Project object as context manager:
>>> with Project(project="ml-team/classification") as project:
... project["metadata"] = some_metadata
For more, see the docs:
Initializing a project:
https://docs.neptune.ai/api/neptune#init_project
Project class reference:
https://docs.neptune.ai/api/project/
"""
container_type = ContainerType.PROJECT
def __init__(
self,
project: Optional[str] = None,
*,
api_token: Optional[str] = None,
mode: Optional[Literal["async", "sync", "read-only", "debug"]] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
):
verify_type("mode", mode, (str, type(None)))
# make mode proper Enum instead of string
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
if mode == Mode.OFFLINE:
raise NeptuneException("Project can't be initialized in OFFLINE mode")
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
return ApiExperiment(
id=self._project_api_object.id,
type=ContainerType.PROJECT,
sys_id=self._project_api_object.sys_id,
workspace=self._project_api_object.workspace,
project_name=self._project_api_object.name,
)
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveProjectException(label=f"{self._workspace}/{self._project_name}")
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_project_url(
project_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
)
def fetch_runs_table(
self,
*,
query: Optional[str] = None,
id: Optional[Union[str, Iterable[str]]] = None,
state: Optional[Union[Literal["inactive", "active"], Iterable[Literal["inactive", "active"]]]] = None,
owner: Optional[Union[str, Iterable[str]]] = None,
tag: Optional[Union[str, Iterable[str]]] = None,
columns: Optional[Iterable[str]] = None,
trashed: Optional[bool] = False,
limit: Optional[int] = None,
sort_by: str = "sys/creation_time",
ascending: bool = False,
progress_bar: Optional[ProgressBarType] = None,
) -> Table:
"""Retrieve runs matching the specified criteria.
All parameters are optional. Each of them specifies a single criterion.
Only runs matching all of the criteria will be returned.
Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(accuracy: float > 0.88) AND (loss: float < 0.2)"`.
Exclusive with the `id`, `state`, `owner`, and `tag` parameters.
id: Neptune ID of a run, or list of several IDs.
Example: `"SAN-1"` or `["SAN-1", "SAN-2"]`.
Matching any element of the list is sufficient to pass the criterion.
state: Run state, or list of states.
Example: `"active"`.
Possible values: `"inactive"`, `"active"`.
Matching any element of the list is sufficient to pass the criterion.
owner: Username of the run owner, or a list of owners.
Example: `"josh"` or `["frederic", "josh"]`.
The owner is the user who created the run.
Matching any element of the list is sufficient to pass the criterion.
tag: A tag or list of tags applied to the run.
Example: `"lightGBM"` or `["pytorch", "cycleLR"]`.
Only runs that have all specified tags will match this criterion.
columns: Names of columns to include in the table, as a list of field names.
The Neptune ID ("sys/id") is included automatically.
If `None` (default), all the columns of the runs table are included, up to a maximum of 10 000 columns.
trashed: Whether to retrieve trashed runs.
If `True`, only trashed runs are retrieved.
If `False` (default), only not-trashed runs are retrieved.
If `None`, both trashed and not-trashed runs are retrieved.
limit: How many entries to return at most. If `None`, all entries are returned.
sort_by: Name of the field to sort the results by.
The field must represent a simple type (string, float, datetime, integer, or Boolean).
ascending: Whether to sort the entries in ascending order of the sorting column values.
progress_bar: Set to `False` to disable the download progress bar,
or pass a `ProgressBarCallback` class to use your own progress bar callback.
Returns:
`Table` object containing `Run` objects matching the specified criteria.
Use `to_pandas()` to convert the table to a pandas DataFrame.
Examples:
>>> import neptune
... # Fetch project "jackie/sandbox"
... project = neptune.init_project(mode="read-only", project="jackie/sandbox")
>>> # Fetch the metadata of all runs as a pandas DataFrame
... runs_table_df = project.fetch_runs_table().to_pandas()
... # Extract the ID of the last run
... last_run_id = runs_table_df["sys/id"].values[0]
>>> # Fetch the 100 oldest runs
... runs_table_df = project.fetch_runs_table(
... sort_by="sys/creation_time", ascending=True, limit=100
... ).to_pandas()
>>> # Fetch the 100 largest runs (space they take up in Neptune)
... runs_table_df = project.fetch_runs_table(sort_by="sys/size", limit=100).to_pandas()
>>> # Include only the fields "train/loss" and "params/lr" as columns:
... runs_table_df = project.fetch_runs_table(columns=["params/lr", "train/loss"]).to_pandas()
>>> # Pass a custom progress bar callback
... runs_table_df = project.fetch_runs_table(progress_bar=MyProgressBar).to_pandas()
... # The class MyProgressBar(ProgressBarCallback) must be defined
You can also filter the runs table by state, owner, tag, or a combination of these:
>>> # Fetch only inactive runs
... runs_table_df = project.fetch_runs_table(state="inactive").to_pandas()
>>> # Fetch only runs created by CI service
... runs_table_df = project.fetch_runs_table(owner="my_company_ci_service").to_pandas()
>>> # Fetch only runs that have both "Exploration" and "Optuna" tags
... runs_table_df = project.fetch_runs_table(tag=["Exploration", "Optuna"]).to_pandas()
>>> # You can combine conditions. Runs satisfying all conditions will be fetched
... runs_table_df = project.fetch_runs_table(state="inactive", tag="Exploration").to_pandas()
See also the API reference in the docs:
https://docs.neptune.ai/api/project#fetch_runs_table
"""
if any((id, state, owner, tag)) and query is not None:
raise ValueError(
"You can't use the 'query' parameter together with the 'id', 'state', 'owner', or 'tag' parameters."
)
ids = as_list("id", id)
states = as_list("state", state)
owners = as_list("owner", owner)
tags = as_list("tag", tag)
verify_type("query", query, (str, type(None)))
verify_type("trashed", trashed, (bool, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback)))
verify_collection_type("state", states, str)
if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")
for state in states:
verify_value("state", state.lower(), ("inactive", "active"))
if query is not None:
nql_query = build_raw_query(query, trashed=trashed)
else:
nql_query = prepare_nql_query(ids, states, owners, tags, trashed)
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.RUN,
query=nql_query,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
@model_registry_deprecation
def fetch_models_table(
self,
*,
query: Optional[str] = None,
columns: Optional[Iterable[str]] = None,
trashed: Optional[bool] = False,
limit: Optional[int] = None,
sort_by: str = "sys/creation_time",
ascending: bool = False,
progress_bar: Optional[ProgressBarType] = None,
) -> Table:
"""Retrieve models stored in the project.
Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(model_size: float > 100) AND (backbone: string = VGG)"`.
trashed: Whether to retrieve trashed models.
If `True`, only trashed models are retrieved.
If `False`, only not-trashed models are retrieved.
If `None`, both trashed and not-trashed models are retrieved.
columns: Names of columns to include in the table, as a list of field names.
The Neptune ID ("sys/id") is included automatically.
If `None`, all the columns of the models table are included, up to a maximum of 10 000 columns.
limit: How many entries to return at most. If `None`, all entries are returned.
sort_by: Name of the field to sort the results by.
The field must represent a simple type (string, float, datetime, integer, or Boolean).
ascending: Whether to sort the entries in ascending order of the sorting column values.
progress_bar: Set to `False` to disable the download progress bar,
or pass a `ProgressBarCallback` class to use your own progress bar callback.
Returns:
`Table` object containing `Model` objects.
Use `to_pandas()` to convert the table to a pandas DataFrame.
Examples:
>>> import neptune
... # Fetch project "jackie/sandbox"
... project = neptune.init_project(mode="read-only", project="jackie/sandbox")
>>> # Fetch the metadata of all models as a pandas DataFrame
... models_table_df = project.fetch_models_table().to_pandas()
>>> # Include only the fields "params/lr" and "info/size" as columns:
... models_table_df = project.fetch_models_table(columns=["params/lr", "info/size"]).to_pandas()
>>> # Fetch 10 oldest model objects
... models_table_df = project.fetch_models_table(
... sort_by="sys/creation_time", ascending=True, limit=10
... ).to_pandas()
... # Extract the ID of the first listed (oldest) model object
... last_model_id = models_table_df["sys/id"].values[0]
>>> # Fetch models with VGG backbone
... models_table_df = project.fetch_models_table(
query="(backbone: string = VGG)"
).to_pandas()
See also the API reference in the docs:
https://docs.neptune.ai/api/project#fetch_models_table
"""
verify_type("query", query, (str, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback)))
if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")
query = query if query is not None else ""
nql = build_raw_query(query=query, trashed=trashed)
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.MODEL,
query=nql,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Run"]
import os
import threading
from platform import node as get_hostname
from typing import (
TYPE_CHECKING,
List,
Optional,
Tuple,
Union,
)
from typing_extensions import Literal
from neptune.attributes.constants import (
SYSTEM_DESCRIPTION_ATTRIBUTE_PATH,
SYSTEM_FAILED_ATTRIBUTE_PATH,
SYSTEM_HOSTNAME_ATTRIBUTE_PATH,
SYSTEM_NAME_ATTRIBUTE_PATH,
SYSTEM_TAGS_ATTRIBUTE_PATH,
)
from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.envs import (
CONNECTION_MODE,
CUSTOM_RUN_ID_ENV_NAME,
MONITORING_NAMESPACE,
NEPTUNE_NOTEBOOK_ID,
NEPTUNE_NOTEBOOK_PATH,
)
from neptune.exceptions import (
InactiveRunException,
NeedExistingRunForReadOnlyMode,
NeptunePossibleLegacyUsageException,
NeptuneRunResumeAndCustomIdCollision,
)
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.backends.neptune_backend import NeptuneBackend
from neptune.internal.container_type import ContainerType
from neptune.internal.hardware.hardware_metric_reporting_job import HardwareMetricReportingJob
from neptune.internal.id_formats import QualifiedName
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
DEFAULT_NAME,
OFFLINE_PROJECT_QUALIFIED_NAME,
)
from neptune.internal.notebooks.notebooks import create_checkpoint
from neptune.internal.state import ContainerState
from neptune.internal.streams.std_capture_background_job import (
StderrCaptureBackgroundJob,
StdoutCaptureBackgroundJob,
)
from neptune.internal.utils import (
verify_collection_type,
verify_type,
)
from neptune.internal.utils.dependency_tracking import (
FileDependenciesStrategy,
InferDependenciesStrategy,
)
from neptune.internal.utils.git import (
to_git_info,
track_uncommitted_changes,
)
from neptune.internal.utils.hashing import generate_hash
from neptune.internal.utils.limits import custom_run_id_exceeds_length
from neptune.internal.utils.ping_background_job import PingBackgroundJob
from neptune.internal.utils.runningmode import (
in_interactive,
in_notebook,
)
from neptune.internal.utils.source_code import upload_source_code
from neptune.internal.utils.traceback_job import TracebackJob
from neptune.internal.websockets.websocket_signals_background_job import WebsocketSignalsBackgroundJob
from neptune.metadata_containers import MetadataContainer
from neptune.metadata_containers.abstract import NeptuneObjectCallback
from neptune.types import (
GitRef,
StringSeries,
)
from neptune.types.atoms.git_ref import GitRefDisabled
from neptune.types.mode import Mode
if TYPE_CHECKING:
from neptune.internal.background_job import BackgroundJob
class Run(MetadataContainer):
"""Starts a new tracked run that logs ML model-building metadata to neptune.ai.
You can log metadata by assigning it to the initialized Run object:
```
run = neptune.init_run()
run["your/structure"] = some_metadata
```
Examples of metadata you can log: metrics, losses, scores, artifact versions, images, predictions,
model weights, parameters, checkpoints, and interactive visualizations.
By default, the run automatically tracks hardware consumption, stdout/stderr, source code, and Git information.
If you're using Neptune in an interactive session, however, some background monitoring needs to be enabled
explicitly.
If you provide the ID of an existing run, that run is resumed and no new run is created. You may resume a run
either to log more metadata or to fetch metadata from it.
The run ends either when its `stop()` method is called or when the script finishes execution.
You can also use the Run object as a context manager (see examples).
Args:
project: Name of the project where the run should go, in the form `workspace-name/project_name`.
If left empty, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
with_id: If you want to resume a run, pass the identifier of an existing run. For example, "SAN-1".
If left empty, a new run is created.
custom_run_id: A unique identifier to be used when running Neptune in distributed training jobs.
Make sure to use the same identifier throughout the whole pipeline execution.
mode: Connection mode in which the tracking will work.
If left empty, the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `offline`, `read-only`, and `debug`.
name: Custom name for the run. You can add it as a column in the runs table ("sys/name").
You can also edit the name in the app: Open the run menu and access the run information.
description: Custom description of the run. You can add it as a column in the runs table
("sys/description").
You can also edit the description in the app: Open the run menu and access the run information.
tags: Tags of the run as a list of strings.
You can edit the tags through the "sys/tags" field or in the app (run menu -> information).
You can also select multiple runs and manage their tags as a single action.
source_files: List of source files to be uploaded.
Uploaded source files are displayed in the "Source code" dashboard.
To not upload anything, pass an empty list (`[]`).
Unix style pathname pattern expansion is supported. For example, you can pass `*.py` to upload
all Python files from the current directory.
If None is passed, the Python file from which the run was created will be uploaded.
capture_stdout: Whether to log the stdout of the run.
Defaults to `False` in interactive sessions and `True` otherwise.
The data is logged under the monitoring namespace (see the `monitoring_namespace` parameter).
capture_stderr: Whether to log the stderr of the run.
Defaults to `False` in interactive sessions and `True` otherwise.
The data is logged under the monitoring namespace (see the `monitoring_namespace` parameter).
capture_hardware_metrics: Whether to send hardware monitoring logs (CPU, GPU, and memory utilization).
Defaults to `False` in interactive sessions and `True` otherwise.
The data is logged under the monitoring namespace (see the `monitoring_namespace` parameter).
fail_on_exception: Whether to register an uncaught exception handler to this process and,
in case of an exception, set the "sys/failed" field of the run to `True`.
An exception is always logged.
monitoring_namespace: Namespace inside which all hardware monitoring logs are stored.
Defaults to "monitoring/<hash>", where the hash is generated based on environment information,
to ensure that it's unique for each process.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered
(in seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
capture_traceback: Whether to log the traceback of the run in case of an exception.
The tracked metadata is stored in the "<monitoring_namespace>/traceback" namespace (see the
`monitoring_namespace` parameter).
git_ref: GitRef object containing information about the Git repository path.
If None, Neptune looks for a repository in the path of the script that is executed.
To specify a different location, set to GitRef(repository_path="path/to/repo").
To turn off Git tracking for the run, set to False or GitRef.DISABLED.
dependencies: If you pass `"infer"`, Neptune logs dependencies installed in the current environment.
You can also pass a path to your dependency file directly.
If left empty, no dependencies are tracked.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a Run object as the argument and can contain any custom code, such as calling `stop()` on
the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback
should take a Run object as the argument and can contain any custom code, such as calling `stop()` on
the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
Run object that is used to manage the tracked run and log metadata to it.
Examples:
Creating a new run:
>>> import neptune
>>> # Minimal invoke
... # (creates a run in the project specified by the NEPTUNE_PROJECT environment variable)
... run = neptune.init_run()
>>> # Or initialize with the constructor
... run = Run(project="ml-team/classification")
>>> # Create a run with a name and description, with no sources files or Git info tracked:
>>> run = neptune.init_run(
... name="neural-net-mnist",
... description="neural net trained on MNIST",
... source_files=[],
... git_ref=False,
... )
>>> # Log all .py files from all subdirectories, excluding hidden files
... run = neptune.init_run(source_files="**/*.py")
>>> # Log all files and directories in the current working directory, excluding hidden files
... run = neptune.init_run(source_files="*")
>>> # Larger example
... run = neptune.init_run(
... project="ml-team/classification",
... name="first-pytorch-ever",
... description="Longer description of the run goes here",
... tags=["tags", "go-here", "as-list-of-strings"],
... source_files=["training_with_pytorch.py", "net.py"],
... dependencies="infer",
... capture_stderr=False,
... git_ref=GitRef(repository_path="/Users/Jackie/repos/cls_project"),
... )
Connecting to an existing run:
>>> # Resume logging to an existing run with the ID "SAN-3"
... run = neptune.init_run(with_id="SAN-3")
... run["parameters/lr"] = 0.1 # modify or add metadata
>>> # Initialize an existing run in read-only mode (logging new data is not possible, only fetching)
... run = neptune.init_run(with_id="SAN-4", mode="read-only")
... learning_rate = run["parameters/lr"].fetch()
Using the Run object as context manager:
>>> with Run() as run:
... run["metric"].append(value)
For more, see the docs:
Initializing a run:
https://docs.neptune.ai/api/neptune#init_run
Run class reference:
https://docs.neptune.ai/api/run/
Essential logging methods:
https://docs.neptune.ai/logging/methods/
Resuming a run:
https://docs.neptune.ai/logging/to_existing_object/
Setting a custom run ID:
https://docs.neptune.ai/logging/custom_run_id/
Logging to multiple runs at once:
https://docs.neptune.ai/logging/to_multiple_objects/
Accessing the run from multiple places:
https://docs.neptune.ai/logging/from_multiple_places/
"""
container_type = ContainerType.RUN
LEGACY_METHODS = (
"create_experiment",
"send_metric",
"log_metric",
"send_text",
"log_text",
"send_image",
"log_image",
"send_artifact",
"log_artifact",
"delete_artifacts",
"download_artifact",
"download_sources",
"download_artifacts",
"reset_log",
"get_parameters",
"get_properties",
"set_property",
"remove_property",
"get_hardware_utilization",
"get_numeric_channels_values",
)
def __init__(
self,
with_id: Optional[str] = None,
*,
project: Optional[str] = None,
api_token: Optional[str] = None,
custom_run_id: Optional[str] = None,
mode: Optional[Literal["async", "sync", "offline", "read-only", "debug"]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[Union[List[str], str]] = None,
source_files: Optional[Union[List[str], str]] = None,
capture_stdout: Optional[bool] = None,
capture_stderr: Optional[bool] = None,
capture_hardware_metrics: Optional[bool] = None,
fail_on_exception: bool = True,
monitoring_namespace: Optional[str] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
capture_traceback: bool = True,
git_ref: Optional[Union[GitRef, GitRefDisabled, bool]] = None,
dependencies: Optional[Union[str, os.PathLike]] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
**kwargs,
):
check_for_extra_kwargs("Run", kwargs)
verify_type("with_id", with_id, (str, type(None)))
verify_type("project", project, (str, type(None)))
verify_type("custom_run_id", custom_run_id, (str, type(None)))
verify_type("mode", mode, (str, type(None)))
verify_type("name", name, (str, type(None)))
verify_type("description", description, (str, type(None)))
verify_type("capture_stdout", capture_stdout, (bool, type(None)))
verify_type("capture_stderr", capture_stderr, (bool, type(None)))
verify_type("capture_hardware_metrics", capture_hardware_metrics, (bool, type(None)))
verify_type("fail_on_exception", fail_on_exception, bool)
verify_type("monitoring_namespace", monitoring_namespace, (str, type(None)))
verify_type("capture_traceback", capture_traceback, bool)
verify_type("git_ref", git_ref, (GitRef, str, bool, type(None)))
verify_type("dependencies", dependencies, (str, os.PathLike, type(None)))
if tags is not None:
if isinstance(tags, str):
tags = [tags]
else:
verify_collection_type("tags", tags, str)
if source_files is not None:
if isinstance(source_files, str):
source_files = [source_files]
else:
verify_collection_type("source_files", source_files, str)
self._with_id: Optional[str] = with_id
self._name: Optional[str] = name
self._description: Optional[str] = "" if with_id is None and description is None else description
self._custom_run_id: Optional[str] = custom_run_id or os.getenv(CUSTOM_RUN_ID_ENV_NAME)
self._hostname: str = get_hostname()
self._pid: int = os.getpid()
self._tid: int = threading.get_ident()
self._tags: Optional[List[str]] = tags
self._source_files: Optional[List[str]] = source_files
self._fail_on_exception: bool = fail_on_exception
self._capture_traceback: bool = capture_traceback
if type(git_ref) is bool:
git_ref = GitRef() if git_ref else GitRef.DISABLED
self._git_ref: Optional[GitRef, GitRefDisabled] = git_ref or GitRef()
self._dependencies: Optional[str, os.PathLike] = dependencies
self._monitoring_namespace: str = (
monitoring_namespace
or os.getenv(MONITORING_NAMESPACE)
or generate_monitoring_namespace(self._hostname, self._pid, self._tid)
)
# for backward compatibility imports
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
self._stdout_path: str = "{}/stdout".format(self._monitoring_namespace)
self._capture_stdout: bool = capture_stdout
if capture_stdout is None:
self._capture_stdout = capture_only_if_non_interactive(mode=mode)
self._stderr_path: str = "{}/stderr".format(self._monitoring_namespace)
self._capture_stderr: bool = capture_stderr
if capture_stderr is None:
self._capture_stderr = capture_only_if_non_interactive(mode=mode)
self._capture_hardware_metrics: bool = capture_hardware_metrics
if capture_hardware_metrics is None:
self._capture_hardware_metrics = capture_only_if_non_interactive(mode=mode)
if with_id and custom_run_id:
raise NeptuneRunResumeAndCustomIdCollision()
if mode == Mode.OFFLINE or mode == Mode.DEBUG:
project = OFFLINE_PROJECT_QUALIFIED_NAME
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
project_workspace = self._project_api_object.workspace
project_name = self._project_api_object.name
project_qualified_name = f"{project_workspace}/{project_name}"
if self._with_id:
return self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._with_id),
expected_container_type=Run.container_type,
)
else:
if self._mode == Mode.READ_ONLY:
raise NeedExistingRunForReadOnlyMode()
git_info = to_git_info(git_ref=self._git_ref)
custom_run_id = self._custom_run_id
if custom_run_id_exceeds_length(self._custom_run_id):
custom_run_id = None
notebook_id, checkpoint_id = create_notebook_checkpoint(backend=self._backend)
return self._backend.create_run(
project_id=self._project_api_object.id,
git_info=git_info,
custom_run_id=custom_run_id,
notebook_id=notebook_id,
checkpoint_id=checkpoint_id,
)
def _get_background_jobs(self) -> List["BackgroundJob"]:
background_jobs = [PingBackgroundJob()]
websockets_factory = self._backend.websockets_factory(self._project_api_object.id, self._id)
if websockets_factory:
background_jobs.append(WebsocketSignalsBackgroundJob(websockets_factory))
if self._capture_stdout:
background_jobs.append(StdoutCaptureBackgroundJob(attribute_name=self._stdout_path))
if self._capture_stderr:
background_jobs.append(StderrCaptureBackgroundJob(attribute_name=self._stderr_path))
if self._capture_hardware_metrics:
background_jobs.append(HardwareMetricReportingJob(attribute_namespace=self._monitoring_namespace))
if self._capture_traceback:
background_jobs.append(
TracebackJob(path=f"{self._monitoring_namespace}/traceback", fail_on_exception=self._fail_on_exception)
)
return background_jobs
def _write_initial_monitoring_attributes(self) -> None:
if self._hostname is not None:
self[f"{self._monitoring_namespace}/hostname"] = self._hostname
if self._with_id is None:
self[SYSTEM_HOSTNAME_ATTRIBUTE_PATH] = self._hostname
if self._pid is not None:
self[f"{self._monitoring_namespace}/pid"] = str(self._pid)
if self._tid is not None:
self[f"{self._monitoring_namespace}/tid"] = str(self._tid)
def _write_initial_attributes(self):
if not getattr(self._backend, "sys_name_set_by_backend", False):
self._name = self._name if self._name is not None else DEFAULT_NAME
if self._name is not None:
self[SYSTEM_NAME_ATTRIBUTE_PATH] = self._name
if self._description is not None:
self[SYSTEM_DESCRIPTION_ATTRIBUTE_PATH] = self._description
if any((self._capture_stderr, self._capture_stdout, self._capture_traceback, self._capture_hardware_metrics)):
self._write_initial_monitoring_attributes()
if self._tags is not None:
self[SYSTEM_TAGS_ATTRIBUTE_PATH].add(self._tags)
if self._with_id is None:
self[SYSTEM_FAILED_ATTRIBUTE_PATH] = False
if self._capture_stdout and not self.exists(self._stdout_path):
self.define(self._stdout_path, StringSeries([]))
if self._capture_stderr and not self.exists(self._stderr_path):
self.define(self._stderr_path, StringSeries([]))
if self._with_id is None or self._source_files is not None:
# upload default sources ONLY if creating a new run
upload_source_code(source_files=self._source_files, run=self)
if self._dependencies:
try:
if self._dependencies == "infer":
dependency_strategy = InferDependenciesStrategy()
else:
dependency_strategy = FileDependenciesStrategy(path=self._dependencies)
dependency_strategy.log_dependencies(run=self)
except Exception as e:
warn_once(
"An exception occurred in automatic dependency tracking."
"Skipping upload of requirement files."
"Exception: " + str(e),
exception=NeptuneWarning,
)
try:
track_uncommitted_changes(
git_ref=self._git_ref,
run=self,
)
except Exception as e:
warn_once(
"An exception occurred in tracking uncommitted changes."
"Skipping upload of patch files."
"Exception: " + str(e),
exception=NeptuneWarning,
)
@property
def monitoring_namespace(self) -> str:
return self._monitoring_namespace
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveRunException(label=self._sys_id)
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_run_url(
run_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
sys_id=self._sys_id,
)
def capture_only_if_non_interactive(mode) -> bool:
if in_interactive() or in_notebook():
if mode in {Mode.OFFLINE, Mode.SYNC, Mode.ASYNC}:
warn_once(
"By default, these monitoring options are disabled in interactive sessions:"
" 'capture_stdout', 'capture_stderr', 'capture_traceback', 'capture_hardware_metrics'."
" You can set them to 'True' when initializing the run and the monitoring will"
" continue until you call run.stop() or the kernel stops."
" NOTE: To track the source files, pass their paths to the 'source_code'"
" argument. For help, see: https://docs.neptune.ai/logging/source_code/",
exception=NeptuneWarning,
)
return False
return True
def generate_monitoring_namespace(*descriptors) -> str:
return f"monitoring/{generate_hash(*descriptors, length=8)}"
def check_for_extra_kwargs(caller_name: str, kwargs: dict):
legacy_kwargs = ("project_qualified_name", "backend")
for name in legacy_kwargs:
if name in kwargs:
raise NeptunePossibleLegacyUsageException()
if kwargs:
first_key = next(iter(kwargs.keys()))
raise TypeError(f"{caller_name}() got an unexpected keyword argument '{first_key}'")
def create_notebook_checkpoint(backend: NeptuneBackend) -> Tuple[Optional[str], Optional[str]]:
notebook_id = os.getenv(NEPTUNE_NOTEBOOK_ID, None)
notebook_path = os.getenv(NEPTUNE_NOTEBOOK_PATH, None)
checkpoint_id = None
if notebook_id is not None and notebook_path is not None:
checkpoint_id = create_checkpoint(backend=backend, notebook_id=notebook_id, notebook_path=notebook_path)
return notebook_id, checkpoint_id
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["StructureVersion"]
from enum import Enum
class StructureVersion(Enum):
# -------------------------------------------------
# .neptune/
# async/
# <uuid>/
# exec-<num><timestamp>/
# container_type
# data-1.log
# ...
# -------------------------------------------------
LEGACY = 1
# -------------------------------------------------
# .neptune/
# async/
# run__<uuid>/
# exec-<timestamp>-<date>-<pid>/
# data-1.log
# ...
# -------------------------------------------------
CHILD_EXECUTION_DIRECTORIES = 2
# -------------------------------------------------
# .neptune/
# async/
# run__<uuid>__<pid>__<random_key>/
# data-1.log
# ...
# -------------------------------------------------
DIRECT_DIRECTORY = 3
#
# Copyright (c) 2023, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"parse_dates",
"prepare_nql_query",
]
from typing import (
Generator,
Iterable,
List,
Optional,
Union,
)
from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.internal.backends.api_model import (
AttributeType,
AttributeWithProperties,
LeaderboardEntry,
)
from neptune.internal.backends.nql import (
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLQuery,
NQLQueryAggregate,
NQLQueryAttribute,
RawNQLQuery,
)
from neptune.internal.utils.iso_dates import parse_iso_date
from neptune.internal.utils.run_state import RunState
def prepare_nql_query(
ids: Optional[Iterable[str]],
states: Optional[Iterable[str]],
owners: Optional[Iterable[str]],
tags: Optional[Iterable[str]],
trashed: Optional[bool],
) -> NQLQueryAggregate:
query_items: List[Union[NQLQueryAttribute, NQLQueryAggregate]] = []
if trashed is not None:
query_items.append(
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=trashed,
)
)
if ids:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/id",
type=NQLAttributeType.STRING,
operator=NQLAttributeOperator.EQUALS,
value=api_id,
)
for api_id in ids
],
aggregator=NQLAggregator.OR,
)
)
if states:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/state",
type=NQLAttributeType.EXPERIMENT_STATE,
operator=NQLAttributeOperator.EQUALS,
value=RunState.from_string(state).to_api(),
)
for state in states
],
aggregator=NQLAggregator.OR,
)
)
if owners:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/owner",
type=NQLAttributeType.STRING,
operator=NQLAttributeOperator.EQUALS,
value=owner,
)
for owner in owners
],
aggregator=NQLAggregator.OR,
)
)
if tags:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/tags",
type=NQLAttributeType.STRING_SET,
operator=NQLAttributeOperator.CONTAINS,
value=tag,
)
for tag in tags
],
aggregator=NQLAggregator.AND,
)
)
query = NQLQueryAggregate(items=query_items, aggregator=NQLAggregator.AND)
return query
def parse_dates(leaderboard_entries: Iterable[LeaderboardEntry]) -> Generator[LeaderboardEntry, None, None]:
yield from [_parse_entry(entry) for entry in leaderboard_entries]
def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry:
try:
return LeaderboardEntry(
entry.id,
attributes=[
(
AttributeWithProperties(
attribute.path,
attribute.type,
{
**attribute.properties,
"value": parse_iso_date(attribute.properties["value"]),
},
)
if attribute.type == AttributeType.DATETIME
else attribute
)
for attribute in entry.attributes
],
)
except ValueError:
# the parsing format is incorrect
warn_once(
"Date parsing failed. The date format is incorrect. Returning as string instead of datetime.",
exception=NeptuneWarning,
)
return entry
def build_raw_query(query: str, trashed: Optional[bool]) -> NQLQuery:
raw_nql = RawNQLQuery(query)
if trashed is None:
return raw_nql
nql = NQLQueryAggregate(
items=[
raw_nql,
NQLQueryAttribute(
name="sys/trashed", type=NQLAttributeType.BOOLEAN, operator=NQLAttributeOperator.EQUALS, value=trashed
),
],
aggregator=NQLAggregator.AND,
)
return nql
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# flake8: noqa
__all__ = [
"ANONYMOUS_API_TOKEN",
"NeptunePossibleLegacyUsageException",
"init_model",
"init_model_version",
"init_project",
"init_run",
"Run",
"__version__",
"create_experiment",
"get_experiment",
"append_tag",
"append_tags",
"remove_tag",
"set_property",
"remove_property",
"send_metric",
"log_metric",
"send_text",
"log_text",
"send_image",
"log_image",
"send_artifact",
"delete_artifacts",
"log_artifact",
"stop",
]
import sys
from neptune.new._compatibility import CompatibilityImporter
sys.meta_path.insert(0, CompatibilityImporter())
from neptune import (
ANONYMOUS_API_TOKEN,
Run,
__version__,
init_model,
init_model_version,
init_project,
init_run,
)
from neptune.common.warnings import warn_once
from neptune.new.attributes import *
from neptune.new.cli import *
from neptune.new.constants import *
from neptune.new.envs import *
from neptune.new.exceptions import *
from neptune.new.handler import *
from neptune.new.integrations import *
from neptune.new.logging import *
from neptune.new.metadata_containers import *
from neptune.new.project import *
from neptune.new.run import *
from neptune.new.runs_table import *
from neptune.new.types import *
from neptune.new.utils import *
def _raise_legacy_client_expected(*args, **kwargs):
raise NeptunePossibleLegacyUsageException()
create_experiment = get_experiment = append_tag = append_tags = remove_tag = set_property = remove_property = (
send_metric
) = log_metric = send_text = log_text = send_image = log_image = send_artifact = delete_artifacts = log_artifact = (
stop
) = _raise_legacy_client_expected
warn_once(
message="You're importing the Neptune client library via the deprecated"
" `neptune.new` module, which will be removed in a future release."
" Import directly from `neptune` instead."
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["CompatibilityImporter"]
import sys
from importlib import import_module
from importlib.abc import (
Loader,
MetaPathFinder,
)
from importlib.machinery import ModuleSpec
class CompatibilityModuleLoader(Loader):
def exec_module(self, module):
fullname = module.__name__
module_name_parts = fullname.split(".")
new_module_name = f"neptune.{module_name_parts[2]}"
# Load the module with the new name and update sys.modules
new_module = import_module(new_module_name)
sys.modules[fullname] = new_module
# Update the module's dictionary to reflect the newly loaded module
module.__dict__.update(new_module.__dict__)
modules = [
"neptune.new.attributes",
"neptune.new.cli",
"neptune.new.integrations",
"neptune.new.logging",
"neptune.new.metadata_containers",
"neptune.new.types",
]
class CompatibilityImporter(MetaPathFinder):
def find_spec(self, fullname, path=None, target=None):
if fullname in modules:
return ModuleSpec(fullname, CompatibilityModuleLoader(), is_package=False)
return None # Not handling other modules
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"ANONYMOUS_API_TOKEN",
"NEPTUNE_DATA_DIRECTORY",
"OFFLINE_DIRECTORY",
"ASYNC_DIRECTORY",
"SYNC_DIRECTORY",
"OFFLINE_NAME_PREFIX",
]
"""Constants used by Neptune"""
from neptune.constants import (
ANONYMOUS_API_TOKEN,
ASYNC_DIRECTORY,
NEPTUNE_DATA_DIRECTORY,
OFFLINE_DIRECTORY,
OFFLINE_NAME_PREFIX,
SYNC_DIRECTORY,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"API_TOKEN_ENV_NAME",
"CONNECTION_MODE",
"PROJECT_ENV_NAME",
"CUSTOM_RUN_ID_ENV_NAME",
"MONITORING_NAMESPACE",
"NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE",
"NEPTUNE_NOTEBOOK_ID",
"NEPTUNE_NOTEBOOK_PATH",
"NEPTUNE_RETRIES_TIMEOUT_ENV",
"NEPTUNE_SYNC_BATCH_TIMEOUT_ENV",
"NEPTUNE_SUBPROCESS_KILL_TIMEOUT",
"NEPTUNE_FETCH_TABLE_STEP_SIZE",
]
from neptune.envs import (
API_TOKEN_ENV_NAME,
CONNECTION_MODE,
CUSTOM_RUN_ID_ENV_NAME,
MONITORING_NAMESPACE,
NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE,
NEPTUNE_FETCH_TABLE_STEP_SIZE,
NEPTUNE_NOTEBOOK_ID,
NEPTUNE_NOTEBOOK_PATH,
NEPTUNE_RETRIES_TIMEOUT_ENV,
NEPTUNE_SUBPROCESS_KILL_TIMEOUT,
NEPTUNE_SYNC_BATCH_TIMEOUT_ENV,
PROJECT_ENV_NAME,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"InternalClientError",
"NeptuneException",
"NeptuneInvalidApiTokenException",
"NeptuneApiException",
"MetadataInconsistency",
"MissingFieldException",
"TypeDoesNotSupportAttributeException",
"MalformedOperation",
"FileNotFound",
"FileUploadError",
"FileSetUploadError",
"ClientHttpError",
"MetadataContainerNotFound",
"ProjectNotFound",
"RunNotFound",
"ModelNotFound",
"ModelVersionNotFound",
"ExceptionWithProjectsWorkspacesListing",
"ContainerUUIDNotFound",
"RunUUIDNotFound",
"ProjectNotFoundWithSuggestions",
"AmbiguousProjectName",
"NeptuneMissingProjectNameException",
"InactiveContainerException",
"InactiveRunException",
"InactiveModelException",
"InactiveModelVersionException",
"InactiveProjectException",
"NeptuneMissingApiTokenException",
"CannotSynchronizeOfflineRunsWithoutProject",
"NeedExistingExperimentForReadOnlyMode",
"NeedExistingRunForReadOnlyMode",
"NeedExistingModelForReadOnlyMode",
"NeedExistingModelVersionForReadOnlyMode",
"NeptuneParametersCollision",
"NeptuneWrongInitParametersException",
"NeptuneRunResumeAndCustomIdCollision",
"NeptuneClientUpgradeRequiredError",
"NeptuneMissingRequiredInitParameter",
"CannotResolveHostname",
"NeptuneSSLVerificationError",
"NeptuneConnectionLostException",
"InternalServerError",
"Unauthorized",
"Forbidden",
"NeptuneOfflineModeException",
"NeptuneOfflineModeFetchException",
"NeptuneOfflineModeChangeStageException",
"NeptuneProtectedPathException",
"NeptuneCannotChangeStageManually",
"OperationNotSupported",
"NeptuneLegacyProjectException",
"NeptuneMissingRequirementException",
"NeptuneLimitExceedException",
"NeptuneFieldCountLimitExceedException",
"NeptuneStorageLimitException",
"FetchAttributeNotFoundException",
"ArtifactNotFoundException",
"PlotlyIncompatibilityException",
"NeptunePossibleLegacyUsageException",
"NeptuneLegacyIncompatibilityException",
"NeptuneUnhandledArtifactSchemeException",
"NeptuneUnhandledArtifactTypeException",
"NeptuneLocalStorageAccessException",
"NeptuneRemoteStorageCredentialsException",
"NeptuneRemoteStorageAccessException",
"ArtifactUploadingError",
"NeptuneUnsupportedArtifactFunctionalityException",
"NeptuneEmptyLocationException",
"NeptuneFeatureNotAvailableException",
"NeptuneObjectCreationConflict",
"NeptuneModelKeyAlreadyExistsError",
"NeptuneSynchronizationAlreadyStoppedException",
"StreamAlreadyUsedException",
"NeptuneUserApiInputException",
]
from neptune.exceptions import (
AmbiguousProjectName,
ArtifactNotFoundException,
ArtifactUploadingError,
CannotResolveHostname,
CannotSynchronizeOfflineRunsWithoutProject,
ClientHttpError,
ContainerUUIDNotFound,
ExceptionWithProjectsWorkspacesListing,
FetchAttributeNotFoundException,
FileNotFound,
FileSetUploadError,
FileUploadError,
Forbidden,
InactiveContainerException,
InactiveModelException,
InactiveModelVersionException,
InactiveProjectException,
InactiveRunException,
InternalClientError,
InternalServerError,
MalformedOperation,
MetadataContainerNotFound,
MetadataInconsistency,
MissingFieldException,
ModelNotFound,
ModelVersionNotFound,
NeedExistingExperimentForReadOnlyMode,
NeedExistingModelForReadOnlyMode,
NeedExistingModelVersionForReadOnlyMode,
NeedExistingRunForReadOnlyMode,
NeptuneApiException,
NeptuneCannotChangeStageManually,
NeptuneClientUpgradeRequiredError,
NeptuneConnectionLostException,
NeptuneEmptyLocationException,
NeptuneException,
NeptuneFeatureNotAvailableException,
NeptuneFieldCountLimitExceedException,
NeptuneInvalidApiTokenException,
NeptuneLegacyIncompatibilityException,
NeptuneLegacyProjectException,
NeptuneLimitExceedException,
NeptuneLocalStorageAccessException,
NeptuneMissingApiTokenException,
NeptuneMissingProjectNameException,
NeptuneMissingRequiredInitParameter,
NeptuneMissingRequirementException,
NeptuneModelKeyAlreadyExistsError,
NeptuneObjectCreationConflict,
NeptuneOfflineModeChangeStageException,
NeptuneOfflineModeException,
NeptuneOfflineModeFetchException,
NeptuneParametersCollision,
NeptunePossibleLegacyUsageException,
NeptuneProtectedPathException,
NeptuneRemoteStorageAccessException,
NeptuneRemoteStorageCredentialsException,
NeptuneRunResumeAndCustomIdCollision,
NeptuneSSLVerificationError,
NeptuneStorageLimitException,
NeptuneSynchronizationAlreadyStoppedException,
NeptuneUnhandledArtifactSchemeException,
NeptuneUnhandledArtifactTypeException,
NeptuneUnsupportedArtifactFunctionalityException,
NeptuneUserApiInputException,
NeptuneWrongInitParametersException,
OperationNotSupported,
PlotlyIncompatibilityException,
ProjectNotFound,
ProjectNotFoundWithSuggestions,
RunNotFound,
RunUUIDNotFound,
StreamAlreadyUsedException,
TypeDoesNotSupportAttributeException,
Unauthorized,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Handler"]
from neptune.handler import Handler
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# backwards compatibility
__all__ = ["Project"]
from neptune.metadata_containers import Project
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"Attribute",
"NamespaceAttr",
"NamespaceBuilder",
"InactiveRunException",
"MetadataInconsistency",
"NeptunePossibleLegacyUsageException",
"Handler",
"RunState",
"Run",
"Boolean",
"Integer",
"Datetime",
"Float",
"String",
"Namespace",
"Value",
]
# backwards compatibility
from neptune.attributes.attribute import Attribute
from neptune.attributes.namespace import Namespace as NamespaceAttr
from neptune.attributes.namespace import NamespaceBuilder
from neptune.exceptions import (
InactiveRunException,
MetadataInconsistency,
NeptunePossibleLegacyUsageException,
)
from neptune.handler import Handler
from neptune.internal.state import ContainerState as RunState
from neptune.metadata_containers import Run
from neptune.types import (
Boolean,
Integer,
)
from neptune.types.atoms.datetime import Datetime
from neptune.types.atoms.float import Float
from neptune.types.atoms.string import String
from neptune.types.namespace import Namespace
from neptune.types.value import Value
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"MetadataInconsistency",
"AttributeType",
"AttributeWithProperties",
"NeptuneBackend",
"ContainerType",
"LeaderboardEntry",
"LeaderboardHandler",
"RunsTable",
"RunsTableEntry",
]
# backwards compatibility
from neptune.exceptions import MetadataInconsistency
from neptune.internal.backends.api_model import (
AttributeType,
AttributeWithProperties,
)
from neptune.internal.backends.neptune_backend import NeptuneBackend
from neptune.internal.container_type import ContainerType
from neptune.table import (
LeaderboardEntry,
LeaderboardHandler,
)
from neptune.table import Table as RunsTable
from neptune.table import TableEntry as RunsTableEntry
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["stringify_unsupported"]
from neptune.utils import stringify_unsupported
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["version", "__version__"]
from neptune.version import (
__version__,
version,
)
+49
-37

@@ -1,45 +0,57 @@

## [UNRELEASED] neptune 2.0.0
## neptune 1.12.0
### Breaking changes
- Deleted `neptune.new` package ([#1684](https://github.com/neptune-ai/neptune-client/pull/1684))
- Deleted `neptune.legacy` package ([#1685](https://github.com/neptune-ai/neptune-client/pull/1685))
- Deleted `neptune.common` package ([#1693](https://github.com/neptune-ai/neptune-client/pull/1693))
([#1690](https://github.com/neptune-ai/neptune-client/pull/1690))
- Renamed `metadata_containers` to `objects` ([#1696](https://github.com/neptune-ai/neptune-client/pull/1696))
- Removed `neptune-client` ([#1699](https://github.com/neptune-ai/neptune-client/pull/1699))
- Deleted `neptune.logging` package ([#1698](https://github.com/neptune-ai/neptune-client/pull/1698))
- Disabled `Model` ([#1701](https://github.com/neptune-ai/neptune-client/pull/1701))
- Disabled `ModelVersion` ([#1701](https://github.com/neptune-ai/neptune-client/pull/1708))
- Disabled `Project` ([#1709](https://github.com/neptune-ai/neptune-client/pull/1709))
- Disabled `neptune command-line tool` ([#1718](https://github.com/neptune-ai/neptune-client/pull/1718))
- Disabled deleting fields in `Handler` ([#1729](https://github.com/neptune-ai/neptune-client/pull/1729))
- Removed `AbstractNeptuneObject` ([#1725](https://github.com/neptune-ai/neptune-client/pull/1725))
- Disabled artifact-related methods in `Handler` ([#1734](https://github.com/neptune-ai/neptune-client/pull/1734))
- Removed `boto3` from requirements ([#1743](https://github.com/neptune-ai/neptune-client/pull/1743))
- Disabled `StringSet` `remove` and `clear` methods ([#1732](https://github.com/neptune-ai/neptune-client/pull/1732))
- Disable `fetch_last` and `download_last` ([#1731](https://github.com/neptune-ai/neptune-client/pull/1731))
- Removed `pillow` from requirements ([#1745](https://github.com/neptune-ai/neptune-client/pull/1745))
- Disabled file related functionality ([#1726](https://github.com/neptune-ai/neptune-client/pull/1726))
- Disabled file logging ([#1733](https://github.com/neptune-ai/neptune-client/pull/1733))
### Changes
- Added deprecation notice to model/model_version endpoints ([#1876](https://github.com/neptune-ai/neptune-client/pull/1876))
- Dropped support for Python 3.7 ([#1864](https://github.com/neptune-ai/neptune-client/pull/1864))
### Fixes
- Fixed support for additional Seaborn figure types ([#1864](https://github.com/neptune-ai/neptune-client/pull/1864))
## neptune 1.11.1
### Fixes
- Fixed GPU power consumption monitoring on certain devices ([#1861](https://github.com/neptune-ai/neptune-client/pull/1861))
## neptune 1.11.0
### Features
- Added GPU power consumption monitoring ([#1854](https://github.com/neptune-ai/neptune-client/pull/1854))
### Changes
- Stop sending `X-Neptune-LegacyClient` header ([#1715](https://github.com/neptune-ai/neptune-client/pull/1715))
- Use `tqdm.auto` ([#1717](https://github.com/neptune-ai/neptune-client/pull/1717))
- Fields DTO conversion reworked ([#1722](https://github.com/neptune-ai/neptune-client/pull/1722))
- Added support for Protocol Buffers ([#1728](https://github.com/neptune-ai/neptune-client/pull/1728))
- Series values DTO conversion reworked with protocol buffer support ([#1738](https://github.com/neptune-ai/neptune-client/pull/1738))
- Series values fetching reworked with protocol buffer support ([#1744](https://github.com/neptune-ai/neptune-client/pull/1744))
- Added support for enhanced field definitions querying ([#1751](https://github.com/neptune-ai/neptune-client/pull/1751))
- Added support for `NQL` `MATCHES` operator ([#1863](https://github.com/neptune-ai/neptune-client/pull/1863))
- Pagination respecting `limit` parameter and page size ([#1866](https://github.com/neptune-ai/neptune-client/pull/1866))
- Added support for Protocol Buffers in `query_fields_within_project` ([#1872](https://github.com/neptune-ai/neptune-client/pull/1872))
- Added docstring for the `pop()` function ([#1781](https://github.com/neptune-ai/neptune-client/pull/1781))
## neptune 1.10.4
### Fixes
- Fixed `tqdm.notebook` import only in Notebook environment ([#1716](https://github.com/neptune-ai/neptune-client/pull/1716))
- Added `setuptools` to dependencies for `python >= 3.12` ([#1721](https://github.com/neptune-ai/neptune-client/pull/1721))
- Fixed compatibility checks with pre-release versions ([#1730](https://github.com/neptune-ai/neptune-client/pull/1730))
- Added `Accept` and `Accept-Encoding` headers to protocol buffer requests ([#1736](https://github.com/neptune-ai/neptune-client/pull/1736))
- Fixed default params for getAttributesWithPathsFilter ([#1740](https://github.com/neptune-ai/neptune-client/pull/1740))
- Fixed run not failing in case of an exception if context manager was used ([#1755](https://github.com/neptune-ai/neptune-client/pull/1755))
### Changes
- Make handling of `sys/name` dependent on the client config flag ([#1777](https://github.com/neptune-ai/neptune-client/pull/1777))
## neptune 1.10.3
### Fixes
- Clarified the `autoscale` parameter description in the `as_image()` docstring ([#1742](https://github.com/neptune-ai/neptune-client/pull/1742))
### Changes
- Neptune now shows a warning instead of an error when the dependency tracking file is not found ([#1739](https://github.com/neptune-ai/neptune-client/pull/1739))
## neptune 1.10.2
### Fixes
- Added `setuptools` to dependencies for `python >= 3.12` ([#1723](https://github.com/neptune-ai/neptune-client/pull/1723))
- Fixed `PTL` integration requirement check ([#1719](https://github.com/neptune-ai/neptune-client/pull/1719))
## neptune 1.10.1
### Fixes
- Fixed requirement checking for integrations ([#1711](https://github.com/neptune-ai/neptune-client/pull/1711))
### Changes
- Stop initializing `sys/name` as `"Untitled"` by default ([#1720](https://github.com/neptune-ai/neptune-client/pull/1720))
## neptune 1.10.0

@@ -46,0 +58,0 @@

Metadata-Version: 2.1
Name: neptune
Version: 2.0.0a8
Version: 1.12.0
Summary: Neptune Client

@@ -10,3 +10,3 @@ Home-page: https://neptune.ai/

Author-email: contact@neptune.ai
Requires-Python: >=3.7,<4.0
Requires-Python: >=3.8,<4.0
Classifier: Development Status :: 5 - Production/Stable

@@ -23,3 +23,2 @@ Classifier: Environment :: Console

Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8

@@ -54,6 +53,8 @@ Classifier: Programming Language :: Python :: 3.9

Requires-Dist: GitPython (>=2.0.8)
Requires-Dist: Pillow (>=1.1.6)
Requires-Dist: PyJWT
Requires-Dist: boto3 (>=1.28.0)
Requires-Dist: bravado (>=11.0.0,<12.0.0)
Requires-Dist: click (>=7.0)
Requires-Dist: importlib-metadata ; python_version < "3.8"
Requires-Dist: future (>=0.17.1)
Requires-Dist: kedro-neptune ; (python_version >= "3.9" and python_version < "3.12") and (extra == "kedro")

@@ -63,3 +64,3 @@ Requires-Dist: mosaicml ; extra == "mosaicml"

Requires-Dist: neptune-aws ; extra == "aws"
Requires-Dist: neptune-detectron2 ; (python_version >= "3.7") and (extra == "detectron2")
Requires-Dist: neptune-detectron2 ; (python_version >= "3.8") and (extra == "detectron2")
Requires-Dist: neptune-fastai ; extra == "fastai"

@@ -79,3 +80,2 @@ Requires-Dist: neptune-lightgbm ; extra == "lightgbm"

Requires-Dist: pandas
Requires-Dist: protobuf (>=4.0.0,<5.0.0)
Requires-Dist: psutil

@@ -86,2 +86,3 @@ Requires-Dist: pytorch-lightning ; extra == "pytorch-lightning"

Requires-Dist: setuptools ; python_version >= "3.12"
Requires-Dist: six (>=1.12.0)
Requires-Dist: swagger-spec-validator (>=2.7.4)

@@ -88,0 +89,0 @@ Requires-Dist: transformers ; extra == "transformers"

@@ -12,12 +12,15 @@ [build-system]

[tool.poetry.dependencies]
python = "^3.7"
python = "^3.8"
# Python lack of functionalities from future versions
importlib-metadata = { version = "*", python = "<3.8" }
typing-extensions = ">=3.10.0"
# Missing compatibility layer between Python 2 and Python 3
six = ">=1.12.0"
future = ">=0.17.1"
# Utility
packaging = "*"
click = ">=7.0"
setuptools = { version = "*", python = ">=3.12" }
setuptools = { version= "*", python = ">=3.12" }

@@ -33,5 +36,6 @@ # Networking

swagger-spec-validator = ">=2.7.4"
protobuf = "^4.0.0"
# Built-in integrations
boto3 = ">=1.28.0"
Pillow = ">=1.1.6"
GitPython = ">=2.0.8"

@@ -43,3 +47,3 @@ psutil = "*"

kedro-neptune = { version = "*", optional = true, python = ">=3.9,<3.12" }
neptune-detectron2 = { version = "*", optional = true, python = ">=3.7"}
neptune-detectron2 = { version = "*", optional = true, python = ">=3.8"}
neptune-fastai = { version = "*", optional = true }

@@ -95,3 +99,3 @@ neptune-lightgbm = { version = "*", optional = true }

readme = "README.md"
version = "2.0.0-alpha.8"
version = "1.12.0"
classifiers = [

@@ -133,3 +137,3 @@ "Development Status :: 5 - Production/Stable",

target-version = ['py37', 'py38', 'py39', 'py310', 'py311', 'py312']
include = '\.pyi?$,\_pb2\.py$'
include = '\.pyi?$'
exclude = '''

@@ -180,3 +184,3 @@ /(

"neptune.api.exceptions_utils",
"neptune.objects.abstract",
"neptune.metadata_containers.abstract",
"neptune.types.value_copy",

@@ -217,31 +221,32 @@ "neptune.types.namespace",

"neptune.attributes.utils",
"neptune.internal.exceptions",
"neptune.internal.utils.git_info",
"neptune.internal.hardware.cgroup.cgroup_filesystem_reader",
"neptune.internal.hardware.cgroup.cgroup_monitor",
"neptune.internal.hardware.gauges.cpu",
"neptune.internal.hardware.gauges.gauge",
"neptune.internal.hardware.gauges.gauge_factory",
"neptune.internal.hardware.gauges.gpu",
"neptune.internal.hardware.gauges.memory",
"neptune.internal.hardware.gpu.gpu_monitor",
"neptune.internal.hardware.metrics.metric",
"neptune.internal.hardware.metrics.metrics_container",
"neptune.internal.hardware.metrics.metrics_factory",
"neptune.internal.hardware.metrics.reports.metric_reporter",
"neptune.internal.hardware.metrics.reports.metric_reporter_factory",
"neptune.internal.hardware.metrics.service.metric_service",
"neptune.internal.hardware.metrics.service.metric_service_factory",
"neptune.internal.hardware.resources.gpu_card_indices_provider",
"neptune.internal.hardware.resources.system_resource_info",
"neptune.internal.hardware.resources.system_resource_info_factory",
"neptune.internal.hardware.system.system_monitor",
"neptune.internal.oauth",
"neptune.internal.patches.bravado",
"neptune.internal.storage.datastream",
"neptune.internal.storage.storage_utils",
"neptune.internal.utils.utils",
"neptune.internal.warnings",
"neptune.internal.websockets.reconnecting_websocket",
"neptune.internal.websockets.websocket_client_adapter",
"neptune.common.backends.utils",
"neptune.common.exceptions",
"neptune.common.git_info",
"neptune.common.hardware.cgroup.cgroup_filesystem_reader",
"neptune.common.hardware.cgroup.cgroup_monitor",
"neptune.common.hardware.gauges.cpu",
"neptune.common.hardware.gauges.gauge",
"neptune.common.hardware.gauges.gauge_factory",
"neptune.common.hardware.gauges.gpu",
"neptune.common.hardware.gauges.memory",
"neptune.common.hardware.gpu.gpu_monitor",
"neptune.common.hardware.metrics.metric",
"neptune.common.hardware.metrics.metrics_container",
"neptune.common.hardware.metrics.metrics_factory",
"neptune.common.hardware.metrics.reports.metric_reporter",
"neptune.common.hardware.metrics.reports.metric_reporter_factory",
"neptune.common.hardware.metrics.service.metric_service",
"neptune.common.hardware.metrics.service.metric_service_factory",
"neptune.common.hardware.resources.gpu_card_indices_provider",
"neptune.common.hardware.resources.system_resource_info",
"neptune.common.hardware.resources.system_resource_info_factory",
"neptune.common.hardware.system.system_monitor",
"neptune.common.oauth",
"neptune.common.patches.bravado",
"neptune.common.storage.datastream",
"neptune.common.storage.storage_utils",
"neptune.common.utils",
"neptune.common.warnings",
"neptune.common.websockets.reconnecting_websocket",
"neptune.common.websockets.websocket_client_adapter",
"neptune.exceptions",

@@ -276,3 +281,2 @@ "neptune.handler",

"neptune.internal.credentials",
"neptune.internal.hardware.gpu.gpu_monitor",
"neptune.internal.hardware.hardware_metric_reporting_job",

@@ -304,5 +308,44 @@ "neptune.internal.id_formats",

"neptune.internal.utils.traceback_job",
"neptune.internal.utils.uncaught_exception_handler",
"neptune.internal.websockets.websocket_signals_background_job",
"neptune.internal.websockets.websockets_factory",
"neptune.legacy",
"neptune.legacy.api_exceptions",
"neptune.legacy.backend",
"neptune.legacy.checkpoint",
"neptune.legacy.exceptions",
"neptune.legacy.experiments",
"neptune.legacy.internal.abort",
"neptune.legacy.internal.api_clients.backend_factory",
"neptune.legacy.internal.api_clients.client_config",
"neptune.legacy.internal.api_clients.credentials",
"neptune.legacy.internal.api_clients.hosted_api_clients.hosted_alpha_leaderboard_api_client",
"neptune.legacy.internal.api_clients.hosted_api_clients.hosted_backend_api_client",
"neptune.legacy.internal.api_clients.hosted_api_clients.mixins",
"neptune.legacy.internal.api_clients.hosted_api_clients.utils",
"neptune.legacy.internal.api_clients.offline_backend",
"neptune.legacy.internal.channels.channels",
"neptune.legacy.internal.channels.channels_values_sender",
"neptune.legacy.internal.execution.execution_context",
"neptune.legacy.internal.notebooks.comm",
"neptune.legacy.internal.notebooks.notebooks",
"neptune.legacy.internal.streams.channel_writer",
"neptune.legacy.internal.streams.stdstream_uploader",
"neptune.legacy.internal.threads.aborting_thread",
"neptune.legacy.internal.threads.hardware_metric_reporting_thread",
"neptune.legacy.internal.threads.neptune_thread",
"neptune.legacy.internal.threads.ping_thread",
"neptune.legacy.internal.utils.alpha_integration",
"neptune.legacy.internal.utils.deprecation",
"neptune.legacy.internal.utils.http",
"neptune.legacy.internal.utils.http_utils",
"neptune.legacy.internal.utils.image",
"neptune.legacy.internal.utils.source_code",
"neptune.legacy.internal.websockets.message",
"neptune.legacy.internal.websockets.reconnecting_websocket_factory",
"neptune.legacy.internal.websockets.websocket_message_processor",
"neptune.legacy.model",
"neptune.legacy.notebook",
"neptune.legacy.projects",
"neptune.legacy.sessions",
"neptune.logging.logger",
"neptune.management.exceptions",

@@ -312,7 +355,8 @@ "neptune.management.internal.api",

"neptune.management.internal.utils",
"neptune.objects.neptune_object",
"neptune.objects.model",
"neptune.objects.model_version",
"neptune.objects.project",
"neptune.objects.run",
"neptune.metadata_containers.metadata_container",
"neptune.metadata_containers.model",
"neptune.metadata_containers.model_version",
"neptune.metadata_containers.project",
"neptune.metadata_containers.run",
"neptune.new._compatibility",
"neptune.types.type_casting",

@@ -323,5 +367,6 @@ "neptune.vendor.pynvml",

"neptune.internal.container_type",
"neptune.internal.patches",
"neptune.new",
"neptune.common.patches",
"neptune.internal.utils",
]
ignore_errors = "True"

@@ -105,6 +105,6 @@ #

from neptune.common.patches import apply_patches
from neptune.constants import ANONYMOUS_API_TOKEN
from neptune.internal.extensions import load_extensions
from neptune.internal.patches import apply_patches
from neptune.objects import (
from neptune.metadata_containers import (
Model,

@@ -111,0 +111,0 @@ ModelVersion,

@@ -16,3 +16,3 @@ #

#
__all__ = ["get_single_page", "iter_over_pages", "find_attribute"]
__all__ = ["get_single_page", "iter_over_pages"]

@@ -25,5 +25,8 @@ from typing import (

Iterable,
List,
Optional,
)
from bravado.client import construct_request # type: ignore
from bravado.config import RequestConfig # type: ignore
from bravado.exception import HTTPBadRequest # type: ignore

@@ -35,15 +38,9 @@ from typing_extensions import (

from neptune.api.field_visitor import FieldToValueVisitor
from neptune.api.models import (
Field,
FieldType,
LeaderboardEntriesSearchResult,
from neptune.exceptions import NeptuneInvalidQueryException
from neptune.internal.backends.api_model import (
AttributeType,
AttributeWithProperties,
LeaderboardEntry,
)
from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import ProtoLeaderboardEntriesSearchResultDTO
from neptune.exceptions import NeptuneInvalidQueryException
from neptune.internal.backends.hosted_client import (
DEFAULT_PROTO_REQUEST_KWARGS,
DEFAULT_REQUEST_KWARGS,
)
from neptune.internal.backends.hosted_client import DEFAULT_REQUEST_KWARGS
from neptune.internal.backends.nql import (

@@ -67,3 +64,3 @@ NQLAggregator,

SUPPORTED_ATTRIBUTE_TYPES = {item.value for item in FieldType}
SUPPORTED_ATTRIBUTE_TYPES = {item.value for item in AttributeType}

@@ -106,6 +103,5 @@ SORT_BY_COLUMN_TYPE: TypeAlias = Literal["string", "datetime", "integer", "boolean", "float"]

searching_after: Optional[str],
use_proto: Optional[bool] = None,
) -> LeaderboardEntriesSearchResult:
) -> Any:
normalized_query = query or NQLEmptyQuery()
sort_by_column_type = sort_by_column_type if sort_by_column_type else FieldType.STRING.value
sort_by_column_type = sort_by_column_type if sort_by_column_type else AttributeType.STRING.value
if sort_by and searching_after:

@@ -115,3 +111,3 @@ sort_by_as_nql = NQLQueryAttribute(

type=NQLAttributeType(sort_by_column_type),
operator=NQLAttributeOperator.GREATER_THAN if ascending else NQLAttributeOperator.LESS_THAN,
operator=NQLAttributeOperator.GREATER_THAN,
value=searching_after,

@@ -132,3 +128,3 @@ )

"name": sort_by,
"type": sort_by_column_type if sort_by_column_type else FieldType.STRING.value,
"type": sort_by_column_type if sort_by_column_type else AttributeType.STRING.value,
},

@@ -152,12 +148,14 @@ }

request_options = DEFAULT_REQUEST_KWARGS.get("_request_options", {})
request_config = RequestConfig(request_options, True)
request_params = construct_request(client.api.searchLeaderboardEntries, request_options, **params)
http_client = client.swagger_spec.http_client
try:
if use_proto:
result = (
client.api.searchLeaderboardEntriesProto(**params, **DEFAULT_PROTO_REQUEST_KWARGS).response().result
)
proto_data = ProtoLeaderboardEntriesSearchResultDTO.FromString(result)
return LeaderboardEntriesSearchResult.from_proto(proto_data)
else:
model_data = client.api.searchLeaderboardEntries(**params, **DEFAULT_REQUEST_KWARGS).response().result
return LeaderboardEntriesSearchResult.from_model(model_data)
return (
http_client.request(request_params, operation=None, request_config=request_config)
.response()
.incoming_response.json()
)
except HTTPBadRequest as e:

@@ -170,6 +168,21 @@ title = e.response.json().get("title")

def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[Field]:
return next((attr for attr in entry.fields if attr.path == path), None)
def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry:
return LeaderboardEntry(
id=entry["experimentId"],
attributes=[
AttributeWithProperties(
path=attr["name"],
type=AttributeType(attr["type"]),
properties=attr.__getitem__(f"{attr['type']}Properties"),
)
for attr in entry["attributes"]
if attr["type"] in SUPPORTED_ATTRIBUTE_TYPES
],
)
def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[AttributeWithProperties]:
return next((attr for attr in entry.attributes if attr.path == path), None)
def iter_over_pages(

@@ -189,3 +202,3 @@ *,

data = get_single_page(
total = get_single_page(
limit=0,

@@ -198,4 +211,3 @@ offset=0,

**kwargs,
)
total = data.matching_item_count
).get("matchingItemCount", 0)

@@ -210,4 +222,2 @@ limit = limit if limit is not None else NoLimit()

field_to_value_visitor = FieldToValueVisitor()
with construct_progress_bar(progress_bar, "Fetching table...") as bar:

@@ -222,7 +232,9 @@ # beginning of the first page

if last_page:
searching_after_field = find_attribute(entry=last_page[-1], path=sort_by)
if not searching_after_field:
page_attribute = find_attribute(entry=last_page[-1], path=sort_by)
if not page_attribute:
raise ValueError(f"Cannot find attribute {sort_by} in last page")
searching_after = field_to_value_visitor.visit(searching_after_field)
searching_after = page_attribute.properties["value"]
for offset in range(0, max_offset, step_size):

@@ -232,3 +244,2 @@ local_limit = min(step_size, max_offset - offset)

local_limit = limit - extracted_records
result = get_single_page(

@@ -246,7 +257,7 @@ limit=local_limit,

if offset == 0 and last_page is not None:
total += result.matching_item_count
total += result.get("matchingItemCount", 0)
total = min(total, limit)
page = result.entries
page = _entries_from_page(result)
extracted_records += len(page)

@@ -264,1 +275,5 @@ bar.update(by=len(page), total=total)

last_page = page
def _entries_from_page(single_page: Dict[str, Any]) -> List[LeaderboardEntry]:
return list(map(to_leaderboard_entry, single_page.get("entries", [])))

@@ -21,3 +21,2 @@ #

from neptune.attributes.atoms.atom import Atom
from neptune.exceptions import NeptuneUnsupportedFunctionalityException
from neptune.internal.operation import UploadFile

@@ -31,3 +30,2 @@ from neptune.internal.utils import verify_type

def assign(self, value: FileVal, *, wait: bool = False) -> None:
raise NeptuneUnsupportedFunctionalityException
verify_type("value", value, FileVal)

@@ -45,3 +43,2 @@

def upload(self, value, *, wait: bool = False) -> None:
raise NeptuneUnsupportedFunctionalityException
self.assign(FileVal.create_from(value), wait=wait)

@@ -54,3 +51,2 @@

) -> None:
raise NeptuneUnsupportedFunctionalityException
verify_type("destination", destination, (str, type(None)))

@@ -60,4 +56,3 @@ self._backend.download_file(self._container_id, self._container_type, self._path, destination, progress_bar)

def fetch_extension(self) -> str:
raise NeptuneUnsupportedFunctionalityException
val = self._backend.get_file_attribute(self._container_id, self._container_type, self._path)
return val.ext

@@ -21,9 +21,9 @@ #

from neptune.attributes.atoms.copiable_atom import CopiableAtom
from neptune.common.warnings import (
NeptuneUnsupportedValue,
warn_once,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.operation import AssignFloat
from neptune.internal.types.utils import is_unsupported_float
from neptune.internal.warnings import (
NeptuneUnsupportedValue,
warn_once,
)
from neptune.types.atoms.float import Float as FloatVal

@@ -30,0 +30,0 @@

@@ -29,3 +29,3 @@ #

from neptune.internal.backends.neptune_backend import NeptuneBackend
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -39,3 +39,3 @@ logger = get_logger()

def __init__(self, container: "NeptuneObject", path: typing.List[str]):
def __init__(self, container: "MetadataContainer", path: typing.List[str]):
super().__init__(container, path)

@@ -42,0 +42,0 @@ self._value_truncation_occurred = False

@@ -30,3 +30,3 @@ #

from neptune.internal.container_type import ContainerType
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -37,3 +37,3 @@

def __init__(self, container: "NeptuneObject", path: List[str]):
def __init__(self, container: "MetadataContainer", path: List[str]):
super().__init__()

@@ -40,0 +40,0 @@ self._container = container

@@ -26,5 +26,4 @@ #

from neptune.api.models import FileEntry
from neptune.api.dtos import FileEntry
from neptune.attributes.attribute import Attribute
from neptune.exceptions import NeptuneUnsupportedFunctionalityException
from neptune.internal.operation import (

@@ -44,3 +43,2 @@ DeleteFiles,

def assign(self, value: Union[FileSetVal, str, Iterable[str]], *, wait: bool = False) -> None:
raise NeptuneUnsupportedFunctionalityException
verify_type("value", value, (FileSetVal, str, Iterable))

@@ -56,3 +54,2 @@ if isinstance(value, FileSetVal):

def upload_files(self, globs: Union[str, Iterable[str]], *, wait: bool = False) -> None:
raise NeptuneUnsupportedFunctionalityException
if isinstance(globs, str):

@@ -65,3 +62,2 @@ globs = [globs]

def delete_files(self, paths: Union[str, Iterable[str]], *, wait: bool = False) -> None:
raise NeptuneUnsupportedFunctionalityException
if isinstance(paths, str):

@@ -75,3 +71,2 @@ paths = [paths]

def _enqueue_upload_operation(self, globs: Iterable[str], *, reset: bool, wait: bool):
raise NeptuneUnsupportedFunctionalityException
with self._container.lock():

@@ -86,3 +81,2 @@ abs_file_globs = list(os.path.abspath(file_glob) for file_glob in globs)

) -> None:
raise NeptuneUnsupportedFunctionalityException
verify_type("destination", destination, (str, type(None)))

@@ -92,4 +86,3 @@ self._backend.download_file_set(self._container_id, self._container_type, self._path, destination, progress_bar)

def list_fileset_files(self, path: Optional[str] = None) -> List[FileEntry]:
raise NeptuneUnsupportedFunctionalityException
path = path or ""
return self._backend.list_fileset_files(self._path, self._container_id, path)

@@ -47,3 +47,3 @@ #

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -55,3 +55,3 @@ logger = get_logger()

class Namespace(Attribute, MutableMapping):
def __init__(self, container: "NeptuneObject", path: List[str]):
def __init__(self, container: "MetadataContainer", path: List[str]):
Attribute.__init__(self, container, path)

@@ -135,3 +135,3 @@ self._attributes = {}

class NamespaceBuilder:
def __init__(self, container: "NeptuneObject"):
def __init__(self, container: "MetadataContainer"):
self._run = container

@@ -138,0 +138,0 @@

@@ -20,3 +20,2 @@ #

from datetime import datetime
from functools import partial
from typing import (

@@ -30,50 +29,48 @@ Dict,

from neptune.api.fetching_series_values import fetch_series_values
from neptune.api.models import (
FloatPointValue,
StringPointValue,
from neptune.internal.backends.api_model import (
FloatSeriesValues,
StringSeriesValues,
)
from neptune.internal.backends.utils import construct_progress_bar
from neptune.internal.utils.paths import path_to_str
from neptune.typing import ProgressBarType
Row = TypeVar("Row", StringPointValue, FloatPointValue)
Row = TypeVar("Row", StringSeriesValues, FloatSeriesValues)
def make_row(entry: Row, include_timestamp: bool = True) -> Dict[str, Union[str, float, datetime]]:
row: Dict[str, Union[str, float, datetime]] = {
"step": entry.step,
"value": entry.value,
}
class FetchableSeries(Generic[Row]):
@abc.abstractmethod
def _fetch_values_from_backend(self, offset, limit) -> Row:
pass
if include_timestamp:
row["timestamp"] = entry.timestamp
def fetch_values(self, *, include_timestamp: bool = True, progress_bar: Optional[ProgressBarType] = None):
import pandas as pd
return row
limit = 1000
val = self._fetch_values_from_backend(0, limit)
data = val.values
offset = limit
def make_row(entry: Row) -> Dict[str, Union[str, float, datetime]]:
row: Dict[str, Union[str, float, datetime]] = dict()
row["step"] = entry.step
row["value"] = entry.value
if include_timestamp:
row["timestamp"] = datetime.fromtimestamp(entry.timestampMillis / 1000)
return row
class FetchableSeries(Generic[Row]):
@abc.abstractmethod
def _fetch_values_from_backend(
self, limit: int, from_step: Optional[float] = None, include_inherited: bool = True
) -> Row: ...
progress_bar = False if len(data) < limit else progress_bar
def fetch_values(
self,
*,
include_timestamp: bool = True,
progress_bar: Optional[ProgressBarType] = None,
include_inherited: bool = True,
):
import pandas as pd
path = path_to_str(self._path) if hasattr(self, "_path") else ""
data = fetch_series_values(
getter=partial(self._fetch_values_from_backend, include_inherited=include_inherited),
path=path,
progress_bar=progress_bar,
)
with construct_progress_bar(progress_bar, f"Fetching {path} values") as bar:
bar.update(by=len(data), total=val.totalItemCount) # first fetch before the loop
while offset < val.totalItemCount:
batch = self._fetch_values_from_backend(offset, limit)
data.extend(batch.values)
offset += limit
bar.update(by=len(batch.values), total=val.totalItemCount)
rows = dict((n, make_row(entry=entry, include_timestamp=include_timestamp)) for (n, entry) in enumerate(data))
rows = dict((n, make_row(entry)) for (n, entry) in enumerate(data))
df = pd.DataFrame.from_dict(data=rows, orient="index")
return df

@@ -27,6 +27,10 @@ #

from PIL import (
Image,
UnidentifiedImageError,
)
from neptune.attributes.series.series import Series
from neptune.exceptions import (
FileNotFound,
NeptuneUnsupportedFunctionalityException,
OperationNotSupported,

@@ -83,7 +87,2 @@ )

from PIL import (
Image,
UnidentifiedImageError,
)
try:

@@ -106,3 +105,3 @@ with Image.open(io.BytesIO(file_content)):

self._container_id, self._container_type, self._path, 0, 1
).total
).totalItemCount
for i in range(0, item_count):

@@ -114,7 +113,6 @@ self._backend.download_file_series_by_index(

def download_last(self, destination: Optional[str]):
raise NeptuneUnsupportedFunctionalityException
target_dir = self._get_destination(destination)
item_count = self._backend.get_image_series_values(
self._container_id, self._container_type, self._path, 0, 1
).total
).totalItemCount
if item_count > 0:

@@ -121,0 +119,0 @@ self._backend.download_file_series_by_index(

@@ -24,6 +24,5 @@ #

from neptune.api.models import FloatSeriesValues
from neptune.attributes.series.fetchable_series import FetchableSeries
from neptune.attributes.series.series import Series
from neptune.exceptions import NeptuneUnsupportedFunctionalityException
from neptune.internal.backends.api_model import FloatSeriesValues
from neptune.internal.operation import (

@@ -72,16 +71,8 @@ ClearFloatLog,

def fetch_last(self) -> float:
raise NeptuneUnsupportedFunctionalityException
val = self._backend.get_float_series_attribute(self._container_id, self._container_type, self._path)
return val.last
def _fetch_values_from_backend(
self, limit: int, from_step: Optional[float] = None, include_inherited: bool = True
) -> FloatSeriesValues:
def _fetch_values_from_backend(self, offset, limit) -> FloatSeriesValues:
return self._backend.get_float_series_values(
container_id=self._container_id,
container_type=self._container_type,
path=self._path,
from_step=from_step,
limit=limit,
include_inherited=include_inherited,
self._container_id, self._container_type, self._path, offset, limit
)

@@ -22,10 +22,8 @@ #

List,
Optional,
Union,
)
from neptune.api.models import StringSeriesValues
from neptune.attributes.series.fetchable_series import FetchableSeries
from neptune.attributes.series.series import Series
from neptune.exceptions import NeptuneUnsupportedFunctionalityException
from neptune.internal.backends.api_model import StringSeriesValues
from neptune.internal.operation import (

@@ -43,3 +41,3 @@ ClearStringLog,

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -56,3 +54,3 @@ Val = StringSeriesVal

):
def __init__(self, container: "NeptuneObject", path: List[str]):
def __init__(self, container: "MetadataContainer", path: List[str]):
super().__init__(container, path)

@@ -99,15 +97,8 @@ self._value_truncation_occurred = False

def fetch_last(self) -> str:
raise NeptuneUnsupportedFunctionalityException
val = self._backend.get_string_series_attribute(self._container_id, self._container_type, self._path)
return val.last
def _fetch_values_from_backend(
self, limit: int, from_step: Optional[float] = None, include_inherited: bool = True
) -> StringSeriesValues:
def _fetch_values_from_backend(self, offset, limit) -> StringSeriesValues:
return self._backend.get_string_series_values(
container_id=self._container_id,
container_type=self._container_type,
path=self._path,
from_step=from_step,
limit=limit,
self._container_id, self._container_type, self._path, offset, limit
)

@@ -25,3 +25,2 @@ #

from neptune.attributes.sets.set import Set
from neptune.exceptions import NeptuneUnsupportedFunctionalityException
from neptune.internal.operation import (

@@ -56,3 +55,2 @@ AddStrings,

def remove(self, values: Union[str, Iterable[str]], *, wait: bool = False):
raise NeptuneUnsupportedFunctionalityException
values = self._to_proper_value_type(values)

@@ -63,3 +61,2 @@ with self._container.lock():

def clear(self, *, wait: bool = False):
raise NeptuneUnsupportedFunctionalityException
with self._container.lock():

@@ -66,0 +63,0 @@ self._enqueue_operation(ClearStringSet(self._path), wait=wait)

@@ -23,3 +23,2 @@ #

from neptune.api.models import FieldType
from neptune.attributes import (

@@ -42,24 +41,25 @@ Artifact,

)
from neptune.internal.exceptions import InternalClientError
from neptune.common.exceptions import InternalClientError
from neptune.internal.backends.api_model import AttributeType
if TYPE_CHECKING:
from neptune.attributes.attribute import Attribute
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer
_attribute_type_to_attr_class_map = {
FieldType.FLOAT: Float,
FieldType.INT: Integer,
FieldType.BOOL: Boolean,
FieldType.STRING: String,
FieldType.DATETIME: Datetime,
FieldType.FILE: File,
FieldType.FILE_SET: FileSet,
FieldType.FLOAT_SERIES: FloatSeries,
FieldType.STRING_SERIES: StringSeries,
FieldType.IMAGE_SERIES: FileSeries,
FieldType.STRING_SET: StringSet,
FieldType.GIT_REF: GitRef,
FieldType.OBJECT_STATE: RunState,
FieldType.NOTEBOOK_REF: NotebookRef,
FieldType.ARTIFACT: Artifact,
AttributeType.FLOAT: Float,
AttributeType.INT: Integer,
AttributeType.BOOL: Boolean,
AttributeType.STRING: String,
AttributeType.DATETIME: Datetime,
AttributeType.FILE: File,
AttributeType.FILE_SET: FileSet,
AttributeType.FLOAT_SERIES: FloatSeries,
AttributeType.STRING_SERIES: StringSeries,
AttributeType.IMAGE_SERIES: FileSeries,
AttributeType.STRING_SET: StringSet,
AttributeType.GIT_REF: GitRef,
AttributeType.RUN_STATE: RunState,
AttributeType.NOTEBOOK_REF: NotebookRef,
AttributeType.ARTIFACT: Artifact,
}

@@ -69,4 +69,4 @@

def create_attribute_from_type(
attribute_type: FieldType,
container: "NeptuneObject",
attribute_type: AttributeType,
container: "MetadataContainer",
path: List[str],

@@ -73,0 +73,0 @@ ) -> "Attribute":

@@ -45,3 +45,3 @@ #

from neptune.internal.id_formats import UniqueId
from neptune.objects.structure_version import StructureVersion
from neptune.metadata_containers.structure_version import StructureVersion

@@ -48,0 +48,0 @@ if TYPE_CHECKING:

@@ -31,3 +31,2 @@ #

from neptune.cli.sync import SyncRunner
from neptune.exceptions import NeptuneUnsupportedFunctionalityException
from neptune.internal.backends.hosted_neptune_backend import HostedNeptuneBackend

@@ -57,3 +56,2 @@ from neptune.internal.credentials import Credentials

raise NeptuneUnsupportedFunctionalityException
backend = HostedNeptuneBackend(Credentials.from_token())

@@ -134,3 +132,2 @@

raise NeptuneUnsupportedFunctionalityException
backend = HostedNeptuneBackend(Credentials.from_token())

@@ -168,6 +165,4 @@

"""
raise NeptuneUnsupportedFunctionalityException
backend = HostedNeptuneBackend(Credentials.from_token())
ClearRunner.clear(backend=backend, path=path)

@@ -42,8 +42,11 @@ #

from neptune.cli.utils import get_qualified_name
from neptune.common.exceptions import NeptuneConnectionLostException
from neptune.constants import ASYNC_DIRECTORY
from neptune.core.components.operation_storage import OperationStorage
from neptune.core.components.queue.disk_queue import DiskQueue
from neptune.envs import NEPTUNE_SYNC_BATCH_TIMEOUT_ENV
from neptune.envs import (
NEPTUNE_ASYNC_BATCH_SIZE,
NEPTUNE_SYNC_BATCH_TIMEOUT_ENV,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import NeptuneConnectionLostException
from neptune.internal.id_formats import UniqueId

@@ -53,3 +56,3 @@ from neptune.internal.operation import Operation

from neptune.internal.utils.logger import get_logger
from neptune.objects.structure_version import StructureVersion
from neptune.metadata_containers.structure_version import StructureVersion

@@ -99,2 +102,4 @@ if TYPE_CHECKING:

batch_size = int(os.getenv(NEPTUNE_ASYNC_BATCH_SIZE, "1000"))
with DiskQueue(

@@ -107,3 +112,3 @@ data_path=self.path,

while True:
raw_batch = disk_queue.get_batch(1000)
raw_batch = disk_queue.get_batch(batch_size)
if not raw_batch:

@@ -110,0 +115,0 @@ break

@@ -39,2 +39,3 @@ #

from neptune.common.exceptions import NeptuneException
from neptune.core.components.queue.disk_queue import DiskQueue

@@ -52,3 +53,2 @@ from neptune.envs import PROJECT_ENV_NAME

from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import NeptuneException
from neptune.internal.id_formats import (

@@ -60,3 +60,3 @@ QualifiedName,

from neptune.internal.utils.logger import get_logger
from neptune.objects.structure_version import StructureVersion
from neptune.metadata_containers.structure_version import StructureVersion

@@ -63,0 +63,0 @@ logger = get_logger(with_prefix=False)

@@ -35,8 +35,6 @@ #

"NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK",
"NEPTUNE_USE_PROTOCOL_BUFFERS",
"NEPTUNE_ASYNC_BATCH_SIZE",
"S3_ENDPOINT_URL",
]
from neptune.internal.envs import (
from neptune.common.envs import (
API_TOKEN_ENV_NAME,

@@ -80,4 +78,2 @@ NEPTUNE_RETRIES_TIMEOUT_ENV,

NEPTUNE_USE_PROTOCOL_BUFFERS = "NEPTUNE_USE_PROTOCOL_BUFFERS"
S3_ENDPOINT_URL = "S3_ENDPOINT_URL"

@@ -69,2 +69,3 @@ #

"OperationNotSupported",
"NeptuneLegacyProjectException",
"NeptuneMissingRequirementException",

@@ -77,2 +78,4 @@ "NeptuneLimitExceedException",

"PlotlyIncompatibilityException",
"NeptunePossibleLegacyUsageException",
"NeptuneLegacyIncompatibilityException",
"NeptuneUnhandledArtifactSchemeException",

@@ -94,3 +97,2 @@ "NeptuneUnhandledArtifactTypeException",

"NeptuneInvalidQueryException",
"NeptuneUnsupportedFunctionalityException",
]

@@ -107,13 +109,6 @@

from neptune.envs import (
CUSTOM_RUN_ID_ENV_NAME,
PROJECT_ENV_NAME,
)
from neptune.internal.backends.api_model import (
Project,
Workspace,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.envs import API_TOKEN_ENV_NAME
from neptune.internal.exceptions import (
from neptune.common.envs import API_TOKEN_ENV_NAME
# Backward compatibility import
from neptune.common.exceptions import (
STYLES,

@@ -131,2 +126,11 @@ ClientHttpError,

)
from neptune.envs import (
CUSTOM_RUN_ID_ENV_NAME,
PROJECT_ENV_NAME,
)
from neptune.internal.backends.api_model import (
Project,
Workspace,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.id_formats import QualifiedName

@@ -647,3 +651,3 @@ from neptune.internal.utils import replace_patch_version

Or if you are using Conda, run the following instead:
{bash}conda update -c conda-forge neptune{end}
{bash}conda update -c conda-forge neptune-client{end}

@@ -777,2 +781,21 @@ {correct}Need help?{end}-> https://docs.neptune.ai/getting_help

class NeptuneLegacyProjectException(NeptuneException):
def __init__(self, project: QualifiedName):
message = """
{h1}
----NeptuneLegacyProjectException---------------------------------------------------------
{end}
Your project "{project}" has not been migrated to the new structure yet.
Unfortunately the neptune.new Python API is incompatible with projects using the old structure,
so please use legacy neptune Python API.
Don't worry - we are working hard on migrating all the projects and you will be able to use the neptune.new API soon.
You can find documentation for the legacy neptune Python API here:
- https://docs-legacy.neptune.ai/index.html
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(project=project, **STYLES))
class NeptuneMissingRequirementException(NeptuneException):

@@ -942,2 +965,49 @@ def __init__(self, package_name: str, framework_name: Optional[str]):

class NeptunePossibleLegacyUsageException(NeptuneWrongInitParametersException):
def __init__(self):
message = """
{h1}
----NeptunePossibleLegacyUsageException----------------------------------------------------------------
{end}
It seems you are trying to use the legacy API, but you imported the new one.
Simply update your import statement to:
{python}import neptune{end}
You may want to check the legacy API docs:
- https://docs-legacy.neptune.ai
If you want to update your code with the new API, we prepared a handy migration guide:
- https://docs.neptune.ai/about/legacy/#migrating-to-neptunenew
You can read more about neptune.new in the release blog post:
- https://neptune.ai/blog/neptune-new
You may also want to check the following docs page:
- https://docs-legacy.neptune.ai/getting-started/integrate-neptune-into-your-codebase.html
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(**STYLES))
class NeptuneLegacyIncompatibilityException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneLegacyIncompatibilityException----------------------------------------
{end}
It seems you are passing the legacy Experiment object, when a Run object is expected.
What can I do?
- Updating your code to the new Python API requires few changes, but to help you with this process we prepared a handy migration guide:
https://docs.neptune.ai/about/legacy/#migrating-to-neptunenew
- You can read more about neptune.new in the release blog post:
https://neptune.ai/blog/neptune-new
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
""" # noqa: E501
super().__init__(message.format(**STYLES))
class NeptuneUnhandledArtifactSchemeException(NeptuneException):

@@ -1175,14 +1245,1 @@ def __init__(self, path: str):

super().__init__(message)
class NeptuneUnsupportedFunctionalityException(NeptuneException):
def __init__(self):
message = """
{h1}
----NeptuneUnsupportedFunctionalityException----------------------------
{end}
You're using neptune 2.0, which is in Beta.
Some functionality that you tried to use is not supported in the installed version.
We will gradually add missing features to the Beta. Check that you're on the latest version.
"""
super().__init__(message)

@@ -22,3 +22,2 @@ #

Any,
Callable,
Collection,

@@ -29,2 +28,3 @@ Dict,

List,
NewType,
Optional,

@@ -34,3 +34,3 @@ Union,

from neptune.api.models import FileEntry
from neptune.api.dtos import FileEntry
from neptune.attributes import File

@@ -45,6 +45,6 @@ from neptune.attributes.atoms.artifact import Artifact

from neptune.attributes.sets.string_set import StringSet
from neptune.common.warnings import warn_about_unsupported_type
from neptune.exceptions import (
MissingFieldException,
NeptuneCannotChangeStageManually,
NeptuneUnsupportedFunctionalityException,
NeptuneUserApiInputException,

@@ -69,4 +69,3 @@ )

from neptune.internal.value_to_attribute_visitor import ValueToAttributeVisitor
from neptune.internal.warnings import warn_about_unsupported_type
from neptune.objects.abstract import SupportsNamespaces
from neptune.metadata_containers.abstract import SupportsNamespaces
from neptune.types.atoms.file import File as FileVal

@@ -79,12 +78,7 @@ from neptune.types.type_casting import cast_value_for_extend

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer
NeptuneObject = NewType("NeptuneObject", MetadataContainer)
def feature_temporarily_unavailable(_: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*_, **__):
raise NeptuneUnsupportedFunctionalityException()
return wrapper
def validate_path_not_protected(target_path: str, handler: "Handler"):

@@ -114,3 +108,3 @@ path_protection_exception = handler._PROTECTED_PATHS.get(target_path)

def __init__(self, container: "NeptuneObject", path: str):
def __init__(self, container: "MetadataContainer", path: str):
super().__init__()

@@ -160,3 +154,3 @@ self._container = container

@property
def container(self) -> "NeptuneObject":
def container(self) -> "MetadataContainer":
"""Returns the container that the attribute is attached to."""

@@ -263,3 +257,2 @@ return self._container

"""
raise NeptuneUnsupportedFunctionalityException
value = FileVal.create_from(value)

@@ -276,3 +269,2 @@

def upload_files(self, value: Union[str, Iterable[str]], *, wait: bool = False) -> None:
raise NeptuneUnsupportedFunctionalityException
if is_collection(value):

@@ -348,3 +340,2 @@ verify_collection_type("value", value, str)

elif FileVal.is_convertable(first_value):
raise NeptuneUnsupportedFunctionalityException
attr = FileSeries(self._container, parse_path(self._path))

@@ -494,2 +485,36 @@ elif is_float_like(first_value):

@check_protected_paths
def pop(self, path: str = None, *, wait: bool = False) -> None:
"""Completely removes the namespace and all associated metadata stored under the path.
Args:
path: Path of the namespace to be removed.
wait: By default, logged metadata is sent to the server in the background.
With this option set to `True`, Neptune first sends all data, then executes the call.
Example:
>>> import neptune
>>> run = neptune.init_run(with_id="RUN-100")
>>> run["large_dataset"].pop()
See also the API reference: https://docs.neptune.ai/api/client_index/#pop
"""
with self._container.lock():
handler = self
if path:
verify_type("path", path, str)
handler = self[path]
path = join_paths(self._path, path)
# extra check: check_protected_paths decorator does not catch flow with non-null path
validate_path_not_protected(path, self)
else:
path = self._path
attribute = self._container.get_attribute(path)
if isinstance(attribute, Namespace):
for child_path in list(attribute):
handler.pop(child_path, wait=wait)
else:
self._container._pop_impl(parse_path(path), wait=wait)
@check_protected_paths
def remove(self, values: Union[str, Iterable[str]], *, wait: bool = False) -> None:

@@ -606,3 +631,2 @@ """Removes the provided tags from the set.

"""
raise NeptuneUnsupportedFunctionalityException
return self._pass_call_to_attr(function_name="delete_files", paths=paths, wait=wait)

@@ -641,3 +665,2 @@

"""
raise NeptuneUnsupportedFunctionalityException
return self._pass_call_to_attr(function_name="download", destination=destination, progress_bar=progress_bar)

@@ -661,3 +684,2 @@

@feature_temporarily_unavailable
def fetch_hash(self) -> str:

@@ -677,6 +699,4 @@ """Fetches the hash of an artifact.

"""
raise NeptuneUnsupportedFunctionalityException
return self._pass_call_to_attr(function_name="fetch_extension")
@feature_temporarily_unavailable
def fetch_files_list(self) -> List[ArtifactFileData]:

@@ -688,3 +708,2 @@ """Fetches the list of files in an artifact and their metadata.

"""
raise NeptuneUnsupportedFunctionalityException
return self._pass_call_to_attr(function_name="fetch_files_list")

@@ -727,3 +746,2 @@

"""
raise NeptuneUnsupportedFunctionalityException
return self._pass_call_to_attr(function_name="list_fileset_files", path=path)

@@ -734,3 +752,2 @@

@feature_temporarily_unavailable
@check_protected_paths

@@ -743,3 +760,2 @@ def track_files(self, path: str, *, destination: str = None, wait: bool = False) -> None:

"""
raise NeptuneUnsupportedFunctionalityException
with self._container.lock():

@@ -756,24 +772,3 @@ attr = self._container.get_attribute(self._path)

@feature_temporarily_unavailable
@check_protected_paths
def pop(self, path: str = None, *, wait: bool = False) -> None:
with self._container.lock():
handler = self
if path:
verify_type("path", path, str)
handler = self[path]
path = join_paths(self._path, path)
# extra check: check_protected_paths decorator does not catch flow with non-null path
validate_path_not_protected(path, self)
else:
path = self._path
attribute = self._container.get_attribute(path)
if isinstance(attribute, Namespace):
for child_path in list(attribute):
handler.pop(child_path, wait=wait)
else:
self._container._pop_impl(parse_path(path), wait=wait)
class ExtendUtils:

@@ -780,0 +775,0 @@ @staticmethod

@@ -18,4 +18,4 @@ #

require_installed("neptune-aws", suggestion="aws")
require_installed("neptune_aws", suggestion="aws")
from neptune_aws.impl import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-detectron2", suggestion="detectron2")
require_installed("neptune_detectron2", suggestion="detectron2")
from neptune_detectron2.impl import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-fastai", suggestion="fastai")
require_installed("neptune_fastai", suggestion="fastai")
from neptune_fastai.impl import * # noqa: F401,F403,E402

@@ -16,6 +16,9 @@ #

#
# mypy: ignore-errors
from neptune.internal.utils.requirement_check import require_installed
require_installed("kedro-neptune", suggestion="kedro")
require_installed("kedro_neptune", suggestion="kedro")
from kedro_neptune.impl import * # noqa: F401,F403,E402
from kedro_neptune import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-lightgbm", suggestion="lightgbm")
require_installed("neptune_lightgbm", suggestion="lightgbm")
from neptune_lightgbm.impl import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-optuna", suggestion="optuna")
require_installed("neptune_optuna", suggestion="optuna")
from neptune_optuna.impl import * # noqa: F401,F403,E402

@@ -23,4 +23,4 @@ #

TYPE_CHECKING,
Any,
Dict,
Optional,
Tuple,

@@ -32,21 +32,9 @@ Union,

from neptune.api.models import (
ArtifactField,
BoolField,
DateTimeField,
FieldVisitor,
FileField,
FileSetField,
FloatField,
FloatSeriesField,
GitRefField,
ImageSeriesField,
IntField,
from neptune.internal.backends.api_model import (
AttributeType,
AttributeWithProperties,
LeaderboardEntry,
NotebookRefField,
ObjectStateField,
StringField,
StringSeriesField,
StringSetField,
)
from neptune.internal.utils.logger import get_logger
from neptune.internal.utils.run_state import RunState

@@ -56,78 +44,59 @@ if TYPE_CHECKING:

PANDAS_AVAILABLE_TYPES = Union[str, float, int, bool, datetime, None]
logger = get_logger()
class FieldToPandasValueVisitor(FieldVisitor[PANDAS_AVAILABLE_TYPES]):
def visit_float(self, field: FloatField) -> float:
return field.value
def visit_int(self, field: IntField) -> int:
return field.value
def visit_bool(self, field: BoolField) -> bool:
return field.value
def visit_string(self, field: StringField) -> str:
return field.value
def visit_datetime(self, field: DateTimeField) -> datetime:
return field.value
def visit_file(self, field: FileField) -> None:
def to_pandas(table: Table) -> pd.DataFrame:
def make_attribute_value(attribute: AttributeWithProperties) -> Any:
_type = attribute.type
_properties = attribute.properties
if _type == AttributeType.RUN_STATE:
return RunState.from_api(_properties.get("value")).value
if _type in (
AttributeType.FLOAT,
AttributeType.INT,
AttributeType.BOOL,
AttributeType.STRING,
AttributeType.DATETIME,
):
return _properties.get("value")
if _type == AttributeType.FLOAT_SERIES:
return _properties.get("last")
if _type == AttributeType.STRING_SERIES:
return _properties.get("last")
if _type == AttributeType.IMAGE_SERIES:
return None
if _type == AttributeType.FILE or _type == AttributeType.FILE_SET:
return None
if _type == AttributeType.STRING_SET:
return ",".join(_properties.get("values"))
if _type == AttributeType.GIT_REF:
return _properties.get("commit", {}).get("commitId")
if _type == AttributeType.NOTEBOOK_REF:
return _properties.get("notebookName")
if _type == AttributeType.ARTIFACT:
return _properties.get("hash")
logger.error(
"Attribute type %s not supported in this version, yielding None. Recommended client upgrade.",
_type,
)
return None
def visit_string_set(self, field: StringSetField) -> Optional[str]:
return ",".join(field.values)
def make_row(entry: LeaderboardEntry) -> Dict[str, Any]:
row: Dict[str, Union[str, float, datetime]] = dict()
for attr in entry.attributes:
value = make_attribute_value(attr)
if value is not None:
row[attr.path] = value
return row
def visit_float_series(self, field: FloatSeriesField) -> Optional[float]:
return field.last
def sort_key(attr: str) -> Tuple[int, str]:
domain = attr.split("/")[0]
if domain == "sys":
return 0, attr
if domain == "monitoring":
return 2, attr
return 1, attr
def visit_string_series(self, field: StringSeriesField) -> Optional[str]:
return field.last
rows = dict((n, make_row(entry)) for (n, entry) in enumerate(table._entries))
def visit_image_series(self, field: ImageSeriesField) -> None:
return None
def visit_file_set(self, field: FileSetField) -> None:
return None
def visit_git_ref(self, field: GitRefField) -> Optional[str]:
return field.commit.commit_id if field.commit is not None else None
def visit_object_state(self, field: ObjectStateField) -> str:
return field.value
def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]:
return field.notebook_name
def visit_artifact(self, field: ArtifactField) -> str:
return field.hash
def make_row(entry: LeaderboardEntry, to_value_visitor: FieldVisitor) -> Dict[str, PANDAS_AVAILABLE_TYPES]:
row: Dict[str, PANDAS_AVAILABLE_TYPES] = dict()
for field in entry.fields:
value = to_value_visitor.visit(field)
if value is not None:
row[field.path] = value
return row
def sort_key(field: str) -> Tuple[int, str]:
domain = field.split("/")[0]
if domain == "sys":
return 0, field
if domain == "monitoring":
return 2, field
return 1, field
def to_pandas(table: Table) -> pd.DataFrame:
to_value_visitor = FieldToPandasValueVisitor()
rows = dict((n, make_row(entry, to_value_visitor)) for (n, entry) in enumerate(table._entries))
df = pd.DataFrame.from_dict(data=rows, orient="index")

@@ -134,0 +103,0 @@ df = df.reindex(sorted(df.columns, key=sort_key), axis="columns")

@@ -24,3 +24,4 @@ #

from neptune.internal.utils import verify_type
from neptune.version import version as neptune_version
from neptune.logging import Logger
from neptune.version import version as neptune_client_version

@@ -61,10 +62,12 @@ INTEGRATION_VERSION_KEY = "source_code/integrations/neptune-python-logger"

verify_type("level", level, int)
verify_type("path", path, (str, type(None)))
if path is None:
path = f"{run.monitoring_namespace}/python_logger"
verify_type("path", path, str)
super().__init__(level=level)
self._path = path if path else f"{run.monitoring_namespace}/python_logger"
self._run = run
self._logger = Logger(run, path)
self._thread_local = threading.local()
self._run[INTEGRATION_VERSION_KEY] = str(neptune_version)
self._run[INTEGRATION_VERSION_KEY] = str(neptune_client_version)

@@ -79,4 +82,4 @@ def emit(self, record: logging.LogRecord) -> None:

message = self.format(record)
self._run[self._path].append(message)
self._logger.log(message)
finally:
self._thread_local.inside_write = False

@@ -20,5 +20,5 @@ #

require_installed("pytorch-lightning")
require_installed("pytorch_lightning")
from pytorch_lightning.loggers import NeptuneLogger # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-pytorch", suggestion="pytorch")
require_installed("neptune_pytorch", suggestion="pytorch")
from neptune_pytorch.impl import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-sacred", suggestion="sacred")
require_installed("neptune_sacred", suggestion="sacred")
from neptune_sacred.impl import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-sklearn", suggestion="sklearn")
require_installed("neptune_sklearn", suggestion="sklearn")
from neptune_sklearn.impl import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-tensorboard", suggestion="tensorboard")
require_installed("neptune_tensorboard", suggestion="tensorboard")
from neptune_tensorboard.impl import * # noqa: F401,F403,E402

@@ -18,4 +18,4 @@ #

require_installed("neptune-tensorflow-keras", suggestion="tensorflow-keras")
require_installed("neptune_tensorflow_keras", suggestion="tensorflow-keras")
from neptune_tensorflow_keras.impl import * # noqa: F401,F403,E402

@@ -16,3 +16,3 @@ #

#
__all__ = ["join_paths", "verify_type", "RunType"]
__all__ = ["expect_not_an_experiment", "join_paths", "verify_type", "RunType"]

@@ -22,2 +22,4 @@ from typing import Union

from neptune import Run
from neptune.common.experiments import LegacyExperiment as Experiment
from neptune.exceptions import NeptuneLegacyIncompatibilityException
from neptune.handler import Handler

@@ -27,2 +29,8 @@ from neptune.internal.utils import verify_type

def expect_not_an_experiment(run: Run):
if isinstance(run, Experiment):
raise NeptuneLegacyIncompatibilityException()
RunType = Union[Run, Handler]

@@ -18,4 +18,4 @@ #

require_installed("neptune-xgboost", suggestion="xgboost")
require_installed("neptune_xgboost", suggestion="xgboost")
from neptune_xgboost.impl import * # noqa: F401,F403,E402

@@ -23,2 +23,4 @@ #

from botocore.exceptions import NoCredentialsError
from neptune.exceptions import (

@@ -76,4 +78,2 @@ NeptuneRemoteStorageAccessException,

from botocore.exceptions import NoCredentialsError
try:

@@ -122,5 +122,2 @@ for remote_object in remote_storage.objects.filter(Prefix=prefix):

remote_storage = get_boto_s3_client()
from botocore.exceptions import NoCredentialsError
try:

@@ -127,0 +124,0 @@ bucket = remote_storage.Bucket(bucket_name)

@@ -23,10 +23,33 @@ #

"ClientConfig",
"AttributeType",
"Attribute",
"AttributeWithProperties",
"LeaderboardEntry",
"StringPointValue",
"ImageSeriesValues",
"StringSeriesValues",
"FloatPointValue",
"FloatSeriesValues",
"FloatAttribute",
"IntAttribute",
"BoolAttribute",
"FileAttribute",
"StringAttribute",
"DatetimeAttribute",
"ArtifactAttribute",
"ArtifactModel",
"MultipartConfig",
"FloatSeriesAttribute",
"StringSeriesAttribute",
"StringSetAttribute",
]
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import (
Any,
FrozenSet,
List,
Optional,
Set,
)

@@ -36,2 +59,3 @@

from neptune.common.backends.api_model import MultipartConfig
from neptune.internal.container_type import ContainerType

@@ -44,19 +68,2 @@ from neptune.internal.id_formats import (

@dataclass(frozen=True)
class MultipartConfig:
min_chunk_size: int
max_chunk_size: int
max_chunk_count: int
max_single_part_size: int
@staticmethod
def get_default() -> "MultipartConfig":
return MultipartConfig(
min_chunk_size=5242880,
max_chunk_size=1073741824,
max_chunk_count=1000,
max_single_part_size=5242880,
)
@dataclass

@@ -131,2 +138,3 @@ class Project:

multipart_config: MultipartConfig
sys_name_set_by_backend: bool

@@ -173,2 +181,4 @@ def has_feature(self, feature_name: str) -> bool:

sys_name_set_by_backend = getattr(config, "sysNameSetByBackend", False)
return ClientConfig(

@@ -180,6 +190,112 @@ api_url=config.apiUrl,

multipart_config=multipart_upload_config,
sys_name_set_by_backend=sys_name_set_by_backend,
)
class AttributeType(Enum):
FLOAT = "float"
INT = "int"
BOOL = "bool"
STRING = "string"
DATETIME = "datetime"
FILE = "file"
FILE_SET = "fileSet"
FLOAT_SERIES = "floatSeries"
STRING_SERIES = "stringSeries"
IMAGE_SERIES = "imageSeries"
STRING_SET = "stringSet"
GIT_REF = "gitRef"
RUN_STATE = "experimentState"
NOTEBOOK_REF = "notebookRef"
ARTIFACT = "artifact"
@dataclass
class Attribute:
path: str
type: AttributeType
@dataclass
class AttributeWithProperties:
path: str
type: AttributeType
properties: Any
@dataclass
class LeaderboardEntry:
id: str
attributes: List[AttributeWithProperties]
@dataclass
class StringPointValue:
timestampMillis: int
step: float
value: str
@dataclass
class ImageSeriesValues:
totalItemCount: int
@dataclass
class StringSeriesValues:
totalItemCount: int
values: List[StringPointValue]
@dataclass
class FloatPointValue:
timestampMillis: int
step: float
value: float
@dataclass
class FloatSeriesValues:
totalItemCount: int
values: List[FloatPointValue]
@dataclass
class FloatAttribute:
value: float
@dataclass
class IntAttribute:
value: int
@dataclass
class BoolAttribute:
value: bool
@dataclass
class FileAttribute:
name: str
ext: str
size: int
@dataclass
class StringAttribute:
value: str
@dataclass
class DatetimeAttribute:
value: datetime
@dataclass
class ArtifactAttribute:
hash: str
@dataclass
class ArtifactModel:

@@ -189,1 +305,16 @@ received_metadata: bool

size: int
@dataclass
class FloatSeriesAttribute:
last: Optional[float]
@dataclass
class StringSeriesAttribute:
last: Optional[str]
@dataclass
class StringSetAttribute:
values: Set[str]

@@ -33,3 +33,3 @@ #

from neptune.api.models import ArtifactField
from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.exceptions import (

@@ -47,5 +47,7 @@ ArtifactNotFoundException,

)
from neptune.internal.backends.api_model import ArtifactModel
from neptune.internal.backends.api_model import (
ArtifactAttribute,
ArtifactModel,
)
from neptune.internal.backends.swagger_client_wrapper import SwaggerClientWrapper
from neptune.internal.backends.utils import with_api_exceptions_handler
from neptune.internal.operation import (

@@ -257,3 +259,3 @@ AssignArtifact,

default_request_params: Dict,
) -> ArtifactField:
) -> ArtifactAttribute:
requests_params = add_artifact_version_to_request_params(default_request_params)

@@ -267,3 +269,3 @@ params = {

result = swagger_client.api.getArtifactAttribute(**params).response().result
return ArtifactField.from_model(result)
return ArtifactAttribute(hash=result.hash)
except HTTPNotFound:

@@ -270,0 +272,0 @@ raise FetchAttributeNotFoundException(path_to_str(path))

@@ -18,3 +18,2 @@ #

"DEFAULT_REQUEST_KWARGS",
"DEFAULT_PROTO_REQUEST_KWARGS",
"create_http_client_with_auth",

@@ -37,2 +36,4 @@ "create_backend_client",

from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.common.oauth import NeptuneAuthenticator
from neptune.envs import NEPTUNE_REQUEST_TIMEOUT

@@ -50,7 +51,5 @@ from neptune.exceptions import NeptuneClientUpgradeRequiredError

verify_host_resolution,
with_api_exceptions_handler,
)
from neptune.internal.credentials import Credentials
from neptune.internal.oauth import NeptuneAuthenticator
from neptune.version import version as neptune_version
from neptune.version import version as neptune_client_version

@@ -68,18 +67,7 @@ BACKEND_SWAGGER_PATH = "/api/backend/swagger.json"

"timeout": REQUEST_TIMEOUT,
"headers": {},
"headers": {"X-Neptune-LegacyClient": "false"},
}
}
DEFAULT_PROTO_REQUEST_KWARGS = {
"_request_options": {
**DEFAULT_REQUEST_KWARGS["_request_options"],
"headers": {
**DEFAULT_REQUEST_KWARGS["_request_options"]["headers"],
"Accept": "application/x-protobuf,application/json",
"Accept-Encoding": "gzip, deflate, br",
},
}
}
def _close_connections_on_fork(session: requests.Session):

@@ -108,3 +96,3 @@ try:

user_agent = "neptune-client/{lib_version} ({system}, python {python_version})".format(
lib_version=neptune_version,
lib_version=neptune_client_version,
system=platform.platform(),

@@ -156,3 +144,3 @@ python_version=platform.python_version(),

if not client_config.version_info:
raise NeptuneClientUpgradeRequiredError(neptune_version, max_version="0.4.111")
raise NeptuneClientUpgradeRequiredError(neptune_client_version, max_version="0.4.111")
return client_config

@@ -169,3 +157,3 @@

verify_client_version(client_config, neptune_version)
verify_client_version(client_config, neptune_client_version)

@@ -172,0 +160,0 @@ endpoint_url = None

@@ -51,2 +51,22 @@ #

from neptune.common.backends.api_model import MultipartConfig
from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.common.exceptions import (
InternalClientError,
NeptuneException,
UploadedFileChanged,
)
from neptune.common.hardware.constants import BYTES_IN_ONE_MB
from neptune.common.storage.datastream import (
FileChunk,
FileChunker,
compress_to_tar_gz_in_memory,
)
from neptune.common.storage.storage_utils import (
AttributeUploadConfiguration,
UploadEntry,
normalize_file_name,
scan_unique_upload_entries,
split_upload_files,
)
from neptune.exceptions import (

@@ -57,3 +77,2 @@ FileUploadError,

)
from neptune.internal.backends.api_model import MultipartConfig
from neptune.internal.backends.swagger_client_wrapper import (

@@ -67,20 +86,3 @@ ApiMethodWrapper,

handle_server_raw_response_messages,
with_api_exceptions_handler,
)
from neptune.internal.exceptions import (
InternalClientError,
NeptuneException,
UploadedFileChanged,
)
from neptune.internal.hardware.constants import BYTES_IN_ONE_MB
from neptune.internal.storage import (
AttributeUploadConfiguration,
FileChunk,
FileChunker,
UploadEntry,
compress_to_tar_gz_in_memory,
normalize_file_name,
scan_unique_upload_entries,
split_upload_files,
)
from neptune.internal.utils import (

@@ -87,0 +89,0 @@ get_absolute_paths,

@@ -41,37 +41,17 @@ #

from neptune.api.models import (
ArtifactField,
BoolField,
DateTimeField,
Field,
FieldDefinition,
FieldType,
FileEntry,
FileField,
FloatField,
FloatSeriesField,
FloatSeriesValues,
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
StringSetField,
from neptune.api.dtos import FileEntry
from neptune.api.searching_entries import iter_over_pages
from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.common.exceptions import (
ClientHttpError,
InternalClientError,
NeptuneException,
)
from neptune.api.proto.neptune_pb.api.model.attributes_pb2 import (
ProtoAttributesSearchResultDTO,
ProtoQueryAttributesResultDTO,
from neptune.common.patterns import PROJECT_QUALIFIED_NAME_PATTERN
from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import ProtoAttributesDTO
from neptune.api.proto.neptune_pb.api.model.series_values_pb2 import ProtoFloatSeriesValuesDTO
from neptune.api.searching_entries import iter_over_pages
from neptune.core.components.operation_storage import OperationStorage
from neptune.envs import (
NEPTUNE_FETCH_TABLE_STEP_SIZE,
NEPTUNE_USE_PROTOCOL_BUFFERS,
)
from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE
from neptune.exceptions import (

@@ -85,2 +65,3 @@ AmbiguousProjectName,

NeptuneFeatureNotAvailableException,
NeptuneLegacyProjectException,
NeptuneLimitExceedException,

@@ -94,4 +75,22 @@ NeptuneObjectCreationConflict,

ApiExperiment,
ArtifactAttribute,
Attribute,
AttributeType,
BoolAttribute,
DatetimeAttribute,
FileAttribute,
FloatAttribute,
FloatPointValue,
FloatSeriesAttribute,
FloatSeriesValues,
ImageSeriesValues,
IntAttribute,
LeaderboardEntry,
OptionalFeatures,
Project,
StringAttribute,
StringPointValue,
StringSeriesAttribute,
StringSeriesValues,
StringSetAttribute,
Workspace,

@@ -106,3 +105,2 @@ )

from neptune.internal.backends.hosted_client import (
DEFAULT_PROTO_REQUEST_KWARGS,
DEFAULT_REQUEST_KWARGS,

@@ -131,11 +129,5 @@ create_artifacts_client,

ssl_verify,
with_api_exceptions_handler,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.credentials import Credentials
from neptune.internal.exceptions import (
ClientHttpError,
InternalClientError,
NeptuneException,
)
from neptune.internal.id_formats import (

@@ -158,11 +150,6 @@ QualifiedName,

from neptune.internal.utils.paths import path_to_str
from neptune.internal.utils.patterns import PROJECT_QUALIFIED_NAME_PATTERN
from neptune.internal.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.internal.websockets.websockets_factory import WebsocketsFactory
from neptune.management.exceptions import ObjectNotFound
from neptune.typing import ProgressBarType
from neptune.version import version as neptune_version
from neptune.version import version as neptune_client_version

@@ -178,11 +165,20 @@ if TYPE_CHECKING:

ATOMIC_ATTRIBUTE_TYPES = {
FieldType.INT.value,
FieldType.FLOAT.value,
FieldType.STRING.value,
FieldType.BOOL.value,
FieldType.DATETIME.value,
FieldType.OBJECT_STATE.value,
AttributeType.INT.value,
AttributeType.FLOAT.value,
AttributeType.STRING.value,
AttributeType.BOOL.value,
AttributeType.DATETIME.value,
AttributeType.RUN_STATE.value,
}
ATOMIC_ATTRIBUTE_TYPES = {
AttributeType.INT.value,
AttributeType.FLOAT.value,
AttributeType.STRING.value,
AttributeType.BOOL.value,
AttributeType.DATETIME.value,
AttributeType.RUN_STATE.value,
}
class HostedNeptuneBackend(NeptuneBackend):

@@ -193,3 +189,2 @@ def __init__(self, credentials: Credentials, proxies: Optional[Dict[str, str]] = None):

self.missing_features = []
self.use_proto = os.getenv(NEPTUNE_USE_PROTOCOL_BUFFERS, "False").lower() in {"true", "1", "y"}

@@ -211,2 +206,4 @@ http_client, client_config = create_http_client_with_auth(

self.sys_name_set_by_backend = self._client_config.sys_name_set_by_backend
def verify_feature_available(self, feature_name: str):

@@ -258,3 +255,5 @@ if not self._client_config.has_feature(feature_name):

project = response.result
project_version = project.version if hasattr(project, "version") else 1
if project_version < 2:
raise NeptuneLegacyProjectException(project_id)
return Project(

@@ -435,3 +434,3 @@ id=project.id,

"type": container_type.to_api(),
"cliVersion": str(neptune_version),
"cliVersion": str(neptune_client_version),
**additional_params,

@@ -442,3 +441,3 @@ }

"experimentCreationParams": params,
"X-Neptune-CliVersion": str(neptune_version),
"X-Neptune-CliVersion": str(neptune_client_version),
**DEFAULT_REQUEST_KWARGS,

@@ -703,3 +702,6 @@ }

@with_api_exceptions_handler
def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]:
def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]:
def to_attribute(attr) -> Attribute:
return Attribute(attr.name, AttributeType(attr.type))
params = {

@@ -712,3 +714,3 @@ "experimentId": container_id,

attribute_type_names = [at.value for at in FieldType]
attribute_type_names = [at.value for at in AttributeType]
accepted_attributes = [attr for attr in experiment.attributes if attr.type in attribute_type_names]

@@ -726,5 +728,3 @@

return [
FieldDefinition.from_model(field) for field in accepted_attributes if field.type in attribute_type_names
]
return [to_attribute(attr) for attr in accepted_attributes if attr.type in attribute_type_names]
except HTTPNotFound as e:

@@ -805,3 +805,3 @@ raise ContainerUUIDNotFound(

@with_api_exceptions_handler
def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField:
def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute:
params = {

@@ -814,3 +814,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getFloatAttribute(**params).response().result
return FloatField.from_model(result)
return FloatAttribute(result.value)
except HTTPNotFound:

@@ -820,3 +820,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField:
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute:
params = {

@@ -829,3 +829,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getIntAttribute(**params).response().result
return IntField.from_model(result)
return IntAttribute(result.value)
except HTTPNotFound:

@@ -835,3 +835,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField:
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute:
params = {

@@ -844,3 +844,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getBoolAttribute(**params).response().result
return BoolField.from_model(result)
return BoolAttribute(result.value)
except HTTPNotFound:

@@ -850,3 +850,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField:
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute:
params = {

@@ -859,3 +859,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getFileAttribute(**params).response().result
return FileField.from_model(result)
return FileAttribute(name=result.name, ext=result.ext, size=result.size)
except HTTPNotFound:

@@ -865,3 +865,5 @@ raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField:
def get_string_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringAttribute:
params = {

@@ -874,3 +876,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getStringAttribute(**params).response().result
return StringField.from_model(result)
return StringAttribute(result.value)
except HTTPNotFound:

@@ -882,3 +884,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

self, container_id: str, container_type: ContainerType, path: List[str]
) -> DateTimeField:
) -> DatetimeAttribute:
params = {

@@ -891,3 +893,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getDatetimeAttribute(**params).response().result
return DateTimeField.from_model(result)
return DatetimeAttribute(result.value)
except HTTPNotFound:

@@ -898,3 +900,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

self, container_id: str, container_type: ContainerType, path: List[str]
) -> ArtifactField:
) -> ArtifactAttribute:
return get_artifact_attribute(

@@ -933,3 +935,3 @@ swagger_client=self.leaderboard_client,

self, container_id: str, container_type: ContainerType, path: List[str]
) -> FloatSeriesField:
) -> FloatSeriesAttribute:
params = {

@@ -942,3 +944,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getFloatSeriesAttribute(**params).response().result
return FloatSeriesField.from_model(result)
return FloatSeriesAttribute(result.last)
except HTTPNotFound:

@@ -950,3 +952,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSeriesField:
) -> StringSeriesAttribute:
params = {

@@ -959,3 +961,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getStringSeriesAttribute(**params).response().result
return StringSeriesField.from_model(result)
return StringSeriesAttribute(result.last)
except HTTPNotFound:

@@ -967,3 +969,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSetField:
) -> StringSetAttribute:
params = {

@@ -976,3 +978,3 @@ "experimentId": container_id,

result = self.leaderboard_client.api.getStringSetAttribute(**params).response().result
return StringSetField.from_model(result)
return StringSetAttribute(set(result.values))
except HTTPNotFound:

@@ -999,3 +1001,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

result = self.leaderboard_client.api.getImageSeriesValues(**params).response().result
return ImageSeriesValues.from_model(result)
return ImageSeriesValues(result.totalItemCount)
except HTTPNotFound:

@@ -1010,4 +1012,4 @@ raise FetchAttributeNotFoundException(path_to_str(path))

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
) -> StringSeriesValues:

@@ -1018,3 +1020,3 @@ params = {

"limit": limit,
"skipToStep": from_step,
"offset": offset,
**DEFAULT_REQUEST_KWARGS,

@@ -1024,3 +1026,6 @@ }

result = self.leaderboard_client.api.getStringSeriesValues(**params).response().result
return StringSeriesValues.from_model(result)
return StringSeriesValues(
result.totalItemCount,
[StringPointValue(v.timestampMillis, v.step, v.value) for v in result.values],
)
except HTTPNotFound:

@@ -1035,9 +1040,5 @@ raise FetchAttributeNotFoundException(path_to_str(path))

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues:
use_proto = use_proto if use_proto is not None else self.use_proto
params = {

@@ -1047,30 +1048,11 @@ "experimentId": container_id,

"limit": limit,
"skipToStep": from_step,
"offset": offset,
**DEFAULT_REQUEST_KWARGS,
}
if not include_inherited:
params["lineage"] = "NONE"
try:
if use_proto:
result = (
self.leaderboard_client.api.getFloatSeriesValuesProto(
**params,
**DEFAULT_PROTO_REQUEST_KWARGS,
)
.response()
.result
)
data = ProtoFloatSeriesValuesDTO.FromString(result)
return FloatSeriesValues.from_proto(data)
else:
result = (
self.leaderboard_client.api.getFloatSeriesValues(
**params,
**DEFAULT_REQUEST_KWARGS,
)
.response()
.result
)
return FloatSeriesValues.from_model(result)
result = self.leaderboard_client.api.getFloatSeriesValues(**params).response().result
return FloatSeriesValues(
result.totalItemCount,
[FloatPointValue(v.timestampMillis, v.step, v.value) for v in result.values],
)
except HTTPNotFound:

@@ -1082,3 +1064,3 @@ raise FetchAttributeNotFoundException(path_to_str(path))

self, container_id: str, container_type: ContainerType, path: List[str]
) -> List[Tuple[str, FieldType, Any]]:
) -> List[Tuple[str, AttributeType, Any]]:
params = {

@@ -1114,14 +1096,12 @@ "experimentId": container_id,

@with_api_exceptions_handler
def _get_column_types(self, project_id: UniqueId, column: str) -> List[Any]:
def _get_column_types(self, project_id: UniqueId, column: str, types: Optional[Iterable[str]] = None) -> List[Any]:
params = {
"projectIdentifier": project_id,
"query": {
"attributeNameFilter": {"mustMatchRegexes": [column]},
},
"search": column,
"type": types,
"params": {},
**DEFAULT_REQUEST_KWARGS,
}
try:
return (
self.leaderboard_client.api.queryAttributeDefinitionsWithinProject(**params).response().result.entries
)
return self.leaderboard_client.api.searchLeaderboardAttributes(**params).response().result.entries
except HTTPNotFound as e:

@@ -1142,5 +1122,3 @@ raise ProjectNotFound(project_id=project_id) from e

step_size: Optional[int] = None,
use_proto: Optional[bool] = None,
) -> Generator[LeaderboardEntry, None, None]:
use_proto = use_proto if use_proto is not None else self.use_proto
default_step_size = step_size or int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100"))

@@ -1150,4 +1128,2 @@

columns = set(columns) | {sort_by} if columns else {sort_by}
types_filter = list(map(lambda container_type: container_type.to_api(), types)) if types else None

@@ -1157,7 +1133,7 @@ attributes_filter = {"attributeFilters": [{"path": column} for column in columns]} if columns else {}

if sort_by == "sys/creation_time":
sort_by_column_type = FieldType.DATETIME.value
elif sort_by == "sys/id":
sort_by_column_type = FieldType.STRING.value
sort_by_column_type = AttributeType.DATETIME.value
if sort_by == "sys/id":
sort_by_column_type = AttributeType.STRING.value
else:
sort_by_column_type_candidates = self._get_column_types(project_id, sort_by)
sort_by_column_type_candidates = self._get_column_types(project_id, sort_by, types_filter)
sort_by_column_type = _get_column_type_from_entries(sort_by_column_type_candidates, sort_by)

@@ -1178,3 +1154,2 @@

progress_bar=progress_bar,
use_proto=use_proto,
)

@@ -1207,173 +1182,3 @@ except HTTPNotFound:

def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
pagination = {"nextPage": next_page.to_dto()} if next_page else {}
params = {
"projectIdentifier": project_id,
"query": {
**pagination,
"experimentIdsFilter": experiment_ids_filter,
"attributeNameRegex": field_name_regex,
},
}
try:
data = (
self.leaderboard_client.api.queryAttributeDefinitionsWithinProject(
**params,
**DEFAULT_REQUEST_KWARGS,
)
.response()
.result
)
return QueryFieldDefinitionsResult.from_model(data)
except HTTPNotFound:
raise ProjectNotFound(project_id=project_id)
def query_fields_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
experiment_names_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
use_proto: Optional[bool] = None,
) -> QueryFieldsResult:
use_proto = use_proto if use_proto is not None else self.use_proto
query = {
"experimentIdsFilter": experiment_ids_filter or None,
"experimentNamesFilter": experiment_names_filter or None,
"nextPage": next_page.to_dto() if next_page else None,
}
# If we are provided with both explicit column names, and a regex,
# we need to paste together all of them into a single regex (with OR between terms)
if field_name_regex:
terms = [field_name_regex]
if field_names_filter:
# Make sure we don't pass too broad regex for explicit column names
terms += [f"^{name}$" for name in field_names_filter]
regex = "|".join(terms)
query["attributeNameFilter"] = {"mustMatchRegexes": [regex]}
elif field_names_filter:
query["attributeNamesFilter"] = field_names_filter
params = {"projectIdentifier": project_id, "query": query}
try:
if use_proto:
result = (
self.leaderboard_client.api.queryAttributesWithinProjectProto(
**params,
**DEFAULT_PROTO_REQUEST_KWARGS,
)
.response()
.result
)
data = ProtoQueryAttributesResultDTO.FromString(result)
return QueryFieldsResult.from_proto(data)
else:
data = (
self.leaderboard_client.api.queryAttributesWithinProject(
**params,
**DEFAULT_REQUEST_KWARGS,
)
.response()
.result
)
return QueryFieldsResult.from_model(data)
except HTTPNotFound:
raise ProjectNotFound(project_id=project_id)
def get_fields_definitions(
self,
container_id: str,
container_type: ContainerType,
use_proto: Optional[bool] = None,
) -> List[FieldDefinition]:
use_proto = use_proto if use_proto is not None else self.use_proto
params = {
"experimentIdentifier": container_id,
}
try:
if use_proto:
result = (
self.leaderboard_client.api.queryAttributeDefinitionsProto(
**params,
**DEFAULT_PROTO_REQUEST_KWARGS,
)
.response()
.result
)
data = ProtoAttributesSearchResultDTO.FromString(result)
return [FieldDefinition.from_proto(field_def) for field_def in data.entries]
else:
data = (
self.leaderboard_client.api.queryAttributeDefinitions(
**params,
**DEFAULT_REQUEST_KWARGS,
)
.response()
.result
)
return [FieldDefinition.from_model(field_def) for field_def in data.entries]
except HTTPNotFound as e:
raise ContainerUUIDNotFound(
container_id=container_id,
container_type=container_type,
) from e
def get_fields_with_paths_filter(
self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None
) -> List[Field]:
use_proto = use_proto if use_proto is not None else self.use_proto
params = {
"holderIdentifier": container_id,
"holderType": "experiment",
"attributeQuery": {
"attributePathsFilter": paths,
},
}
try:
if use_proto:
result = (
self.leaderboard_client.api.getAttributesWithPathsFilterProto(
**params,
**DEFAULT_PROTO_REQUEST_KWARGS,
)
.response()
.result
)
data = ProtoAttributesDTO.FromString(result)
return [Field.from_proto(field) for field in data.attributes]
else:
data = (
self.leaderboard_client.api.getAttributesWithPathsFilter(
**params,
**DEFAULT_REQUEST_KWARGS,
)
.response()
.result
)
return [Field.from_model(field) for field in data.attributes]
except HTTPNotFound as e:
raise ContainerUUIDNotFound(
container_id=container_id,
container_type=container_type,
) from e
def _get_column_type_from_entries(entries: List[Any], column: str) -> str:

@@ -1397,4 +1202,4 @@ if not entries: # column chosen is not present in the table

if types == {FieldType.INT.value, FieldType.FLOAT.value}:
return FieldType.FLOAT.value
if types == {AttributeType.INT.value, AttributeType.FLOAT.value}:
return AttributeType.FLOAT.value

@@ -1405,2 +1210,2 @@ warn_once(

)
return FieldType.STRING.value
return AttributeType.STRING.value

@@ -37,26 +37,6 @@ #

from neptune.api.models import (
ArtifactField,
BoolField,
DateTimeField,
Field,
FieldDefinition,
FieldType,
FileEntry,
FileField,
FloatField,
FloatPointValue,
FloatSeriesField,
FloatSeriesValues,
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringPointValue,
StringSeriesField,
StringSeriesValues,
StringSetField,
from neptune.api.dtos import FileEntry
from neptune.common.exceptions import (
InternalClientError,
NeptuneException,
)

@@ -74,3 +54,21 @@ from neptune.core.components.operation_storage import OperationStorage

ApiExperiment,
ArtifactAttribute,
Attribute,
AttributeType,
BoolAttribute,
DatetimeAttribute,
FileAttribute,
FloatAttribute,
FloatPointValue,
FloatSeriesAttribute,
FloatSeriesValues,
ImageSeriesValues,
IntAttribute,
LeaderboardEntry,
Project,
StringAttribute,
StringPointValue,
StringSeriesAttribute,
StringSeriesValues,
StringSetAttribute,
Workspace,

@@ -83,6 +81,2 @@ )

from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import (
InternalClientError,
NeptuneException,
)
from neptune.internal.id_formats import (

@@ -323,3 +317,3 @@ QualifiedName,

def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]:
def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]:
run = self._get_container(container_id, container_type)

@@ -334,3 +328,3 @@ return list(self._generate_attributes(None, run.get_structure()))

else:
yield FieldDefinition(
yield Attribute(
new_path,

@@ -384,18 +378,17 @@ value_or_dict.accept(self._attribute_type_converter_value_visitor),

def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField:
def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute:
val = self._get_attribute(container_id, container_type, path, Float)
return FloatField(path=path_to_str(path), value=val.value)
return FloatAttribute(val.value)
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField:
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute:
val = self._get_attribute(container_id, container_type, path, Integer)
return IntField(path=path_to_str(path), value=val.value)
return IntAttribute(val.value)
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField:
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute:
val = self._get_attribute(container_id, container_type, path, Boolean)
return BoolField(path=path_to_str(path), value=val.value)
return BoolAttribute(val.value)
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField:
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute:
val = self._get_attribute(container_id, container_type, path, File)
return FileField(
path=path_to_str(path),
return FileAttribute(
name=os.path.basename(val.path) if val.file_type is FileType.LOCAL_FILE else "",

@@ -406,17 +399,19 @@ ext=val.extension or "",

def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField:
def get_string_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringAttribute:
val = self._get_attribute(container_id, container_type, path, String)
return StringField(path=path_to_str(path), value=val.value)
return StringAttribute(val.value)
def get_datetime_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> DateTimeField:
) -> DatetimeAttribute:
val = self._get_attribute(container_id, container_type, path, Datetime)
return DateTimeField(path=path_to_str(path), value=val.value)
return DatetimeAttribute(val.value)
def get_artifact_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> ArtifactField:
) -> ArtifactAttribute:
val = self._get_attribute(container_id, container_type, path, Artifact)
return ArtifactField(path=path_to_str(path), hash=val.hash)
return ArtifactAttribute(val.hash)

@@ -428,17 +423,17 @@ def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[ArtifactFileData]:

self, container_id: str, container_type: ContainerType, path: List[str]
) -> FloatSeriesField:
) -> FloatSeriesAttribute:
val = self._get_attribute(container_id, container_type, path, FloatSeries)
return FloatSeriesField(path=path_to_str(path), last=val.values[-1] if val.values else None)
return FloatSeriesAttribute(val.values[-1] if val.values else None)
def get_string_series_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSeriesField:
) -> StringSeriesAttribute:
val = self._get_attribute(container_id, container_type, path, StringSeries)
return StringSeriesField(path=path_to_str(path), last=val.values[-1] if val.values else None)
return StringSeriesAttribute(val.values[-1] if val.values else None)
def get_string_set_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSetField:
) -> StringSetAttribute:
val = self._get_attribute(container_id, container_type, path, StringSet)
return StringSetField(path=path_to_str(path), values=set(val.values))
return StringSetAttribute(set(val.values))

@@ -466,4 +461,4 @@ def _get_attribute(

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
) -> StringSeriesValues:

@@ -473,3 +468,3 @@ val = self._get_attribute(container_id, container_type, path, StringSeries)

len(val.values),
[StringPointValue(timestamp=datetime.now(), step=idx, value=v) for idx, v in enumerate(val.values)],
[StringPointValue(timestampMillis=42342, step=idx, value=v) for idx, v in enumerate(val.values)],
)

@@ -482,6 +477,4 @@

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues:

@@ -491,3 +484,3 @@ val = self._get_attribute(container_id, container_type, path, FloatSeries)

len(val.values),
[FloatPointValue(timestamp=datetime.now(), step=idx, value=v) for idx, v in enumerate(val.values)],
[FloatPointValue(timestampMillis=42342, step=idx, value=v) for idx, v in enumerate(val.values)],
)

@@ -535,10 +528,2 @@

def get_fields_definitions(
self,
container_id: str,
container_type: ContainerType,
use_proto: Optional[bool] = None,
) -> List[FieldDefinition]:
return []
def _get_attribute_values(self, value_dict, path_prefix: List[str]):

@@ -559,3 +544,3 @@ assert isinstance(value_dict, dict)

self, container_id: str, container_type: ContainerType, path: List[str]
) -> List[Tuple[str, FieldType, Any]]:
) -> List[Tuple[str, AttributeType, Any]]:
run = self._get_container(container_id, container_type)

@@ -583,50 +568,49 @@ values = self._get_attribute_values(run.get(path), path)

progress_bar: Optional[ProgressBarType] = None,
use_proto: Optional[bool] = None,
) -> Generator[LeaderboardEntry, None, None]:
"""Non relevant for mock"""
class AttributeTypeConverterValueVisitor(ValueVisitor[FieldType]):
def visit_float(self, _: Float) -> FieldType:
return FieldType.FLOAT
class AttributeTypeConverterValueVisitor(ValueVisitor[AttributeType]):
def visit_float(self, _: Float) -> AttributeType:
return AttributeType.FLOAT
def visit_integer(self, _: Integer) -> FieldType:
return FieldType.INT
def visit_integer(self, _: Integer) -> AttributeType:
return AttributeType.INT
def visit_boolean(self, _: Boolean) -> FieldType:
return FieldType.BOOL
def visit_boolean(self, _: Boolean) -> AttributeType:
return AttributeType.BOOL
def visit_string(self, _: String) -> FieldType:
return FieldType.STRING
def visit_string(self, _: String) -> AttributeType:
return AttributeType.STRING
def visit_datetime(self, _: Datetime) -> FieldType:
return FieldType.DATETIME
def visit_datetime(self, _: Datetime) -> AttributeType:
return AttributeType.DATETIME
def visit_file(self, _: File) -> FieldType:
return FieldType.FILE
def visit_file(self, _: File) -> AttributeType:
return AttributeType.FILE
def visit_file_set(self, _: FileSet) -> FieldType:
return FieldType.FILE_SET
def visit_file_set(self, _: FileSet) -> AttributeType:
return AttributeType.FILE_SET
def visit_float_series(self, _: FloatSeries) -> FieldType:
return FieldType.FLOAT_SERIES
def visit_float_series(self, _: FloatSeries) -> AttributeType:
return AttributeType.FLOAT_SERIES
def visit_string_series(self, _: StringSeries) -> FieldType:
return FieldType.STRING_SERIES
def visit_string_series(self, _: StringSeries) -> AttributeType:
return AttributeType.STRING_SERIES
def visit_image_series(self, _: FileSeries) -> FieldType:
return FieldType.IMAGE_SERIES
def visit_image_series(self, _: FileSeries) -> AttributeType:
return AttributeType.IMAGE_SERIES
def visit_string_set(self, _: StringSet) -> FieldType:
return FieldType.STRING_SET
def visit_string_set(self, _: StringSet) -> AttributeType:
return AttributeType.STRING_SET
def visit_git_ref(self, _: GitRef) -> FieldType:
return FieldType.GIT_REF
def visit_git_ref(self, _: GitRef) -> AttributeType:
return AttributeType.GIT_REF
def visit_artifact(self, _: Artifact) -> FieldType:
return FieldType.ARTIFACT
def visit_artifact(self, _: Artifact) -> AttributeType:
return AttributeType.ARTIFACT
def visit_namespace(self, _: Namespace) -> FieldType:
def visit_namespace(self, _: Namespace) -> AttributeType:
raise NotImplementedError
def copy_value(self, source_type: Type[FieldDefinition], source_path: List[str]) -> FieldType:
def copy_value(self, source_type: Type[Attribute], source_path: List[str]) -> AttributeType:
raise NotImplementedError

@@ -820,33 +804,1 @@

]
def get_fields_with_paths_filter(
self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None
) -> List[Field]:
return []
def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)
def query_fields_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
experiment_names_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
use_proto: Optional[bool] = None,
) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)

@@ -28,25 +28,4 @@ #

from neptune.api.models import (
ArtifactField,
BoolField,
DateTimeField,
Field,
FieldDefinition,
FieldType,
FileEntry,
FileField,
FloatField,
FloatSeriesField,
FloatSeriesValues,
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
StringSetField,
)
from neptune.api.dtos import FileEntry
from neptune.common.exceptions import NeptuneException
from neptune.core.components.operation_storage import OperationStorage

@@ -56,3 +35,19 @@ from neptune.internal.artifacts.types import ArtifactFileData

ApiExperiment,
ArtifactAttribute,
Attribute,
AttributeType,
BoolAttribute,
DatetimeAttribute,
FileAttribute,
FloatAttribute,
FloatSeriesAttribute,
FloatSeriesValues,
ImageSeriesValues,
IntAttribute,
LeaderboardEntry,
Project,
StringAttribute,
StringSeriesAttribute,
StringSeriesValues,
StringSetAttribute,
Workspace,

@@ -62,3 +57,2 @@ )

from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import NeptuneException
from neptune.internal.id_formats import (

@@ -158,3 +152,3 @@ QualifiedName,

@abc.abstractmethod
def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]:
def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]:
pass

@@ -185,19 +179,21 @@

@abc.abstractmethod
def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField:
def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute:
pass
@abc.abstractmethod
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField:
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute:
pass
@abc.abstractmethod
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField:
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute:
pass
@abc.abstractmethod
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField:
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute:
pass
@abc.abstractmethod
def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField:
def get_string_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringAttribute:
pass

@@ -208,3 +204,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> DateTimeField:
) -> DatetimeAttribute:
pass

@@ -215,3 +211,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> ArtifactField:
) -> ArtifactAttribute:
pass

@@ -226,3 +222,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> FloatSeriesField:
) -> FloatSeriesAttribute:
pass

@@ -233,3 +229,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSeriesField:
) -> StringSeriesAttribute:
pass

@@ -240,3 +236,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSetField:
) -> StringSetAttribute:
pass

@@ -273,5 +269,6 @@

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
) -> StringSeriesValues: ...
) -> StringSeriesValues:
pass

@@ -284,7 +281,6 @@ @abc.abstractmethod

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues: ...
) -> FloatSeriesValues:
pass

@@ -314,21 +310,6 @@ @abc.abstractmethod

# WARN: Used in Neptune Fetcher
@abc.abstractmethod
def get_fields_definitions(
self,
container_id: str,
container_type: ContainerType,
use_proto: Optional[bool] = None,
) -> List[FieldDefinition]: ...
# WARN: Used in Neptune Fetcher
@abc.abstractmethod
def get_fields_with_paths_filter(
self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None
) -> List[Field]: ...
@abc.abstractmethod
def fetch_atom_attribute_values(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> List[Tuple[str, FieldType, Any]]:
) -> List[Tuple[str, AttributeType, Any]]:
pass

@@ -347,3 +328,2 @@

progress_bar: Optional[ProgressBarType] = None,
use_proto: Optional[bool] = None,
) -> Generator[LeaderboardEntry, None, None]:

@@ -355,22 +335,1 @@ pass

pass
@abc.abstractmethod
def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult: ...
@abc.abstractmethod
def query_fields_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
experiment_names_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
use_proto: Optional[bool] = None,
) -> QueryFieldsResult: ...

@@ -78,4 +78,2 @@ #

GREATER_THAN = ">"
LESS_THAN = "<"
MATCHES = "MATCHES"

@@ -82,0 +80,0 @@

@@ -19,4 +19,2 @@ #

from typing import (
Generator,
Iterable,
List,

@@ -26,33 +24,23 @@ Optional,

from neptune.api.models import (
ArtifactField,
BoolField,
DateTimeField,
Field,
FieldDefinition,
FileEntry,
FileField,
FloatField,
FloatSeriesField,
from neptune.api.dtos import FileEntry
from neptune.exceptions import NeptuneOfflineModeFetchException
from neptune.internal.artifacts.types import ArtifactFileData
from neptune.internal.backends.api_model import (
ArtifactAttribute,
Attribute,
BoolAttribute,
DatetimeAttribute,
FileAttribute,
FloatAttribute,
FloatSeriesAttribute,
FloatSeriesValues,
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
IntAttribute,
StringAttribute,
StringSeriesAttribute,
StringSeriesValues,
StringSetField,
StringSetAttribute,
)
from neptune.exceptions import NeptuneOfflineModeFetchException
from neptune.internal.artifacts.types import ArtifactFileData
from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock
from neptune.internal.backends.nql import NQLQuery
from neptune.internal.container_type import ContainerType
from neptune.internal.id_formats import (
QualifiedName,
UniqueId,
)
from neptune.typing import ProgressBarType

@@ -64,18 +52,20 @@

def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]:
def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]:
raise NeptuneOfflineModeFetchException
def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField:
def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute:
raise NeptuneOfflineModeFetchException
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField:
def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute:
raise NeptuneOfflineModeFetchException
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField:
def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute:
raise NeptuneOfflineModeFetchException
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField:
def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute:
raise NeptuneOfflineModeFetchException
def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField:
def get_string_attribute(
self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringAttribute:
raise NeptuneOfflineModeFetchException

@@ -85,3 +75,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> DateTimeField:
) -> DatetimeAttribute:
raise NeptuneOfflineModeFetchException

@@ -91,3 +81,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> ArtifactField:
) -> ArtifactAttribute:
raise NeptuneOfflineModeFetchException

@@ -100,3 +90,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> FloatSeriesField:
) -> FloatSeriesAttribute:
raise NeptuneOfflineModeFetchException

@@ -106,3 +96,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSeriesField:
) -> StringSeriesAttribute:
raise NeptuneOfflineModeFetchException

@@ -112,3 +102,3 @@

self, container_id: str, container_type: ContainerType, path: List[str]
) -> StringSetField:
) -> StringSetAttribute:
raise NeptuneOfflineModeFetchException

@@ -121,4 +111,4 @@

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
) -> StringSeriesValues:

@@ -132,6 +122,4 @@ raise NeptuneOfflineModeFetchException

path: List[str],
offset: int,
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues:

@@ -163,49 +151,1 @@ raise NeptuneOfflineModeFetchException

raise NeptuneOfflineModeFetchException
def get_fields_with_paths_filter(
self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None
) -> List[Field]:
raise NeptuneOfflineModeFetchException
def get_fields_definitions(
self,
container_id: str,
container_type: ContainerType,
use_proto: Optional[bool] = None,
) -> List[FieldDefinition]:
raise NeptuneOfflineModeFetchException
def search_leaderboard_entries(
self,
project_id: UniqueId,
types: Optional[Iterable[ContainerType]] = None,
query: Optional[NQLQuery] = None,
columns: Optional[Iterable[str]] = None,
limit: Optional[int] = None,
sort_by: str = "sys/creation_time",
ascending: bool = False,
progress_bar: Optional[ProgressBarType] = None,
use_proto: Optional[bool] = None,
) -> Generator[LeaderboardEntry, None, None]:
raise NeptuneOfflineModeFetchException
def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
raise NeptuneOfflineModeFetchException
def query_fields_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
experiment_names_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
use_proto: Optional[bool] = None,
) -> QueryFieldsResult:
raise NeptuneOfflineModeFetchException

@@ -18,3 +18,3 @@ #

from neptune.internal.exceptions import InternalClientError
from neptune.common.exceptions import InternalClientError
from neptune.internal.operation import (

@@ -21,0 +21,0 @@ AddStrings,

@@ -18,3 +18,3 @@ #

from neptune.internal.exceptions import InternalClientError
from neptune.common.exceptions import InternalClientError
from neptune.internal.operation import (

@@ -21,0 +21,0 @@ AddStrings,

@@ -27,4 +27,4 @@ #

from neptune.common.exceptions import InternalClientError
from neptune.exceptions import MetadataInconsistency
from neptune.internal.exceptions import InternalClientError
from neptune.internal.operation import (

@@ -31,0 +31,0 @@ AddStrings,

@@ -28,2 +28,6 @@ #

from neptune.api.requests_utils import ensure_json_response
from neptune.common.exceptions import (
NeptuneAuthTokenExpired,
WritingToArchivedProjectException,
)
from neptune.exceptions import (

@@ -33,6 +37,2 @@ NeptuneFieldCountLimitExceedException,

)
from neptune.internal.exceptions import (
NeptuneAuthTokenExpired,
WritingToArchivedProjectException,
)

@@ -39,0 +39,0 @@

@@ -31,10 +31,7 @@ #

"construct_progress_bar",
"with_api_exceptions_handler",
]
import dataclasses
import itertools
import os
import socket
import time
from functools import (

@@ -60,23 +57,8 @@ lru_cache,

import requests
import urllib3
from bravado.client import SwaggerClient
from bravado.exception import (
BravadoConnectionError,
BravadoTimeoutError,
HTTPBadGateway,
HTTPClientError,
HTTPError,
HTTPForbidden,
HTTPGatewayTimeout,
HTTPInternalServerError,
HTTPRequestTimeout,
HTTPServiceUnavailable,
HTTPTooManyRequests,
HTTPUnauthorized,
)
from bravado.exception import HTTPError
from bravado.http_client import HttpClient
from bravado.requests_client import RequestsResponseAdapter
from bravado_core.formatter import SwaggerFormat
from bravado_core.util import RecursiveCallException
from packaging.version import Version

@@ -87,5 +69,8 @@ from requests import (

)
from requests.exceptions import ChunkedEncodingError
from urllib3.exceptions import NewConnectionError
from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.envs import NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE

@@ -100,12 +85,2 @@ from neptune.exceptions import (

from neptune.internal.backends.swagger_client_wrapper import SwaggerClientWrapper
from neptune.internal.envs import NEPTUNE_RETRIES_TIMEOUT_ENV
from neptune.internal.exceptions import (
ClientHttpError,
Forbidden,
NeptuneAuthTokenExpired,
NeptuneConnectionLostException,
NeptuneInvalidApiTokenException,
NeptuneSSLVerificationError,
Unauthorized,
)
from neptune.internal.operation import (

@@ -117,7 +92,2 @@ CopyAttribute,

from neptune.internal.utils.logger import get_logger
from neptune.internal.utils.utils import reset_internal_ssl_state
from neptune.internal.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.typing import (

@@ -138,115 +108,2 @@ ProgressBarCallback,

MAX_RETRY_TIME = 30
MAX_RETRY_MULTIPLIER = 10
retries_timeout = int(os.getenv(NEPTUNE_RETRIES_TIMEOUT_ENV, "60"))
def get_retry_from_headers_or_default(headers, retry_count):
try:
return (
int(headers["retry-after"][0]) if "retry-after" in headers else 2 ** min(MAX_RETRY_MULTIPLIER, retry_count)
)
except Exception:
return min(2 ** min(MAX_RETRY_MULTIPLIER, retry_count), MAX_RETRY_TIME)
def with_api_exceptions_handler(func):
def wrapper(*args, **kwargs):
ssl_error_occurred = False
last_exception = None
start_time = time.monotonic()
for retry in itertools.count(0):
if time.monotonic() - start_time > retries_timeout:
break
try:
return func(*args, **kwargs)
except requests.exceptions.InvalidHeader as e:
if "X-Neptune-Api-Token" in e.args[0]:
raise NeptuneInvalidApiTokenException()
raise
except requests.exceptions.SSLError as e:
"""
OpenSSL's internal random number generator does not properly handle forked processes.
Applications must change the PRNG state of the parent process
if they use any SSL feature with os.fork().
Any successful call of RAND_add(), RAND_bytes() or RAND_pseudo_bytes() is sufficient.
https://docs.python.org/3/library/ssl.html#multi-processing
On Linux it looks like it does not help much but does not break anything either.
But single retry seems to solve the issue.
"""
if not ssl_error_occurred:
ssl_error_occurred = True
reset_internal_ssl_state()
continue
if "CertificateError" in str(e.__context__):
raise NeptuneSSLVerificationError() from e
else:
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
except (
BravadoConnectionError,
BravadoTimeoutError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
HTTPRequestTimeout,
HTTPServiceUnavailable,
HTTPGatewayTimeout,
HTTPBadGateway,
HTTPInternalServerError,
NewConnectionError,
ChunkedEncodingError,
RecursiveCallException,
) as e:
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
except HTTPTooManyRequests as e:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
last_exception = e
continue
except NeptuneAuthTokenExpired as e:
last_exception = e
continue
except HTTPUnauthorized:
raise Unauthorized()
except HTTPForbidden:
raise Forbidden()
except HTTPClientError as e:
raise ClientHttpError(e.status_code, e.response.text) from e
except requests.exceptions.RequestException as e:
if e.response is None:
raise
status_code = e.response.status_code
if status_code in (
HTTPRequestTimeout.status_code,
HTTPBadGateway.status_code,
HTTPServiceUnavailable.status_code,
HTTPGatewayTimeout.status_code,
HTTPInternalServerError.status_code,
):
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
elif status_code == HTTPTooManyRequests.status_code:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
last_exception = e
continue
elif status_code == HTTPUnauthorized.status_code:
raise Unauthorized()
elif status_code == HTTPForbidden.status_code:
raise Forbidden()
elif 400 <= status_code < 500:
raise ClientHttpError(status_code, e.response.text) from e
else:
raise
raise NeptuneConnectionLostException(last_exception) from last_exception
return wrapper
@lru_cache(maxsize=None, typed=True)

@@ -285,16 +142,8 @@ def verify_host_resolution(url: str) -> None:

def verify_client_version(client_config: ClientConfig, version: Version):
base_version = Version(f"{version.major}.{version.minor}.{version.micro}")
version_with_patch_0 = Version(replace_patch_version(str(version)))
min_compatible = client_config.version_info.min_compatible
max_compatible = client_config.version_info.max_compatible
min_recommended = client_config.version_info.min_recommended
if min_compatible and min_compatible > base_version:
if client_config.version_info.min_compatible and client_config.version_info.min_compatible > version:
raise NeptuneClientUpgradeRequiredError(version, min_version=client_config.version_info.min_compatible)
if max_compatible and max_compatible < version_with_patch_0:
if client_config.version_info.max_compatible and client_config.version_info.max_compatible < version_with_patch_0:
raise NeptuneClientUpgradeRequiredError(version, max_version=client_config.version_info.max_compatible)
if min_recommended and min_recommended > version:
if client_config.version_info.min_recommended and client_config.version_info.min_recommended > version:
logger.warning(

@@ -305,3 +154,3 @@ "WARNING: Your version of the Neptune client library (%s) is deprecated,"

version,
min_recommended,
client_config.version_info.min_recommended,
)

@@ -308,0 +157,0 @@

@@ -28,3 +28,3 @@ #

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -36,3 +36,3 @@

def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
for job in self._jobs:

@@ -39,0 +39,0 @@ job.start(container)

@@ -25,3 +25,3 @@ #

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -31,3 +31,3 @@

@abc.abstractmethod
def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
pass

@@ -34,0 +34,0 @@

@@ -27,7 +27,7 @@ #

from neptune.common.envs import API_TOKEN_ENV_NAME
from neptune.common.exceptions import NeptuneInvalidApiTokenException
from neptune.constants import ANONYMOUS_API_TOKEN
from neptune.exceptions import NeptuneMissingApiTokenException
from neptune.internal.constants import ANONYMOUS_API_TOKEN_CONTENT
from neptune.internal.envs import API_TOKEN_ENV_NAME
from neptune.internal.exceptions import NeptuneInvalidApiTokenException

@@ -34,0 +34,0 @@

@@ -16,416 +16,6 @@ #

#
import platform
from typing import (
Any,
Optional,
)
__all__ = ["NeptuneInternalException"]
from neptune.internal.envs import (
API_TOKEN_ENV_NAME,
PROJECT_ENV_NAME,
)
UNIX_STYLES = {
"h1": "\033[95m",
"h2": "\033[94m",
"blue": "\033[94m",
"python": "\033[96m",
"bash": "\033[95m",
"warning": "\033[93m",
"correct": "\033[92m",
"fail": "\033[91m",
"bold": "\033[1m",
"underline": "\033[4m",
"end": "\033[0m",
}
WINDOWS_STYLES = {
"h1": "",
"h2": "",
"python": "",
"bash": "",
"warning": "",
"correct": "",
"fail": "",
"bold": "",
"underline": "",
"end": "",
}
EMPTY_STYLES = {
"h1": "",
"h2": "",
"python": "",
"bash": "",
"warning": "",
"correct": "",
"fail": "",
"bold": "",
"underline": "",
"end": "",
}
if platform.system() in ["Linux", "Darwin"]:
STYLES = UNIX_STYLES
elif platform.system() == "Windows":
STYLES = WINDOWS_STYLES
else:
STYLES = EMPTY_STYLES
class NeptuneException(Exception):
def __eq__(self, other: Any) -> bool:
if type(other) is type(self):
return super().__eq__(other) and str(self).__eq__(str(other))
else:
return False
def __hash__(self) -> int:
return hash((super().__hash__(), str(self)))
class NeptuneInvalidApiTokenException(NeptuneException):
def __init__(self) -> None:
message = """
{h1}
----NeptuneInvalidApiTokenException------------------------------------------------
{end}
The provided API token is invalid.
Make sure you copied and provided your API token correctly.
You can get it or check if it is correct here:
- https://app.neptune.ai/get_my_api_token
There are two options to add it:
- specify it in your code
- set it as an environment variable in your operating system.
{h2}CODE{end}
Pass the token to the {bold}init_run(){end} function via the {bold}api_token{end} argument:
{python}neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN'){end}
{h2}ENVIRONMENT VARIABLE{end} {correct}(Recommended option){end}
or export or set an environment variable depending on your operating system:
{correct}Linux/Unix{end}
In your terminal run:
{bash}export {env_api_token}="YOUR_API_TOKEN"{end}
{correct}Windows{end}
In your CMD run:
{bash}set {env_api_token}="YOUR_API_TOKEN"{end}
and skip the {bold}api_token{end} argument of the {bold}init_run(){end} function:
{python}neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME'){end}
You may also want to check the following docs page:
- https://docs.neptune.ai/setup/setting_api_token/
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(env_api_token=API_TOKEN_ENV_NAME, **STYLES))
class UploadedFileChanged(NeptuneException):
def __init__(self, filename: str) -> None:
super().__init__("File {} changed during upload, restarting upload.".format(filename))
class InternalClientError(NeptuneException):
def __init__(self, msg: str) -> None:
message = """
{h1}
----InternalClientError-----------------------------------------------------------------------
{end}
The Neptune client library encountered an unexpected internal error:
{msg}
Please contact Neptune support.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(msg=msg, **STYLES))
class ClientHttpError(NeptuneException):
def __init__(self, status: str, response: str) -> None:
self.status = status
self.response = response
message = """
{h1}
----ClientHttpError-----------------------------------------------------------------------
{end}
The Neptune server returned the status {fail}{status}{end}.
The server response was:
{fail}{response}{end}
Verify the correctness of your call or contact Neptune support.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(status=status, response=response, **STYLES))
class NeptuneApiException(NeptuneException):
class NeptuneInternalException(Exception):
pass
class Forbidden(NeptuneApiException):
def __init__(self) -> None:
message = """
{h1}
----Forbidden-----------------------------------------------------------------------
{end}
You don't have permission to access the given resource.
- Verify that your API token is correct.
See: https://app.neptune.ai/get_my_api_token
- Verify that the provided project name is correct.
The correct project name should look like this: {correct}WORKSPACE_NAME/PROJECT_NAME{end}
It has two parts:
- {correct}WORKSPACE_NAME{end}: can be your username or your organization name
- {correct}PROJECT_NAME{end}: the name specified for the project
- Ask your organization administrator to grant you the necessary privileges to the project.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(**STYLES))
class Unauthorized(NeptuneApiException):
def __init__(self, msg: Optional[str] = None) -> None:
default_message = """
{h1}
----Unauthorized-----------------------------------------------------------------------
{end}
You don't have permission to access the given resource.
- Verify that your API token is correct.
See: https://app.neptune.ai/get_my_api_token
- Verify that the provided project name is correct.
The correct project name should look like this: {correct}WORKSPACE_NAME/PROJECT_NAME{end}
It has two parts:
- {correct}WORKSPACE_NAME{end}: can be your username or your organization name
- {correct}PROJECT_NAME{end}: the name specified for the project
- Ask your organization administrator to grant you the necessary privileges to the project.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
message = msg if msg is not None else default_message
super().__init__(message.format(**STYLES))
class NeptuneAuthTokenExpired(Unauthorized):
def __init__(self) -> None:
super().__init__("Authorization token expired")
class InternalServerError(NeptuneApiException):
def __init__(self, response: str) -> None:
message = """
{h1}
----InternalServerError-----------------------------------------------------------------------
{end}
The Neptune client library encountered an unexpected internal server error.
The server response was:
{fail}{response}{end}
Please try again later or contact Neptune support.
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
"""
super().__init__(message.format(response=response, **STYLES))
class NeptuneConnectionLostException(NeptuneException):
def __init__(self, cause: Exception) -> None:
self.cause = cause
message = """
{h1}
----NeptuneConnectionLostException---------------------------------------------------------
{end}
The connection to the Neptune server was lost.
If you are using the asynchronous (default) connection mode, Neptune continues to locally track your metadata and continuously tries to re-establish a connection to the Neptune servers.
If the connection is not re-established, you can upload your data later with the Neptune Command Line Interface tool:
{bash}neptune sync -p workspace_name/project_name{end}
What should I do?
- Check if your computer is connected to the internet.
- If your connection is unstable, consider working in offline mode:
{python}run = neptune.init_run(mode="offline"){end}
You can find detailed instructions on the following doc pages:
- https://docs.neptune.ai/api/connection_modes/#offline-mode
- https://docs.neptune.ai/api/neptune_sync/
You may also want to check the following docs page:
- https://docs.neptune.ai/api/connection_modes/#connectivity-issues
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
""" # noqa: E501
super().__init__(message.format(**STYLES))
class NeptuneSSLVerificationError(NeptuneException):
def __init__(self) -> None:
message = """
{h1}
----NeptuneSSLVerificationError-----------------------------------------------------------------------
{end}
The Neptune client was unable to verify your SSL Certificate.
{bold}What could have gone wrong?{end}
- You are behind a proxy that inspects traffic to Neptune servers.
- Contact your network administrator
- The SSL/TLS certificate of your on-premises installation is not recognized due to a custom Certificate Authority (CA).
- To check, run the following command in your terminal:
{bash}curl https://<your_domain>/api/backend/echo {end}
- Where <your_domain> is the address that you use to access Neptune app, such as abc.com
- Contact your network administrator if you get the following output:
{fail}"curl: (60) server certificate verification failed..."{end}
- Your machine software is outdated.
- Minimal OS requirements:
- Windows >= XP SP3
- macOS >= 10.12.1
- Ubuntu >= 12.04
- Debian >= 8
{bold}What can I do?{end}
You can manually configure Neptune to skip all SSL checks. To do that,
set the NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE environment variable to 'TRUE'.
{bold}Note: This might mean that your connection is less secure{end}.
Linux/Unix
In your terminal run:
{bash}export NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
Windows
In your terminal run:
{bash}set NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
Jupyter notebook
In your code cell:
{bash}%env NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE='TRUE'{end}
You may also want to check the following docs page:
- https://docs.neptune.ai/api/environment_variables/#neptune_allow_self_signed_certificate
{correct}Need help?{end}-> https://docs.neptune.ai/getting_help
""" # noqa: E501
super().__init__(message.format(**STYLES))
class FileNotFound(NeptuneException):
def __init__(self, path: str) -> None:
super(FileNotFound, self).__init__("File {} doesn't exist.".format(path))
class InvalidNotebookPath(NeptuneException):
def __init__(self, path: str) -> None:
super(InvalidNotebookPath, self).__init__(
"File {} is not a valid notebook. Should end with .ipynb.".format(path)
)
class NeptuneIncorrectProjectQualifiedNameException(NeptuneException):
def __init__(self, project_qualified_name: str) -> None:
message = """
{h1}
----NeptuneIncorrectProjectQualifiedNameException-----------------------------------------------------------------------
{end}
Project qualified name {fail}"{project_qualified_name}"{end} you specified was incorrect.
The correct project qualified name should look like this {correct}WORKSPACE/PROJECT_NAME{end}.
It has two parts:
- {correct}WORKSPACE{end}: which can be your username or your organization name
- {correct}PROJECT_NAME{end}: which is the actual project name you chose
For example, a project {correct}neptune-ai/credit-default-prediction{end} parts are:
- {correct}neptune-ai{end}: {underline}WORKSPACE{end} our company organization name
- {correct}credit-default-prediction{end}: {underline}PROJECT_NAME{end} a project name
The URL to this project looks like this: https://app.neptune.ai/neptune-ai/credit-default-prediction
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/workspace-project-and-user-management/index.html
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(NeptuneIncorrectProjectQualifiedNameException, self).__init__(
message.format(project_qualified_name=project_qualified_name, **STYLES)
)
class NeptuneMissingProjectQualifiedNameException(NeptuneException):
def __init__(self) -> None:
message = """
{h1}
----NeptuneMissingProjectQualifiedNameException-------------------------------------------------------------------------
{end}
Neptune client couldn't find your project name.
There are two options two add it:
- specify it in your code
- set an environment variable in your operating system.
{h2}CODE{end}
Pass it to {bold}neptune.init(){end} via {bold}project_qualified_name{end} argument:
{python}neptune.init(project_qualified_name='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN'){end}
{h2}ENVIRONMENT VARIABLE{end}
or export or set an environment variable depending on your operating system:
{correct}Linux/Unix{end}
In your terminal run:
{bash}export {env_project}=WORKSPACE_NAME/PROJECT_NAME{end}
{correct}Windows{end}
In your CMD run:
{bash}set {env_project}=WORKSPACE_NAME/PROJECT_NAME{end}
and skip the {bold}project_qualified_name{end} argument of {bold}neptune.init(){end}:
{python}neptune.init(api_token='YOUR_API_TOKEN'){end}
You may also want to check the following docs pages:
- https://docs-legacy.neptune.ai/workspace-project-and-user-management/index.html
- https://docs-legacy.neptune.ai/getting-started/quick-starts/log_first_experiment.html
{correct}Need help?{end}-> https://docs-legacy.neptune.ai/getting-started/getting-help.html
"""
super(NeptuneMissingProjectQualifiedNameException, self).__init__(
message.format(env_project=PROJECT_ENV_NAME, **STYLES)
)
class NotAFile(NeptuneException):
def __init__(self, path: str) -> None:
super(NotAFile, self).__init__("Path {} is not a file.".format(path))
class NotADirectory(NeptuneException):
def __init__(self, path: str) -> None:
super(NotADirectory, self).__init__("Path {} is not a directory.".format(path))
class WritingToArchivedProjectException(NeptuneException):
def __init__(self) -> None:
message = """
{h1}
----WritingToArchivedProjectException-----------------------------------------------------------------------
{end}
You're trying to write to a project that was archived.
Set the project as active again or use mode="read-only" at initialization to fetch metadata from it.
{correct}Need help?{end}-> https://docs.neptune.ai/help/error_writing_to_archived_project/
"""
super(WritingToArchivedProjectException, self).__init__(message.format(**STYLES))

@@ -30,3 +30,3 @@ #

from neptune.internal.warnings import (
from neptune.common.warnings import (
NeptuneWarning,

@@ -33,0 +33,0 @@ warn_once,

@@ -27,18 +27,18 @@ #

from neptune.common.hardware.gauges.gauge_factory import GaugeFactory
from neptune.common.hardware.gauges.gauge_mode import GaugeMode
from neptune.common.hardware.gpu.gpu_monitor import GPUMonitor
from neptune.common.hardware.metrics.metrics_factory import MetricsFactory
from neptune.common.hardware.metrics.reports.metric_reporter import MetricReporter
from neptune.common.hardware.metrics.reports.metric_reporter_factory import MetricReporterFactory
from neptune.common.hardware.resources.system_resource_info_factory import SystemResourceInfoFactory
from neptune.common.hardware.system.system_monitor import SystemMonitor
from neptune.common.utils import in_docker
from neptune.internal.background_job import BackgroundJob
from neptune.internal.hardware.gauges.gauge_factory import GaugeFactory
from neptune.internal.hardware.gauges.gauge_mode import GaugeMode
from neptune.internal.hardware.gpu.gpu_monitor import GPUMonitor
from neptune.internal.hardware.metrics.metrics_factory import MetricsFactory
from neptune.internal.hardware.metrics.reports.metric_reporter import MetricReporter
from neptune.internal.hardware.metrics.reports.metric_reporter_factory import MetricReporterFactory
from neptune.internal.hardware.resources.system_resource_info_factory import SystemResourceInfoFactory
from neptune.internal.hardware.system.system_monitor import SystemMonitor
from neptune.internal.threading.daemon import Daemon
from neptune.internal.utils.logger import get_logger
from neptune.internal.utils.utils import in_docker
from neptune.types.series import FloatSeries
if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -56,3 +56,3 @@ _logger = get_logger()

def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
gauge_mode = GaugeMode.CGROUP if in_docker() else GaugeMode.SYSTEM

@@ -109,3 +109,3 @@ system_resource_info = SystemResourceInfoFactory(

period: float,
container: "NeptuneObject",
container: "MetadataContainer",
metric_reporter: MetricReporter,

@@ -112,0 +112,0 @@ ):

@@ -36,2 +36,7 @@ #

from neptune.common.exceptions import NeptuneException
from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.constants import ASYNC_DIRECTORY

@@ -44,3 +49,2 @@ from neptune.core.components.abstract import WithResources

from neptune.exceptions import NeptuneSynchronizationAlreadyStoppedException
from neptune.internal.exceptions import NeptuneException
from neptune.internal.init.parameters import DEFAULT_STOP_TIMEOUT

@@ -62,6 +66,2 @@ from neptune.internal.operation import Operation

from neptune.internal.utils.logger import get_logger
from neptune.internal.warnings import (
NeptuneWarning,
warn_once,
)

@@ -68,0 +68,0 @@ if TYPE_CHECKING:

@@ -20,7 +20,7 @@ #

from neptune.internal.operation_processors.operation_processor import OperationProcessor
from neptune.internal.warnings import (
from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.internal.operation_processors.operation_processor import OperationProcessor

@@ -27,0 +27,0 @@ if TYPE_CHECKING:

@@ -35,3 +35,3 @@ #

from neptune.constants import NEPTUNE_DATA_DIRECTORY
from neptune.objects.structure_version import StructureVersion
from neptune.metadata_containers.structure_version import StructureVersion

@@ -38,0 +38,0 @@ if TYPE_CHECKING:

@@ -31,9 +31,9 @@ #

from neptune.common.exceptions import (
InternalClientError,
NeptuneException,
)
from neptune.core.components.operation_storage import OperationStorage
from neptune.exceptions import MalformedOperation
from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import (
InternalClientError,
NeptuneException,
)
from neptune.internal.types.file_types import FileType

@@ -40,0 +40,0 @@ from neptune.types.atoms.file import File

@@ -30,3 +30,3 @@ #

from neptune.internal.signals_processing.signals import Signal
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -40,4 +40,4 @@

async_no_progress_threshold: float,
async_lag_callback: Optional[Callable[["NeptuneObject"], None]] = None,
async_no_progress_callback: Optional[Callable[["NeptuneObject"], None]] = None,
async_lag_callback: Optional[Callable[["MetadataContainer"], None]] = None,
async_no_progress_callback: Optional[Callable[["MetadataContainer"], None]] = None,
period: float = 10,

@@ -51,6 +51,6 @@ ) -> None:

self._async_no_progress_threshold: float = async_no_progress_threshold
self._async_lag_callback: Optional[Callable[["NeptuneObject"], None]] = async_lag_callback
self._async_no_progress_callback: Optional[Callable[["NeptuneObject"], None]] = async_no_progress_callback
self._async_lag_callback: Optional[Callable[["MetadataContainer"], None]] = async_lag_callback
self._async_no_progress_callback: Optional[Callable[["MetadataContainer"], None]] = async_no_progress_callback
def start(self, container: "NeptuneObject") -> None:
def start(self, container: "MetadataContainer") -> None:
self._thread = SignalsProcessor(

@@ -57,0 +57,0 @@ period=self._period,

@@ -39,3 +39,3 @@ #

from neptune.internal.signals_processing.signals import Signal
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -48,8 +48,8 @@

period: float,
container: "NeptuneObject",
container: "MetadataContainer",
queue: "Queue[Signal]",
async_lag_threshold: float,
async_no_progress_threshold: float,
async_lag_callback: Optional[Callable[["NeptuneObject"], None]] = None,
async_no_progress_callback: Optional[Callable[["NeptuneObject"], None]] = None,
async_lag_callback: Optional[Callable[["MetadataContainer"], None]] = None,
async_no_progress_callback: Optional[Callable[["MetadataContainer"], None]] = None,
callbacks_interval: float = IN_BETWEEN_CALLBACKS_MINIMUM_INTERVAL,

@@ -60,8 +60,8 @@ in_async: bool = True,

self._container: "NeptuneObject" = container
self._container: "MetadataContainer" = container
self._queue: "Queue[Signal]" = queue
self._async_lag_threshold: float = async_lag_threshold
self._async_no_progress_threshold: float = async_no_progress_threshold
self._async_lag_callback: Optional[Callable[["NeptuneObject"], None]] = async_lag_callback
self._async_no_progress_callback: Optional[Callable[["NeptuneObject"], None]] = async_no_progress_callback
self._async_lag_callback: Optional[Callable[["MetadataContainer"], None]] = async_lag_callback
self._async_no_progress_callback: Optional[Callable[["MetadataContainer"], None]] = async_no_progress_callback
self._callbacks_interval: float = callbacks_interval

@@ -125,3 +125,3 @@ self._in_async: bool = in_async

def execute_callback(
*, callback: Callable[["NeptuneObject"], None], container: "NeptuneObject", in_async: bool
*, callback: Callable[["MetadataContainer"], None], container: "MetadataContainer", in_async: bool
) -> None:

@@ -128,0 +128,0 @@ if in_async:

@@ -25,2 +25,6 @@ #

from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.internal.signals_processing.signals import (

@@ -32,6 +36,2 @@ BatchLagSignal,

)
from neptune.internal.warnings import (
NeptuneWarning,
warn_once,
)

@@ -38,0 +38,0 @@

@@ -30,3 +30,3 @@ #

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -39,3 +39,3 @@

def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
self._logger = StdoutCaptureLogger(container, self._attribute_name)

@@ -61,3 +61,3 @@

def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
self._logger = StderrCaptureLogger(container, self._attribute_name)

@@ -64,0 +64,0 @@

@@ -24,9 +24,9 @@ #

from neptune.internal.threading.daemon import Daemon
from neptune.objects import NeptuneObject
from neptune.logging import Logger as NeptuneLogger
from neptune.metadata_containers import MetadataContainer
class StdStreamCaptureLogger:
def __init__(self, container: NeptuneObject, attribute_name: str, stream: TextIO):
self._container = container
self._attribute_name = attribute_name
def __init__(self, container: MetadataContainer, attribute_name: str, stream: TextIO):
self._logger = NeptuneLogger(container, attribute_name)
self.stream = stream

@@ -39,5 +39,2 @@ self._thread_local = threading.local()

def log_data(self, data):
self._container[self._attribute_name].append(data)
def pause(self):

@@ -75,7 +72,7 @@ self._log_data_queue.put_nowait(None)

break
self._logger.log_data(data)
self._logger._logger.log(data)
class StdoutCaptureLogger(StdStreamCaptureLogger):
def __init__(self, container: NeptuneObject, attribute_name: str):
def __init__(self, container: MetadataContainer, attribute_name: str):
super().__init__(container, attribute_name, sys.stdout)

@@ -90,3 +87,3 @@ sys.stdout = self

class StderrCaptureLogger(StdStreamCaptureLogger):
def __init__(self, container: NeptuneObject, attribute_name: str):
def __init__(self, container: MetadataContainer, attribute_name: str):
super().__init__(container, attribute_name, sys.stderr)

@@ -93,0 +90,0 @@ sys.stderr = self

@@ -23,3 +23,3 @@ #

from neptune.internal.exceptions import NeptuneConnectionLostException
from neptune.common.exceptions import NeptuneConnectionLostException
from neptune.internal.utils.logger import get_logger

@@ -26,0 +26,0 @@

@@ -35,4 +35,4 @@ #

from neptune.common.exceptions import NeptuneException
from neptune.exceptions import StreamAlreadyUsedException
from neptune.internal.exceptions import NeptuneException
from neptune.internal.utils import verify_type

@@ -39,0 +39,0 @@

@@ -85,2 +85,2 @@ #

else:
logger.error("[ERROR] File '%s' does not exist - skipping dependency file upload.", self._path)
logger.warning("File '%s' does not exist - skipping dependency file upload.", self._path)

@@ -19,4 +19,4 @@ #

from neptune.common.warnings import warn_once
from neptune.exceptions import NeptuneParametersCollision
from neptune.internal.warnings import warn_once

@@ -50,3 +50,7 @@ __all__ = ["deprecated", "deprecated_parameter"]

if required_kwarg_name in kwargs:
raise NeptuneParametersCollision(required_kwarg_name, deprecated_kwarg_name, method_name=f.__name__)
raise NeptuneParametersCollision(
required_kwarg_name,
deprecated_kwarg_name,
method_name=f.__name__,
)

@@ -66,1 +70,15 @@ warn_once(

return deco
def model_registry_deprecation(func):
@wraps(func)
def inner(*args, **kwargs):
warn_once(
"Neptune's model registry has been deprecated and will be removed in a future release."
"Use runs to store model metadata instead. For more, see https://docs.neptune.ai/model_registry/."
"If you are already using the model registry, you can migrate existing metadata to runs."
"Learn how: https://docs.neptune.ai/model_registry/migrate_to_runs/."
)
return func(*args, **kwargs)
return inner

@@ -34,2 +34,6 @@ #

from neptune.common.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.constants import NEPTUNE_DATA_DIRECTORY

@@ -41,6 +45,2 @@ from neptune.envs import (

from neptune.exceptions import NeptuneMaxDiskUtilizationExceeded
from neptune.internal.warnings import (
NeptuneWarning,
warn_once,
)

@@ -47,0 +47,0 @@

@@ -18,3 +18,3 @@ #

from neptune.api.models import FieldType
from neptune.internal.backends.api_model import AttributeType

@@ -31,18 +31,18 @@

atomic_attribute_types_map = {
FieldType.FLOAT.value: "floatProperties",
FieldType.INT.value: "intProperties",
FieldType.BOOL.value: "boolProperties",
FieldType.STRING.value: "stringProperties",
FieldType.DATETIME.value: "datetimeProperties",
FieldType.OBJECT_STATE.value: "experimentStateProperties",
FieldType.NOTEBOOK_REF.value: "notebookRefProperties",
AttributeType.FLOAT.value: "floatProperties",
AttributeType.INT.value: "intProperties",
AttributeType.BOOL.value: "boolProperties",
AttributeType.STRING.value: "stringProperties",
AttributeType.DATETIME.value: "datetimeProperties",
AttributeType.RUN_STATE.value: "experimentStateProperties",
AttributeType.NOTEBOOK_REF.value: "notebookRefProperties",
}
value_series_attribute_types_map = {
FieldType.FLOAT_SERIES.value: "floatSeriesProperties",
FieldType.STRING_SERIES.value: "stringSeriesProperties",
AttributeType.FLOAT_SERIES.value: "floatSeriesProperties",
AttributeType.STRING_SERIES.value: "stringSeriesProperties",
}
value_set_attribute_types_map = {
FieldType.STRING_SET.value: "stringSetProperties",
AttributeType.STRING_SET.value: "stringSetProperties",
}

@@ -52,6 +52,6 @@

_unmapped_attribute_types_map = {
FieldType.FILE_SET.value: "fileSetProperties", # TODO: return size?
FieldType.FILE.value: "fileProperties", # TODO: name? size?
FieldType.IMAGE_SERIES.value: "imageSeriesProperties", # TODO: return last step?
FieldType.GIT_REF.value: "gitRefProperties", # TODO: commit? branch?
AttributeType.FILE_SET.value: "fileSetProperties", # TODO: return size?
AttributeType.FILE.value: "fileProperties", # TODO: name? size?
AttributeType.IMAGE_SERIES.value: "imageSeriesProperties", # TODO: return last step?
AttributeType.GIT_REF.value: "gitRefProperties", # TODO: commit? branch?
}

@@ -58,0 +58,0 @@

@@ -252,2 +252,4 @@ #

def _get_figure_image_data(figure) -> bytes:
if figure.__class__.__name__ == "Axes":
figure = figure.figure
with io.BytesIO() as image_buffer:

@@ -287,3 +289,3 @@ figure.savefig(image_buffer, format="png", bbox_inches="tight")

def is_matplotlib_figure(image):
return image.__class__.__module__.startswith("matplotlib.") and image.__class__.__name__ == "Figure"
return image.__class__.__module__.startswith("matplotlib.") and image.__class__.__name__ in ["Figure", "Axes"]

@@ -290,0 +292,0 @@

@@ -40,3 +40,3 @@ #

sys.stderr at handler construction time.
This enables Neptune to capture stdout regardless
This enables neptune-client to capture stdout regardless
of logging configuration time.

@@ -43,0 +43,0 @@ Based on logging._StderrHandler from standard library.

@@ -28,3 +28,3 @@ #

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -40,3 +40,3 @@ _logger = get_logger()

def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
self._thread = self.ReportingThread(self._period, container)

@@ -63,3 +63,3 @@ self._thread.start()

class ReportingThread(Daemon):
def __init__(self, period: float, container: "NeptuneObject"):
def __init__(self, period: float, container: "MetadataContainer"):
super().__init__(sleep_time=period, name="NeptunePing")

@@ -66,0 +66,0 @@ self._container = container

@@ -20,3 +20,3 @@ #

from neptune.internal.exceptions import NeptuneException
from neptune.common.exceptions import NeptuneException

@@ -23,0 +23,0 @@

@@ -20,2 +20,4 @@ #

import boto3
from neptune.envs import S3_ENDPOINT_URL

@@ -34,5 +36,2 @@

endpoint_url = os.getenv(S3_ENDPOINT_URL)
import boto3
return boto3.resource(

@@ -39,0 +38,0 @@ service_name="s3",

@@ -26,3 +26,4 @@ #

from neptune.attributes import constants as attr_consts
from neptune.internal.storage import normalize_file_name
from neptune.common.storage.storage_utils import normalize_file_name
from neptune.common.utils import is_ipython
from neptune.internal.utils import (

@@ -33,3 +34,2 @@ does_paths_share_common_drive,

)
from neptune.internal.utils.utils import is_ipython
from neptune.vendor.lib_programname import (

@@ -36,0 +36,0 @@ empty_path,

@@ -31,3 +31,3 @@ #

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -44,3 +44,3 @@ _logger = get_logger()

def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
if not self._started:

@@ -47,0 +47,0 @@ path = self._path

@@ -23,6 +23,10 @@ #

from platform import node as get_hostname
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Type,
)

@@ -32,45 +36,48 @@

if TYPE_CHECKING:
pass
_logger = get_logger()
SYS_UNCAUGHT_EXCEPTION_HANDLER_TYPE = Callable[[Type[BaseException], BaseException, Optional[TracebackType]], Any]
class UncaughtExceptionHandler:
def __init__(self):
self._previous_uncaught_exception_handler = None
self._handlers = dict()
def __init__(self) -> None:
self._previous_uncaught_exception_handler: Optional[SYS_UNCAUGHT_EXCEPTION_HANDLER_TYPE] = None
self._handlers: Dict[uuid.UUID, Callable[[List[str]], None]] = dict()
self._lock = threading.Lock()
def activate(self):
with self._lock:
this = self
def trigger(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
header_lines = [
f"An uncaught exception occurred while run was active on worker {get_hostname()}.",
"Marking run as failed",
"Traceback:",
]
def exception_handler(exc_type, exc_val, exc_tb):
header_lines = [
f"An uncaught exception occurred while run was active on worker {get_hostname()}.",
"Marking run as failed",
"Traceback:",
]
traceback_lines = header_lines + traceback.format_tb(exc_tb) + str(exc_val).split("\n")
for _, handler in self._handlers.items():
handler(traceback_lines)
traceback_lines = header_lines + traceback.format_tb(exc_tb) + str(exc_val).split("\n")
for _, handler in self._handlers.items():
handler(traceback_lines)
def activate(self) -> None:
with self._lock:
if self._previous_uncaught_exception_handler is not None:
return
self._previous_uncaught_exception_handler = sys.excepthook
sys.excepthook = self.exception_handler
this._previous_uncaught_exception_handler(exc_type, exc_val, exc_tb)
def deactivate(self) -> None:
with self._lock:
if self._previous_uncaught_exception_handler is None:
self._previous_uncaught_exception_handler = sys.excepthook
sys.excepthook = exception_handler
def deactivate(self):
with self._lock:
return
sys.excepthook = self._previous_uncaught_exception_handler
self._previous_uncaught_exception_handler = None
def register(self, uid: uuid.UUID, handler: Callable[[List[str]], None]):
def register(self, uid: uuid.UUID, handler: Callable[[List[str]], None]) -> None:
with self._lock:
self._handlers[uid] = handler
def unregister(self, uid: uuid.UUID):
def unregister(self, uid: uuid.UUID) -> None:
with self._lock:

@@ -80,3 +87,9 @@ if uid in self._handlers:

def exception_handler(self, *args: Any, **kwargs: Any) -> None:
self.trigger(*args, **kwargs)
if self._previous_uncaught_exception_handler is not None:
self._previous_uncaught_exception_handler(*args, **kwargs)
instance = UncaughtExceptionHandler()

@@ -58,7 +58,7 @@ #

if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer
class ValueToAttributeVisitor(ValueVisitor[Attribute]):
def __init__(self, container: "NeptuneObject", path: List[str]):
def __init__(self, container: "MetadataContainer", path: List[str]):
self._container = container

@@ -65,0 +65,0 @@ self._path = path

@@ -33,2 +33,3 @@ #

)
from neptune.common.websockets.reconnecting_websocket import ReconnectingWebsocket
from neptune.internal.background_job import BackgroundJob

@@ -38,7 +39,6 @@ from neptune.internal.threading.daemon import Daemon

from neptune.internal.utils.logger import get_logger
from neptune.internal.websockets.reconnecting_websocket import ReconnectingWebsocket
from neptune.internal.websockets.websockets_factory import WebsocketsFactory
if TYPE_CHECKING:
from neptune.objects import NeptuneObject
from neptune.metadata_containers import MetadataContainer

@@ -54,3 +54,3 @@ logger = get_logger()

def start(self, container: "NeptuneObject"):
def start(self, container: "MetadataContainer"):
self._thread = self._ListenerThread(container, self._ws_factory.create())

@@ -80,3 +80,3 @@ self._thread.start()

class _ListenerThread(Daemon):
def __init__(self, container: "NeptuneObject", ws_client: ReconnectingWebsocket):
def __init__(self, container: "MetadataContainer", ws_client: ReconnectingWebsocket):
super().__init__(sleep_time=0, name="NeptuneWebhooks")

@@ -83,0 +83,0 @@ self._container = container

@@ -23,3 +23,3 @@ #

from neptune.internal.websockets.reconnecting_websocket import ReconnectingWebsocket
from neptune.common.websockets.reconnecting_websocket import ReconnectingWebsocket

@@ -26,0 +26,0 @@

@@ -54,2 +54,4 @@ #

from neptune.common.backends.utils import with_api_exceptions_handler
from neptune.common.envs import API_TOKEN_ENV_NAME
from neptune.internal.backends.hosted_client import (

@@ -65,6 +67,4 @@ DEFAULT_REQUEST_KWARGS,

ssl_verify,
with_api_exceptions_handler,
)
from neptune.internal.credentials import Credentials
from neptune.internal.envs import API_TOKEN_ENV_NAME
from neptune.internal.id_formats import QualifiedName

@@ -71,0 +71,0 @@ from neptune.internal.utils import (

@@ -20,3 +20,3 @@ #

from neptune.internal.utils.patterns import PROJECT_QUALIFIED_NAME_PATTERN
from neptune.common.patterns import PROJECT_QUALIFIED_NAME_PATTERN
from neptune.management.exceptions import (

@@ -23,0 +23,0 @@ ConflictingWorkspaceName,

@@ -26,10 +26,9 @@ #

from neptune.api.field_visitor import FieldToValueVisitor
from neptune.api.models import (
Field,
FieldType,
from neptune.exceptions import MetadataInconsistency
from neptune.integrations.pandas import to_pandas
from neptune.internal.backends.api_model import (
AttributeType,
AttributeWithProperties,
LeaderboardEntry,
)
from neptune.exceptions import MetadataInconsistency
from neptune.integrations.pandas import to_pandas
from neptune.internal.backends.neptune_backend import NeptuneBackend

@@ -42,2 +41,3 @@ from neptune.internal.container_type import ContainerType

)
from neptune.internal.utils.run_state import RunState
from neptune.typing import ProgressBarType

@@ -58,3 +58,3 @@

_id: str,
attributes: List[Field],
attributes: List[AttributeWithProperties],
):

@@ -64,4 +64,3 @@ self._backend = backend

self._id = _id
self._fields = attributes
self._field_to_value_visitor = FieldToValueVisitor()
self._attributes = attributes

@@ -71,13 +70,43 @@ def __getitem__(self, path: str) -> "LeaderboardHandler":

def get_attribute_type(self, path: str) -> FieldType:
for field in self._fields:
if field.path == path:
return field.type
def get_attribute_type(self, path: str) -> AttributeType:
for attr in self._attributes:
if attr.path == path:
return attr.type
raise ValueError("Could not find {} attribute".format(path))
raise ValueError(f"Could not find {path} field")
def get_attribute_value(self, path: str) -> Any:
for field in self._fields:
if field.path == path:
return self._field_to_value_visitor.visit(field)
for attr in self._attributes:
if attr.path == path:
_type = attr.type
if _type == AttributeType.RUN_STATE:
return RunState.from_api(attr.properties.get("value")).value
if _type in (
AttributeType.FLOAT,
AttributeType.INT,
AttributeType.BOOL,
AttributeType.STRING,
AttributeType.DATETIME,
):
return attr.properties.get("value")
if _type == AttributeType.FLOAT_SERIES or _type == AttributeType.STRING_SERIES:
return attr.properties.get("last")
if _type == AttributeType.IMAGE_SERIES:
raise MetadataInconsistency("Cannot get value for image series.")
if _type == AttributeType.FILE:
raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.")
if _type == AttributeType.FILE_SET:
raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.")
if _type == AttributeType.STRING_SET:
return set(attr.properties.get("values"))
if _type == AttributeType.GIT_REF:
return attr.properties.get("commit", {}).get("commitId")
if _type == AttributeType.NOTEBOOK_REF:
return attr.properties.get("notebookName")
if _type == AttributeType.ARTIFACT:
return attr.properties.get("hash")
logger.error(
"Attribute type %s not supported in this version, yielding None. Recommended client upgrade.",
_type,
)
return None
raise ValueError("Could not find {} attribute".format(path))

@@ -91,6 +120,6 @@

) -> None:
for attr in self._fields:
for attr in self._attributes:
if attr.path == path:
_type = attr.type
if _type == FieldType.FILE:
if _type == AttributeType.FILE:
self._backend.download_file(

@@ -113,6 +142,6 @@ container_id=self._id,

) -> None:
for attr in self._fields:
for attr in self._attributes:
if attr.path == path:
_type = attr.type
if _type == FieldType.FILE_SET:
if _type == AttributeType.FILE_SET:
self._backend.download_file_set(

@@ -143,5 +172,5 @@ container_id=self._id,

attr_type = self._table_entry.get_attribute_type(self._path)
if attr_type == FieldType.FILE:
if attr_type == AttributeType.FILE:
return self._table_entry.download_file_attribute(self._path, destination)
elif attr_type == FieldType.FILE_SET:
elif attr_type == AttributeType.FILE_SET:
return self._table_entry.download_file_set_attribute(path=self._path, destination=destination)

@@ -175,4 +204,4 @@ raise MetadataInconsistency("Cannot download file from attribute of type {}".format(attr_type))

container_type=self._container_type,
_id=entry.object_id,
attributes=entry.fields,
_id=entry.id,
attributes=entry.attributes,
)

@@ -179,0 +208,0 @@

@@ -184,4 +184,5 @@ #

Matplotlib figures and Seaborn figures.
autoscale: Whether Neptune should try to detect the pixel range automatically
and scale it to an acceptable format.
autoscale: Whether Neptune should try to scale image pixel values to better render them in the web app.
Scaling can distort images if their pixels lie outside the [0.0, 1.0] or [0, 255] range.
To disable auto-scaling, set the argument to False.

@@ -188,0 +189,0 @@ Returns:

@@ -28,9 +28,9 @@ #

from neptune.common.warnings import (
NeptuneUnsupportedValue,
warn_once,
)
from neptune.internal.types.stringify_value import extract_if_stringify_value
from neptune.internal.types.utils import is_unsupported_float
from neptune.internal.utils import is_collection
from neptune.internal.warnings import (
NeptuneUnsupportedValue,
warn_once,
)
from neptune.types.series.series import Series

@@ -37,0 +37,0 @@

@@ -16,7 +16,3 @@ #

#
__all__ = [
"SupportsNamespaces",
"ProgressBarCallback",
"ProgressBarType",
]
__all__ = ["SupportsNamespaces", "NeptuneObject", "NeptuneObjectCallback", "ProgressBarCallback", "ProgressBarType"]

@@ -34,3 +30,7 @@ import abc

from neptune.objects.abstract import SupportsNamespaces
from neptune.metadata_containers.abstract import (
NeptuneObject,
NeptuneObjectCallback,
SupportsNamespaces,
)

@@ -37,0 +37,0 @@

@@ -26,3 +26,2 @@ #

from typing import (
TYPE_CHECKING,
Any,

@@ -39,8 +38,11 @@ Mapping,

from neptune.internal.utils.logger import get_logger
from neptune.typing import ProgressBarCallback
from neptune.internal.utils.runningmode import (
in_interactive,
in_notebook,
)
from neptune.typing import (
NeptuneObject,
ProgressBarCallback,
)
if TYPE_CHECKING:
from neptune.objects.neptune_object import NeptuneObject
logger = get_logger()

@@ -63,2 +65,5 @@

>>> run["complex_dict"] = stringify_unsupported(complex_dict)
For more information, see:
https://docs.neptune.ai/setup/neptune-client_1-0_release_changes/#no-more-implicit-casting-to-string
"""

@@ -71,3 +76,3 @@ if isinstance(value, MutableMapping):

def stop_synchronization_callback(neptune_object: "NeptuneObject") -> None:
def stop_synchronization_callback(neptune_object: NeptuneObject) -> None:
"""Default callback function that stops a Neptune object's synchronization with the server.

@@ -120,4 +125,8 @@

super().__init__(*args, **kwargs)
interactive = in_interactive() or in_notebook()
from tqdm.auto import tqdm
if interactive:
from tqdm.notebook import tqdm
else:
from tqdm import tqdm # type: ignore

@@ -124,0 +133,0 @@ unit = unit if unit else ""

@@ -23,2 +23,4 @@ #

from neptune.common.warnings import warn_once
if sys.version_info >= (3, 8):

@@ -42,5 +44,18 @@ from importlib.metadata import PackageNotFoundError

neptune_version = check_version("neptune")
neptune_client_version = check_version("neptune-client")
if neptune_version is not None:
if neptune_version is not None and neptune_client_version is not None:
raise RuntimeError(
"We've detected that the 'neptune' and 'neptune-client' packages are both installed. "
"Uninstall each of them and then install only the new 'neptune' package. For more information, "
"see https://docs.neptune.ai/setup/upgrading/"
)
elif neptune_version is not None:
return neptune_version
elif neptune_client_version is not None:
warn_once(
"The 'neptune-client' package has been deprecated and will be removed in the future. Install "
"the 'neptune' package instead. For more, see https://docs.neptune.ai/setup/upgrading/"
)
return neptune_client_version

@@ -47,0 +62,0 @@ raise PackageNotFoundError("neptune")

#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ("fetch_series_values",)
from typing import (
Any,
Callable,
Iterator,
Optional,
TypeVar,
)
from neptune.api.models import (
FloatPointValue,
StringPointValue,
)
from neptune.internal.backends.utils import construct_progress_bar
from neptune.typing import ProgressBarType
PointValue = TypeVar("PointValue", StringPointValue, FloatPointValue)
def fetch_series_values(
getter: Callable[..., Any], path: str, step_size: int = 1000, progress_bar: Optional[ProgressBarType] = None
) -> Iterator[PointValue]:
first_batch = getter(from_step=None, limit=1)
data_count = 0
total = first_batch.total
last_step_value = (first_batch.values[-1].step - 1) if first_batch.values else None
progress_bar = False if total < step_size else progress_bar
if total <= 1:
yield from first_batch.values
return
with construct_progress_bar(progress_bar, f"Fetching {path} values") as bar:
bar.update(by=data_count, total=total)
while data_count < first_batch.total:
batch = getter(from_step=last_step_value, limit=step_size)
bar.update(by=len(batch.values), total=total)
yield from batch.values
last_step_value = batch.values[-1].step if batch.values else None
data_count += len(batch.values)
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ("FieldToValueVisitor",)
from datetime import datetime
from typing import (
Any,
Optional,
Set,
)
from neptune.api.models import (
ArtifactField,
BoolField,
DateTimeField,
FieldVisitor,
FileField,
FileSetField,
FloatField,
FloatSeriesField,
GitRefField,
ImageSeriesField,
IntField,
NotebookRefField,
ObjectStateField,
StringField,
StringSeriesField,
StringSetField,
)
from neptune.exceptions import MetadataInconsistency
class FieldToValueVisitor(FieldVisitor[Any]):
def visit_float(self, field: FloatField) -> float:
return field.value
def visit_int(self, field: IntField) -> int:
return field.value
def visit_bool(self, field: BoolField) -> bool:
return field.value
def visit_string(self, field: StringField) -> str:
return field.value
def visit_datetime(self, field: DateTimeField) -> datetime:
return field.value
def visit_file(self, field: FileField) -> None:
raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.")
def visit_file_set(self, field: FileSetField) -> None:
raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.")
def visit_float_series(self, field: FloatSeriesField) -> Optional[float]:
return field.last
def visit_string_series(self, field: StringSeriesField) -> Optional[str]:
return field.last
def visit_image_series(self, field: ImageSeriesField) -> None:
raise MetadataInconsistency("Cannot get value for image series.")
def visit_string_set(self, field: StringSetField) -> Set[str]:
return field.values
def visit_git_ref(self, field: GitRefField) -> Optional[str]:
return field.commit.commit_id if field.commit is not None else None
def visit_object_state(self, field: ObjectStateField) -> str:
return field.value
def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]:
return field.notebook_name
def visit_artifact(self, field: ArtifactField) -> str:
return field.hash
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
__all__ = (
"FileEntry",
"Field",
"FieldType",
"GitCommit",
"LeaderboardEntry",
"LeaderboardEntriesSearchResult",
"FieldVisitor",
"FloatField",
"IntField",
"BoolField",
"StringField",
"DateTimeField",
"FileField",
"FileSetField",
"FloatSeriesField",
"StringSeriesField",
"ImageSeriesField",
"StringSetField",
"GitRefField",
"ObjectStateField",
"NotebookRefField",
"ArtifactField",
"FieldDefinition",
"FloatSeriesValues",
"FloatPointValue",
"StringSeriesValues",
"StringPointValue",
"ImageSeriesValues",
"QueryFieldDefinitionsResult",
"NextPage",
"QueryFieldsResult",
)
import abc
import re
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from datetime import (
datetime,
timezone,
)
from enum import Enum
from typing import (
Any,
ClassVar,
Dict,
Generic,
List,
Optional,
Set,
Type,
TypeVar,
)
from neptune.api.proto.neptune_pb.api.model.attributes_pb2 import ProtoAttributeDefinitionDTO
from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import (
ProtoAttributeDTO,
ProtoAttributesDTO,
ProtoBoolAttributeDTO,
ProtoDatetimeAttributeDTO,
ProtoFloatAttributeDTO,
ProtoFloatSeriesAttributeDTO,
ProtoIntAttributeDTO,
ProtoLeaderboardEntriesSearchResultDTO,
ProtoStringAttributeDTO,
ProtoStringSetAttributeDTO,
)
from neptune.internal.utils.iso_dates import parse_iso_date
from neptune.internal.utils.run_state import RunState
Ret = TypeVar("Ret")
@dataclass
class FileEntry:
name: str
size: int
mtime: datetime
file_type: str
@classmethod
def from_dto(cls, file_dto: Any) -> "FileEntry":
return cls(name=file_dto.name, size=file_dto.size, mtime=file_dto.mtime, file_type=file_dto.fileType)
class FieldType(Enum):
FLOAT = "float"
INT = "int"
BOOL = "bool"
STRING = "string"
DATETIME = "datetime"
FILE = "file"
FILE_SET = "fileSet"
FLOAT_SERIES = "floatSeries"
STRING_SERIES = "stringSeries"
IMAGE_SERIES = "imageSeries"
STRING_SET = "stringSet"
GIT_REF = "gitRef"
OBJECT_STATE = "experimentState"
NOTEBOOK_REF = "notebookRef"
ARTIFACT = "artifact"
@dataclass
class Field(abc.ABC):
path: str
type: ClassVar[FieldType] = dataclass_field(init=False)
_registry: ClassVar[Dict[str, Type[Field]]] = {}
def __init_subclass__(cls, *args: Any, field_type: FieldType, **kwargs: Any) -> None:
super().__init_subclass__(*args, **kwargs)
cls.type = field_type
cls._registry[field_type.value] = cls
@classmethod
def by_type(cls, field_type: FieldType) -> Type[Field]:
return cls._registry[field_type.value]
@abc.abstractmethod
def accept(self, visitor: FieldVisitor[Ret]) -> Ret: ...
@staticmethod
def from_dict(data: Dict[str, Any]) -> Field:
field_type = data["type"]
return Field._registry[field_type].from_dict(data[f"{field_type}Properties"])
@staticmethod
def from_model(model: Any) -> Field:
field_type = str(model.type)
return Field._registry[field_type].from_model(model.__getattr__(f"{field_type}Properties"))
@staticmethod
def from_proto(data: Any) -> Field:
field_type = str(data.type)
return Field._registry[field_type].from_proto(data.__getattribute__(f"{camel_to_snake(field_type)}_properties"))
def camel_to_snake(name: str) -> str:
# Insert an underscore before any uppercase letters and convert the string to lowercase
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
# Handle the case where there are uppercase letters in the middle of the name
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
class FieldVisitor(Generic[Ret], abc.ABC):
def visit(self, field: Field) -> Ret:
return field.accept(self)
@abc.abstractmethod
def visit_float(self, field: FloatField) -> Ret: ...
@abc.abstractmethod
def visit_int(self, field: IntField) -> Ret: ...
@abc.abstractmethod
def visit_bool(self, field: BoolField) -> Ret: ...
@abc.abstractmethod
def visit_string(self, field: StringField) -> Ret: ...
@abc.abstractmethod
def visit_datetime(self, field: DateTimeField) -> Ret: ...
@abc.abstractmethod
def visit_file(self, field: FileField) -> Ret: ...
@abc.abstractmethod
def visit_file_set(self, field: FileSetField) -> Ret: ...
@abc.abstractmethod
def visit_float_series(self, field: FloatSeriesField) -> Ret: ...
@abc.abstractmethod
def visit_string_series(self, field: StringSeriesField) -> Ret: ...
@abc.abstractmethod
def visit_image_series(self, field: ImageSeriesField) -> Ret: ...
@abc.abstractmethod
def visit_string_set(self, field: StringSetField) -> Ret: ...
@abc.abstractmethod
def visit_git_ref(self, field: GitRefField) -> Ret: ...
@abc.abstractmethod
def visit_object_state(self, field: ObjectStateField) -> Ret: ...
@abc.abstractmethod
def visit_notebook_ref(self, field: NotebookRefField) -> Ret: ...
@abc.abstractmethod
def visit_artifact(self, field: ArtifactField) -> Ret: ...
@dataclass
class FloatField(Field, field_type=FieldType.FLOAT):
value: float
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_float(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> FloatField:
return FloatField(path=data["attributeName"], value=float(data["value"]))
@staticmethod
def from_model(model: Any) -> FloatField:
return FloatField(path=model.attributeName, value=model.value)
@staticmethod
def from_proto(data: ProtoFloatAttributeDTO) -> FloatField:
return FloatField(path=data.attribute_name, value=data.value)
@dataclass
class IntField(Field, field_type=FieldType.INT):
value: int
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_int(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> IntField:
return IntField(path=data["attributeName"], value=int(data["value"]))
@staticmethod
def from_model(model: Any) -> IntField:
return IntField(path=model.attributeName, value=model.value)
@staticmethod
def from_proto(data: ProtoIntAttributeDTO) -> IntField:
return IntField(path=data.attribute_name, value=data.value)
@dataclass
class BoolField(Field, field_type=FieldType.BOOL):
value: bool
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_bool(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> BoolField:
return BoolField(path=data["attributeName"], value=bool(data["value"]))
@staticmethod
def from_model(model: Any) -> BoolField:
return BoolField(path=model.attributeName, value=model.value)
@staticmethod
def from_proto(data: ProtoBoolAttributeDTO) -> BoolField:
return BoolField(path=data.attribute_name, value=data.value)
@dataclass
class StringField(Field, field_type=FieldType.STRING):
value: str
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_string(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> StringField:
return StringField(path=data["attributeName"], value=str(data["value"]))
@staticmethod
def from_model(model: Any) -> StringField:
return StringField(path=model.attributeName, value=model.value)
@staticmethod
def from_proto(data: ProtoStringAttributeDTO) -> StringField:
return StringField(path=data.attribute_name, value=data.value)
@dataclass
class DateTimeField(Field, field_type=FieldType.DATETIME):
value: datetime
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_datetime(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> DateTimeField:
return DateTimeField(path=data["attributeName"], value=parse_iso_date(data["value"]))
@staticmethod
def from_model(model: Any) -> DateTimeField:
return DateTimeField(path=model.attributeName, value=parse_iso_date(model.value))
@staticmethod
def from_proto(data: ProtoDatetimeAttributeDTO) -> DateTimeField:
return DateTimeField(
path=data.attribute_name, value=datetime.fromtimestamp(data.value / 1000.0, tz=timezone.utc)
)
@dataclass
class FileField(Field, field_type=FieldType.FILE):
name: str
ext: str
size: int
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_file(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> FileField:
return FileField(path=data["attributeName"], name=data["name"], ext=data["ext"], size=int(data["size"]))
@staticmethod
def from_model(model: Any) -> FileField:
return FileField(path=model.attributeName, name=model.name, ext=model.ext, size=model.size)
@staticmethod
def from_proto(data: Any) -> FileField:
raise NotImplementedError()
@dataclass
class FileSetField(Field, field_type=FieldType.FILE_SET):
size: int
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_file_set(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> FileSetField:
return FileSetField(path=data["attributeName"], size=int(data["size"]))
@staticmethod
def from_model(model: Any) -> FileSetField:
return FileSetField(path=model.attributeName, size=model.size)
@staticmethod
def from_proto(data: Any) -> FileSetField:
raise NotImplementedError()
@dataclass
class FloatSeriesField(Field, field_type=FieldType.FLOAT_SERIES):
last: Optional[float]
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_float_series(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> FloatSeriesField:
last = float(data["last"]) if "last" in data else None
return FloatSeriesField(path=data["attributeName"], last=last)
@staticmethod
def from_model(model: Any) -> FloatSeriesField:
return FloatSeriesField(path=model.attributeName, last=model.last)
@staticmethod
def from_proto(data: ProtoFloatSeriesAttributeDTO) -> FloatSeriesField:
last = data.last if data.HasField("last") else None
return FloatSeriesField(path=data.attribute_name, last=last)
@dataclass
class StringSeriesField(Field, field_type=FieldType.STRING_SERIES):
last: Optional[str]
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_string_series(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> StringSeriesField:
last = str(data["last"]) if "last" in data else None
return StringSeriesField(path=data["attributeName"], last=last)
@staticmethod
def from_model(model: Any) -> StringSeriesField:
return StringSeriesField(path=model.attributeName, last=model.last)
@staticmethod
def from_proto(data: Any) -> StringSeriesField:
raise NotImplementedError()
@dataclass
class ImageSeriesField(Field, field_type=FieldType.IMAGE_SERIES):
last_step: Optional[float]
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_image_series(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> ImageSeriesField:
last_step = float(data["lastStep"]) if "lastStep" in data else None
return ImageSeriesField(path=data["attributeName"], last_step=last_step)
@staticmethod
def from_model(model: Any) -> ImageSeriesField:
return ImageSeriesField(path=model.attributeName, last_step=model.lastStep)
@staticmethod
def from_proto(data: Any) -> ImageSeriesField:
raise NotImplementedError()
@dataclass
class StringSetField(Field, field_type=FieldType.STRING_SET):
values: Set[str]
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_string_set(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> StringSetField:
return StringSetField(path=data["attributeName"], values=set(map(str, data["values"])))
@staticmethod
def from_model(model: Any) -> StringSetField:
return StringSetField(path=model.attributeName, values=set(model.values))
@staticmethod
def from_proto(data: ProtoStringSetAttributeDTO) -> StringSetField:
return StringSetField(path=data.attribute_name, values=set(data.value))
@dataclass
class GitCommit:
commit_id: Optional[str]
@staticmethod
def from_dict(data: Dict[str, Any]) -> GitCommit:
commit_id = str(data["commitId"]) if "commitId" in data else None
return GitCommit(commit_id=commit_id)
@staticmethod
def from_model(model: Any) -> GitCommit:
return GitCommit(commit_id=model.commitId)
@staticmethod
def from_proto(data: Any) -> GitCommit:
raise NotImplementedError()
@dataclass
class GitRefField(Field, field_type=FieldType.GIT_REF):
commit: Optional[GitCommit]
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_git_ref(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> GitRefField:
commit = GitCommit.from_dict(data["commit"]) if "commit" in data else None
return GitRefField(path=data["attributeName"], commit=commit)
@staticmethod
def from_model(model: Any) -> GitRefField:
commit = GitCommit.from_model(model.commit) if model.commit is not None else None
return GitRefField(path=model.attributeName, commit=commit)
@staticmethod
def from_proto(data: ProtoAttributeDTO) -> GitRefField:
raise NotImplementedError()
@dataclass
class ObjectStateField(Field, field_type=FieldType.OBJECT_STATE):
value: str
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_object_state(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> ObjectStateField:
value = RunState.from_api(str(data["value"])).value
return ObjectStateField(path=data["attributeName"], value=value)
@staticmethod
def from_model(model: Any) -> ObjectStateField:
value = RunState.from_api(str(model.value)).value
return ObjectStateField(path=model.attributeName, value=value)
@staticmethod
def from_proto(data: Any) -> ObjectStateField:
raise NotImplementedError()
@dataclass
class NotebookRefField(Field, field_type=FieldType.NOTEBOOK_REF):
notebook_name: Optional[str]
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_notebook_ref(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> NotebookRefField:
notebook_name = str(data["notebookName"]) if "notebookName" in data else None
return NotebookRefField(path=data["attributeName"], notebook_name=notebook_name)
@staticmethod
def from_model(model: Any) -> NotebookRefField:
return NotebookRefField(path=model.attributeName, notebook_name=model.notebookName)
@staticmethod
def from_proto(data: Any) -> NotebookRefField:
raise NotImplementedError()
@dataclass
class ArtifactField(Field, field_type=FieldType.ARTIFACT):
hash: str
def accept(self, visitor: FieldVisitor[Ret]) -> Ret:
return visitor.visit_artifact(self)
@staticmethod
def from_dict(data: Dict[str, Any]) -> ArtifactField:
return ArtifactField(path=data["attributeName"], hash=str(data["hash"]))
@staticmethod
def from_model(model: Any) -> ArtifactField:
return ArtifactField(path=model.attributeName, hash=model.hash)
@staticmethod
def from_proto(data: Any) -> ArtifactField:
raise NotImplementedError()
@dataclass
class LeaderboardEntry:
object_id: str
fields: List[Field]
@staticmethod
def from_dict(data: Dict[str, Any]) -> LeaderboardEntry:
return LeaderboardEntry(
object_id=data["experimentId"], fields=[Field.from_dict(field) for field in data["attributes"]]
)
@staticmethod
def from_model(model: Any) -> LeaderboardEntry:
return LeaderboardEntry(
object_id=model.experimentId, fields=[Field.from_model(field) for field in model.attributes]
)
@staticmethod
def from_proto(data: ProtoAttributesDTO) -> LeaderboardEntry:
with_proto_support = {
FieldType.STRING.value,
FieldType.BOOL.value,
FieldType.INT.value,
FieldType.FLOAT.value,
FieldType.DATETIME.value,
FieldType.STRING_SET.value,
FieldType.FLOAT_SERIES.value,
}
return LeaderboardEntry(
object_id=data.experiment_id,
fields=[Field.from_proto(field) for field in data.attributes if str(field.type) in with_proto_support],
)
@dataclass
class LeaderboardEntriesSearchResult:
entries: List[LeaderboardEntry]
matching_item_count: int
@staticmethod
def from_dict(result: Dict[str, Any]) -> LeaderboardEntriesSearchResult:
return LeaderboardEntriesSearchResult(
entries=[LeaderboardEntry.from_dict(entry) for entry in result.get("entries", [])],
matching_item_count=result["matchingItemCount"],
)
@staticmethod
def from_model(result: Any) -> LeaderboardEntriesSearchResult:
return LeaderboardEntriesSearchResult(
entries=[LeaderboardEntry.from_model(entry) for entry in result.entries],
matching_item_count=result.matchingItemCount,
)
@staticmethod
def from_proto(data: ProtoLeaderboardEntriesSearchResultDTO) -> LeaderboardEntriesSearchResult:
return LeaderboardEntriesSearchResult(
entries=[LeaderboardEntry.from_proto(entry) for entry in data.entries],
matching_item_count=data.matching_item_count,
)
@dataclass
class NextPage:
limit: Optional[int]
next_page_token: Optional[str]
@staticmethod
def from_dict(data: Dict[str, Any]) -> NextPage:
return NextPage(limit=data.get("limit"), next_page_token=data.get("nextPageToken"))
@staticmethod
def from_model(model: Any) -> NextPage:
return NextPage(limit=model.limit, next_page_token=model.nextPageToken)
@staticmethod
def from_proto(data: Any) -> NextPage:
return NextPage(limit=data.limit, next_page_token=data.nextPageToken)
def to_dto(self) -> Dict[str, Any]:
return {
"limit": self.limit,
"nextPageToken": self.next_page_token,
}
@dataclass
class QueryFieldsExperimentResult:
object_id: str
object_key: str
fields: List[Field]
# Any field the type of which is not in this set will not be
# returned to the user. Applies to protobuf calls only.
PROTO_SUPPORTED_FIELD_TYPES = {
FieldType.STRING.value,
FieldType.BOOL.value,
FieldType.INT.value,
FieldType.FLOAT.value,
FieldType.DATETIME.value,
FieldType.STRING_SET.value,
FieldType.FLOAT_SERIES.value,
}
@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldsExperimentResult:
return QueryFieldsExperimentResult(
object_id=data["experimentId"],
object_key=data["experimentShortId"],
fields=[Field.from_dict(field) for field in data["attributes"]],
)
@staticmethod
def from_model(model: Any) -> QueryFieldsExperimentResult:
return QueryFieldsExperimentResult(
object_id=model.experimentId,
object_key=model.experimentShortId,
fields=[Field.from_model(field) for field in model.attributes],
)
@staticmethod
def from_proto(data: Any) -> QueryFieldsExperimentResult:
return QueryFieldsExperimentResult(
object_id=data.experimentId,
object_key=data.experimentShortId,
fields=[
Field.from_proto(field)
for field in data.attributes
if field.type in QueryFieldsExperimentResult.PROTO_SUPPORTED_FIELD_TYPES
],
)
@dataclass
class QueryFieldsResult:
entries: List[QueryFieldsExperimentResult]
next_page: NextPage
@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[QueryFieldsExperimentResult.from_dict(entry) for entry in data["entries"]],
next_page=NextPage.from_dict(data["nextPage"]),
)
@staticmethod
def from_model(model: Any) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[QueryFieldsExperimentResult.from_model(entry) for entry in model.entries],
next_page=NextPage.from_model(model.nextPage),
)
@staticmethod
def from_proto(data: Any) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[QueryFieldsExperimentResult.from_proto(entry) for entry in data.entries],
next_page=NextPage.from_proto(data.nextPage),
)
@dataclass
class QueryFieldDefinitionsResult:
entries: List[FieldDefinition]
next_page: NextPage
@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[FieldDefinition.from_dict(entry) for entry in data["entries"]],
next_page=NextPage.from_dict(data["nextPage"]),
)
@staticmethod
def from_model(model: Any) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[FieldDefinition.from_model(entry) for entry in model.entries],
next_page=NextPage.from_model(model.nextPage),
)
@staticmethod
def from_proto(data: Any) -> QueryFieldDefinitionsResult:
raise NotImplementedError()
@dataclass
class FieldDefinition:
path: str
type: FieldType
@staticmethod
def from_dict(data: Dict[str, Any]) -> FieldDefinition:
return FieldDefinition(path=data["name"], type=FieldType(data["type"]))
@staticmethod
def from_model(model: Any) -> FieldDefinition:
return FieldDefinition(path=model.name, type=FieldType(model.type))
@staticmethod
def from_proto(data: ProtoAttributeDefinitionDTO) -> FieldDefinition:
return FieldDefinition(path=data.name, type=FieldType(data.type))
@dataclass
class FloatSeriesValues:
total: int
values: List[FloatPointValue]
@staticmethod
def from_dict(data: Dict[str, Any]) -> FloatSeriesValues:
return FloatSeriesValues(
total=data["totalItemCount"], values=[FloatPointValue.from_dict(value) for value in data["values"]]
)
@staticmethod
def from_model(model: Any) -> FloatSeriesValues:
return FloatSeriesValues(
total=model.totalItemCount, values=[FloatPointValue.from_model(value) for value in model.values]
)
@staticmethod
def from_proto(data: Any) -> FloatSeriesValues:
return FloatSeriesValues(
total=data.total_item_count, values=[FloatPointValue.from_proto(value) for value in data.values]
)
@dataclass
class FloatPointValue:
timestamp: datetime
value: float
step: float
@staticmethod
def from_dict(data: Dict[str, Any]) -> FloatPointValue:
return FloatPointValue(
timestamp=datetime.fromtimestamp(data["timestampMillis"] / 1000.0, tz=timezone.utc),
value=float(data["value"]),
step=float(data["step"]),
)
@staticmethod
def from_model(model: Any) -> FloatPointValue:
return FloatPointValue(
timestamp=datetime.fromtimestamp(model.timestampMillis / 1000.0, tz=timezone.utc),
value=model.value,
step=model.step,
)
@staticmethod
def from_proto(data: Any) -> FloatPointValue:
return FloatPointValue(
timestamp=datetime.fromtimestamp(data.timestamp_millis / 1000.0, tz=timezone.utc),
value=data.value,
step=data.step,
)
@dataclass
class StringSeriesValues:
total: int
values: List[StringPointValue]
@staticmethod
def from_dict(data: Dict[str, Any]) -> StringSeriesValues:
return StringSeriesValues(
total=data["totalItemCount"], values=[StringPointValue.from_dict(value) for value in data["values"]]
)
@staticmethod
def from_model(model: Any) -> StringSeriesValues:
return StringSeriesValues(
total=model.totalItemCount, values=[StringPointValue.from_model(value) for value in model.values]
)
@staticmethod
def from_proto(data: Any) -> StringSeriesValues:
raise NotImplementedError()
@dataclass
class StringPointValue:
timestamp: datetime
step: float
value: str
@staticmethod
def from_dict(data: Dict[str, Any]) -> StringPointValue:
return StringPointValue(
timestamp=datetime.fromtimestamp(data["timestampMillis"] / 1000.0, tz=timezone.utc),
value=str(data["value"]),
step=float(data["step"]),
)
@staticmethod
def from_model(model: Any) -> StringPointValue:
return StringPointValue(
timestamp=datetime.fromtimestamp(model.timestampMillis / 1000.0, tz=timezone.utc),
value=model.value,
step=model.step,
)
@staticmethod
def from_proto(data: Any) -> StringPointValue:
raise NotImplementedError()
@dataclass
class ImageSeriesValues:
total: int
@staticmethod
def from_dict(data: Dict[str, Any]) -> ImageSeriesValues:
return ImageSeriesValues(total=data["totalItemCount"])
@staticmethod
def from_model(model: Any) -> ImageSeriesValues:
return ImageSeriesValues(total=model.totalItemCount)
@staticmethod
def from_proto(data: Any) -> ImageSeriesValues:
raise NotImplementedError()
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ("paginate_over",)
import abc
import itertools
from dataclasses import dataclass
from typing import (
Any,
Callable,
Iterator,
List,
Optional,
TypeVar,
)
from typing_extensions import Protocol
from neptune.api.models import NextPage
@dataclass
class WithPagination(abc.ABC):
next_page: Optional[NextPage]
T = TypeVar("T", bound=WithPagination)
Entry = TypeVar("Entry")
class Paginatable(Protocol):
def __call__(self, *, next_page: Optional[NextPage] = None, **kwargs: Any) -> Any: ...
def paginate_over(
getter: Paginatable,
extract_entries: Callable[[T], List[Entry]],
page_size: int = 50,
limit: Optional[int] = None,
**kwargs: Any,
) -> Iterator[Entry]:
"""
Generic approach to pagination via `NextPage`
"""
counter = 0
data = getter(**kwargs, next_page=NextPage(limit=page_size, next_page_token=None))
results = extract_entries(data)
if limit is not None:
counter = len(results[:limit])
yield from itertools.islice(results, limit)
while data.next_page is not None and data.next_page.next_page_token is not None:
to_fetch = page_size
if limit is not None:
if counter >= limit:
break
to_fetch = min(page_size, limit - counter)
data = getter(**kwargs, next_page=NextPage(limit=to_fetch, next_page_token=data.next_page.next_page_token))
results = extract_entries(data)
if limit is not None:
counter += len(results[:to_fetch])
yield from itertools.islice(results, to_fetch)
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: neptune_pb/api/model/attributes.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from neptune.api.proto.neptune_pb.api.model import leaderboard_entries_pb2 as neptune__pb_dot_api_dot_model_dot_leaderboard__entries__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n%neptune_pb/api/model/attributes.proto\x12\x11neptune.api.model\x1a.neptune_pb/api/model/leaderboard_entries.proto\"a\n\x1eProtoAttributesSearchResultDTO\x12?\n\x07\x65ntries\x18\x01 \x03(\x0b\x32..neptune.api.model.ProtoAttributeDefinitionDTO\"9\n\x1bProtoAttributeDefinitionDTO\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"\xa3\x01\n\x1dProtoQueryAttributesResultDTO\x12K\n\x07\x65ntries\x18\x01 \x03(\x0b\x32:.neptune.api.model.ProtoQueryAttributesExperimentResultDTO\x12\x35\n\x08nextPage\x18\x02 \x01(\x0b\x32#.neptune.api.model.ProtoNextPageDTO\"^\n\x10ProtoNextPageDTO\x12\x1a\n\rnextPageToken\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x12\n\x05limit\x18\x02 \x01(\rH\x01\x88\x01\x01\x42\x10\n\x0e_nextPageTokenB\x08\n\x06_limit\"\x94\x01\n\'ProtoQueryAttributesExperimentResultDTO\x12\x14\n\x0c\x65xperimentId\x18\x01 \x01(\t\x12\x19\n\x11\x65xperimentShortId\x18\x02 \x01(\t\x12\x38\n\nattributes\x18\x03 \x03(\x0b\x32$.neptune.api.model.ProtoAttributeDTOB4\n0ml.neptune.leaderboard.api.model.proto.generatedP\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'neptune.api.proto.neptune_pb.api.model.attributes_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\n0ml.neptune.leaderboard.api.model.proto.generatedP\001'
_globals['_PROTOATTRIBUTESSEARCHRESULTDTO']._serialized_start=108
_globals['_PROTOATTRIBUTESSEARCHRESULTDTO']._serialized_end=205
_globals['_PROTOATTRIBUTEDEFINITIONDTO']._serialized_start=207
_globals['_PROTOATTRIBUTEDEFINITIONDTO']._serialized_end=264
_globals['_PROTOQUERYATTRIBUTESRESULTDTO']._serialized_start=267
_globals['_PROTOQUERYATTRIBUTESRESULTDTO']._serialized_end=430
_globals['_PROTONEXTPAGEDTO']._serialized_start=432
_globals['_PROTONEXTPAGEDTO']._serialized_end=526
_globals['_PROTOQUERYATTRIBUTESEXPERIMENTRESULTDTO']._serialized_start=529
_globals['_PROTOQUERYATTRIBUTESEXPERIMENTRESULTDTO']._serialized_end=677
# @@protoc_insertion_point(module_scope)
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2
import sys
import typing
if sys.version_info >= (3, 8):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing_extensions.final
class ProtoAttributesSearchResultDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ENTRIES_FIELD_NUMBER: builtins.int
@property
def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoAttributeDefinitionDTO]: ...
def __init__(
self,
*,
entries: collections.abc.Iterable[global___ProtoAttributeDefinitionDTO] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["entries", b"entries"]) -> None: ...
global___ProtoAttributesSearchResultDTO = ProtoAttributesSearchResultDTO
@typing_extensions.final
class ProtoAttributeDefinitionDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NAME_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
name: builtins.str
type: builtins.str
def __init__(
self,
*,
name: builtins.str = ...,
type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["name", b"name", "type", b"type"]) -> None: ...
global___ProtoAttributeDefinitionDTO = ProtoAttributeDefinitionDTO
@typing_extensions.final
class ProtoQueryAttributesResultDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ENTRIES_FIELD_NUMBER: builtins.int
NEXTPAGE_FIELD_NUMBER: builtins.int
@property
def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoQueryAttributesExperimentResultDTO]: ...
@property
def nextPage(self) -> global___ProtoNextPageDTO: ...
def __init__(
self,
*,
entries: collections.abc.Iterable[global___ProtoQueryAttributesExperimentResultDTO] | None = ...,
nextPage: global___ProtoNextPageDTO | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["nextPage", b"nextPage"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["entries", b"entries", "nextPage", b"nextPage"]) -> None: ...
global___ProtoQueryAttributesResultDTO = ProtoQueryAttributesResultDTO
@typing_extensions.final
class ProtoNextPageDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NEXTPAGETOKEN_FIELD_NUMBER: builtins.int
LIMIT_FIELD_NUMBER: builtins.int
nextPageToken: builtins.str
limit: builtins.int
def __init__(
self,
*,
nextPageToken: builtins.str | None = ...,
limit: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_limit", b"_limit", "_nextPageToken", b"_nextPageToken", "limit", b"limit", "nextPageToken", b"nextPageToken"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_limit", b"_limit", "_nextPageToken", b"_nextPageToken", "limit", b"limit", "nextPageToken", b"nextPageToken"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_limit", b"_limit"]) -> typing_extensions.Literal["limit"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_nextPageToken", b"_nextPageToken"]) -> typing_extensions.Literal["nextPageToken"] | None: ...
global___ProtoNextPageDTO = ProtoNextPageDTO
@typing_extensions.final
class ProtoQueryAttributesExperimentResultDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
EXPERIMENTID_FIELD_NUMBER: builtins.int
EXPERIMENTSHORTID_FIELD_NUMBER: builtins.int
ATTRIBUTES_FIELD_NUMBER: builtins.int
experimentId: builtins.str
experimentShortId: builtins.str
@property
def attributes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2.ProtoAttributeDTO]: ...
def __init__(
self,
*,
experimentId: builtins.str = ...,
experimentShortId: builtins.str = ...,
attributes: collections.abc.Iterable[neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2.ProtoAttributeDTO] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attributes", b"attributes", "experimentId", b"experimentId", "experimentShortId", b"experimentShortId"]) -> None: ...
global___ProtoQueryAttributesExperimentResultDTO = ProtoQueryAttributesExperimentResultDTO
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: neptune_pb/api/model/leaderboard_entries.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n.neptune_pb/api/model/leaderboard_entries.proto\x12\x11neptune.api.model\"\xb3\x01\n&ProtoLeaderboardEntriesSearchResultDTO\x12\x1b\n\x13matching_item_count\x18\x01 \x01(\x03\x12\x1e\n\x11total_group_count\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x36\n\x07\x65ntries\x18\x03 \x03(\x0b\x32%.neptune.api.model.ProtoAttributesDTOB\x14\n\x12_total_group_count\"\xd1\x01\n\x12ProtoAttributesDTO\x12\x15\n\rexperiment_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x12\n\nproject_id\x18\x03 \x01(\t\x12\x17\n\x0forganization_id\x18\x04 \x01(\t\x12\x14\n\x0cproject_name\x18\x05 \x01(\t\x12\x19\n\x11organization_name\x18\x06 \x01(\t\x12\x38\n\nattributes\x18\x07 \x03(\x0b\x32$.neptune.api.model.ProtoAttributeDTO\"\xed\x05\n\x11ProtoAttributeDTO\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x44\n\x0eint_properties\x18\x03 \x01(\x0b\x32\'.neptune.api.model.ProtoIntAttributeDTOH\x00\x88\x01\x01\x12H\n\x10\x66loat_properties\x18\x04 \x01(\x0b\x32).neptune.api.model.ProtoFloatAttributeDTOH\x01\x88\x01\x01\x12J\n\x11string_properties\x18\x05 \x01(\x0b\x32*.neptune.api.model.ProtoStringAttributeDTOH\x02\x88\x01\x01\x12\x46\n\x0f\x62ool_properties\x18\x06 \x01(\x0b\x32(.neptune.api.model.ProtoBoolAttributeDTOH\x03\x88\x01\x01\x12N\n\x13\x64\x61tetime_properties\x18\x07 \x01(\x0b\x32,.neptune.api.model.ProtoDatetimeAttributeDTOH\x04\x88\x01\x01\x12Q\n\x15string_set_properties\x18\x08 \x01(\x0b\x32-.neptune.api.model.ProtoStringSetAttributeDTOH\x05\x88\x01\x01\x12U\n\x17\x66loat_series_properties\x18\t \x01(\x0b\x32/.neptune.api.model.ProtoFloatSeriesAttributeDTOH\x06\x88\x01\x01\x42\x11\n\x0f_int_propertiesB\x13\n\x11_float_propertiesB\x14\n\x12_string_propertiesB\x12\n\x10_bool_propertiesB\x16\n\x14_datetime_propertiesB\x18\n\x16_string_set_propertiesB\x1a\n\x18_float_series_properties\"U\n\x14ProtoIntAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x03\"W\n\x16ProtoFloatAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x01\"X\n\x17ProtoStringAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\"V\n\x15ProtoBoolAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x08\"Z\n\x19ProtoDatetimeAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x03\"[\n\x1aProtoStringSetAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x03(\t\"\xd1\x02\n\x1cProtoFloatSeriesAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\x16\n\tlast_step\x18\x03 \x01(\x01H\x00\x88\x01\x01\x12\x11\n\x04last\x18\x04 \x01(\x01H\x01\x88\x01\x01\x12\x10\n\x03min\x18\x05 \x01(\x01H\x02\x88\x01\x01\x12\x10\n\x03max\x18\x06 \x01(\x01H\x03\x88\x01\x01\x12\x14\n\x07\x61verage\x18\x07 \x01(\x01H\x04\x88\x01\x01\x12\x15\n\x08variance\x18\x08 \x01(\x01H\x05\x88\x01\x01\x12\x45\n\x06\x63onfig\x18\t \x01(\x0b\x32\x35.neptune.api.model.ProtoFloatSeriesAttributeConfigDTOB\x0c\n\n_last_stepB\x07\n\x05_lastB\x06\n\x04_minB\x06\n\x04_maxB\n\n\x08_averageB\x0b\n\t_variance\"t\n\"ProtoFloatSeriesAttributeConfigDTO\x12\x10\n\x03min\x18\x01 \x01(\x01H\x00\x88\x01\x01\x12\x10\n\x03max\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12\x11\n\x04unit\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x06\n\x04_minB\x06\n\x04_maxB\x07\n\x05_unitB4\n0ml.neptune.leaderboard.api.model.proto.generatedP\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\n0ml.neptune.leaderboard.api.model.proto.generatedP\001'
_globals['_PROTOLEADERBOARDENTRIESSEARCHRESULTDTO']._serialized_start=70
_globals['_PROTOLEADERBOARDENTRIESSEARCHRESULTDTO']._serialized_end=249
_globals['_PROTOATTRIBUTESDTO']._serialized_start=252
_globals['_PROTOATTRIBUTESDTO']._serialized_end=461
_globals['_PROTOATTRIBUTEDTO']._serialized_start=464
_globals['_PROTOATTRIBUTEDTO']._serialized_end=1213
_globals['_PROTOINTATTRIBUTEDTO']._serialized_start=1215
_globals['_PROTOINTATTRIBUTEDTO']._serialized_end=1300
_globals['_PROTOFLOATATTRIBUTEDTO']._serialized_start=1302
_globals['_PROTOFLOATATTRIBUTEDTO']._serialized_end=1389
_globals['_PROTOSTRINGATTRIBUTEDTO']._serialized_start=1391
_globals['_PROTOSTRINGATTRIBUTEDTO']._serialized_end=1479
_globals['_PROTOBOOLATTRIBUTEDTO']._serialized_start=1481
_globals['_PROTOBOOLATTRIBUTEDTO']._serialized_end=1567
_globals['_PROTODATETIMEATTRIBUTEDTO']._serialized_start=1569
_globals['_PROTODATETIMEATTRIBUTEDTO']._serialized_end=1659
_globals['_PROTOSTRINGSETATTRIBUTEDTO']._serialized_start=1661
_globals['_PROTOSTRINGSETATTRIBUTEDTO']._serialized_end=1752
_globals['_PROTOFLOATSERIESATTRIBUTEDTO']._serialized_start=1755
_globals['_PROTOFLOATSERIESATTRIBUTEDTO']._serialized_end=2092
_globals['_PROTOFLOATSERIESATTRIBUTECONFIGDTO']._serialized_start=2094
_globals['_PROTOFLOATSERIESATTRIBUTECONFIGDTO']._serialized_end=2210
# @@protoc_insertion_point(module_scope)
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 8):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing_extensions.final
class ProtoLeaderboardEntriesSearchResultDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MATCHING_ITEM_COUNT_FIELD_NUMBER: builtins.int
TOTAL_GROUP_COUNT_FIELD_NUMBER: builtins.int
ENTRIES_FIELD_NUMBER: builtins.int
matching_item_count: builtins.int
total_group_count: builtins.int
@property
def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoAttributesDTO]: ...
def __init__(
self,
*,
matching_item_count: builtins.int = ...,
total_group_count: builtins.int | None = ...,
entries: collections.abc.Iterable[global___ProtoAttributesDTO] | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_total_group_count", b"_total_group_count", "total_group_count", b"total_group_count"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_total_group_count", b"_total_group_count", "entries", b"entries", "matching_item_count", b"matching_item_count", "total_group_count", b"total_group_count"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions.Literal["_total_group_count", b"_total_group_count"]) -> typing_extensions.Literal["total_group_count"] | None: ...
global___ProtoLeaderboardEntriesSearchResultDTO = ProtoLeaderboardEntriesSearchResultDTO
@typing_extensions.final
class ProtoAttributesDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
EXPERIMENT_ID_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
PROJECT_ID_FIELD_NUMBER: builtins.int
ORGANIZATION_ID_FIELD_NUMBER: builtins.int
PROJECT_NAME_FIELD_NUMBER: builtins.int
ORGANIZATION_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTES_FIELD_NUMBER: builtins.int
experiment_id: builtins.str
type: builtins.str
project_id: builtins.str
organization_id: builtins.str
project_name: builtins.str
organization_name: builtins.str
@property
def attributes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoAttributeDTO]: ...
def __init__(
self,
*,
experiment_id: builtins.str = ...,
type: builtins.str = ...,
project_id: builtins.str = ...,
organization_id: builtins.str = ...,
project_name: builtins.str = ...,
organization_name: builtins.str = ...,
attributes: collections.abc.Iterable[global___ProtoAttributeDTO] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attributes", b"attributes", "experiment_id", b"experiment_id", "organization_id", b"organization_id", "organization_name", b"organization_name", "project_id", b"project_id", "project_name", b"project_name", "type", b"type"]) -> None: ...
global___ProtoAttributesDTO = ProtoAttributesDTO
@typing_extensions.final
class ProtoAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NAME_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
INT_PROPERTIES_FIELD_NUMBER: builtins.int
FLOAT_PROPERTIES_FIELD_NUMBER: builtins.int
STRING_PROPERTIES_FIELD_NUMBER: builtins.int
BOOL_PROPERTIES_FIELD_NUMBER: builtins.int
DATETIME_PROPERTIES_FIELD_NUMBER: builtins.int
STRING_SET_PROPERTIES_FIELD_NUMBER: builtins.int
FLOAT_SERIES_PROPERTIES_FIELD_NUMBER: builtins.int
name: builtins.str
type: builtins.str
@property
def int_properties(self) -> global___ProtoIntAttributeDTO: ...
@property
def float_properties(self) -> global___ProtoFloatAttributeDTO: ...
@property
def string_properties(self) -> global___ProtoStringAttributeDTO: ...
@property
def bool_properties(self) -> global___ProtoBoolAttributeDTO: ...
@property
def datetime_properties(self) -> global___ProtoDatetimeAttributeDTO: ...
@property
def string_set_properties(self) -> global___ProtoStringSetAttributeDTO: ...
@property
def float_series_properties(self) -> global___ProtoFloatSeriesAttributeDTO: ...
def __init__(
self,
*,
name: builtins.str = ...,
type: builtins.str = ...,
int_properties: global___ProtoIntAttributeDTO | None = ...,
float_properties: global___ProtoFloatAttributeDTO | None = ...,
string_properties: global___ProtoStringAttributeDTO | None = ...,
bool_properties: global___ProtoBoolAttributeDTO | None = ...,
datetime_properties: global___ProtoDatetimeAttributeDTO | None = ...,
string_set_properties: global___ProtoStringSetAttributeDTO | None = ...,
float_series_properties: global___ProtoFloatSeriesAttributeDTO | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_bool_properties", b"_bool_properties", "_datetime_properties", b"_datetime_properties", "_float_properties", b"_float_properties", "_float_series_properties", b"_float_series_properties", "_int_properties", b"_int_properties", "_string_properties", b"_string_properties", "_string_set_properties", b"_string_set_properties", "bool_properties", b"bool_properties", "datetime_properties", b"datetime_properties", "float_properties", b"float_properties", "float_series_properties", b"float_series_properties", "int_properties", b"int_properties", "string_properties", b"string_properties", "string_set_properties", b"string_set_properties"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_bool_properties", b"_bool_properties", "_datetime_properties", b"_datetime_properties", "_float_properties", b"_float_properties", "_float_series_properties", b"_float_series_properties", "_int_properties", b"_int_properties", "_string_properties", b"_string_properties", "_string_set_properties", b"_string_set_properties", "bool_properties", b"bool_properties", "datetime_properties", b"datetime_properties", "float_properties", b"float_properties", "float_series_properties", b"float_series_properties", "int_properties", b"int_properties", "name", b"name", "string_properties", b"string_properties", "string_set_properties", b"string_set_properties", "type", b"type"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_bool_properties", b"_bool_properties"]) -> typing_extensions.Literal["bool_properties"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_datetime_properties", b"_datetime_properties"]) -> typing_extensions.Literal["datetime_properties"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_float_properties", b"_float_properties"]) -> typing_extensions.Literal["float_properties"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_float_series_properties", b"_float_series_properties"]) -> typing_extensions.Literal["float_series_properties"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_int_properties", b"_int_properties"]) -> typing_extensions.Literal["int_properties"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_string_properties", b"_string_properties"]) -> typing_extensions.Literal["string_properties"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_string_set_properties", b"_string_set_properties"]) -> typing_extensions.Literal["string_set_properties"] | None: ...
global___ProtoAttributeDTO = ProtoAttributeDTO
@typing_extensions.final
class ProtoIntAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
attribute_name: builtins.str
attribute_type: builtins.str
value: builtins.int
def __init__(
self,
*,
attribute_name: builtins.str = ...,
attribute_type: builtins.str = ...,
value: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ...
global___ProtoIntAttributeDTO = ProtoIntAttributeDTO
@typing_extensions.final
class ProtoFloatAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
attribute_name: builtins.str
attribute_type: builtins.str
value: builtins.float
def __init__(
self,
*,
attribute_name: builtins.str = ...,
attribute_type: builtins.str = ...,
value: builtins.float = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ...
global___ProtoFloatAttributeDTO = ProtoFloatAttributeDTO
@typing_extensions.final
class ProtoStringAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
attribute_name: builtins.str
attribute_type: builtins.str
value: builtins.str
def __init__(
self,
*,
attribute_name: builtins.str = ...,
attribute_type: builtins.str = ...,
value: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ...
global___ProtoStringAttributeDTO = ProtoStringAttributeDTO
@typing_extensions.final
class ProtoBoolAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
attribute_name: builtins.str
attribute_type: builtins.str
value: builtins.bool
def __init__(
self,
*,
attribute_name: builtins.str = ...,
attribute_type: builtins.str = ...,
value: builtins.bool = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ...
global___ProtoBoolAttributeDTO = ProtoBoolAttributeDTO
@typing_extensions.final
class ProtoDatetimeAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
attribute_name: builtins.str
attribute_type: builtins.str
value: builtins.int
def __init__(
self,
*,
attribute_name: builtins.str = ...,
attribute_type: builtins.str = ...,
value: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ...
global___ProtoDatetimeAttributeDTO = ProtoDatetimeAttributeDTO
@typing_extensions.final
class ProtoStringSetAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
attribute_name: builtins.str
attribute_type: builtins.str
@property
def value(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
attribute_name: builtins.str = ...,
attribute_type: builtins.str = ...,
value: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ...
global___ProtoStringSetAttributeDTO = ProtoStringSetAttributeDTO
@typing_extensions.final
class ProtoFloatSeriesAttributeDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int
ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int
LAST_STEP_FIELD_NUMBER: builtins.int
LAST_FIELD_NUMBER: builtins.int
MIN_FIELD_NUMBER: builtins.int
MAX_FIELD_NUMBER: builtins.int
AVERAGE_FIELD_NUMBER: builtins.int
VARIANCE_FIELD_NUMBER: builtins.int
CONFIG_FIELD_NUMBER: builtins.int
attribute_name: builtins.str
attribute_type: builtins.str
last_step: builtins.float
last: builtins.float
min: builtins.float
max: builtins.float
average: builtins.float
variance: builtins.float
@property
def config(self) -> global___ProtoFloatSeriesAttributeConfigDTO: ...
def __init__(
self,
*,
attribute_name: builtins.str = ...,
attribute_type: builtins.str = ...,
last_step: builtins.float | None = ...,
last: builtins.float | None = ...,
min: builtins.float | None = ...,
max: builtins.float | None = ...,
average: builtins.float | None = ...,
variance: builtins.float | None = ...,
config: global___ProtoFloatSeriesAttributeConfigDTO | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_average", b"_average", "_last", b"_last", "_last_step", b"_last_step", "_max", b"_max", "_min", b"_min", "_variance", b"_variance", "average", b"average", "config", b"config", "last", b"last", "last_step", b"last_step", "max", b"max", "min", b"min", "variance", b"variance"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_average", b"_average", "_last", b"_last", "_last_step", b"_last_step", "_max", b"_max", "_min", b"_min", "_variance", b"_variance", "attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "average", b"average", "config", b"config", "last", b"last", "last_step", b"last_step", "max", b"max", "min", b"min", "variance", b"variance"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_average", b"_average"]) -> typing_extensions.Literal["average"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_last", b"_last"]) -> typing_extensions.Literal["last"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_last_step", b"_last_step"]) -> typing_extensions.Literal["last_step"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_max", b"_max"]) -> typing_extensions.Literal["max"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_min", b"_min"]) -> typing_extensions.Literal["min"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_variance", b"_variance"]) -> typing_extensions.Literal["variance"] | None: ...
global___ProtoFloatSeriesAttributeDTO = ProtoFloatSeriesAttributeDTO
@typing_extensions.final
class ProtoFloatSeriesAttributeConfigDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MIN_FIELD_NUMBER: builtins.int
MAX_FIELD_NUMBER: builtins.int
UNIT_FIELD_NUMBER: builtins.int
min: builtins.float
max: builtins.float
unit: builtins.str
def __init__(
self,
*,
min: builtins.float | None = ...,
max: builtins.float | None = ...,
unit: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["_max", b"_max", "_min", b"_min", "_unit", b"_unit", "max", b"max", "min", b"min", "unit", b"unit"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["_max", b"_max", "_min", b"_min", "_unit", b"_unit", "max", b"max", "min", b"min", "unit", b"unit"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_max", b"_max"]) -> typing_extensions.Literal["max"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_min", b"_min"]) -> typing_extensions.Literal["min"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing_extensions.Literal["_unit", b"_unit"]) -> typing_extensions.Literal["unit"] | None: ...
global___ProtoFloatSeriesAttributeConfigDTO = ProtoFloatSeriesAttributeConfigDTO
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: neptune_pb/api/model/series_values.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(neptune_pb/api/model/series_values.proto\x12\x11neptune.api.model\"q\n\x19ProtoFloatSeriesValuesDTO\x12\x18\n\x10total_item_count\x18\x01 \x01(\x03\x12:\n\x06values\x18\x02 \x03(\x0b\x32*.neptune.api.model.ProtoFloatPointValueDTO\"P\n\x17ProtoFloatPointValueDTO\x12\x18\n\x10timestamp_millis\x18\x01 \x01(\x03\x12\x0c\n\x04step\x18\x02 \x01(\x01\x12\r\n\x05value\x18\x03 \x01(\x01\x42\x34\n0ml.neptune.leaderboard.api.model.proto.generatedP\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'neptune.api.proto.neptune_pb.api.model.series_values_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\n0ml.neptune.leaderboard.api.model.proto.generatedP\001'
_globals['_PROTOFLOATSERIESVALUESDTO']._serialized_start=63
_globals['_PROTOFLOATSERIESVALUESDTO']._serialized_end=176
_globals['_PROTOFLOATPOINTVALUEDTO']._serialized_start=178
_globals['_PROTOFLOATPOINTVALUEDTO']._serialized_end=258
# @@protoc_insertion_point(module_scope)
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import sys
if sys.version_info >= (3, 8):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing_extensions.final
class ProtoFloatSeriesValuesDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOTAL_ITEM_COUNT_FIELD_NUMBER: builtins.int
VALUES_FIELD_NUMBER: builtins.int
total_item_count: builtins.int
@property
def values(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoFloatPointValueDTO]: ...
def __init__(
self,
*,
total_item_count: builtins.int = ...,
values: collections.abc.Iterable[global___ProtoFloatPointValueDTO] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["total_item_count", b"total_item_count", "values", b"values"]) -> None: ...
global___ProtoFloatSeriesValuesDTO = ProtoFloatSeriesValuesDTO
@typing_extensions.final
class ProtoFloatPointValueDTO(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TIMESTAMP_MILLIS_FIELD_NUMBER: builtins.int
STEP_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
timestamp_millis: builtins.int
step: builtins.float
value: builtins.float
def __init__(
self,
*,
timestamp_millis: builtins.int = ...,
step: builtins.float = ...,
value: builtins.float = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["step", b"step", "timestamp_millis", b"timestamp_millis", "value", b"value"]) -> None: ...
global___ProtoFloatPointValueDTO = ProtoFloatPointValueDTO
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
API_TOKEN_ENV_NAME = "NEPTUNE_API_TOKEN"
NEPTUNE_RETRIES_TIMEOUT_ENV = "NEPTUNE_RETRIES_TIMEOUT"
PROJECT_ENV_NAME = "NEPTUNE_PROJECT"
NOTEBOOK_ID_ENV_NAME = "NEPTUNE_NOTEBOOK_ID"
NOTEBOOK_PATH_ENV_NAME = "NEPTUNE_NOTEBOOK_PATH"
BACKEND = "NEPTUNE_BACKEND"
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
class CGroupFilesystemReader(object):
def __init__(self):
cgroup_memory_dir = self.__cgroup_mount_dir(subsystem="memory")
self.__memory_usage_file = os.path.join(cgroup_memory_dir, "memory.usage_in_bytes")
self.__memory_limit_file = os.path.join(cgroup_memory_dir, "memory.limit_in_bytes")
cgroup_cpu_dir = self.__cgroup_mount_dir(subsystem="cpu")
self.__cpu_period_file = os.path.join(cgroup_cpu_dir, "cpu.cfs_period_us")
self.__cpu_quota_file = os.path.join(cgroup_cpu_dir, "cpu.cfs_quota_us")
cgroup_cpuacct_dir = self.__cgroup_mount_dir(subsystem="cpuacct")
self.__cpuacct_usage_file = os.path.join(cgroup_cpuacct_dir, "cpuacct.usage")
def get_memory_usage_in_bytes(self):
return self.__read_int_file(self.__memory_usage_file)
def get_memory_limit_in_bytes(self):
return self.__read_int_file(self.__memory_limit_file)
def get_cpu_quota_micros(self):
return self.__read_int_file(self.__cpu_quota_file)
def get_cpu_period_micros(self):
return self.__read_int_file(self.__cpu_period_file)
def get_cpuacct_usage_nanos(self):
return self.__read_int_file(self.__cpuacct_usage_file)
def __read_int_file(self, filename):
with open(filename) as f:
return int(f.read())
def __cgroup_mount_dir(self, subsystem):
"""
:param subsystem: cgroup subsystem like memory, cpu
:return: directory where given subsystem is mounted
"""
with open("/proc/mounts", "r") as f:
for line in f.readlines():
split_line = re.split(r"\s+", line)
mount_dir = split_line[1]
if "cgroup" in mount_dir:
dirname = mount_dir.split("/")[-1]
subsystems = dirname.split(",")
if subsystem in subsystems:
return mount_dir
assert False, 'Mount directory for "{}" subsystem not found'.format(subsystem)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from neptune.internal.hardware.cgroup.cgroup_filesystem_reader import CGroupFilesystemReader
from neptune.internal.hardware.system.system_monitor import SystemMonitor
class CGroupMonitor(object):
def __init__(self, cgroup_filesystem_reader, system_monitor):
self.__cgroup_filesystem_reader = cgroup_filesystem_reader
self.__system_monitor = system_monitor
self.__last_cpu_usage_measurement_timestamp_nanos = None
self.__last_cpu_cumulative_usage_nanos = None
@staticmethod
def create():
return CGroupMonitor(CGroupFilesystemReader(), SystemMonitor())
def get_memory_usage_in_bytes(self):
return self.__cgroup_filesystem_reader.get_memory_usage_in_bytes()
def get_memory_limit_in_bytes(self):
cgroup_mem_limit = self.__cgroup_filesystem_reader.get_memory_limit_in_bytes()
total_virtual_memory = self.__system_monitor.virtual_memory().total
return min(cgroup_mem_limit, total_virtual_memory)
def get_cpu_usage_limit_in_cores(self):
cpu_quota_micros = self.__cgroup_filesystem_reader.get_cpu_quota_micros()
if cpu_quota_micros == -1:
return float(self.__system_monitor.cpu_count())
else:
cpu_period_micros = self.__cgroup_filesystem_reader.get_cpu_period_micros()
return float(cpu_quota_micros) / float(cpu_period_micros)
def get_cpu_usage_percentage(self):
current_timestamp_nanos = time.time() * 10**9
cpu_cumulative_usage_nanos = self.__cgroup_filesystem_reader.get_cpuacct_usage_nanos()
if self.__first_measurement():
current_usage = 0.0
else:
usage_diff = cpu_cumulative_usage_nanos - self.__last_cpu_cumulative_usage_nanos
time_diff = current_timestamp_nanos - self.__last_cpu_usage_measurement_timestamp_nanos
current_usage = float(usage_diff) / float(time_diff) / self.get_cpu_usage_limit_in_cores() * 100.0
self.__last_cpu_usage_measurement_timestamp_nanos = current_timestamp_nanos
self.__last_cpu_cumulative_usage_nanos = cpu_cumulative_usage_nanos
# cgroup cpu usage may slightly exceed the given limit, but we don't want to show it
return self.__clamp(current_usage, lower_limit=0.0, upper_limit=100.0)
def __first_measurement(self):
return (
self.__last_cpu_usage_measurement_timestamp_nanos is None or self.__last_cpu_cumulative_usage_nanos is None
)
@staticmethod
def __clamp(value, lower_limit, upper_limit):
return max(lower_limit, min(value, upper_limit))
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["BYTES_IN_ONE_MB", "BYTES_IN_ONE_GB"]
BYTES_IN_ONE_MB = 2**20
BYTES_IN_ONE_GB = 2**30
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.cgroup.cgroup_monitor import CGroupMonitor
from neptune.internal.hardware.gauges.gauge import Gauge
from neptune.internal.hardware.system.system_monitor import SystemMonitor
class SystemCpuUsageGauge(Gauge):
def __init__(self):
self.__system_monitor = SystemMonitor()
def name(self):
return "cpu"
def value(self):
return self.__system_monitor.cpu_percent()
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("SystemCpuUsageGauge")
class CGroupCpuUsageGauge(Gauge):
def __init__(self):
self.__cgroup_monitor = CGroupMonitor.create()
def name(self):
return "cpu"
def value(self):
return self.__cgroup_monitor.get_cpu_usage_percentage()
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("CGroupCpuUsageGauge")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.gauges.cpu import (
CGroupCpuUsageGauge,
SystemCpuUsageGauge,
)
from neptune.internal.hardware.gauges.gauge_mode import GaugeMode
from neptune.internal.hardware.gauges.gpu import (
GpuMemoryGauge,
GpuUsageGauge,
)
from neptune.internal.hardware.gauges.memory import (
CGroupMemoryUsageGauge,
SystemMemoryUsageGauge,
)
class GaugeFactory(object):
def __init__(self, gauge_mode):
self.__gauge_mode = gauge_mode
def create_cpu_usage_gauge(self):
if self.__gauge_mode == GaugeMode.SYSTEM:
return SystemCpuUsageGauge()
elif self.__gauge_mode == GaugeMode.CGROUP:
return CGroupCpuUsageGauge()
else:
raise self.__invalid_gauge_mode_exception()
def create_memory_usage_gauge(self):
if self.__gauge_mode == GaugeMode.SYSTEM:
return SystemMemoryUsageGauge()
elif self.__gauge_mode == GaugeMode.CGROUP:
return CGroupMemoryUsageGauge()
else:
raise self.__invalid_gauge_mode_exception()
@staticmethod
def create_gpu_usage_gauge(card_index):
return GpuUsageGauge(card_index=card_index)
@staticmethod
def create_gpu_memory_gauge(card_index):
return GpuMemoryGauge(card_index=card_index)
def __invalid_gauge_mode_exception(self):
return ValueError(str("Invalid gauge mode: {}".format(self.__gauge_mode)))
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class GaugeMode(object):
SYSTEM = "system"
CGROUP = "cgroup"
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import (
ABCMeta,
abstractmethod,
)
class Gauge(object):
__metaclass__ = ABCMeta
@abstractmethod
def name(self):
"""
:return: Gauge name (str).
"""
raise NotImplementedError()
@abstractmethod
def value(self):
"""
:return: Current value (float).
"""
raise NotImplementedError()
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.constants import BYTES_IN_ONE_GB
from neptune.internal.hardware.gauges.gauge import Gauge
from neptune.internal.hardware.gpu.gpu_monitor import GPUMonitor
class GpuUsageGauge(Gauge):
def __init__(self, card_index):
self.card_index = card_index
self.__gpu_monitor = GPUMonitor()
def name(self):
return str(self.card_index)
def value(self):
return self.__gpu_monitor.get_card_usage_percent(self.card_index)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.card_index == other.card_index
def __repr__(self):
return str("GpuUsageGauge")
class GpuMemoryGauge(Gauge):
def __init__(self, card_index):
self.card_index = card_index
self.__gpu_monitor = GPUMonitor()
def name(self):
return str(self.card_index)
def value(self):
return self.__gpu_monitor.get_card_used_memory_in_bytes(self.card_index) / float(BYTES_IN_ONE_GB)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.card_index == other.card_index
def __repr__(self):
return str("GpuMemoryGauge")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.cgroup.cgroup_monitor import CGroupMonitor
from neptune.internal.hardware.constants import BYTES_IN_ONE_GB
from neptune.internal.hardware.gauges.gauge import Gauge
from neptune.internal.hardware.system.system_monitor import SystemMonitor
class SystemMemoryUsageGauge(Gauge):
def __init__(self):
self.__system_monitor = SystemMonitor()
def name(self):
return "ram"
def value(self):
virtual_mem = self.__system_monitor.virtual_memory()
return (virtual_mem.total - virtual_mem.available) / float(BYTES_IN_ONE_GB)
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("SystemMemoryUsageGauge")
class CGroupMemoryUsageGauge(Gauge):
def __init__(self):
self.__cgroup_monitor = CGroupMonitor.create()
def name(self):
return "ram"
def value(self):
return self.__cgroup_monitor.get_memory_usage_in_bytes() / float(BYTES_IN_ONE_GB)
def __eq__(self, other):
return self.__class__ == other.__class__
def __repr__(self):
return str("CGroupMemoryUsageGauge")
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["GPUMonitor"]
from neptune.internal.utils.logger import get_logger
from neptune.vendor.pynvml import (
NVMLError,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetUtilizationRates,
nvmlInit,
)
_logger = get_logger()
class GPUMonitor(object):
nvml_error_printed = False
def get_card_count(self):
return self.__nvml_get_or_else(nvmlDeviceGetCount, default=0)
def get_card_usage_percent(self, card_index):
return self.__nvml_get_or_else(
lambda: float(nvmlDeviceGetUtilizationRates(nvmlDeviceGetHandleByIndex(card_index)).gpu)
)
def get_card_used_memory_in_bytes(self, card_index):
return self.__nvml_get_or_else(lambda: nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(card_index)).used)
def get_top_card_memory_in_bytes(self):
def read_top_card_memory_in_bytes():
return self.__nvml_get_or_else(
lambda: [
nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(card_index)).total
for card_index in range(nvmlDeviceGetCount())
],
default=0,
)
memory_per_card = read_top_card_memory_in_bytes()
if not memory_per_card:
return 0
return max(memory_per_card)
def __nvml_get_or_else(self, getter, default=None):
try:
nvmlInit()
return getter()
except NVMLError as e:
if not GPUMonitor.nvml_error_printed:
warning = (
"Info (NVML): %s. GPU usage metrics may not be reported. For more information, "
"see https://docs.neptune.ai/help/nvml_error/"
)
_logger.warning(warning, e)
GPUMonitor.nvml_error_printed = True
return default
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class Metric(object):
def __init__(
self,
name,
description,
resource_type,
unit,
min_value,
max_value,
gauges,
internal_id=None,
):
self.__internal_id = internal_id
self.__name = name
self.__description = description
self.__resource_type = resource_type
self.__unit = unit
self.__min_value = min_value
self.__max_value = max_value
self.__gauges = gauges
@property
def internal_id(self):
return self.__internal_id
@internal_id.setter
def internal_id(self, value):
self.__internal_id = value
@property
def name(self):
return self.__name
@property
def description(self):
return self.__description
@property
def resource_type(self):
return self.__resource_type
@property
def unit(self):
return self.__unit
@property
def min_value(self):
return self.__min_value
@property
def max_value(self):
return self.__max_value
@property
def gauges(self):
return self.__gauges
def __repr__(self):
return (
"Metric(internal_id={}, name={}, description={}, resource_type={}, unit={}, min_value={}, "
"max_value={}, gauges={})"
).format(
self.internal_id,
self.name,
self.description,
self.resource_type,
self.unit,
self.min_value,
self.max_value,
self.gauges,
)
def __eq__(self, other):
return self.__class__ == other.__class__ and repr(self) == repr(other)
class MetricResourceType(object):
CPU = "CPU"
RAM = "MEMORY"
GPU = "GPU"
GPU_RAM = "GPU_MEMORY"
OTHER = "OTHER"
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class MetricsContainer(object):
def __init__(self, cpu_usage_metric, memory_metric, gpu_usage_metric, gpu_memory_metric):
self.cpu_usage_metric = cpu_usage_metric
self.memory_metric = memory_metric
self.gpu_usage_metric = gpu_usage_metric
self.gpu_memory_metric = gpu_memory_metric
def metrics(self):
return [
metric
for metric in [
self.cpu_usage_metric,
self.memory_metric,
self.gpu_usage_metric,
self.gpu_memory_metric,
]
if metric is not None
]
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.constants import BYTES_IN_ONE_GB
from neptune.internal.hardware.metrics.metric import (
Metric,
MetricResourceType,
)
from neptune.internal.hardware.metrics.metrics_container import MetricsContainer
class MetricsFactory(object):
def __init__(self, gauge_factory, system_resource_info):
self.__gauge_factory = gauge_factory
self.__system_resource_info = system_resource_info
def create_metrics_container(self):
cpu_usage_metric = self.__create_cpu_usage_metric()
memory_metric = self.__create_memory_metric()
has_gpu = self.__system_resource_info.has_gpu()
gpu_usage_metric = self.__create_gpu_usage_metric() if has_gpu else None
gpu_memory_metric = self.__create_gpu_memory_metric() if has_gpu else None
return MetricsContainer(
cpu_usage_metric=cpu_usage_metric,
memory_metric=memory_metric,
gpu_usage_metric=gpu_usage_metric,
gpu_memory_metric=gpu_memory_metric,
)
def __create_cpu_usage_metric(self):
return Metric(
name="CPU - usage",
description="average of all cores",
resource_type=MetricResourceType.CPU,
unit="%",
min_value=0.0,
max_value=100.0,
gauges=[self.__gauge_factory.create_cpu_usage_gauge()],
)
def __create_memory_metric(self):
return Metric(
name="RAM",
description="",
resource_type=MetricResourceType.RAM,
unit="GB",
min_value=0.0,
max_value=self.__system_resource_info.memory_amount_bytes / float(BYTES_IN_ONE_GB),
gauges=[self.__gauge_factory.create_memory_usage_gauge()],
)
def __create_gpu_usage_metric(self):
return Metric(
name="GPU - usage",
description="{} cards".format(self.__system_resource_info.gpu_card_count),
resource_type=MetricResourceType.GPU,
unit="%",
min_value=0.0,
max_value=100.0,
gauges=[
self.__gauge_factory.create_gpu_usage_gauge(card_index=card_index)
for card_index in self.__system_resource_info.gpu_card_indices
],
)
def __create_gpu_memory_metric(self):
return Metric(
name="GPU - memory",
description="{} cards".format(self.__system_resource_info.gpu_card_count),
resource_type=MetricResourceType.GPU_RAM,
unit="GB",
min_value=0.0,
max_value=self.__system_resource_info.gpu_memory_amount_bytes / float(BYTES_IN_ONE_GB),
gauges=[
self.__gauge_factory.create_gpu_memory_gauge(card_index=card_index)
for card_index in self.__system_resource_info.gpu_card_indices
],
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import namedtuple
MetricReport = namedtuple("MetricReport", ["metric", "values"])
MetricValue = namedtuple("MetricValue", ["timestamp", "running_time", "gauge_name", "value"])
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.metrics.reports.metric_reporter import MetricReporter
class MetricReporterFactory(object):
def __init__(self, reference_timestamp):
self.__reference_timestamp = reference_timestamp
def create(self, metrics):
return MetricReporter(metrics=metrics, reference_timestamp=self.__reference_timestamp)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.metrics.reports.metric_report import (
MetricReport,
MetricValue,
)
class MetricReporter(object):
def __init__(self, metrics, reference_timestamp):
self.__metrics = metrics
self.__reference_timestamp = reference_timestamp
def report(self, timestamp):
"""
:param timestamp: Time of measurement (float, seconds since Epoch).
:return: list[MetricReport]
"""
return [
MetricReport(
metric=metric,
values=[x for x in [self.__metric_value_for_gauge(gauge, timestamp) for gauge in metric.gauges] if x],
)
for metric in self.__metrics
]
def __metric_value_for_gauge(self, gauge, timestamp):
value = gauge.value()
return (
MetricValue(
timestamp=timestamp,
running_time=timestamp - self.__reference_timestamp,
gauge_name=gauge.name(),
value=value,
)
if value
else None
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.gauges.gauge_factory import GaugeFactory
from neptune.internal.hardware.gpu.gpu_monitor import GPUMonitor
from neptune.internal.hardware.metrics.metrics_factory import MetricsFactory
from neptune.internal.hardware.metrics.reports.metric_reporter_factory import MetricReporterFactory
from neptune.internal.hardware.metrics.service.metric_service import MetricService
from neptune.internal.hardware.resources.system_resource_info_factory import SystemResourceInfoFactory
from neptune.internal.hardware.system.system_monitor import SystemMonitor
class MetricServiceFactory(object):
def __init__(self, backend, os_environ):
self.__backend = backend
self.__os_environ = os_environ
def create(self, gauge_mode, experiment, reference_timestamp):
system_resource_info = SystemResourceInfoFactory(
system_monitor=SystemMonitor(),
gpu_monitor=GPUMonitor(),
os_environ=self.__os_environ,
).create(gauge_mode=gauge_mode)
gauge_factory = GaugeFactory(gauge_mode=gauge_mode)
metrics_factory = MetricsFactory(gauge_factory=gauge_factory, system_resource_info=system_resource_info)
metrics_container = metrics_factory.create_metrics_container()
for metric in metrics_container.metrics():
metric.internal_id = self.__backend.create_hardware_metric(experiment, metric)
metric_reporter = MetricReporterFactory(reference_timestamp).create(metrics=metrics_container.metrics())
return MetricService(
backend=self.__backend,
metric_reporter=metric_reporter,
experiment=experiment,
metrics_container=metrics_container,
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class MetricService(object):
def __init__(self, backend, metric_reporter, experiment, metrics_container):
self.__backend = backend
self.__metric_reporter = metric_reporter
self.experiment = experiment
self.metrics_container = metrics_container
def report_and_send(self, timestamp):
metric_reports = self.__metric_reporter.report(timestamp)
self.__backend.send_hardware_metric_reports(self.experiment, self.metrics_container.metrics(), metric_reports)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
class GPUCardIndicesProvider(object):
def __init__(self, cuda_visible_devices, gpu_card_count):
self.__cuda_visible_devices = cuda_visible_devices
self.__gpu_card_count = gpu_card_count
self.__cuda_visible_devices_regex = r"^-?\d+(,-?\d+)*$"
def get(self):
if self.__is_cuda_visible_devices_correct():
return self.__gpu_card_indices_from_cuda_visible_devices()
else:
return list(range(self.__gpu_card_count))
def __is_cuda_visible_devices_correct(self):
return self.__cuda_visible_devices is not None and re.match(
self.__cuda_visible_devices_regex, self.__cuda_visible_devices
)
def __gpu_card_indices_from_cuda_visible_devices(self):
correct_indices = []
# According to CUDA Toolkit specification.
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
for gpu_index_str in self.__cuda_visible_devices.split(","):
gpu_index = int(gpu_index_str)
if 0 <= gpu_index < self.__gpu_card_count:
correct_indices.append(gpu_index)
else:
break
return list(set(correct_indices))
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from neptune.internal.hardware.cgroup.cgroup_monitor import CGroupMonitor
from neptune.internal.hardware.gauges.gauge_mode import GaugeMode
from neptune.internal.hardware.resources.gpu_card_indices_provider import GPUCardIndicesProvider
from neptune.internal.hardware.resources.system_resource_info import SystemResourceInfo
class SystemResourceInfoFactory(object):
def __init__(self, system_monitor, gpu_monitor, os_environ):
self.__system_monitor = system_monitor
self.__gpu_monitor = gpu_monitor
self.__gpu_card_indices_provider = GPUCardIndicesProvider(
cuda_visible_devices=os_environ.get("CUDA_VISIBLE_DEVICES"),
gpu_card_count=self.__gpu_monitor.get_card_count(),
)
def create(self, gauge_mode):
if gauge_mode == GaugeMode.SYSTEM:
return self.__create_whole_system_resource_info()
elif gauge_mode == GaugeMode.CGROUP:
return self.__create_cgroup_resource_info()
else:
raise ValueError(str("Unknown gauge mode: {}".format(gauge_mode)))
def __create_whole_system_resource_info(self):
return SystemResourceInfo(
cpu_core_count=float(self.__system_monitor.cpu_count()),
memory_amount_bytes=self.__system_monitor.virtual_memory().total,
gpu_card_indices=self.__gpu_card_indices_provider.get(),
gpu_memory_amount_bytes=self.__gpu_monitor.get_top_card_memory_in_bytes(),
)
def __create_cgroup_resource_info(self):
cgroup_monitor = CGroupMonitor.create()
return SystemResourceInfo(
cpu_core_count=cgroup_monitor.get_cpu_usage_limit_in_cores(),
memory_amount_bytes=cgroup_monitor.get_memory_limit_in_bytes(),
gpu_card_indices=self.__gpu_card_indices_provider.get(),
gpu_memory_amount_bytes=self.__gpu_monitor.get_top_card_memory_in_bytes(),
)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class SystemResourceInfo(object):
def __init__(
self,
cpu_core_count,
memory_amount_bytes,
gpu_card_indices,
gpu_memory_amount_bytes,
):
self.__cpu_core_count = cpu_core_count
self.__memory_amount_bytes = memory_amount_bytes
self.__gpu_card_indices = gpu_card_indices
self.__gpu_memory_amount_bytes = gpu_memory_amount_bytes
@property
def cpu_core_count(self):
return self.__cpu_core_count
@property
def memory_amount_bytes(self):
return self.__memory_amount_bytes
@property
def gpu_card_count(self):
return len(self.__gpu_card_indices)
@property
def gpu_card_indices(self):
return self.__gpu_card_indices
@property
def gpu_memory_amount_bytes(self):
return self.__gpu_memory_amount_bytes
def has_gpu(self):
return self.gpu_card_count > 0
def __repr__(self):
return str(self.__dict__)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
try:
import psutil
PSUTIL_INSTALLED = True
except ImportError:
PSUTIL_INSTALLED = False
class SystemMonitor(object):
@staticmethod
def cpu_count():
return psutil.cpu_count()
@staticmethod
def cpu_percent():
return psutil.cpu_percent()
@staticmethod
def virtual_memory():
return psutil.virtual_memory()
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
import time
import jwt
from bravado.exception import HTTPUnauthorized
from bravado.requests_client import Authenticator
from oauthlib.oauth2 import (
OAuth2Error,
TokenExpiredError,
)
from requests.auth import AuthBase
from requests_oauthlib import OAuth2Session
from neptune.internal.backends.utils import with_api_exceptions_handler
from neptune.internal.exceptions import NeptuneInvalidApiTokenException
from neptune.internal.utils.utils import update_session_proxies
_decoding_options = {
"verify_signature": False,
"verify_exp": False,
"verify_nbf": False,
"verify_iat": False,
"verify_aud": False,
"verify_iss": False,
}
class NeptuneAuth(AuthBase):
__LOCK = threading.RLock()
def __init__(self, session_factory):
self.session_factory = session_factory
self.session = session_factory()
self.token_expires_at = 0
def __call__(self, r):
try:
return self._add_token(r)
except TokenExpiredError:
self._refresh_token()
return self._add_token(r)
def _add_token(self, r):
r.url, r.headers, r.body = self.session._client.add_token(
r.url, http_method=r.method, body=r.body, headers=r.headers
)
return r
@with_api_exceptions_handler
def refresh_token_if_needed(self):
if self.token_expires_at - time.time() < 30:
self._refresh_token()
def _refresh_token(self):
with self.__LOCK:
try:
self._refresh_session_token()
except OAuth2Error:
# for some reason oauth session is no longer valid. Retry by creating new fresh session
# we can safely ignore this error, as it will be thrown again if it's persistent
try:
self.session.close()
except Exception:
pass
self.session = self.session_factory()
self._refresh_session_token()
def _refresh_session_token(self):
self.session.refresh_token(self.session.auto_refresh_url, verify=self.session.verify)
if self.session.token is not None and self.session.token.get("access_token") is not None:
decoded_json_token = jwt.decode(self.session.token.get("access_token"), options=_decoding_options)
self.token_expires_at = decoded_json_token.get("exp")
class NeptuneAuthenticator(Authenticator):
def __init__(self, api_token, backend_client, ssl_verify, proxies):
super(NeptuneAuthenticator, self).__init__(host="")
# We need to pass a lambda to be able to re-create fresh session at any time when needed
def session_factory():
try:
auth_tokens = backend_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result
except HTTPUnauthorized:
raise NeptuneInvalidApiTokenException()
decoded_json_token = jwt.decode(auth_tokens.accessToken, options=_decoding_options)
expires_at = decoded_json_token.get("exp")
client_name = decoded_json_token.get("azp")
refresh_url = "{realm_url}/protocol/openid-connect/token".format(realm_url=decoded_json_token.get("iss"))
token = {
"access_token": auth_tokens.accessToken,
"refresh_token": auth_tokens.refreshToken,
"expires_in": expires_at - time.time(),
}
session = OAuth2Session(
client_id=client_name,
token=token,
auto_refresh_url=refresh_url,
auto_refresh_kwargs={"client_id": client_name},
token_updater=_no_token_updater,
)
session.verify = ssl_verify
update_session_proxies(session, proxies)
return session
self.auth = NeptuneAuth(session_factory)
def matches(self, url):
return True
def apply(self, request):
self.auth.refresh_token_if_needed()
request.auth = self.auth
return request
def _no_token_updater():
# For unit tests.
return None
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["apply_patches"]
from neptune.internal.patches.bravado import patch as bravado_patch
patches = [bravado_patch]
# Apply patches when importing a patching module
# Should be called before usages of patched objects
def apply_patches():
for patch in patches:
patch()
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import re
import bravado_core.model
from bravado_core.model import (
_bless_models,
_collect_models,
_get_unprocessed_uri,
_post_process_spec,
_tag_models,
)
def _run_post_processing(spec):
visited_models = {}
def _call_post_process_spec(spec_dict):
# Discover all the models in spec_dict
_post_process_spec(
spec_dict=spec_dict,
spec_resolver=spec.resolver,
on_container_callbacks=[
functools.partial(
_tag_models,
visited_models=visited_models,
swagger_spec=spec,
),
functools.partial(
_bless_models,
visited_models=visited_models,
swagger_spec=spec,
),
functools.partial(
_collect_models,
models=spec.definitions,
swagger_spec=spec,
),
],
)
# Post process specs to identify models
_call_post_process_spec(spec.spec_dict)
processed_uris = {
uri
for uri in spec.resolver.store
if uri == spec.origin_url or re.match(r"http(s)?://json-schema\.org/draft(/\d{4})?-\d+/(schema|meta/.*)", uri)
}
additional_uri = _get_unprocessed_uri(spec, processed_uris)
while additional_uri is not None:
# Post process each referenced specs to identify models in definitions of linked files
with spec.resolver.in_scope(additional_uri):
_call_post_process_spec(
spec.resolver.store[additional_uri],
)
processed_uris.add(additional_uri)
additional_uri = _get_unprocessed_uri(spec, processed_uris)
# Issue: https://github.com/Yelp/bravado-core/issues/388
# Bravado currently makes additional requests to `json-schema.org` in order to gather mission schemas
# This makes `neptune` unable to run without internet connection or with a many security policies
def patch():
bravado_core.model._run_post_processing = _run_post_processing
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"AttributeUploadConfiguration",
"UploadEntry",
"normalize_file_name",
"scan_unique_upload_entries",
"split_upload_files",
"FileChunk",
"FileChunker",
"compress_to_tar_gz_in_memory",
]
from neptune.internal.storage.datastream import (
FileChunk,
FileChunker,
compress_to_tar_gz_in_memory,
)
from neptune.internal.storage.storage_utils import (
AttributeUploadConfiguration,
UploadEntry,
normalize_file_name,
scan_unique_upload_entries,
split_upload_files,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
import io
import math
import os
import tarfile
from typing import (
Any,
Generator,
Optional,
)
from neptune.internal.backends.api_model import MultipartConfig
from neptune.internal.exceptions import (
InternalClientError,
UploadedFileChanged,
)
@dataclasses.dataclass
class FileChunk:
data: bytes
start: int
end: int
class FileChunker:
def __init__(self, filename: Optional[str], fobj, total_size, multipart_config: MultipartConfig):
self._filename: Optional[str] = filename
self._fobj = fobj
self._total_size = total_size
self._min_chunk_size = multipart_config.min_chunk_size
self._max_chunk_size = multipart_config.max_chunk_size
self._max_chunk_count = multipart_config.max_chunk_count
def _get_chunk_size(self) -> int:
if self._total_size > self._max_chunk_count * self._max_chunk_size:
# can't fit it
max_size = self._max_chunk_count * self._max_chunk_size
raise InternalClientError(
f"File {self._filename or 'stream'} is too big to upload:"
f" {self._total_size} bytes exceeds max size {max_size}"
)
if self._total_size <= self._max_chunk_count * self._min_chunk_size:
# can be done as minimal size chunks -- go for it!
return self._min_chunk_size
else:
# need larger chunks -- split more or less equally
return math.ceil(self._total_size / self._max_chunk_count)
def generate(self) -> Generator[FileChunk, Any, None]:
chunk_size = self._get_chunk_size()
last_offset = 0
last_change: Optional = os.stat(self._filename).st_mtime if self._filename else None
while last_offset < self._total_size:
chunk = self._fobj.read(chunk_size)
if chunk:
if last_change and last_change < os.stat(self._filename).st_mtime:
raise UploadedFileChanged(self._filename)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
new_offset = last_offset + len(chunk)
yield FileChunk(data=chunk, start=last_offset, end=new_offset)
last_offset = new_offset
def compress_to_tar_gz_in_memory(upload_entries) -> bytes:
f = io.BytesIO(b"")
with tarfile.TarFile.open(fileobj=f, mode="w|gz", dereference=True) as archive:
for entry in upload_entries:
archive.add(name=entry.source, arcname=entry.target_path, recursive=True)
f.seek(0)
data = f.read()
return data
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import os
import stat
import time
from abc import (
ABCMeta,
abstractmethod,
)
from dataclasses import dataclass
from io import BytesIO
from pprint import pformat
from typing import (
BinaryIO,
Generator,
List,
Set,
Union,
)
from neptune.internal.utils.logger import get_logger
_logger = get_logger()
@dataclass
class AttributeUploadConfiguration:
chunk_size: int
class UploadEntry(object):
def __init__(self, source: Union[str, BytesIO], target_path: str):
self.source = source
self.target_path = target_path
def length(self) -> int:
if self.is_stream():
return self.source.getbuffer().nbytes
else:
return os.path.getsize(self.source)
def get_stream(self) -> Union[BinaryIO, io.BytesIO]:
if self.is_stream():
return self.source
else:
return io.open(self.source, "rb")
def get_permissions(self) -> str:
if self.is_stream():
return "----------"
else:
return self.permissions_to_unix_string(self.source)
@classmethod
def permissions_to_unix_string(cls, path):
st = 0
if os.path.exists(path):
st = os.lstat(path).st_mode
is_dir = "d" if stat.S_ISDIR(st) else "-"
dic = {
"7": "rwx",
"6": "rw-",
"5": "r-x",
"4": "r--",
"3": "-wx",
"2": "-w-",
"1": "--x",
"0": "---",
}
perm = ("%03o" % st)[-3:]
return is_dir + "".join(dic.get(x, x) for x in perm)
def __eq__(self, other):
"""
Returns true if both objects are equal
"""
return self.__dict__ == other.__dict__
def __ne__(self, other):
"""
Returns true if both objects are not equal
"""
return not self == other
def __hash__(self):
"""
Returns the hash of source and target path
"""
return hash((self.source, self.target_path))
def to_str(self):
"""
Returns the string representation of the model
"""
return pformat(self.__dict__)
def __repr__(self):
"""
For `print` and `pprint`
"""
return self.to_str()
def is_stream(self):
return hasattr(self.source, "read")
class UploadPackage(object):
def __init__(self):
self.items: List[UploadEntry] = []
self.size: int = 0
self.len: int = 0
def reset(self):
self.items = []
self.size = 0
self.len = 0
def update(self, entry: UploadEntry, size: int):
self.items.append(entry)
self.size += size
self.len += 1
def is_empty(self):
return self.len == 0
def __eq__(self, other):
"""
Returns true if both objects are equal
"""
return self.__dict__ == other.__dict__
def __ne__(self, other):
"""
Returns true if both objects are not equal
"""
return not self == other
def to_str(self):
"""
Returns the string representation of the model
"""
return pformat(self.__dict__)
def __repr__(self):
"""
For `print` and `pprint`
"""
return self.to_str()
class ProgressIndicator(metaclass=ABCMeta):
@abstractmethod
def progress(self, steps):
pass
@abstractmethod
def complete(self):
pass
class LoggingProgressIndicator(ProgressIndicator):
def __init__(self, total, frequency=10):
self.current = 0
self.total = total
self.last_warning = time.time()
self.frequency = frequency
_logger.warning(
"You are sending %dMB of source code to Neptune. "
"It is pretty uncommon - please make sure it's what you wanted.",
self.total / (1024 * 1024),
)
def progress(self, steps):
self.current += steps
if time.time() - self.last_warning > self.frequency:
_logger.warning(
"%d MB / %d MB (%d%%) of source code was sent to Neptune.",
self.current / (1024 * 1024),
self.total / (1024 * 1024),
100 * self.current / self.total,
)
self.last_warning = time.time()
def complete(self):
_logger.warning(
"%d MB (100%%) of source code was sent to Neptune.",
self.total / (1024 * 1024),
)
class SilentProgressIndicator(ProgressIndicator):
def __init__(self):
pass
def progress(self, steps):
pass
def complete(self):
pass
def scan_unique_upload_entries(upload_entries):
"""
Returns upload entries for all files that could be found for given upload entries.
In case of directory as upload entry, files we be taken from all subdirectories recursively.
Any duplicated entries are removed.
"""
walked_entries = set()
for entry in upload_entries:
if entry.is_stream() or not os.path.isdir(entry.source):
walked_entries.add(entry)
else:
for root, _, files in os.walk(entry.source):
path_relative_to_entry_source = os.path.relpath(root, entry.source)
target_root = os.path.normpath(os.path.join(entry.target_path, path_relative_to_entry_source))
for filename in files:
walked_entries.add(
UploadEntry(
os.path.join(root, filename),
os.path.join(target_root, filename),
)
)
return walked_entries
def split_upload_files(
upload_entries: Set[UploadEntry],
upload_configuration: AttributeUploadConfiguration,
max_files=500,
) -> Generator[UploadPackage, None, None]:
current_package = UploadPackage()
for entry in upload_entries:
if entry.is_stream():
if current_package.len > 0:
yield current_package
current_package.reset()
current_package.update(entry, 0)
yield current_package
current_package.reset()
else:
size = os.path.getsize(entry.source)
if (
size + current_package.size > upload_configuration.chunk_size or current_package.len > max_files
) and not current_package.is_empty():
yield current_package
current_package.reset()
current_package.update(entry, size)
yield current_package
def normalize_file_name(name):
return name.replace(os.sep, "/")
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
class GitInfo(object):
"""Class that keeps information about a git repository in experiment.
When :meth:`~neptune.projects.Project.create_experiment` is invoked, instance of this class is created to
store information about git repository.
This information is later presented in the experiment details tab in the Neptune web application.
Args:
commit_id (:obj:`str`): commit id sha.
message (:obj:`str`, optional, default is ``""``): commit message.
author_name (:obj:`str`, optional, default is ``""``): commit author username.
author_email (:obj:`str`, optional, default is ``""``): commit author email.
commit_date (:obj:`datetime.datetime`, optional, default is ``""``): commit datetime.
repository_dirty (:obj:`bool`, optional, default is ``True``):
``True``, if the repository has uncommitted changes, ``False`` otherwise.
"""
def __init__(
self,
commit_id,
message="",
author_name="",
author_email="",
commit_date="",
repository_dirty=True,
active_branch="",
remote_urls=None,
):
if remote_urls is None:
remote_urls = []
if commit_id is None:
raise TypeError("commit_id must not be None")
self.commit_id = commit_id
self.message = message
self.author_name = author_name
self.author_email = author_email
self.commit_date = commit_date
self.repository_dirty = repository_dirty
self.active_branch = active_branch
self.remote_urls = remote_urls
def __eq__(self, o):
return o is not None and self.__dict__ == o.__dict__
def __ne__(self, o):
return not self.__eq__(o)
def __str__(self):
return "GitInfo({})".format(self.commit_id)
def __repr__(self):
return str(self)
#
# Copyright (c) 2019, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
PROJECT_QUALIFIED_NAME_PATTERN = "^((?P<workspace>[^/]+)/){0,1}(?P<project>[^/]+)$"
__all__ = ["PROJECT_QUALIFIED_NAME_PATTERN"]
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import glob as globlib
import math
import os
import re
import ssl
import sys
import numpy as np
import pandas as pd
from neptune.internal import envs
from neptune.internal.exceptions import (
FileNotFound,
InvalidNotebookPath,
NeptuneIncorrectProjectQualifiedNameException,
NeptuneMissingProjectQualifiedNameException,
NotADirectory,
NotAFile,
)
from neptune.internal.utils.git_info import GitInfo
from neptune.internal.utils.logger import get_logger
from neptune.internal.utils.patterns import PROJECT_QUALIFIED_NAME_PATTERN
_logger = get_logger()
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
def reset_internal_ssl_state():
"""
OpenSSL's internal random number generator does not properly handle forked processes.
Applications must change the PRNG state of the parent process if they use any SSL feature with os.fork().
Any successful call of RAND_add(), RAND_bytes() or RAND_pseudo_bytes() is sufficient.
https://docs.python.org/3/library/ssl.html#multi-processing
"""
ssl.RAND_bytes(100)
def map_values(f_value, dictionary):
return dict((k, f_value(v)) for k, v in dictionary.items())
def map_keys(f_key, dictionary):
return dict((f_key(k), v) for k, v in dictionary.items())
def as_list(value):
if value is None or isinstance(value, list):
return value
else:
return [value]
def validate_notebook_path(path):
if not path.endswith(".ipynb"):
raise InvalidNotebookPath(path)
if not os.path.exists(path):
raise FileNotFound(path)
if not os.path.isfile(path):
raise NotAFile(path)
def assure_directory_exists(destination_dir):
"""Checks if `destination_dir` DIRECTORY exists, or creates one"""
if not destination_dir:
destination_dir = os.getcwd()
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)
elif not os.path.isdir(destination_dir):
raise NotADirectory(destination_dir)
return destination_dir
def align_channels_on_x(dataframe):
channel_dfs, common_x = _split_df_by_stems(dataframe)
return merge_dataframes([common_x] + channel_dfs, on="x", how="outer")
def get_channel_name_stems(columns):
return list(set([col[2:] for col in columns]))
def merge_dataframes(dataframes, on, how="outer"):
merged_df = functools.reduce(lambda left, right: pd.merge(left, right, on=on, how=how), dataframes)
return merged_df
def is_float(value):
try:
_ = float(value)
except ValueError:
return False
else:
return True
def is_nan_or_inf(value):
return math.isnan(value) or math.isinf(value)
def is_notebook():
try:
get_ipython # noqa: F821
return True
except Exception:
return False
def _split_df_by_stems(df):
channel_dfs, x_vals = [], []
for stem in get_channel_name_stems(df.columns):
channel_df = df[["x_{}".format(stem), "y_{}".format(stem)]]
channel_df.columns = ["x", stem]
channel_df = channel_df.dropna()
channel_dfs.append(channel_df)
x_vals.extend(channel_df["x"].tolist())
common_x = pd.DataFrame({"x": np.unique(x_vals)}, dtype=float)
return channel_dfs, common_x
def discover_git_repo_location():
import __main__
if hasattr(__main__, "__file__"):
return os.path.dirname(os.path.abspath(__main__.__file__))
return None
def update_session_proxies(session, proxies):
if proxies is not None:
try:
session.proxies.update(proxies)
except (TypeError, ValueError):
raise ValueError("Wrong proxies format: {}".format(proxies))
def get_git_info(repo_path=None):
"""Retrieve information about git repository.
If the attempt fails, ``None`` will be returned.
Args:
repo_path (:obj:`str`, optional, default is ``None``):
| Path to the repository from which extract information about git.
| If ``None`` is passed, calling ``get_git_info`` is equivalent to calling
``git.Repo(search_parent_directories=True)``.
Check `GitPython <https://gitpython.readthedocs.io/en/stable/reference.html#git.repo.base.Repo>`_
docs for more information.
Returns:
:class:`~neptune.git_info.GitInfo` - An object representing information about git repository.
Examples:
.. code:: python3
# Get git info from the current directory
git_info = get_git_info('.')
"""
try:
import git
repo = git.Repo(repo_path, search_parent_directories=True)
commit = repo.head.commit
active_branch = ""
try:
active_branch = repo.active_branch.name
except TypeError as e:
if str(e.args[0]).startswith("HEAD is a detached symbolic reference as it points to"):
active_branch = "Detached HEAD"
remote_urls = [remote.url for remote in repo.remotes]
return GitInfo(
commit_id=commit.hexsha,
message=commit.message,
author_name=commit.author.name,
author_email=commit.author.email,
commit_date=commit.committed_datetime,
repository_dirty=repo.is_dirty(index=False, untracked_files=True),
active_branch=active_branch,
remote_urls=remote_urls,
)
except: # noqa: E722
return None
def file_contains(filename, text):
with open(filename) as f:
for line in f:
if text in line:
return True
return False
def in_docker():
cgroup_file = "/proc/self/cgroup"
return os.path.exists("./dockerenv") or (os.path.exists(cgroup_file) and file_contains(cgroup_file, text="docker"))
def is_ipython():
try:
import IPython
ipython = IPython.core.getipython.get_ipython()
return ipython is not None
except ImportError:
return False
def glob(pathname):
if sys.version_info.major < 3 or (sys.version_info.major == 3 and sys.version_info.minor < 5):
return globlib.glob(pathname)
else:
return globlib.glob(pathname, recursive=True)
def assure_project_qualified_name(project_qualified_name):
project_qualified_name = project_qualified_name or os.getenv(envs.PROJECT_ENV_NAME)
if not project_qualified_name:
raise NeptuneMissingProjectQualifiedNameException()
if not re.match(PROJECT_QUALIFIED_NAME_PATTERN, project_qualified_name):
raise NeptuneIncorrectProjectQualifiedNameException(project_qualified_name)
return project_qualified_name
class NoopObject(object):
def __getattr__(self, name):
return self
def __getitem__(self, key):
return self
def __call__(self, *args, **kwargs):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"warn_once",
"warn_about_unsupported_type",
"NeptuneDeprecationWarning",
"NeptuneWarning",
"NeptuneUnsupportedType",
"NeptuneUnsupportedValue",
]
import os
import traceback
import warnings
import neptune
from neptune.internal.utils.logger import NEPTUNE_LOGGER_NAME
from neptune.internal.utils.runningmode import in_interactive
DEFAULT_FORMAT = "[%(name)s] [warning] %(filename)s:%(lineno)d: %(category)s: %(message)s\n"
INTERACTIVE_FORMAT = "[%(name)s] [warning] %(category)s: %(message)s\n"
class NeptuneDeprecationWarning(DeprecationWarning):
pass
class NeptuneUnsupportedValue(Warning):
pass
class NeptuneWarning(Warning):
pass
class NeptuneUnsupportedType(Warning):
pass
warnings.simplefilter("always", category=NeptuneDeprecationWarning)
MAX_WARNED_ONCE_CAPACITY = 1_000
warned_once = set()
path_to_root_module = os.path.dirname(os.path.realpath(neptune.__file__))
def get_user_code_stack_level():
call_stack = traceback.extract_stack()
for level, stack_frame in enumerate(reversed(call_stack)):
if path_to_root_module not in stack_frame.filename:
return level
return 2
def format_message(message, category, filename, lineno, line=None) -> str:
variables = {
"message": message,
"category": category.__name__,
"filename": filename,
"lineno": lineno,
"name": NEPTUNE_LOGGER_NAME,
}
message_format = INTERACTIVE_FORMAT if in_interactive() else DEFAULT_FORMAT
return message_format % variables
def warn_once(message: str, *, exception: type(Exception) = None):
if len(warned_once) < MAX_WARNED_ONCE_CAPACITY:
if exception is None:
exception = NeptuneDeprecationWarning
message_hash = hash(message)
if message_hash not in warned_once:
old_formatting = warnings.formatwarning
warnings.formatwarning = format_message
warnings.warn(
message=message,
category=exception,
stacklevel=get_user_code_stack_level(),
)
warnings.formatwarning = old_formatting
warned_once.add(message_hash)
def warn_about_unsupported_type(type_str: str):
warn_once(
message=f"""You're attempting to log a type that is not directly supported by Neptune ({type_str}).
Convert the value to a supported type, such as a string or float, or use stringify_unsupported(obj)
for dictionaries or collections that contain unsupported values.
For more, see https://docs.neptune.ai/help/value_of_unsupported_type""",
exception=NeptuneUnsupportedType,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from websocket import (
WebSocketConnectionClosedException,
WebSocketTimeoutException,
)
from neptune.internal.websockets.websocket_client_adapter import (
WebsocketClientAdapter,
WebsocketNotConnectedException,
)
class ReconnectingWebsocket(object):
def __init__(self, url, oauth2_session, shutdown_event, proxies=None):
self.url = url
self.client = WebsocketClientAdapter()
self._shutdown_event = shutdown_event
self._oauth2_session = oauth2_session
self._reconnect_counter = ReconnectCounter()
self._token = oauth2_session.token
self._proxies = proxies
def shutdown(self):
self._shutdown_event.set()
self.client.close()
self.client.abort()
self.client.shutdown()
def recv(self):
if not self.client.connected:
self._try_to_establish_connection()
while self._is_active():
try:
data = self.client.recv()
self._on_successful_connect()
return data
except WebSocketTimeoutException:
raise
except WebSocketConnectionClosedException:
if self._is_active():
self._handle_lost_connection()
else:
raise
except WebsocketNotConnectedException:
if self._is_active():
self._handle_lost_connection()
except Exception:
if self._is_active():
self._handle_lost_connection()
def _is_active(self):
return not self._shutdown_event.is_set()
def _on_successful_connect(self):
self._reconnect_counter.clear()
def _try_to_establish_connection(self):
try:
self._request_token_refresh()
if self.client.connected:
self.client.shutdown()
self.client.connect(url=self.url, token=self._token, proxies=self._proxies)
except Exception:
self._shutdown_event.wait(self._reconnect_counter.calculate_delay())
def _handle_lost_connection(self):
self._reconnect_counter.increment()
self._try_to_establish_connection()
def _request_token_refresh(self):
self._token = self._oauth2_session.refresh_token(token_url=self._oauth2_session.auto_refresh_url)
class ReconnectCounter(object):
MAX_RETRY_DELAY = 128
def __init__(self):
self.retries = 0
def clear(self):
self.retries = 0
def increment(self):
self.retries += 1
def calculate_delay(self):
return self._compute_delay(self.retries, self.MAX_RETRY_DELAY)
@classmethod
def _compute_delay(cls, attempt, max_delay):
delay = cls._full_jitter_delay(attempt, max_delay)
return delay
@classmethod
def _full_jitter_delay(cls, attempt, cap):
exp = min(2 ** (attempt - 1), cap)
return random.uniform(0, exp)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import ssl
import urllib.parse
from websocket import (
ABNF,
create_connection,
)
class WebsocketClientAdapter(object):
def __init__(self):
self._ws_client = None
def connect(self, url, token, proxies=None):
sslopt = None
if os.getenv("NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE"):
sslopt = {"cert_reqs": ssl.CERT_NONE}
proto = url.split(":")[0].replace("ws", "http")
proxy = proxies[proto] if proxies and proto in proxies else os.getenv("{}_PROXY".format(proto.upper()))
if proxy:
proxy_split = urllib.parse.urlparse(proxy).netloc.split(":")
proxy_host = proxy_split[0]
proxy_port = proxy_split[1] if len(proxy_split) > 1 else "80" if proto == "http" else "443"
else:
proxy_host = None
proxy_port = None
self._ws_client = create_connection(
url,
header=self._auth_header(token),
sslopt=sslopt,
http_proxy_host=proxy_host,
http_proxy_port=proxy_port,
)
def recv(self):
if self._ws_client is None:
raise WebsocketNotConnectedException()
opcode, data = None, None
while opcode != ABNF.OPCODE_TEXT:
opcode, data = self._ws_client.recv_data()
return data.decode("utf-8")
@property
def connected(self):
return self._ws_client and self._ws_client.connected
def close(self):
if self._ws_client:
return self._ws_client.close()
def abort(self):
if self._ws_client:
return self._ws_client.abort()
def shutdown(self):
if self._ws_client:
return self._ws_client.shutdown()
@classmethod
def _auth_header(cls, token):
return ["Authorization: Bearer " + token["access_token"]]
class WebsocketNotConnectedException(Exception):
def __init__(self):
super(WebsocketNotConnectedException, self).__init__("Websocket client is not connected!")
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"NeptuneObject",
"Model",
"ModelVersion",
"Project",
"Run",
]
from neptune.objects.model import Model
from neptune.objects.model_version import ModelVersion
from neptune.objects.neptune_object import NeptuneObject
from neptune.objects.project import Project
from neptune.objects.run import Run
#
# Copyright (c) 2023, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["SupportsNamespaces"]
from abc import (
ABC,
abstractmethod,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from neptune.handler import Handler
class SupportsNamespaces(ABC):
"""
Interface for Neptune objects that supports subscripting (selecting namespaces)
It could be a Run, Model, ModelVersion, Project or already selected namespace (Handler).
Example:
>>> from neptune import init_run
>>> from neptune.typing import SupportsNamespaces
>>> class NeptuneCallback:
... # Proper type hinting of `start_from` parameter.
... def __init__(self, start_from: SupportsNamespaces):
... self._start_from = start_from
...
... def log_accuracy(self, accuracy: float) -> None:
... self._start_from["train/acc"] = accuracy
...
>>> run = init_run()
>>> callback = NeptuneCallback(start_from=run)
>>> callback.log_accuracy(0.8)
>>> # or
... callback = NeptuneCallback(start_from=run["some/random/path"])
>>> callback.log_accuracy(0.8)
"""
@abstractmethod
def __getitem__(self, path: str) -> "Handler": ...
@abstractmethod
def __setitem__(self, key: str, value) -> None: ...
@abstractmethod
def __delitem__(self, path) -> None: ...
@abstractmethod
def get_root_object(self) -> "SupportsNamespaces": ...
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["ModelVersion"]
import os
from typing import (
TYPE_CHECKING,
List,
Optional,
)
from typing_extensions import Literal
from neptune.attributes.constants import (
SYSTEM_NAME_ATTRIBUTE_PATH,
SYSTEM_STAGE_ATTRIBUTE_PATH,
)
from neptune.envs import CONNECTION_MODE
from neptune.exceptions import (
InactiveModelVersionException,
NeedExistingModelVersionForReadOnlyMode,
NeptuneMissingRequiredInitParameter,
NeptuneOfflineModeChangeStageException,
NeptuneUnsupportedFunctionalityException,
)
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import NeptuneException
from neptune.internal.id_formats import QualifiedName
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
DEFAULT_NAME,
OFFLINE_PROJECT_QUALIFIED_NAME,
)
from neptune.internal.operation_processors.offline_operation_processor import OfflineOperationProcessor
from neptune.internal.state import ContainerState
from neptune.internal.utils import verify_type
from neptune.internal.utils.ping_background_job import PingBackgroundJob
from neptune.objects.neptune_object import (
NeptuneObject,
NeptuneObjectCallback,
)
from neptune.types.mode import Mode
from neptune.types.model_version_stage import ModelVersionStage
if TYPE_CHECKING:
from neptune.internal.background_job import BackgroundJob
class ModelVersion(NeptuneObject):
"""Initializes a ModelVersion object from an existing or new model version.
Before creating model versions, you must first register a model by creating a Model object.
A ModelVersion object is suitable for storing model metadata that is version-specific. It does not track
background metrics or logs automatically, but you can assign metadata to the model version just like you can
for runs. You can use the parent Model object to store metadata that is common to all versions of the model.
To learn more about model registry, see the docs: https://docs.neptune.ai/model_registry/overview/
To manage the stage of a model version, use its `change_stage()` method or use the menu in the web app.
You can also use the ModelVersion object as a context manager (see examples).
Args:
with_id: The Neptune identifier of an existing model version to resume, such as "CLS-PRE-3".
The identifier is stored in the model version's "sys/id" field.
If left empty, a new model version is created.
name: Custom name for the model version. You can add it as a column in the model versions table
("sys/name"). You can also edit the name in the app, in the information view.
model: Identifier of the model for which the new version should be created.
Required when creating a new model version.
You can find the model ID in the leftmost column of the models table, or in a model's "sys/id" field.
project: Name of a project in the form `workspace-name/project-name`.
If None, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
mode: Connection mode in which the tracking will work.
If None (default), the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `read-only`, and `debug`.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered
(in seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a ModelVersion object as the argument and can contain any custom code, such as calling
`stop()` on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback should take a
ModelVersion object as the argument and can contain any custom code, such as calling `stop()` on the
object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
ModelVersion object that is used to manage the model version and log metadata to it.
Examples:
>>> import neptune
Creating a new model version:
>>> # Create a new model version for a model with identifier "CLS-PRE"
... model_version = neptune.init_model_version(model="CLS-PRE")
>>> model_version["your/structure"] = some_metadata
>>> # You can provide the project parameter as an environment variable
... # or directly in the init_model_version() function:
... model_version = neptune.init_model_version(
... model="CLS-PRE",
... project="ml-team/classification",
... )
>>> # Or initialize with the constructor:
... model_version = ModelVersion(model="CLS-PRE")
Connecting to an existing model version:
>>> # Initialize an existing model version with identifier "CLS-PRE-12"
... model_version = neptune.init_model_version(with_id="CLS-PRE-12")
>>> # To prevent modifications when connecting to an existing model version,
... # you can connect in read-only mode:
... model_version = neptune.init_model(with_id="CLS-PRE-12", mode="read-only")
Using the ModelVersion object as context manager:
>>> with ModelVersion(model="CLS-PRE") as model_version:
... model_version["metadata"] = some_metadata
For more, see the docs:
Initializing a model version:
https://docs.neptune.ai/api/neptune#init_model_version
ModelVersion class reference:
https://docs.neptune.ai/api/model_version/
"""
container_type = ContainerType.MODEL_VERSION
def __init__(
self,
with_id: Optional[str] = None,
*,
name: Optional[str] = None,
model: Optional[str] = None,
project: Optional[str] = None,
api_token: Optional[str] = None,
mode: Optional[Literal["async", "sync", "read-only", "debug"]] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
) -> None:
raise NeptuneUnsupportedFunctionalityException
verify_type("with_id", with_id, (str, type(None)))
verify_type("name", name, (str, type(None)))
verify_type("model", model, (str, type(None)))
verify_type("project", project, (str, type(None)))
verify_type("mode", mode, (str, type(None)))
self._model: Optional[str] = model
self._with_id: Optional[str] = with_id
self._name: Optional[str] = DEFAULT_NAME if model is None and name is None else name
# make mode proper Enum instead of string
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
if mode == Mode.OFFLINE:
raise NeptuneException("ModelVersion can't be initialized in OFFLINE mode")
if mode == Mode.DEBUG:
project = OFFLINE_PROJECT_QUALIFIED_NAME
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
project_workspace = self._project_api_object.workspace
project_name = self._project_api_object.name
project_qualified_name = f"{project_workspace}/{project_name}"
if self._with_id is not None:
# with_id (resume existing model_version) has priority over model (creating a new model_version)
return self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._with_id),
expected_container_type=self.container_type,
)
elif self._model is not None:
if self._mode == Mode.READ_ONLY:
raise NeedExistingModelVersionForReadOnlyMode()
api_model = self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._model),
expected_container_type=ContainerType.MODEL,
)
return self._backend.create_model_version(project_id=self._project_api_object.id, model_id=api_model.id)
else:
raise NeptuneMissingRequiredInitParameter(
parameter_name="model",
called_function="init_model_version",
)
def _get_background_jobs(self) -> List["BackgroundJob"]:
return [PingBackgroundJob()]
def _write_initial_attributes(self):
if self._name is not None:
self[SYSTEM_NAME_ATTRIBUTE_PATH] = self._name
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveModelVersionException(label=self._sys_id)
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_model_version_url(
model_version_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
sys_id=self._sys_id,
model_id=self["sys/model_id"].fetch(),
)
def change_stage(self, stage: str) -> None:
"""Changes the stage of the model version.
This method is always synchronous, which means that Neptune will wait for all other calls to reach the Neptune
servers before executing it.
Args:
stage: The new stage of the model version.
Possible values are `none`, `staging`, `production`, and `archived`.
Examples:
>>> import neptune
>>> model_version = neptune.init_model_version(with_id="CLS-TREE-3")
>>> # If the model is good enough, promote it to the staging
... val_acc = model_version["validation/metrics/acc"].fetch()
>>> if val_acc >= ACC_THRESHOLD:
... model_version.change_stage("staging")
Learn more about stage management in the docs:
https://docs.neptune.ai/model_registry/managing_stage/
API reference:
https://docs.neptune.ai/api/model_version/#change_stage
"""
mapped_stage = ModelVersionStage(stage)
if isinstance(self._op_processor, OfflineOperationProcessor):
raise NeptuneOfflineModeChangeStageException()
self.wait()
with self.lock():
attr = self.get_attribute(SYSTEM_STAGE_ATTRIBUTE_PATH)
# We are sure that such attribute exists, because
# SYSTEM_STAGE_ATTRIBUTE_PATH is set by default on ModelVersion creation
assert attr is not None, f"No {SYSTEM_STAGE_ATTRIBUTE_PATH} found in model version"
attr.process_assignment(
value=mapped_stage.value,
wait=True,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Model"]
import os
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
)
from typing_extensions import Literal
from neptune.attributes.constants import SYSTEM_NAME_ATTRIBUTE_PATH
from neptune.envs import CONNECTION_MODE
from neptune.exceptions import (
InactiveModelException,
NeedExistingModelForReadOnlyMode,
NeptuneMissingRequiredInitParameter,
NeptuneModelKeyAlreadyExistsError,
NeptuneObjectCreationConflict,
NeptuneUnsupportedFunctionalityException,
)
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.backends.nql import (
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLQueryAggregate,
NQLQueryAttribute,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import NeptuneException
from neptune.internal.id_formats import QualifiedName
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
DEFAULT_NAME,
OFFLINE_PROJECT_QUALIFIED_NAME,
)
from neptune.internal.state import ContainerState
from neptune.internal.utils import verify_type
from neptune.internal.utils.ping_background_job import PingBackgroundJob
from neptune.objects.neptune_object import (
NeptuneObject,
NeptuneObjectCallback,
)
from neptune.objects.utils import build_raw_query
from neptune.table import Table
from neptune.types.mode import Mode
from neptune.typing import (
ProgressBarCallback,
ProgressBarType,
)
if TYPE_CHECKING:
from neptune.internal.background_job import BackgroundJob
class Model(NeptuneObject):
"""Initializes a Model object from an existing or new model.
You can use this to create a new model from code or to perform actions on existing models.
A Model object is suitable for storing model metadata that is common to all versions (you can use ModelVersion
objects to track version-specific metadata). It does not track background metrics or logs automatically,
but you can assign metadata to the Model object just like you can for runs.
To learn more about model registry, see the docs: https://docs.neptune.ai/model_registry/overview/
You can also use the Model object as a context manager (see examples).
Args:
with_id: The Neptune identifier of an existing model to resume, such as "CLS-PRE".
The identifier is stored in the model's "sys/id" field.
If left empty, a new model is created.
name: Custom name for the model. You can add it as a column in the models table ("sys/name").
You can also edit the name in the app, in the information view.
key: Key for the model. Required when creating a new model.
Used together with the project key to form the model identifier.
Must be uppercase and unique within the project.
project: Name of a project in the form `workspace-name/project-name`.
If None, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
mode: Connection mode in which the tracking will work.
If `None` (default), the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `read-only`, and `debug`.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered
(in seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a Model object as the argument and can contain any custom code, such as calling `stop()` on
the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback should take a Model
object as the argument and can contain any custom code, such as calling `stop()` on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
Model object that is used to manage the model and log metadata to it.
Examples:
>>> import neptune
Creating a new model:
>>> model = neptune.init_model(key="PRE")
>>> model["metadata"] = some_metadata
>>> # Or initialize with the constructor
... model = Model(key="PRE")
>>> # You can provide the project parameter as an environment variable
... # or as an argument to the init_model() function:
... model = neptune.init_model(key="PRE", project="workspace-name/project-name")
>>> # When creating a model, you can give it a name:
... model = neptune.init_model(key="PRE", name="Pre-trained model")
Connecting to an existing model:
>>> # Initialize existing model with identifier "CLS-PRE"
... model = neptune.init_model(with_id="CLS-PRE")
>>> # To prevent modifications when connecting to an existing model, you can connect in read-only mode
... model = neptune.init_model(with_id="CLS-PRE", mode="read-only")
Using the Model object as context manager:
>>> with Model(key="PRE") as model:
... model["metadata"] = some_metadata
For details, see the docs:
Initializing a model:
https://docs.neptune.ai/api/neptune#init_model
Model class reference:
https://docs.neptune.ai/api/model
"""
container_type = ContainerType.MODEL
def __init__(
self,
with_id: Optional[str] = None,
*,
name: Optional[str] = None,
key: Optional[str] = None,
project: Optional[str] = None,
api_token: Optional[str] = None,
mode: Optional[Literal["async", "sync", "read-only", "debug"]] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
):
# not yet supported by the backend
raise NeptuneUnsupportedFunctionalityException
verify_type("with_id", with_id, (str, type(None)))
verify_type("name", name, (str, type(None)))
verify_type("key", key, (str, type(None)))
verify_type("project", project, (str, type(None)))
verify_type("mode", mode, (str, type(None)))
self._key: Optional[str] = key
self._with_id: Optional[str] = with_id
self._name: Optional[str] = DEFAULT_NAME if with_id is None and name is None else name
# make mode proper Enum instead of string
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
if mode == Mode.OFFLINE:
raise NeptuneException("Model can't be initialized in OFFLINE mode")
if mode == Mode.DEBUG:
project = OFFLINE_PROJECT_QUALIFIED_NAME
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
project_workspace = self._project_api_object.workspace
project_name = self._project_api_object.name
project_qualified_name = f"{project_workspace}/{project_name}"
if self._with_id is not None:
# with_id (resume existing model) has priority over key (creating a new model)
# additional creation parameters (e.g. name) are simply ignored in this scenario
return self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._with_id),
expected_container_type=self.container_type,
)
elif self._key is not None:
if self._mode == Mode.READ_ONLY:
raise NeedExistingModelForReadOnlyMode()
try:
return self._backend.create_model(project_id=self._project_api_object.id, key=self._key)
except NeptuneObjectCreationConflict as e:
base_url = self._backend.get_display_address()
raise NeptuneModelKeyAlreadyExistsError(
model_key=self._key,
models_tab_url=f"{base_url}/{project_workspace}/{project_name}/models",
) from e
else:
raise NeptuneMissingRequiredInitParameter(
parameter_name="key",
called_function="init_model",
)
def _get_background_jobs(self) -> List["BackgroundJob"]:
return [PingBackgroundJob()]
def _write_initial_attributes(self):
if self._name is not None:
self[SYSTEM_NAME_ATTRIBUTE_PATH] = self._name
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveModelException(label=self._sys_id)
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_model_url(
model_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
sys_id=self._sys_id,
)
def fetch_model_versions_table(
self,
*,
query: Optional[str] = None,
columns: Optional[Iterable[str]] = None,
limit: Optional[int] = None,
sort_by: str = "sys/creation_time",
ascending: bool = False,
progress_bar: Optional[ProgressBarType] = None,
) -> Table:
"""Retrieve all versions of the given model.
Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(model_size: float > 100) AND (backbone: string = VGG)"`.
columns: Names of columns to include in the table, as a list of field names.
The Neptune ID ("sys/id") is included automatically.
If `None` (default), all the columns of the model versions table are included,
up to a maximum of 10 000 columns.
limit: How many entries to return at most. If `None`, all entries are returned.
sort_by: Name of the field to sort the results by.
The field must represent a simple type (string, float, datetime, integer, or Boolean).
ascending: Whether to sort the entries in ascending order of the sorting column values.
progress_bar: Set to `False` to disable the download progress bar,
or pass a `ProgressBarCallback` class to use your own progress bar callback.
Returns:
`Table` object containing `ModelVersion` objects that match the specified criteria.
Use `to_pandas()` to convert it to a pandas DataFrame.
Examples:
>>> import neptune
... # Initialize model with the ID "CLS-FOREST"
... model = neptune.init_model(with_id="CLS-FOREST")
... # Fetch the metadata of all the model's versions as a pandas DataFrame
... model_versions_df = model.fetch_model_versions_table().to_pandas()
>>> # Include only the fields "params/lr" and "val/loss" as columns:
... model_versions_df = model.fetch_model_versions_table(columns=["params/lr", "val/loss"]).to_pandas()
>>> # Sort model versions by size (space they take up in Neptune)
... model_versions_df = model.fetch_model_versions_table(sort_by="sys/size").to_pandas()
... # Extract the ID of the largest model version object
... largest_model_version_id = model_versions_df["sys/id"].values[0]
>>> # Fetch model versions with VGG backbone
... models_table_df = project.fetch_model_versions_table(
... query="(backbone: string = VGG)"
... ).to_pandas()
See also the API referene:
https://docs.neptune.ai/api/model/#fetch_model_versions_table
"""
verify_type("query", query, (str, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback)))
if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")
query = query if query is not None else ""
nql = build_raw_query(query=query, trashed=False)
nql = NQLQueryAggregate(
items=[
nql,
NQLQueryAttribute(
name="sys/model_id",
value=self._sys_id,
operator=NQLAttributeOperator.EQUALS,
type=NQLAttributeType.STRING,
),
],
aggregator=NQLAggregator.AND,
)
return NeptuneObject._fetch_entries(
self,
child_type=ContainerType.MODEL_VERSION,
query=nql,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["NeptuneObject"]
import abc
import atexit
import itertools
import logging
import os
import threading
import time
import traceback
from contextlib import AbstractContextManager
from functools import (
partial,
wraps,
)
from queue import Queue
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Union,
)
from neptune.api.models import FieldType
from neptune.attributes import create_attribute_from_type
from neptune.attributes.attribute import Attribute
from neptune.attributes.namespace import Namespace as NamespaceAttr
from neptune.attributes.namespace import NamespaceBuilder
from neptune.envs import (
NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK,
NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK,
)
from neptune.exceptions import MetadataInconsistency
from neptune.handler import Handler
from neptune.internal.backends.api_model import (
ApiExperiment,
Project,
)
from neptune.internal.backends.factory import get_backend
from neptune.internal.backends.neptune_backend import NeptuneBackend
from neptune.internal.backends.nql import NQLQuery
from neptune.internal.backends.project_name_lookup import project_name_lookup
from neptune.internal.backgroud_job_list import BackgroundJobList
from neptune.internal.background_job import BackgroundJob
from neptune.internal.container_structure import ContainerStructure
from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import UNIX_STYLES
from neptune.internal.id_formats import (
QualifiedName,
SysId,
UniqueId,
conform_optional,
)
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
)
from neptune.internal.operation import DeleteAttribute
from neptune.internal.operation_processors.factory import get_operation_processor
from neptune.internal.operation_processors.lazy_operation_processor_wrapper import LazyOperationProcessorWrapper
from neptune.internal.operation_processors.operation_processor import OperationProcessor
from neptune.internal.signals_processing.background_job import CallbacksMonitor
from neptune.internal.state import ContainerState
from neptune.internal.utils import (
verify_optional_callable,
verify_type,
)
from neptune.internal.utils.logger import (
get_disabled_logger,
get_logger,
)
from neptune.internal.utils.paths import parse_path
from neptune.internal.utils.uncaught_exception_handler import instance as uncaught_exception_handler
from neptune.internal.utils.utils import reset_internal_ssl_state
from neptune.internal.value_to_attribute_visitor import ValueToAttributeVisitor
from neptune.internal.warnings import warn_about_unsupported_type
from neptune.table import Table
from neptune.types.mode import Mode
from neptune.types.type_casting import cast_value
from neptune.typing import ProgressBarType
from neptune.utils import stop_synchronization_callback
if TYPE_CHECKING:
from neptune.internal.signals_processing.signals import Signal
NeptuneObjectCallback = Callable[["NeptuneObject"], None]
def ensure_not_stopped(fun):
@wraps(fun)
def inner_fun(self: "NeptuneObject", *args, **kwargs):
self._raise_if_stopped()
return fun(self, *args, **kwargs)
return inner_fun
class NeptuneObject(AbstractContextManager):
container_type: ContainerType
def __init__(
self,
*,
project: Optional[str] = None,
api_token: Optional[str] = None,
mode: Mode = Mode.ASYNC,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
):
verify_type("project", project, (str, type(None)))
verify_type("api_token", api_token, (str, type(None)))
verify_type("mode", mode, Mode)
verify_type("flush_period", flush_period, (int, float))
verify_type("proxies", proxies, (dict, type(None)))
verify_type("async_lag_threshold", async_lag_threshold, (int, float))
verify_optional_callable("async_lag_callback", async_lag_callback)
verify_type("async_no_progress_threshold", async_no_progress_threshold, (int, float))
verify_optional_callable("async_no_progress_callback", async_no_progress_callback)
self._mode: Mode = mode
self._flush_period = flush_period
self._lock: threading.RLock = threading.RLock()
self._forking_cond: threading.Condition = threading.Condition()
self._forking_state: bool = False
self._state: ContainerState = ContainerState.CREATED
self._signals_queue: "Queue[Signal]" = Queue()
self._logger: logging.Logger = get_logger()
self._backend: NeptuneBackend = get_backend(mode=mode, api_token=api_token, proxies=proxies)
self._project_qualified_name: Optional[str] = conform_optional(project, QualifiedName)
self._project_api_object: Project = project_name_lookup(
backend=self._backend, name=self._project_qualified_name
)
self._project_id: UniqueId = self._project_api_object.id
self._api_object: ApiExperiment = self._get_or_create_api_object()
self._id: UniqueId = self._api_object.id
self._sys_id: SysId = self._api_object.sys_id
self._workspace: str = self._api_object.workspace
self._project_name: str = self._api_object.project_name
self._async_lag_threshold = async_lag_threshold
self._async_lag_callback = NeptuneObject._get_callback(
provided=async_lag_callback,
env_name=NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK,
)
self._async_no_progress_threshold = async_no_progress_threshold
self._async_no_progress_callback = NeptuneObject._get_callback(
provided=async_no_progress_callback,
env_name=NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK,
)
self._op_processor: OperationProcessor = get_operation_processor(
mode=mode,
container_id=self._id,
container_type=self.container_type,
backend=self._backend,
lock=self._lock,
flush_period=flush_period,
queue=self._signals_queue,
)
self._bg_job: BackgroundJobList = self._prepare_background_jobs_if_non_read_only()
self._structure: ContainerStructure[Attribute, NamespaceAttr] = ContainerStructure(NamespaceBuilder(self))
if self._mode != Mode.OFFLINE:
self.sync(wait=False)
if self._mode != Mode.READ_ONLY:
self._write_initial_attributes()
self._startup(debug_mode=mode == Mode.DEBUG)
try:
os.register_at_fork(
before=self._before_fork,
after_in_child=self._handle_fork_in_child,
after_in_parent=self._handle_fork_in_parent,
)
except AttributeError:
pass
"""
OpenSSL's internal random number generator does not properly handle forked processes.
Applications must change the PRNG state of the parent process if they use any SSL feature with os.fork().
Any successful call of RAND_add(), RAND_bytes() or RAND_pseudo_bytes() is sufficient.
https://docs.python.org/3/library/ssl.html#multi-processing
On Linux it looks like it does not help much but does not break anything either.
"""
@staticmethod
def _get_callback(provided: Optional[NeptuneObjectCallback], env_name: str) -> Optional[NeptuneObjectCallback]:
if provided is not None:
return provided
if os.getenv(env_name, "") == "TRUE":
return stop_synchronization_callback
return None
def _handle_fork_in_parent(self):
reset_internal_ssl_state()
if self._state == ContainerState.STARTED:
self._op_processor.resume()
self._bg_job.resume()
with self._forking_cond:
self._forking_state = False
self._forking_cond.notify_all()
def _handle_fork_in_child(self):
reset_internal_ssl_state()
self._logger = get_disabled_logger()
if self._state == ContainerState.STARTED:
self._op_processor.close()
self._signals_queue = Queue()
self._op_processor = LazyOperationProcessorWrapper(
operation_processor_getter=partial(
get_operation_processor,
mode=self._mode,
container_id=self._id,
container_type=self.container_type,
backend=self._backend,
lock=self._lock,
flush_period=self._flush_period,
queue=self._signals_queue,
),
)
# TODO: Every implementation of background job should handle fork by itself.
jobs = []
if self._mode == Mode.ASYNC:
jobs.append(
CallbacksMonitor(
queue=self._signals_queue,
async_lag_threshold=self._async_lag_threshold,
async_no_progress_threshold=self._async_no_progress_threshold,
async_lag_callback=self._async_lag_callback,
async_no_progress_callback=self._async_no_progress_callback,
)
)
self._bg_job = BackgroundJobList(jobs)
with self._forking_cond:
self._forking_state = False
self._forking_cond.notify_all()
def _before_fork(self):
with self._forking_cond:
self._forking_cond.wait_for(lambda: self._state != ContainerState.STOPPING)
self._forking_state = True
if self._state == ContainerState.STARTED:
self._bg_job.pause()
self._op_processor.pause()
def _prepare_background_jobs_if_non_read_only(self) -> BackgroundJobList:
jobs = []
if self._mode != Mode.READ_ONLY:
jobs.extend(self._get_background_jobs())
if self._mode == Mode.ASYNC:
jobs.append(
CallbacksMonitor(
queue=self._signals_queue,
async_lag_threshold=self._async_lag_threshold,
async_no_progress_threshold=self._async_no_progress_threshold,
async_lag_callback=self._async_lag_callback,
async_no_progress_callback=self._async_no_progress_callback,
)
)
return BackgroundJobList(jobs)
@abc.abstractmethod
def _get_or_create_api_object(self) -> ApiExperiment:
raise NotImplementedError
def _get_background_jobs(self) -> List["BackgroundJob"]:
return []
def _write_initial_attributes(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_tb is not None:
traceback.print_exception(exc_type, exc_val, exc_tb)
self.stop()
def __getattr__(self, item):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
@abc.abstractmethod
def _raise_if_stopped(self):
raise NotImplementedError
def _get_subpath_suggestions(self, path_prefix: str = None, limit: int = 1000) -> List[str]:
parsed_path = parse_path(path_prefix or "")
return list(itertools.islice(self._structure.iterate_subpaths(parsed_path), limit))
def _ipython_key_completions_(self):
return self._get_subpath_suggestions()
@ensure_not_stopped
def __getitem__(self, path: str) -> "Handler":
return Handler(self, path)
@ensure_not_stopped
def __setitem__(self, key: str, value) -> None:
self.__getitem__(key).assign(value)
@ensure_not_stopped
def __delitem__(self, path) -> None:
self.pop(path)
@ensure_not_stopped
def assign(self, value, *, wait: bool = False) -> None:
"""Assigns values to multiple fields from a dictionary.
You can use this method to quickly log all parameters at once.
Args:
value (dict): A dictionary with values to assign, where keys become paths of the fields.
The dictionary can be nested, in which case the path will be a combination of all the keys.
wait: If `True`, Neptune waits to send all tracked metadata to the server before executing the call.
Examples:
>>> import neptune
>>> run = neptune.init_run()
>>> # Assign a single value with the Python "=" operator
>>> run["parameters/learning_rate"] = 0.8
>>> # or the assign() method
>>> run["parameters/learning_rate"].assign(0.8)
>>> # Assign a dictionary with the Python "=" operator
>>> run["parameters"] = {"max_epochs": 10, "optimizer": "Adam", "learning_rate": 0.8}
>>> # or the assign() method
>>> run.assign({"parameters": {"max_epochs": 10, "optimizer": "Adam", "learning_rate": 0.8}})
When operating on a handler object, you can use assign() to circumvent normal Python variable assignment.
>>> params = run["params"]
>>> params.assign({"max_epochs": 10, "optimizer": "Adam", "learning_rate": 0.8})
See also the API reference:
https://docs.neptune.ai/api/universal/#assign
"""
self._get_root_handler().assign(value, wait=wait)
@ensure_not_stopped
def fetch(self) -> dict:
"""Fetch values of all non-File Atom fields as a dictionary.
You can use this method to retrieve metadata from a started or resumed run.
The result preserves the hierarchical structure of the run's metadata, but only contains Atom fields.
This means fields that contain single values, as opposed to series, files, or sets.
Returns:
`dict` containing the values of all non-File Atom fields.
Examples:
Resuming an existing run and fetching metadata from it:
>>> import neptune
>>> resumed_run = neptune.init_run(with_id="CLS-3")
>>> params = resumed_run["model/parameters"].fetch()
>>> run_data = resumed_run.fetch()
>>> print(run_data)
>>> # prints all Atom attributes stored in run as a dict
Fetching metadata from an existing model version:
>>> model_version = neptune.init_model_version(with_id="CLS-TREE-45")
>>> optimizer = model["parameters/optimizer"].fetch()
See also the API reference:
https://docs.neptune.ai/api/universal#fetch
"""
return self._get_root_handler().fetch()
def ping(self):
self._backend.ping(self._id, self.container_type)
def start(self):
atexit.register(self._shutdown_hook)
self._op_processor.start()
self._bg_job.start(self)
self._state = ContainerState.STARTED
def stop(self, *, seconds: Optional[Union[float, int]] = None) -> None:
"""Stops the connection and ends the synchronization thread.
You should stop any initialized runs or other objects when the connection to them is no longer needed.
This method is automatically called:
- when the script that created the run or other object finishes execution.
- if using a context manager, on destruction of the Neptune context.
Note: In interactive sessions, such as Jupyter Notebook, objects are stopped automatically only when
the Python kernel stops. However, background monitoring of system metrics and standard streams is disabled
unless explicitly enabled when initializing Neptune.
Args:
seconds: Seconds to wait for all metadata tracking calls to finish before stopping the object.
If `None`, waits for all tracking calls to finish.
Example:
>>> import neptune
>>> run = neptune.init_run()
>>> # Your training or monitoring code
>>> run.stop()
See also the docs:
Best practices - Stopping objects
https://docs.neptune.ai/usage/best_practices/#stopping-runs-and-other-objects
API reference:
https://docs.neptune.ai/api/universal/#stop
"""
verify_type("seconds", seconds, (float, int, type(None)))
if self._state != ContainerState.STARTED:
return
with self._forking_cond:
self._forking_cond.wait_for(lambda: not self._forking_state)
self._state = ContainerState.STOPPING
ts = time.time()
self._logger.info("Shutting down background jobs, please wait a moment...")
self._bg_job.stop()
self._bg_job.join(seconds)
self._logger.info("Done!")
sec_left = None if seconds is None else seconds - (time.time() - ts)
self._op_processor.stop(sec_left)
if self._mode not in {Mode.OFFLINE, Mode.DEBUG}:
metadata_url = self.get_url().rstrip("/") + "/metadata"
self._logger.info(f"Explore the metadata in the Neptune app: {metadata_url}")
self._backend.close()
with self._forking_cond:
self._state = ContainerState.STOPPED
self._forking_cond.notify_all()
def get_state(self) -> str:
"""Returns the current state of the container as a string.
Examples:
>>> from neptune import init_run
>>> run = init_run()
>>> run.get_state()
'started'
>>> run.stop()
>>> run.get_state()
'stopped'
"""
return self._state.value
def get_structure(self) -> Dict[str, Any]:
"""Returns the object's metadata structure as a dictionary.
This method can be used to programmatically traverse the metadata structure of a run, model,
or project object when using Neptune in automated workflows.
Note: The returned object is a deep copy of the structure of the internal object.
See also the API reference:
https://docs.neptune.ai/api/universal/#get_structure
"""
return self._structure.get_structure().to_dict()
def print_structure(self) -> None:
"""Pretty-prints the structure of the object's metadata.
Paths are ordered lexicographically and the whole structure is neatly colored.
See also: https://docs.neptune.ai/api/universal/#print_structure
"""
self._print_structure_impl(self.get_structure(), indent=0)
def _print_structure_impl(self, struct: dict, indent: int) -> None:
for key in sorted(struct.keys()):
print(" " * indent, end="")
if isinstance(struct[key], dict):
print("{blue}'{key}'{end}:".format(blue=UNIX_STYLES["blue"], key=key, end=UNIX_STYLES["end"]))
self._print_structure_impl(struct[key], indent=indent + 1)
else:
print(
"{blue}'{key}'{end}: {type}".format(
blue=UNIX_STYLES["blue"],
key=key,
end=UNIX_STYLES["end"],
type=type(struct[key]).__name__,
)
)
def define(
self,
path: str,
value: Any,
*,
wait: bool = False,
) -> Optional[Attribute]:
with self._lock:
old_attr = self.get_attribute(path)
if old_attr is not None:
raise MetadataInconsistency("Attribute or namespace {} is already defined".format(path))
neptune_value = cast_value(value)
if neptune_value is None:
warn_about_unsupported_type(type_str=str(type(value)))
return None
attr = ValueToAttributeVisitor(self, parse_path(path)).visit(neptune_value)
self.set_attribute(path, attr)
attr.process_assignment(neptune_value, wait=wait)
return attr
def get_attribute(self, path: str) -> Optional[Attribute]:
with self._lock:
return self._structure.get(parse_path(path))
def set_attribute(self, path: str, attribute: Attribute) -> Optional[Attribute]:
with self._lock:
return self._structure.set(parse_path(path), attribute)
def exists(self, path: str) -> bool:
"""Checks if there is a field or namespace under the specified path."""
verify_type("path", path, str)
return self.get_attribute(path) is not None
@ensure_not_stopped
def pop(self, path: str, *, wait: bool = False) -> None:
"""Removes the field stored under the path and all data associated with it.
Args:
path: Path of the field to be removed.
wait: If `True`, Neptune waits to send all tracked metadata to the server before executing the call.
Examples:
>>> import neptune
>>> run = neptune.init_run()
>>> run["parameters/learninggg_rata"] = 0.3
>>> # Let's delete that misspelled field along with its data
... run.pop("parameters/learninggg_rata")
>>> run["parameters/learning_rate"] = 0.3
>>> # Training finished
... run["trained_model"].upload("model.pt")
>>> # "model_checkpoint" is a File field
... run.pop("model_checkpoint")
See also the API reference:
https://docs.neptune.ai/api/universal/#pop
"""
verify_type("path", path, str)
self._get_root_handler().pop(path, wait=wait)
def _pop_impl(self, parsed_path: List[str], *, wait: bool):
self._structure.pop(parsed_path)
self._op_processor.enqueue_operation(DeleteAttribute(parsed_path), wait=wait)
def lock(self) -> threading.RLock:
return self._lock
def wait(self, *, disk_only=False) -> None:
"""Wait for all the queued metadata tracking calls to reach the Neptune servers.
Args:
disk_only: If `True`, the process will only wait for data to be saved
locally from memory, but will not wait for them to reach Neptune servers.
See also the API reference:
https://docs.neptune.ai/api/universal/#wait
"""
with self._lock:
if disk_only:
self._op_processor.flush()
else:
self._op_processor.wait()
def sync(self, *, wait: bool = True) -> None:
"""Synchronizes the local representation of the object with the representation on the Neptune servers.
Args:
wait: If `True`, the process will only wait for data to be saved
locally from memory, but will not wait for them to reach Neptune servers.
Example:
>>> import neptune
>>> # Connect to a run from Worker #3
... worker_id = 3
>>> run = neptune.init_run(with_id="DIST-43", monitoring_namespace=f"monitoring/{worker_id}")
>>> # Try to access logs that were created in the meantime by Worker #2
... worker_2_status = run["status/2"].fetch()
... # Error if this field was created after this script starts
>>> run.sync() # Synchronizes local representation with Neptune servers
>>> worker_2_status = run["status/2"].fetch()
... # No error
See also the API reference:
https://docs.neptune.ai/api/universal/#sync
"""
with self._lock:
if wait:
self._op_processor.wait()
attributes = self._backend.get_attributes(self._id, self.container_type)
self._structure.clear()
for attribute in attributes:
self._define_attribute(parse_path(attribute.path), attribute.type)
def _define_attribute(self, _path: List[str], _type: FieldType):
attr = create_attribute_from_type(_type, self, _path)
self._structure.set(_path, attr)
def _get_root_handler(self):
return Handler(self, "")
@abc.abstractmethod
def get_url(self) -> str:
"""Returns a link to the object in the Neptune app.
The same link is printed in the console once the object has been initialized.
API reference: https://docs.neptune.ai/api/universal/#get_url
"""
...
def _startup(self, debug_mode):
if not debug_mode:
self._logger.info(f"Neptune initialized. Open in the app: {self.get_url()}")
self.start()
uncaught_exception_handler.activate()
def _shutdown_hook(self):
self.stop()
def _fetch_entries(
self,
child_type: ContainerType,
query: NQLQuery,
columns: Optional[Iterable[str]],
limit: Optional[int],
sort_by: str,
ascending: bool,
progress_bar: Optional[ProgressBarType],
) -> Table:
if columns is not None:
# always return entries with 'sys/id' and the column chosen for sorting when filter applied
columns = set(columns)
columns.add("sys/id")
columns.add(sort_by)
leaderboard_entries = self._backend.search_leaderboard_entries(
project_id=self._project_id,
types=[child_type],
query=query,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
return Table(
backend=self._backend,
container_type=child_type,
entries=leaderboard_entries,
)
def get_root_object(self) -> "NeptuneObject":
"""Returns the same Neptune object."""
return self
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Project"]
import os
from typing import (
Iterable,
Optional,
Union,
)
from typing_extensions import Literal
from neptune.envs import CONNECTION_MODE
from neptune.exceptions import (
InactiveProjectException,
NeptuneUnsupportedFunctionalityException,
)
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.container_type import ContainerType
from neptune.internal.exceptions import NeptuneException
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
)
from neptune.internal.state import ContainerState
from neptune.internal.utils import (
as_list,
verify_collection_type,
verify_type,
verify_value,
)
from neptune.objects.neptune_object import (
NeptuneObject,
NeptuneObjectCallback,
)
from neptune.objects.utils import (
build_raw_query,
prepare_nql_query,
)
from neptune.table import Table
from neptune.types.mode import Mode
from neptune.typing import (
ProgressBarCallback,
ProgressBarType,
)
class Project(NeptuneObject):
"""Starts a connection to an existing Neptune project.
You can use the Project object to retrieve information about runs, models, and model versions
within the project.
You can also log (and fetch) metadata common to the whole project, such as information about datasets,
links to documents, or key project metrics.
Note: If you want to instead create a project, use the
[`management.create_project()`](https://docs.neptune.ai/api/management/#create_project) function.
You can also use the Project object as a context manager (see examples).
Args:
project: Name of a project in the form `workspace-name/project-name`.
If left empty, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
mode: Connection mode in which the tracking will work.
If left empty, the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `read-only`, and `debug`.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered.
Defaults to 5 (every 5 seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a Project object as the argument and can contain any custom code, such as calling `stop()`
on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback
should take a Project object as the argument and can contain any custom code, such as calling `stop()`
on the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
Project object that can be used to interact with the project as a whole,
like logging or fetching project-level metadata.
Examples:
>>> import neptune
>>> # Connect to the project "classification" in the workspace "ml-team":
... project = neptune.init_project(project="ml-team/classification")
>>> # Or initialize with the constructor
... project = Project(project="ml-team/classification")
>>> # Connect to a project in read-only mode:
... project = neptune.init_project(
... project="ml-team/classification",
... mode="read-only",
... )
Using the Project object as context manager:
>>> with Project(project="ml-team/classification") as project:
... project["metadata"] = some_metadata
For more, see the docs:
Initializing a project:
https://docs.neptune.ai/api/neptune#init_project
Project class reference:
https://docs.neptune.ai/api/project/
"""
container_type = ContainerType.PROJECT
def __init__(
self,
project: Optional[str] = None,
*,
api_token: Optional[str] = None,
mode: Optional[Literal["async", "sync", "read-only", "debug"]] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
):
if mode in {Mode.ASYNC.value, Mode.SYNC.value}:
raise NeptuneUnsupportedFunctionalityException
verify_type("mode", mode, (str, type(None)))
# make mode proper Enum instead of string
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
if mode == Mode.OFFLINE:
raise NeptuneException("Project can't be initialized in OFFLINE mode")
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
return ApiExperiment(
id=self._project_api_object.id,
type=ContainerType.PROJECT,
sys_id=self._project_api_object.sys_id,
workspace=self._project_api_object.workspace,
project_name=self._project_api_object.name,
)
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveProjectException(label=f"{self._workspace}/{self._project_name}")
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_project_url(
project_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
)
def fetch_runs_table(
self,
*,
query: Optional[str] = None,
id: Optional[Union[str, Iterable[str]]] = None,
state: Optional[Union[Literal["inactive", "active"], Iterable[Literal["inactive", "active"]]]] = None,
owner: Optional[Union[str, Iterable[str]]] = None,
tag: Optional[Union[str, Iterable[str]]] = None,
columns: Optional[Iterable[str]] = None,
trashed: Optional[bool] = False,
limit: Optional[int] = None,
sort_by: str = "sys/creation_time",
ascending: bool = False,
progress_bar: Optional[ProgressBarType] = None,
) -> Table:
"""Retrieve runs matching the specified criteria.
All parameters are optional. Each of them specifies a single criterion.
Only runs matching all of the criteria will be returned.
Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(accuracy: float > 0.88) AND (loss: float < 0.2)"`.
Exclusive with the `id`, `state`, `owner`, and `tag` parameters.
id: Neptune ID of a run, or list of several IDs.
Example: `"SAN-1"` or `["SAN-1", "SAN-2"]`.
Matching any element of the list is sufficient to pass the criterion.
state: Run state, or list of states.
Example: `"active"`.
Possible values: `"inactive"`, `"active"`.
Matching any element of the list is sufficient to pass the criterion.
owner: Username of the run owner, or a list of owners.
Example: `"josh"` or `["frederic", "josh"]`.
The owner is the user who created the run.
Matching any element of the list is sufficient to pass the criterion.
tag: A tag or list of tags applied to the run.
Example: `"lightGBM"` or `["pytorch", "cycleLR"]`.
Only runs that have all specified tags will match this criterion.
columns: Names of columns to include in the table, as a list of field names.
The Neptune ID ("sys/id") is included automatically.
If `None` (default), all the columns of the runs table are included, up to a maximum of 10 000 columns.
trashed: Whether to retrieve trashed runs.
If `True`, only trashed runs are retrieved.
If `False` (default), only not-trashed runs are retrieved.
If `None`, both trashed and not-trashed runs are retrieved.
limit: How many entries to return at most. If `None`, all entries are returned.
sort_by: Name of the field to sort the results by.
The field must represent a simple type (string, float, datetime, integer, or Boolean).
ascending: Whether to sort the entries in ascending order of the sorting column values.
progress_bar: Set to `False` to disable the download progress bar,
or pass a `ProgressBarCallback` class to use your own progress bar callback.
Returns:
`Table` object containing `Run` objects matching the specified criteria.
Use `to_pandas()` to convert the table to a pandas DataFrame.
Examples:
>>> import neptune
... # Fetch project "jackie/sandbox"
... project = neptune.init_project(mode="read-only", project="jackie/sandbox")
>>> # Fetch the metadata of all runs as a pandas DataFrame
... runs_table_df = project.fetch_runs_table().to_pandas()
... # Extract the ID of the last run
... last_run_id = runs_table_df["sys/id"].values[0]
>>> # Fetch the 100 oldest runs
... runs_table_df = project.fetch_runs_table(
... sort_by="sys/creation_time", ascending=True, limit=100
... ).to_pandas()
>>> # Fetch the 100 largest runs (space they take up in Neptune)
... runs_table_df = project.fetch_runs_table(sort_by="sys/size", limit=100).to_pandas()
>>> # Include only the fields "train/loss" and "params/lr" as columns:
... runs_table_df = project.fetch_runs_table(columns=["params/lr", "train/loss"]).to_pandas()
>>> # Pass a custom progress bar callback
... runs_table_df = project.fetch_runs_table(progress_bar=MyProgressBar).to_pandas()
... # The class MyProgressBar(ProgressBarCallback) must be defined
You can also filter the runs table by state, owner, tag, or a combination of these:
>>> # Fetch only inactive runs
... runs_table_df = project.fetch_runs_table(state="inactive").to_pandas()
>>> # Fetch only runs created by CI service
... runs_table_df = project.fetch_runs_table(owner="my_company_ci_service").to_pandas()
>>> # Fetch only runs that have both "Exploration" and "Optuna" tags
... runs_table_df = project.fetch_runs_table(tag=["Exploration", "Optuna"]).to_pandas()
>>> # You can combine conditions. Runs satisfying all conditions will be fetched
... runs_table_df = project.fetch_runs_table(state="inactive", tag="Exploration").to_pandas()
See also the API reference in the docs:
https://docs.neptune.ai/api/project#fetch_runs_table
"""
if any((id, state, owner, tag)) and query is not None:
raise ValueError(
"You can't use the 'query' parameter together with the 'id', 'state', 'owner', or 'tag' parameters."
)
ids = as_list("id", id)
states = as_list("state", state)
owners = as_list("owner", owner)
tags = as_list("tag", tag)
verify_type("query", query, (str, type(None)))
verify_type("trashed", trashed, (bool, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback)))
verify_collection_type("state", states, str)
if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")
for state in states:
verify_value("state", state.lower(), ("inactive", "active"))
if query is not None:
nql_query = build_raw_query(query, trashed=trashed)
else:
nql_query = prepare_nql_query(ids, states, owners, tags, trashed)
return NeptuneObject._fetch_entries(
self,
child_type=ContainerType.RUN,
query=nql_query,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
def fetch_models_table(
self,
*,
query: Optional[str] = None,
columns: Optional[Iterable[str]] = None,
trashed: Optional[bool] = False,
limit: Optional[int] = None,
sort_by: str = "sys/creation_time",
ascending: bool = False,
progress_bar: Optional[ProgressBarType] = None,
) -> Table:
"""Retrieve models stored in the project.
Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(model_size: float > 100) AND (backbone: string = VGG)"`.
trashed: Whether to retrieve trashed models.
If `True`, only trashed models are retrieved.
If `False`, only not-trashed models are retrieved.
If `None`, both trashed and not-trashed models are retrieved.
columns: Names of columns to include in the table, as a list of field names.
The Neptune ID ("sys/id") is included automatically.
If `None`, all the columns of the models table are included, up to a maximum of 10 000 columns.
limit: How many entries to return at most. If `None`, all entries are returned.
sort_by: Name of the field to sort the results by.
The field must represent a simple type (string, float, datetime, integer, or Boolean).
ascending: Whether to sort the entries in ascending order of the sorting column values.
progress_bar: Set to `False` to disable the download progress bar,
or pass a `ProgressBarCallback` class to use your own progress bar callback.
Returns:
`Table` object containing `Model` objects.
Use `to_pandas()` to convert the table to a pandas DataFrame.
Examples:
>>> import neptune
... # Fetch project "jackie/sandbox"
... project = neptune.init_project(mode="read-only", project="jackie/sandbox")
>>> # Fetch the metadata of all models as a pandas DataFrame
... models_table_df = project.fetch_models_table().to_pandas()
>>> # Include only the fields "params/lr" and "info/size" as columns:
... models_table_df = project.fetch_models_table(columns=["params/lr", "info/size"]).to_pandas()
>>> # Fetch 10 oldest model objects
... models_table_df = project.fetch_models_table(
... sort_by="sys/creation_time", ascending=True, limit=10
... ).to_pandas()
... # Extract the ID of the first listed (oldest) model object
... last_model_id = models_table_df["sys/id"].values[0]
>>> # Fetch models with VGG backbone
... models_table_df = project.fetch_models_table(
query="(backbone: string = VGG)"
).to_pandas()
See also the API reference in the docs:
https://docs.neptune.ai/api/project#fetch_models_table
"""
verify_type("query", query, (str, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback)))
if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")
query = query if query is not None else ""
nql = build_raw_query(query=query, trashed=trashed)
return NeptuneObject._fetch_entries(
self,
child_type=ContainerType.MODEL,
query=nql,
columns=columns,
limit=limit,
sort_by=sort_by,
ascending=ascending,
progress_bar=progress_bar,
)
#
# Copyright (c) 2022, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["Run"]
import os
import threading
from platform import node as get_hostname
from typing import (
TYPE_CHECKING,
Callable,
List,
Optional,
Tuple,
TypeVar,
Union,
)
from typing_extensions import Literal
from neptune.attributes.constants import (
SYSTEM_DESCRIPTION_ATTRIBUTE_PATH,
SYSTEM_FAILED_ATTRIBUTE_PATH,
SYSTEM_HOSTNAME_ATTRIBUTE_PATH,
SYSTEM_NAME_ATTRIBUTE_PATH,
SYSTEM_TAGS_ATTRIBUTE_PATH,
)
from neptune.envs import (
CONNECTION_MODE,
CUSTOM_RUN_ID_ENV_NAME,
MONITORING_NAMESPACE,
NEPTUNE_NOTEBOOK_ID,
NEPTUNE_NOTEBOOK_PATH,
)
from neptune.exceptions import (
InactiveRunException,
NeedExistingRunForReadOnlyMode,
NeptuneRunResumeAndCustomIdCollision,
)
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.backends.neptune_backend import NeptuneBackend
from neptune.internal.container_type import ContainerType
from neptune.internal.hardware.hardware_metric_reporting_job import HardwareMetricReportingJob
from neptune.internal.id_formats import QualifiedName
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
DEFAULT_NAME,
OFFLINE_PROJECT_QUALIFIED_NAME,
)
from neptune.internal.notebooks.notebooks import create_checkpoint
from neptune.internal.state import ContainerState
from neptune.internal.streams.std_capture_background_job import (
StderrCaptureBackgroundJob,
StdoutCaptureBackgroundJob,
)
from neptune.internal.utils import (
verify_collection_type,
verify_type,
)
from neptune.internal.utils.dependency_tracking import (
FileDependenciesStrategy,
InferDependenciesStrategy,
)
from neptune.internal.utils.git import (
to_git_info,
track_uncommitted_changes,
)
from neptune.internal.utils.hashing import generate_hash
from neptune.internal.utils.limits import custom_run_id_exceeds_length
from neptune.internal.utils.ping_background_job import PingBackgroundJob
from neptune.internal.utils.runningmode import (
in_interactive,
in_notebook,
)
from neptune.internal.utils.source_code import upload_source_code
from neptune.internal.utils.traceback_job import TracebackJob
from neptune.internal.warnings import (
NeptuneWarning,
warn_once,
)
from neptune.internal.websockets.websocket_signals_background_job import WebsocketSignalsBackgroundJob
from neptune.objects.neptune_object import (
NeptuneObject,
NeptuneObjectCallback,
)
from neptune.types import (
GitRef,
StringSeries,
)
from neptune.types.atoms.git_ref import GitRefDisabled
from neptune.types.mode import Mode
if TYPE_CHECKING:
from neptune.internal.background_job import BackgroundJob
T = TypeVar("T")
def temporarily_disabled(func: Callable[..., T]) -> Callable[..., T]:
def wrapper(*_, **__):
if func.__name__ == "_get_background_jobs":
return []
elif func.__name__ == "_write_initial_attributes":
return None
elif func.__name__ == "_write_initial_monitoring_attributes":
return None
return wrapper
class Run(NeptuneObject):
"""Starts a new tracked run that logs ML model-building metadata to neptune.ai.
You can log metadata by assigning it to the initialized Run object:
```
run = neptune.init_run()
run["your/structure"] = some_metadata
```
Examples of metadata you can log: metrics, losses, scores, artifact versions, images, predictions,
model weights, parameters, checkpoints, and interactive visualizations.
By default, the run automatically tracks hardware consumption, stdout/stderr, source code, and Git information.
If you're using Neptune in an interactive session, however, some background monitoring needs to be enabled
explicitly.
If you provide the ID of an existing run, that run is resumed and no new run is created. You may resume a run
either to log more metadata or to fetch metadata from it.
The run ends either when its `stop()` method is called or when the script finishes execution.
You can also use the Run object as a context manager (see examples).
Args:
project: Name of the project where the run should go, in the form `workspace-name/project_name`.
If left empty, the value of the NEPTUNE_PROJECT environment variable is used.
api_token: User's API token.
If left empty, the value of the NEPTUNE_API_TOKEN environment variable is used (recommended).
with_id: If you want to resume a run, pass the identifier of an existing run. For example, "SAN-1".
If left empty, a new run is created.
custom_run_id: A unique identifier to be used when running Neptune in distributed training jobs.
Make sure to use the same identifier throughout the whole pipeline execution.
mode: Connection mode in which the tracking will work.
If left empty, the value of the NEPTUNE_MODE environment variable is used.
If no value was set for the environment variable, "async" is used by default.
Possible values are `async`, `sync`, `offline`, `read-only`, and `debug`.
name: Custom name for the run. You can add it as a column in the runs table ("sys/name").
You can also edit the name in the app: Open the run menu and access the run information.
description: Custom description of the run. You can add it as a column in the runs table
("sys/description").
You can also edit the description in the app: Open the run menu and access the run information.
tags: Tags of the run as a list of strings.
You can edit the tags through the "sys/tags" field or in the app (run menu -> information).
You can also select multiple runs and manage their tags as a single action.
source_files: List of source files to be uploaded.
Uploaded source files are displayed in the "Source code" dashboard.
To not upload anything, pass an empty list (`[]`).
Unix style pathname pattern expansion is supported. For example, you can pass `*.py` to upload
all Python files from the current directory.
If None is passed, the Python file from which the run was created will be uploaded.
capture_stdout: Whether to log the stdout of the run.
Defaults to `False` in interactive sessions and `True` otherwise.
The data is logged under the monitoring namespace (see the `monitoring_namespace` parameter).
capture_stderr: Whether to log the stderr of the run.
Defaults to `False` in interactive sessions and `True` otherwise.
The data is logged under the monitoring namespace (see the `monitoring_namespace` parameter).
capture_hardware_metrics: Whether to send hardware monitoring logs (CPU, GPU, and memory utilization).
Defaults to `False` in interactive sessions and `True` otherwise.
The data is logged under the monitoring namespace (see the `monitoring_namespace` parameter).
fail_on_exception: Whether to register an uncaught exception handler to this process and,
in case of an exception, set the "sys/failed" field of the run to `True`.
An exception is always logged.
monitoring_namespace: Namespace inside which all hardware monitoring logs are stored.
Defaults to "monitoring/<hash>", where the hash is generated based on environment information,
to ensure that it's unique for each process.
flush_period: In the asynchronous (default) connection mode, how often disk flushing is triggered
(in seconds).
proxies: Argument passed to HTTP calls made via the Requests library, as dictionary of strings.
For more information about proxies, see the Requests documentation.
capture_traceback: Whether to log the traceback of the run in case of an exception.
The tracked metadata is stored in the "<monitoring_namespace>/traceback" namespace (see the
`monitoring_namespace` parameter).
git_ref: GitRef object containing information about the Git repository path.
If None, Neptune looks for a repository in the path of the script that is executed.
To specify a different location, set to GitRef(repository_path="path/to/repo").
To turn off Git tracking for the run, set to False or GitRef.DISABLED.
dependencies: If you pass `"infer"`, Neptune logs dependencies installed in the current environment.
You can also pass a path to your dependency file directly.
If left empty, no dependencies are tracked.
async_lag_callback: Custom callback which is called if the lag between a queued operation and its
synchronization with the server exceeds the duration defined by `async_lag_threshold`. The callback
should take a Run object as the argument and can contain any custom code, such as calling `stop()` on
the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK` environment variable to `TRUE`.
async_lag_threshold: In seconds, duration between the queueing and synchronization of an operation.
If a lag callback (default callback enabled via environment variable or custom callback passed to the
`async_lag_callback` argument) is enabled, the callback is called when this duration is exceeded.
async_no_progress_callback: Custom callback which is called if there has been no synchronization progress
whatsoever for the duration defined by `async_no_progress_threshold`. The callback
should take a Run object as the argument and can contain any custom code, such as calling `stop()` on
the object.
Note: Instead of using this argument, you can use Neptune's default callback by setting the
`NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK` environment variable to `TRUE`.
async_no_progress_threshold: In seconds, for how long there has been no synchronization progress since the
object was initialized. If a no-progress callback (default callback enabled via environment variable or
custom callback passed to the `async_no_progress_callback` argument) is enabled, the callback is called
when this duration is exceeded.
Returns:
Run object that is used to manage the tracked run and log metadata to it.
Examples:
Creating a new run:
>>> import neptune
>>> # Minimal invoke
... # (creates a run in the project specified by the NEPTUNE_PROJECT environment variable)
... run = neptune.init_run()
>>> # Or initialize with the constructor
... run = Run(project="ml-team/classification")
>>> # Create a run with a name and description, with no sources files or Git info tracked:
>>> run = neptune.init_run(
... name="neural-net-mnist",
... description="neural net trained on MNIST",
... source_files=[],
... git_ref=False,
... )
>>> # Log all .py files from all subdirectories, excluding hidden files
... run = neptune.init_run(source_files="**/*.py")
>>> # Log all files and directories in the current working directory, excluding hidden files
... run = neptune.init_run(source_files="*")
>>> # Larger example
... run = neptune.init_run(
... project="ml-team/classification",
... name="first-pytorch-ever",
... description="Longer description of the run goes here",
... tags=["tags", "go-here", "as-list-of-strings"],
... source_files=["training_with_pytorch.py", "net.py"],
... dependencies="infer",
... capture_stderr=False,
... git_ref=GitRef(repository_path="/Users/Jackie/repos/cls_project"),
... )
Connecting to an existing run:
>>> # Resume logging to an existing run with the ID "SAN-3"
... run = neptune.init_run(with_id="SAN-3")
... run["parameters/lr"] = 0.1 # modify or add metadata
>>> # Initialize an existing run in read-only mode (logging new data is not possible, only fetching)
... run = neptune.init_run(with_id="SAN-4", mode="read-only")
... learning_rate = run["parameters/lr"].fetch()
Using the Run object as context manager:
>>> with Run() as run:
... run["metric"].append(value)
For more, see the docs:
Initializing a run:
https://docs.neptune.ai/api/neptune#init_run
Run class reference:
https://docs.neptune.ai/api/run/
Essential logging methods:
https://docs.neptune.ai/logging/methods/
Resuming a run:
https://docs.neptune.ai/logging/to_existing_object/
Setting a custom run ID:
https://docs.neptune.ai/logging/custom_run_id/
Logging to multiple runs at once:
https://docs.neptune.ai/logging/to_multiple_objects/
Accessing the run from multiple places:
https://docs.neptune.ai/logging/from_multiple_places/
"""
container_type = ContainerType.RUN
def __init__(
self,
with_id: Optional[str] = None,
*,
project: Optional[str] = None,
api_token: Optional[str] = None,
custom_run_id: Optional[str] = None,
mode: Optional[Literal["async", "sync", "offline", "read-only", "debug"]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[Union[List[str], str]] = None,
source_files: Optional[Union[List[str], str]] = None,
capture_stdout: Optional[bool] = None,
capture_stderr: Optional[bool] = None,
capture_hardware_metrics: Optional[bool] = None,
fail_on_exception: bool = True,
monitoring_namespace: Optional[str] = None,
flush_period: float = DEFAULT_FLUSH_PERIOD,
proxies: Optional[dict] = None,
capture_traceback: bool = True,
git_ref: Optional[Union[GitRef, GitRefDisabled, bool]] = None,
dependencies: Optional[Union[str, os.PathLike]] = None,
async_lag_callback: Optional[NeptuneObjectCallback] = None,
async_lag_threshold: float = ASYNC_LAG_THRESHOLD,
async_no_progress_callback: Optional[NeptuneObjectCallback] = None,
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD,
**kwargs,
):
check_for_extra_kwargs("Run", kwargs)
verify_type("with_id", with_id, (str, type(None)))
verify_type("project", project, (str, type(None)))
verify_type("custom_run_id", custom_run_id, (str, type(None)))
verify_type("mode", mode, (str, type(None)))
verify_type("name", name, (str, type(None)))
verify_type("description", description, (str, type(None)))
verify_type("capture_stdout", capture_stdout, (bool, type(None)))
verify_type("capture_stderr", capture_stderr, (bool, type(None)))
verify_type("capture_hardware_metrics", capture_hardware_metrics, (bool, type(None)))
verify_type("fail_on_exception", fail_on_exception, bool)
verify_type("monitoring_namespace", monitoring_namespace, (str, type(None)))
verify_type("capture_traceback", capture_traceback, bool)
verify_type("git_ref", git_ref, (GitRef, str, bool, type(None)))
verify_type("dependencies", dependencies, (str, os.PathLike, type(None)))
if tags is not None:
if isinstance(tags, str):
tags = [tags]
else:
verify_collection_type("tags", tags, str)
if source_files is not None:
if isinstance(source_files, str):
source_files = [source_files]
else:
verify_collection_type("source_files", source_files, str)
self._with_id: Optional[str] = with_id
self._name: Optional[str] = DEFAULT_NAME if with_id is None and name is None else name
self._description: Optional[str] = "" if with_id is None and description is None else description
self._custom_run_id: Optional[str] = custom_run_id or os.getenv(CUSTOM_RUN_ID_ENV_NAME)
self._hostname: str = get_hostname()
self._pid: int = os.getpid()
self._tid: int = threading.get_ident()
self._tags: Optional[List[str]] = tags
self._source_files: Optional[List[str]] = source_files
self._fail_on_exception: bool = fail_on_exception
self._capture_traceback: bool = capture_traceback
if type(git_ref) is bool:
git_ref = GitRef() if git_ref else GitRef.DISABLED
self._git_ref: Optional[GitRef, GitRefDisabled] = git_ref or GitRef()
self._dependencies: Optional[str, os.PathLike] = dependencies
self._monitoring_namespace: str = (
monitoring_namespace
or os.getenv(MONITORING_NAMESPACE)
or generate_monitoring_namespace(self._hostname, self._pid, self._tid)
)
# for backward compatibility imports
mode = Mode(mode or os.getenv(CONNECTION_MODE) or Mode.ASYNC.value)
self._stdout_path: str = "{}/stdout".format(self._monitoring_namespace)
self._capture_stdout: bool = capture_stdout
if capture_stdout is None:
self._capture_stdout = capture_only_if_non_interactive(mode=mode)
self._stderr_path: str = "{}/stderr".format(self._monitoring_namespace)
self._capture_stderr: bool = capture_stderr
if capture_stderr is None:
self._capture_stderr = capture_only_if_non_interactive(mode=mode)
self._capture_hardware_metrics: bool = capture_hardware_metrics
if capture_hardware_metrics is None:
self._capture_hardware_metrics = capture_only_if_non_interactive(mode=mode)
if with_id and custom_run_id:
raise NeptuneRunResumeAndCustomIdCollision()
if mode == Mode.OFFLINE or mode == Mode.DEBUG:
project = OFFLINE_PROJECT_QUALIFIED_NAME
super().__init__(
project=project,
api_token=api_token,
mode=mode,
flush_period=flush_period,
proxies=proxies,
async_lag_callback=async_lag_callback,
async_lag_threshold=async_lag_threshold,
async_no_progress_callback=async_no_progress_callback,
async_no_progress_threshold=async_no_progress_threshold,
)
def _get_or_create_api_object(self) -> ApiExperiment:
project_workspace = self._project_api_object.workspace
project_name = self._project_api_object.name
project_qualified_name = f"{project_workspace}/{project_name}"
if self._with_id:
return self._backend.get_metadata_container(
container_id=QualifiedName(project_qualified_name + "/" + self._with_id),
expected_container_type=Run.container_type,
)
else:
if self._mode == Mode.READ_ONLY:
raise NeedExistingRunForReadOnlyMode()
git_info = to_git_info(git_ref=self._git_ref)
custom_run_id = self._custom_run_id
if custom_run_id_exceeds_length(self._custom_run_id):
custom_run_id = None
notebook_id, checkpoint_id = create_notebook_checkpoint(backend=self._backend)
return self._backend.create_run(
project_id=self._project_api_object.id,
git_info=git_info,
custom_run_id=custom_run_id,
notebook_id=notebook_id,
checkpoint_id=checkpoint_id,
)
@temporarily_disabled
def _get_background_jobs(self) -> List["BackgroundJob"]:
background_jobs = [PingBackgroundJob()]
websockets_factory = self._backend.websockets_factory(self._project_api_object.id, self._id)
if websockets_factory:
background_jobs.append(WebsocketSignalsBackgroundJob(websockets_factory))
if self._capture_stdout:
background_jobs.append(StdoutCaptureBackgroundJob(attribute_name=self._stdout_path))
if self._capture_stderr:
background_jobs.append(StderrCaptureBackgroundJob(attribute_name=self._stderr_path))
if self._capture_hardware_metrics:
background_jobs.append(HardwareMetricReportingJob(attribute_namespace=self._monitoring_namespace))
if self._capture_traceback:
background_jobs.append(
TracebackJob(path=f"{self._monitoring_namespace}/traceback", fail_on_exception=self._fail_on_exception)
)
return background_jobs
@temporarily_disabled
def _write_initial_monitoring_attributes(self) -> None:
if self._hostname is not None:
self[f"{self._monitoring_namespace}/hostname"] = self._hostname
if self._with_id is None:
self[SYSTEM_HOSTNAME_ATTRIBUTE_PATH] = self._hostname
if self._pid is not None:
self[f"{self._monitoring_namespace}/pid"] = str(self._pid)
if self._tid is not None:
self[f"{self._monitoring_namespace}/tid"] = str(self._tid)
@temporarily_disabled
def _write_initial_attributes(self):
if self._name is not None:
self[SYSTEM_NAME_ATTRIBUTE_PATH] = self._name
if self._description is not None:
self[SYSTEM_DESCRIPTION_ATTRIBUTE_PATH] = self._description
if any((self._capture_stderr, self._capture_stdout, self._capture_traceback, self._capture_hardware_metrics)):
self._write_initial_monitoring_attributes()
if self._tags is not None:
self[SYSTEM_TAGS_ATTRIBUTE_PATH].add(self._tags)
if self._with_id is None:
self[SYSTEM_FAILED_ATTRIBUTE_PATH] = False
if self._capture_stdout and not self.exists(self._stdout_path):
self.define(self._stdout_path, StringSeries([]))
if self._capture_stderr and not self.exists(self._stderr_path):
self.define(self._stderr_path, StringSeries([]))
if self._with_id is None or self._source_files is not None:
# upload default sources ONLY if creating a new run
upload_source_code(source_files=self._source_files, run=self)
if self._dependencies:
try:
if self._dependencies == "infer":
dependency_strategy = InferDependenciesStrategy()
else:
dependency_strategy = FileDependenciesStrategy(path=self._dependencies)
dependency_strategy.log_dependencies(run=self)
except Exception as e:
warn_once(
"An exception occurred in automatic dependency tracking."
"Skipping upload of requirement files."
"Exception: " + str(e),
exception=NeptuneWarning,
)
try:
track_uncommitted_changes(
git_ref=self._git_ref,
run=self,
)
except Exception as e:
warn_once(
"An exception occurred in tracking uncommitted changes."
"Skipping upload of patch files."
"Exception: " + str(e),
exception=NeptuneWarning,
)
@property
def monitoring_namespace(self) -> str:
return self._monitoring_namespace
def _raise_if_stopped(self):
if self._state == ContainerState.STOPPED:
raise InactiveRunException(label=self._sys_id)
def get_url(self) -> str:
"""Returns the URL that can be accessed within the browser"""
return self._backend.get_run_url(
run_id=self._id,
workspace=self._workspace,
project_name=self._project_name,
sys_id=self._sys_id,
)
def capture_only_if_non_interactive(mode) -> bool:
if in_interactive() or in_notebook():
if mode in {Mode.OFFLINE, Mode.SYNC, Mode.ASYNC}:
warn_once(
"The following monitoring options are disabled by default in interactive sessions:"
" 'capture_stdout', 'capture_stderr', 'capture_traceback', and 'capture_hardware_metrics'."
" To enable them, set each parameter to 'True' when initializing the run. The monitoring will"
" continue until you call run.stop() or the kernel stops."
" Also note: Your source files can only be tracked if you pass the path(s) to the 'source_code'"
" argument. For help, see the Neptune docs: https://docs.neptune.ai/logging/source_code/",
exception=NeptuneWarning,
)
return False
return True
def generate_monitoring_namespace(*descriptors) -> str:
return f"monitoring/{generate_hash(*descriptors, length=8)}"
def check_for_extra_kwargs(caller_name: str, kwargs: dict):
if kwargs:
first_key = next(iter(kwargs.keys()))
raise TypeError(f"{caller_name}() got an unexpected keyword argument '{first_key}'")
def create_notebook_checkpoint(backend: NeptuneBackend) -> Tuple[Optional[str], Optional[str]]:
notebook_id = os.getenv(NEPTUNE_NOTEBOOK_ID, None)
notebook_path = os.getenv(NEPTUNE_NOTEBOOK_PATH, None)
checkpoint_id = None
if notebook_id is not None and notebook_path is not None:
checkpoint_id = create_checkpoint(backend=backend, notebook_id=notebook_id, notebook_path=notebook_path)
return notebook_id, checkpoint_id
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["StructureVersion"]
from enum import Enum
class StructureVersion(Enum):
# -------------------------------------------------
# .neptune/
# async/
# <uuid>/
# exec-<num><timestamp>/
# container_type
# data-1.log
# ...
# -------------------------------------------------
LEGACY = 1
# -------------------------------------------------
# .neptune/
# async/
# run__<uuid>/
# exec-<timestamp>-<date>-<pid>/
# data-1.log
# ...
# -------------------------------------------------
CHILD_EXECUTION_DIRECTORIES = 2
# -------------------------------------------------
# .neptune/
# async/
# run__<uuid>__<pid>__<random_key>/
# data-1.log
# ...
# -------------------------------------------------
DIRECT_DIRECTORY = 3
#
# Copyright (c) 2023, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"prepare_nql_query",
]
from typing import (
Iterable,
List,
Optional,
Union,
)
from neptune.internal.backends.nql import (
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLQuery,
NQLQueryAggregate,
NQLQueryAttribute,
RawNQLQuery,
)
from neptune.internal.utils.run_state import RunState
def prepare_nql_query(
ids: Optional[Iterable[str]],
states: Optional[Iterable[str]],
owners: Optional[Iterable[str]],
tags: Optional[Iterable[str]],
trashed: Optional[bool],
) -> NQLQueryAggregate:
query_items: List[Union[NQLQueryAttribute, NQLQueryAggregate]] = []
if trashed is not None:
query_items.append(
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=trashed,
)
)
if ids:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/id",
type=NQLAttributeType.STRING,
operator=NQLAttributeOperator.EQUALS,
value=api_id,
)
for api_id in ids
],
aggregator=NQLAggregator.OR,
)
)
if states:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/state",
type=NQLAttributeType.EXPERIMENT_STATE,
operator=NQLAttributeOperator.EQUALS,
value=RunState.from_string(state).to_api(),
)
for state in states
],
aggregator=NQLAggregator.OR,
)
)
if owners:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/owner",
type=NQLAttributeType.STRING,
operator=NQLAttributeOperator.EQUALS,
value=owner,
)
for owner in owners
],
aggregator=NQLAggregator.OR,
)
)
if tags:
query_items.append(
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/tags",
type=NQLAttributeType.STRING_SET,
operator=NQLAttributeOperator.CONTAINS,
value=tag,
)
for tag in tags
],
aggregator=NQLAggregator.AND,
)
)
query = NQLQueryAggregate(items=query_items, aggregator=NQLAggregator.AND)
return query
def build_raw_query(query: str, trashed: Optional[bool]) -> NQLQuery:
raw_nql = RawNQLQuery(query)
if trashed is None:
return raw_nql
nql = NQLQueryAggregate(
items=[
raw_nql,
NQLQueryAttribute(
name="sys/trashed", type=NQLAttributeType.BOOLEAN, operator=NQLAttributeOperator.EQUALS, value=trashed
),
],
aggregator=NQLAggregator.AND,
)
return nql