Latest Threat Research:SANDWORM_MODE: Shai-Hulud-Style npm Worm Hijacks CI Workflows and Poisons AI Toolchains.Details
Socket
Book a DemoInstallSign in
Socket

tgi

Package Overview
Dependencies
Maintainers
1
Versions
7
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

tgi - npm Package Compare versions

Comparing version
2.0.2
to
2.4.0
+4
-1
PKG-INFO
Metadata-Version: 2.1
Name: tgi
Version: 2.0.2
Version: 2.4.0
Summary: Nightly release of Hugging Face Text Generation Python Client

@@ -26,2 +26,5 @@ Home-page: https://github.com/huggingface/text-generation-inference

# Legacy warning ⚠️
The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`.
# Text Generation

@@ -28,0 +31,0 @@

[tool.poetry]
name = "tgi"
version = "2.0.2"
version = "2.4.0"
description = "Nightly release of Hugging Face Text Generation Python Client"

@@ -30,1 +30,4 @@ license = "Apache-2.0"

build-backend = "poetry.core.masonry.api"
[tool.isort]
profile = "black"
+3
-0

@@ -0,1 +1,4 @@

# Legacy warning ⚠️
The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`.
# Text Generation

@@ -2,0 +5,0 @@

@@ -15,5 +15,20 @@ # Copyright 2023 The HuggingFace Team. All rights reserved.

__version__ = "0.6.0"
__version__ = "0.7.0"
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
DEPRECATION_WARNING = (
"`text_generation` clients are deprecated and will be removed in the near future. "
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
)
from text_generation.client import Client, AsyncClient # noqa E402
from text_generation.inference_api import ( # noqa E402
InferenceAPIClient,
InferenceAPIAsyncClient,
)
__all__ = [
"Client",
"AsyncClient",
"InferenceAPIClient",
"InferenceAPIAsyncClient",
]
import json
import requests
import warnings

@@ -8,2 +9,3 @@ from aiohttp import ClientSession, ClientTimeout

from text_generation import DEPRECATION_WARNING
from text_generation.types import (

@@ -15,2 +17,5 @@ StreamResponse,

Grammar,
CompletionRequest,
Completion,
CompletionComplete,
ChatRequest,

@@ -24,3 +29,6 @@ ChatCompletionChunk,

# emit deprecation warnings
warnings.simplefilter("always", DeprecationWarning)
class Client:

@@ -65,2 +73,3 @@ """Client to make calls to a text-generation-inference instance

"""
warnings.warn(DEPRECATION_WARNING, DeprecationWarning)
self.base_url = base_url

@@ -71,2 +80,90 @@ self.headers = headers

def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
):
"""
Given a prompt, generate a response synchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Completion(**payload)
else:
return self._completion_stream_response(request)
def _completion_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def chat(

@@ -90,2 +187,3 @@ self,

tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
):

@@ -133,2 +231,4 @@ """

The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated

@@ -154,2 +254,3 @@ """

tool_choice=tool_choice,
stop=stop,
)

@@ -460,2 +561,3 @@ if not stream:

"""
warnings.warn(DEPRECATION_WARNING, DeprecationWarning)
self.base_url = base_url

@@ -466,2 +568,89 @@ self.headers = headers

async def completion(
self,
prompt: str,
frequency_penalty: Optional[float] = None,
max_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
seed: Optional[int] = None,
stream: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> Union[Completion, AsyncIterator[CompletionComplete]]:
"""
Given a prompt, generate a response asynchronously
Args:
prompt (`str`):
Prompt
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty
Penalize new tokens based on their existing frequency in the text so far,
decreasing the model's likelihood to repeat the same line verbatim.
max_tokens (`int`):
Maximum number of generated tokens
repetition_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
seed (`int`):
Random sampling seed
stream (`bool`):
Stream the response
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated
"""
request = CompletionRequest(
model="tgi",
prompt=prompt,
frequency_penalty=frequency_penalty,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
seed=seed,
stream=stream,
temperature=temperature,
top_p=top_p,
stop=stop,
)
if not stream:
return await self._completion_single_response(request)
else:
return self._completion_stream_response(request)
async def _completion_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Completion(**payload)
async def _completion_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = CompletionComplete(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def chat(

@@ -485,2 +674,3 @@ self,

tool_choice: Optional[str] = None,
stop: Optional[List[str]] = None,
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:

@@ -528,2 +718,4 @@ """

The tool to use
stop (`List[str]`):
Stop generating tokens if a member of `stop` is generated

@@ -549,2 +741,3 @@ """

tool_choice=tool_choice,
stop=stop,
)

@@ -580,3 +773,8 @@ if not stream:

if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
payload_data = (
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
)
if payload_data == "[DONE]":
break
json_payload = json.loads(payload_data)
try:

@@ -583,0 +781,0 @@ response = ChatCompletionChunk(**json_payload)

@@ -24,3 +24,3 @@ import os

resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference",
"https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers,

@@ -27,0 +27,0 @@ timeout=5,

from enum import Enum
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, ConfigDict
from typing import Optional, List, Union, Any

@@ -31,2 +31,8 @@

class Chunk(BaseModel):
type: str
text: Optional[str] = None
image_url: Any = None
class Message(BaseModel):

@@ -36,3 +42,3 @@ # Role of the message sender

# Content of the message
content: Optional[str] = None
content: Optional[Union[str, List[Chunk]]] = None
# Optional name of the message sender

@@ -51,26 +57,2 @@ name: Optional[str] = None

class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel):

@@ -91,3 +73,3 @@ name: Optional[str]

content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall]
tool_calls: Optional[ChoiceDeltaToolCall] = None

@@ -102,20 +84,37 @@

class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
class CompletionRequest(BaseModel):
# Model identifier
model: str
system_fingerprint: str
choices: List[Choice]
# Prompt
prompt: str
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str

@@ -171,4 +170,40 @@

tool_choice: Optional[str] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: Optional[str]
# Usage details of the chat completion
usage: Optional[Any] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
usage: Optional[Any] = None
class Parameters(BaseModel):

@@ -433,3 +468,7 @@ # Activate logits sampling

class DeployedModel(BaseModel):
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
# with model_ prefixes, since this disables guardrails for colliding fields:
# https://github.com/pydantic/pydantic/issues/9177
model_config = ConfigDict(protected_namespaces=())
model_id: str
sha: str