item-matching
Advanced tools
+2
-1
| Metadata-Version: 2.4 | ||
| Name: item_matching | ||
| Version: 0.0.104 | ||
| Version: 0.0.105 | ||
| Summary: A name matching package | ||
@@ -47,2 +47,3 @@ Project-URL: Homepage, https://github.com/kevinkhang2909/item_matching | ||
| Requires-Dist: torch | ||
| Requires-Dist: torchvision | ||
| Requires-Dist: transformers | ||
@@ -49,0 +50,0 @@ Description-Content-Type: text/markdown |
+2
-1
@@ -7,3 +7,3 @@ [build-system] | ||
| name = "item_matching" | ||
| version = "0.0.104" | ||
| version = "0.0.105" | ||
| authors = [ | ||
@@ -30,2 +30,3 @@ { name="Kevin Khang", email="kevinkhang2909@gmail.com" }, | ||
| 'torch', | ||
| 'torchvision', | ||
| 'autofaiss', | ||
@@ -32,0 +33,0 @@ 'datasets', |
@@ -6,3 +6,2 @@ from PIL import Image | ||
| from rich import print | ||
| from numpy.lib.format import open_memmap | ||
| import torch | ||
@@ -16,3 +15,3 @@ import torch.nn.functional as F | ||
| from FlagEmbedding import BGEM3FlagModel | ||
| from transformers import Dinov2WithRegistersModel, SiglipVisionModel, SiglipConfig | ||
| from transformers import Dinov2WithRegistersModel, Siglip2VisionModel | ||
| from .func import _create_folder | ||
@@ -73,5 +72,5 @@ | ||
| def get_img_model(): | ||
| pretrain_name = "google/siglip-base-patch16-224" | ||
| pretrain_name = "google/siglip2-base-patch16-224" | ||
| img_model = ( | ||
| SiglipVisionModel.from_pretrained( | ||
| Siglip2VisionModel.from_pretrained( | ||
| pretrain_name, | ||
@@ -83,3 +82,2 @@ torch_dtype=torch.bfloat16, | ||
| ) | ||
| config = SiglipConfig.from_pretrained(pretrain_name) | ||
@@ -96,3 +94,3 @@ # pretrain_name = "facebook/dinov2-with-registers-base" | ||
| # return torch.compile(img_model) | ||
| return img_model, config | ||
| return img_model | ||
@@ -102,3 +100,2 @@ | ||
| img_model, | ||
| config, | ||
| save_file_path: Path, | ||
@@ -169,3 +166,3 @@ iterable_list: list[str], | ||
| self.col_embedding = f"{self.MATCH_BY}_embed" | ||
| self.img_model, self.config = get_img_model() | ||
| self.img_model = get_img_model() | ||
@@ -205,3 +202,2 @@ def load(self, data: pl.DataFrame): | ||
| img_model=self.img_model, | ||
| config=self.config, | ||
| save_file_path=array_name, | ||
@@ -208,0 +204,0 @@ iterable_list=dataset_chunk[self.col_input].to_list(), |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
3597941
01269
-0.31%