tgi
Advanced tools
+4
-1
| Metadata-Version: 2.1 | ||
| Name: tgi | ||
| Version: 2.0.2 | ||
| Version: 2.4.0 | ||
| Summary: Nightly release of Hugging Face Text Generation Python Client | ||
@@ -26,2 +26,5 @@ Home-page: https://github.com/huggingface/text-generation-inference | ||
| # Legacy warning ⚠️ | ||
| The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`. | ||
| # Text Generation | ||
@@ -28,0 +31,0 @@ |
+4
-1
| [tool.poetry] | ||
| name = "tgi" | ||
| version = "2.0.2" | ||
| version = "2.4.0" | ||
| description = "Nightly release of Hugging Face Text Generation Python Client" | ||
@@ -30,1 +30,4 @@ license = "Apache-2.0" | ||
| build-backend = "poetry.core.masonry.api" | ||
| [tool.isort] | ||
| profile = "black" |
+3
-0
@@ -0,1 +1,4 @@ | ||
| # Legacy warning ⚠️ | ||
| The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`. | ||
| # Text Generation | ||
@@ -2,0 +5,0 @@ |
+18
-3
@@ -15,5 +15,20 @@ # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| __version__ = "0.6.0" | ||
| __version__ = "0.7.0" | ||
| from text_generation.client import Client, AsyncClient | ||
| from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient | ||
| DEPRECATION_WARNING = ( | ||
| "`text_generation` clients are deprecated and will be removed in the near future. " | ||
| "Please use the `InferenceClient` from the `huggingface_hub` package instead." | ||
| ) | ||
| from text_generation.client import Client, AsyncClient # noqa E402 | ||
| from text_generation.inference_api import ( # noqa E402 | ||
| InferenceAPIClient, | ||
| InferenceAPIAsyncClient, | ||
| ) | ||
| __all__ = [ | ||
| "Client", | ||
| "AsyncClient", | ||
| "InferenceAPIClient", | ||
| "InferenceAPIAsyncClient", | ||
| ] |
+199
-1
| import json | ||
| import requests | ||
| import warnings | ||
@@ -8,2 +9,3 @@ from aiohttp import ClientSession, ClientTimeout | ||
| from text_generation import DEPRECATION_WARNING | ||
| from text_generation.types import ( | ||
@@ -15,2 +17,5 @@ StreamResponse, | ||
| Grammar, | ||
| CompletionRequest, | ||
| Completion, | ||
| CompletionComplete, | ||
| ChatRequest, | ||
@@ -24,3 +29,6 @@ ChatCompletionChunk, | ||
| # emit deprecation warnings | ||
| warnings.simplefilter("always", DeprecationWarning) | ||
| class Client: | ||
@@ -65,2 +73,3 @@ """Client to make calls to a text-generation-inference instance | ||
| """ | ||
| warnings.warn(DEPRECATION_WARNING, DeprecationWarning) | ||
| self.base_url = base_url | ||
@@ -71,2 +80,90 @@ self.headers = headers | ||
| def completion( | ||
| self, | ||
| prompt: str, | ||
| frequency_penalty: Optional[float] = None, | ||
| max_tokens: Optional[int] = None, | ||
| repetition_penalty: Optional[float] = None, | ||
| seed: Optional[int] = None, | ||
| stream: bool = False, | ||
| temperature: Optional[float] = None, | ||
| top_p: Optional[float] = None, | ||
| stop: Optional[List[str]] = None, | ||
| ): | ||
| """ | ||
| Given a prompt, generate a response synchronously | ||
| Args: | ||
| prompt (`str`): | ||
| Prompt | ||
| frequency_penalty (`float`): | ||
| The parameter for frequency penalty. 0.0 means no penalty | ||
| Penalize new tokens based on their existing frequency in the text so far, | ||
| decreasing the model's likelihood to repeat the same line verbatim. | ||
| max_tokens (`int`): | ||
| Maximum number of generated tokens | ||
| repetition_penalty (`float`): | ||
| The parameter for frequency penalty. 0.0 means no penalty. See [this | ||
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | ||
| seed (`int`): | ||
| Random sampling seed | ||
| stream (`bool`): | ||
| Stream the response | ||
| temperature (`float`): | ||
| The value used to module the logits distribution. | ||
| top_p (`float`): | ||
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | ||
| higher are kept for generation | ||
| stop (`List[str]`): | ||
| Stop generating tokens if a member of `stop` is generated | ||
| """ | ||
| request = CompletionRequest( | ||
| model="tgi", | ||
| prompt=prompt, | ||
| frequency_penalty=frequency_penalty, | ||
| max_tokens=max_tokens, | ||
| repetition_penalty=repetition_penalty, | ||
| seed=seed, | ||
| stream=stream, | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| stop=stop, | ||
| ) | ||
| if not stream: | ||
| resp = requests.post( | ||
| f"{self.base_url}/v1/completions", | ||
| json=request.dict(), | ||
| headers=self.headers, | ||
| cookies=self.cookies, | ||
| timeout=self.timeout, | ||
| ) | ||
| payload = resp.json() | ||
| if resp.status_code != 200: | ||
| raise parse_error(resp.status_code, payload) | ||
| return Completion(**payload) | ||
| else: | ||
| return self._completion_stream_response(request) | ||
| def _completion_stream_response(self, request): | ||
| resp = requests.post( | ||
| f"{self.base_url}/v1/completions", | ||
| json=request.dict(), | ||
| headers=self.headers, | ||
| cookies=self.cookies, | ||
| timeout=self.timeout, | ||
| stream=True, | ||
| ) | ||
| # iterate and print stream | ||
| for byte_payload in resp.iter_lines(): | ||
| if byte_payload == b"\n": | ||
| continue | ||
| payload = byte_payload.decode("utf-8") | ||
| if payload.startswith("data:"): | ||
| json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) | ||
| try: | ||
| response = CompletionComplete(**json_payload) | ||
| yield response | ||
| except ValidationError: | ||
| raise parse_error(resp.status, json_payload) | ||
| def chat( | ||
@@ -90,2 +187,3 @@ self, | ||
| tool_choice: Optional[str] = None, | ||
| stop: Optional[List[str]] = None, | ||
| ): | ||
@@ -133,2 +231,4 @@ """ | ||
| The tool to use | ||
| stop (`List[str]`): | ||
| Stop generating tokens if a member of `stop` is generated | ||
@@ -154,2 +254,3 @@ """ | ||
| tool_choice=tool_choice, | ||
| stop=stop, | ||
| ) | ||
@@ -460,2 +561,3 @@ if not stream: | ||
| """ | ||
| warnings.warn(DEPRECATION_WARNING, DeprecationWarning) | ||
| self.base_url = base_url | ||
@@ -466,2 +568,89 @@ self.headers = headers | ||
| async def completion( | ||
| self, | ||
| prompt: str, | ||
| frequency_penalty: Optional[float] = None, | ||
| max_tokens: Optional[int] = None, | ||
| repetition_penalty: Optional[float] = None, | ||
| seed: Optional[int] = None, | ||
| stream: bool = False, | ||
| temperature: Optional[float] = None, | ||
| top_p: Optional[float] = None, | ||
| stop: Optional[List[str]] = None, | ||
| ) -> Union[Completion, AsyncIterator[CompletionComplete]]: | ||
| """ | ||
| Given a prompt, generate a response asynchronously | ||
| Args: | ||
| prompt (`str`): | ||
| Prompt | ||
| frequency_penalty (`float`): | ||
| The parameter for frequency penalty. 0.0 means no penalty | ||
| Penalize new tokens based on their existing frequency in the text so far, | ||
| decreasing the model's likelihood to repeat the same line verbatim. | ||
| max_tokens (`int`): | ||
| Maximum number of generated tokens | ||
| repetition_penalty (`float`): | ||
| The parameter for frequency penalty. 0.0 means no penalty. See [this | ||
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | ||
| seed (`int`): | ||
| Random sampling seed | ||
| stream (`bool`): | ||
| Stream the response | ||
| temperature (`float`): | ||
| The value used to module the logits distribution. | ||
| top_p (`float`): | ||
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | ||
| higher are kept for generation | ||
| stop (`List[str]`): | ||
| Stop generating tokens if a member of `stop` is generated | ||
| """ | ||
| request = CompletionRequest( | ||
| model="tgi", | ||
| prompt=prompt, | ||
| frequency_penalty=frequency_penalty, | ||
| max_tokens=max_tokens, | ||
| repetition_penalty=repetition_penalty, | ||
| seed=seed, | ||
| stream=stream, | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| stop=stop, | ||
| ) | ||
| if not stream: | ||
| return await self._completion_single_response(request) | ||
| else: | ||
| return self._completion_stream_response(request) | ||
| async def _completion_single_response(self, request): | ||
| async with ClientSession( | ||
| headers=self.headers, cookies=self.cookies, timeout=self.timeout | ||
| ) as session: | ||
| async with session.post( | ||
| f"{self.base_url}/v1/completions", json=request.dict() | ||
| ) as resp: | ||
| payload = await resp.json() | ||
| if resp.status != 200: | ||
| raise parse_error(resp.status, payload) | ||
| return Completion(**payload) | ||
| async def _completion_stream_response(self, request): | ||
| async with ClientSession( | ||
| headers=self.headers, cookies=self.cookies, timeout=self.timeout | ||
| ) as session: | ||
| async with session.post( | ||
| f"{self.base_url}/v1/completions", json=request.dict() | ||
| ) as resp: | ||
| async for byte_payload in resp.content: | ||
| if byte_payload == b"\n": | ||
| continue | ||
| payload = byte_payload.decode("utf-8") | ||
| if payload.startswith("data:"): | ||
| json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) | ||
| try: | ||
| response = CompletionComplete(**json_payload) | ||
| yield response | ||
| except ValidationError: | ||
| raise parse_error(resp.status, json_payload) | ||
| async def chat( | ||
@@ -485,2 +674,3 @@ self, | ||
| tool_choice: Optional[str] = None, | ||
| stop: Optional[List[str]] = None, | ||
| ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: | ||
@@ -528,2 +718,4 @@ """ | ||
| The tool to use | ||
| stop (`List[str]`): | ||
| Stop generating tokens if a member of `stop` is generated | ||
@@ -549,2 +741,3 @@ """ | ||
| tool_choice=tool_choice, | ||
| stop=stop, | ||
| ) | ||
@@ -580,3 +773,8 @@ if not stream: | ||
| if payload.startswith("data:"): | ||
| json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) | ||
| payload_data = ( | ||
| payload.lstrip("data:").rstrip("\n").removeprefix(" ") | ||
| ) | ||
| if payload_data == "[DONE]": | ||
| break | ||
| json_payload = json.loads(payload_data) | ||
| try: | ||
@@ -583,0 +781,0 @@ response = ChatCompletionChunk(**json_payload) |
@@ -24,3 +24,3 @@ import os | ||
| resp = requests.get( | ||
| f"https://api-inference.huggingface.co/framework/text-generation-inference", | ||
| "https://api-inference.huggingface.co/framework/text-generation-inference", | ||
| headers=headers, | ||
@@ -27,0 +27,0 @@ timeout=5, |
+81
-42
| from enum import Enum | ||
| from pydantic import BaseModel, field_validator | ||
| from pydantic import BaseModel, field_validator, ConfigDict | ||
| from typing import Optional, List, Union, Any | ||
@@ -31,2 +31,8 @@ | ||
| class Chunk(BaseModel): | ||
| type: str | ||
| text: Optional[str] = None | ||
| image_url: Any = None | ||
| class Message(BaseModel): | ||
@@ -36,3 +42,3 @@ # Role of the message sender | ||
| # Content of the message | ||
| content: Optional[str] = None | ||
| content: Optional[Union[str, List[Chunk]]] = None | ||
| # Optional name of the message sender | ||
@@ -51,26 +57,2 @@ name: Optional[str] = None | ||
| class ChatCompletionComplete(BaseModel): | ||
| # Index of the chat completion | ||
| index: int | ||
| # Message associated with the chat completion | ||
| message: Message | ||
| # Log probabilities for the chat completion | ||
| logprobs: Optional[Any] | ||
| # Reason for completion | ||
| finish_reason: str | ||
| # Usage details of the chat completion | ||
| usage: Optional[Any] = None | ||
| class CompletionComplete(BaseModel): | ||
| # Index of the chat completion | ||
| index: int | ||
| # Message associated with the chat completion | ||
| text: str | ||
| # Log probabilities for the chat completion | ||
| logprobs: Optional[Any] | ||
| # Reason for completion | ||
| finish_reason: str | ||
| class Function(BaseModel): | ||
@@ -91,3 +73,3 @@ name: Optional[str] | ||
| content: Optional[str] = None | ||
| tool_calls: Optional[ChoiceDeltaToolCall] | ||
| tool_calls: Optional[ChoiceDeltaToolCall] = None | ||
@@ -102,20 +84,37 @@ | ||
| class ChatCompletionChunk(BaseModel): | ||
| id: str | ||
| object: str | ||
| created: int | ||
| class CompletionRequest(BaseModel): | ||
| # Model identifier | ||
| model: str | ||
| system_fingerprint: str | ||
| choices: List[Choice] | ||
| # Prompt | ||
| prompt: str | ||
| # The parameter for repetition penalty. 1.0 means no penalty. | ||
| # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | ||
| repetition_penalty: Optional[float] = None | ||
| # The parameter for frequency penalty. 1.0 means no penalty | ||
| # Penalize new tokens based on their existing frequency in the text so far, | ||
| # decreasing the model's likelihood to repeat the same line verbatim. | ||
| frequency_penalty: Optional[float] = None | ||
| # Maximum number of tokens to generate | ||
| max_tokens: Optional[int] = None | ||
| # Flag to indicate streaming response | ||
| stream: bool = False | ||
| # Random sampling seed | ||
| seed: Optional[int] = None | ||
| # Sampling temperature | ||
| temperature: Optional[float] = None | ||
| # Top-p value for nucleus sampling | ||
| top_p: Optional[float] = None | ||
| # Stop generating tokens if a member of `stop` is generated | ||
| stop: Optional[List[str]] = None | ||
| class ChatComplete(BaseModel): | ||
| # Chat completion details | ||
| id: str | ||
| object: str | ||
| created: int | ||
| model: str | ||
| system_fingerprint: str | ||
| choices: List[ChatCompletionComplete] | ||
| usage: Any | ||
| class CompletionComplete(BaseModel): | ||
| # Index of the chat completion | ||
| index: int | ||
| # Message associated with the chat completion | ||
| text: str | ||
| # Log probabilities for the chat completion | ||
| logprobs: Optional[Any] | ||
| # Reason for completion | ||
| finish_reason: str | ||
@@ -171,4 +170,40 @@ | ||
| tool_choice: Optional[str] = None | ||
| # Stop generating tokens if a member of `stop` is generated | ||
| stop: Optional[List[str]] = None | ||
| class ChatCompletionComplete(BaseModel): | ||
| # Index of the chat completion | ||
| index: int | ||
| # Message associated with the chat completion | ||
| message: Message | ||
| # Log probabilities for the chat completion | ||
| logprobs: Optional[Any] | ||
| # Reason for completion | ||
| finish_reason: Optional[str] | ||
| # Usage details of the chat completion | ||
| usage: Optional[Any] = None | ||
| class ChatComplete(BaseModel): | ||
| # Chat completion details | ||
| id: str | ||
| object: str | ||
| created: int | ||
| model: str | ||
| system_fingerprint: str | ||
| choices: List[ChatCompletionComplete] | ||
| usage: Any | ||
| class ChatCompletionChunk(BaseModel): | ||
| id: str | ||
| object: str | ||
| created: int | ||
| model: str | ||
| system_fingerprint: str | ||
| choices: List[Choice] | ||
| usage: Optional[Any] = None | ||
| class Parameters(BaseModel): | ||
@@ -433,3 +468,7 @@ # Activate logits sampling | ||
| class DeployedModel(BaseModel): | ||
| # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members | ||
| # with model_ prefixes, since this disables guardrails for colliding fields: | ||
| # https://github.com/pydantic/pydantic/issues/9177 | ||
| model_config = ConfigDict(protected_namespaces=()) | ||
| model_id: str | ||
| sha: str |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
81441
14.43%1564
17.95%