item-matching
Advanced tools
| import duckdb | ||
| import polars as pl | ||
| from transformers import SiglipVisionModel, Dinov2WithRegistersModel | ||
| from transformers import ( | ||
| SiglipVisionModel, | ||
| Dinov2WithRegistersModel, | ||
| Siglip2VisionModel, | ||
| AutoProcessor, | ||
| Siglip2VisionConfig, | ||
| ) | ||
| from PIL import Image | ||
@@ -11,5 +17,5 @@ from accelerate import Accelerator | ||
| from torchvision import transforms | ||
| from numpy.lib.format import open_memmap | ||
| from core_pro.ultilities import make_sync_folder | ||
| from tqdm.auto import tqdm | ||
| import numpy as np | ||
@@ -25,5 +31,5 @@ device = Accelerator().device | ||
| [ | ||
| transforms.Resize( | ||
| img_size, interpolation=transforms.InterpolationMode.BICUBIC | ||
| ), | ||
| # transforms.Resize( | ||
| # img_size, interpolation=transforms.InterpolationMode.BICUBIC | ||
| # ), | ||
| transforms.CenterCrop(img_size), | ||
@@ -43,3 +49,4 @@ transforms.ConvertImageDtype(torch.float32), # to [0,1] | ||
| tensor = transforms.ToTensor()(img) # HWC→CHW float32 | ||
| return self.transform(tensor) | ||
| tensor = self.transform(tensor) | ||
| return tensor | ||
@@ -61,5 +68,29 @@ | ||
| ) | ||
| return torch.compile(img_model) | ||
| # return torch.compile(img_model) | ||
| return img_model | ||
| def setup_siglip2(): | ||
| pretrain_name = "google/siglip2-base-patch16-224" | ||
| config = Siglip2VisionConfig( | ||
| image_size=224, # 224/16 = 14 patches per side → 196 total | ||
| patch_size=16, | ||
| num_channels=3, | ||
| embed_dim=768, | ||
| patch_embed_type="conv", # if SigLIP-2 supports choosing Conv vs Linear | ||
| ) | ||
| img_model = Siglip2VisionModel(config) | ||
| img_model = ( | ||
| img_model.from_pretrained( | ||
| pretrain_name, torch_dtype=torch.bfloat16, ignore_mismatched_sizes=True | ||
| ) | ||
| .to(device) | ||
| .eval() | ||
| ) | ||
| processor = AutoProcessor.from_pretrained(pretrain_name) | ||
| return torch.compile(img_model), processor | ||
| def setup_dinov2(): | ||
@@ -75,3 +106,4 @@ pretrain_name = "facebook/dinov2-with-registers-base" | ||
| ) | ||
| return torch.compile(img_model) | ||
| # return torch.compile(img_model) | ||
| return img_model | ||
@@ -86,3 +118,12 @@ | ||
| ): | ||
| # 1) Prepare DataLoader --- | ||
| # 1) Load & compile model in mixed precision | ||
| device = torch.device("cuda") | ||
| if model == "siglip": | ||
| img_model = setup_siglip() | ||
| elif model == "siglip2": | ||
| img_model, processor = setup_siglip2() | ||
| else: | ||
| img_model = setup_dinov2() | ||
| # 2) Prepare DataLoader --- | ||
| ds = ImagePathsDataset(file_paths, img_size=224) | ||
@@ -98,22 +139,5 @@ loader = DataLoader( | ||
| # 2) Load & compile model in mixed precision | ||
| device = torch.device("cuda") | ||
| if model == "siglip": | ||
| img_model = setup_siglip() | ||
| else: | ||
| img_model = setup_dinov2() | ||
| # 3) Pre‑allocate a .npy memmap for all embeddings | ||
| total = len(ds) | ||
| dim = img_model.config.hidden_size # e.g. 1024 | ||
| mmap = open_memmap( | ||
| filename=str(path / f"{model}_embeds.npy"), | ||
| mode="w+", | ||
| dtype="float32", | ||
| shape=(total, dim), | ||
| ) | ||
| # 4) Inference + save loop | ||
| # 3) Inference + save loop | ||
| start = perf_counter() | ||
| idx = 0 | ||
| list_embeds = [] | ||
| with torch.inference_mode(): | ||
@@ -128,8 +152,7 @@ for batch in tqdm(loader): | ||
| emb = normed.cpu().numpy().astype("float32") # (B, dim) | ||
| bs = emb.shape[0] | ||
| mmap[idx : idx + bs] = emb # write into .npy | ||
| idx += bs | ||
| list_embeds.append(emb) | ||
| embeds = np.concatenate(list_embeds, axis=0) | ||
| np.save(path / f"{model}_embeds.npy", embeds) | ||
| durations = perf_counter() - start | ||
| mmap.flush() # ensure all data is on disk | ||
@@ -147,3 +170,4 @@ result.append((model, durations)) | ||
| query = f""" | ||
| select * | ||
| select * exclude(file_path) | ||
| , REPLACE(file_path, 'data_4t', '75b198db-809a-4bd2-a97c-e52daa6b3a2d') AS file_path | ||
| from read_parquet('{file}') | ||
@@ -156,8 +180,9 @@ """ | ||
| result = [] | ||
| # result = fast_img_inference(result=result, file_paths=file_paths, model="siglip2") | ||
| result = fast_img_inference(result=result, file_paths=file_paths, model="siglip") | ||
| result = fast_img_inference(result=result, file_paths=file_paths, model="dinov2") | ||
| # result | ||
| df_result = pl.DataFrame(result, orient="row", schema=["name", "duration"]) | ||
| df_result.write_csv(path / "img_embed_benchmark.csv") | ||
| print(df_result) | ||
| # # result | ||
| # df_result = pl.DataFrame(result, orient="row", schema=["name", "duration"]) | ||
| # df_result.write_csv(path / "img_embed_benchmark.csv") | ||
| # print(df_result) |
+1
-1
| Metadata-Version: 2.4 | ||
| Name: item_matching | ||
| Version: 0.0.106 | ||
| Version: 0.0.107 | ||
| Summary: A name matching package | ||
@@ -5,0 +5,0 @@ Project-URL: Homepage, https://github.com/kevinkhang2909/item_matching |
+1
-1
@@ -7,3 +7,3 @@ [build-system] | ||
| name = "item_matching" | ||
| version = "0.0.106" | ||
| version = "0.0.107" | ||
| authors = [ | ||
@@ -10,0 +10,0 @@ { name="Kevin Khang", email="kevinkhang2909@gmail.com" }, |
@@ -14,3 +14,3 @@ from PIL import Image | ||
| from FlagEmbedding import BGEM3FlagModel | ||
| from transformers import Dinov2WithRegistersModel, Siglip2VisionModel | ||
| from transformers import Dinov2WithRegistersModel, SiglipVisionModel | ||
| from .func import _create_folder | ||
@@ -53,3 +53,3 @@ | ||
| # ), | ||
| # transforms.CenterCrop(img_size), | ||
| transforms.CenterCrop(img_size), | ||
| transforms.ConvertImageDtype(torch.float32), # to [0,1] | ||
@@ -68,9 +68,10 @@ transforms.Normalize( | ||
| tensor = transforms.ToTensor()(img) # HWC→CHW float32 | ||
| return self.transform(tensor) | ||
| tensor = self.transform(tensor) | ||
| return tensor | ||
| def get_img_model(): | ||
| pretrain_name = "google/siglip2-base-patch16-224" | ||
| def setup_dinov2(): | ||
| pretrain_name = "facebook/dinov2-with-registers-base" | ||
| img_model = ( | ||
| Siglip2VisionModel.from_pretrained( | ||
| Dinov2WithRegistersModel.from_pretrained( | ||
| pretrain_name, | ||
@@ -82,12 +83,16 @@ torch_dtype=torch.bfloat16, | ||
| ) | ||
| # return torch.compile(img_model) | ||
| return img_model | ||
| # pretrain_name = "facebook/dinov2-with-registers-base" | ||
| # img_model = ( | ||
| # Dinov2WithRegistersModel.from_pretrained( | ||
| # pretrain_name, | ||
| # torch_dtype=torch.bfloat16, | ||
| # ) | ||
| # .to(device) | ||
| # .eval() | ||
| # ) | ||
| def setup_siglip(): | ||
| pretrain_name = "google/siglip-base-patch16-224" | ||
| img_model = ( | ||
| SiglipVisionModel.from_pretrained( | ||
| pretrain_name, | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
| .to(device) | ||
| .eval() | ||
| ) | ||
| # return torch.compile(img_model) | ||
@@ -140,2 +145,3 @@ return img_model | ||
| SHARD_SIZE: int = 1_500_000, | ||
| model_name: str = "dinov2", | ||
| ): | ||
@@ -146,2 +152,3 @@ # Config | ||
| self.SHARD_SIZE = SHARD_SIZE | ||
| self.model_name = model_name | ||
@@ -166,3 +173,6 @@ # Path | ||
| self.col_embedding = f"{self.MATCH_BY}_embed" | ||
| self.img_model = get_img_model() | ||
| if self.model_name == "siglip": | ||
| self.img_model = setup_siglip() | ||
| else: | ||
| self.img_model = setup_dinov2() | ||
@@ -169,0 +179,0 @@ def load(self, data: pl.DataFrame): |
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
1300
2.44%3254300
-9.55%