item-matching
Advanced tools
| import duckdb | ||
| import polars as pl | ||
| from transformers import ( | ||
| SiglipVisionModel, | ||
| Dinov2WithRegistersModel, | ||
| Siglip2VisionModel, | ||
@@ -10,3 +7,2 @@ AutoProcessor, | ||
| ) | ||
| from PIL import Image | ||
| from accelerate import Accelerator | ||
@@ -16,3 +12,3 @@ from time import perf_counter | ||
| import torch.nn.functional as F | ||
| from torch.utils.data import Dataset, DataLoader | ||
| from torch.utils.data import DataLoader | ||
| from torchvision import transforms | ||
@@ -22,3 +18,9 @@ from core_pro.ultilities import make_sync_folder | ||
| import numpy as np | ||
| import sys | ||
| from pathlib import Path | ||
| sys.path.extend([str(Path.home() / "PycharmProjects/item_matching")]) | ||
| from src.item_matching.pipeline.data_loading import setup_dinov2, setup_siglip, ImagePathsDataset, collate_batch | ||
| device = Accelerator().device | ||
@@ -28,46 +30,2 @@ torch.backends.cudnn.benchmark = True | ||
| class ImagePathsDataset(Dataset): | ||
| def __init__(self, file_paths: list, img_size: int = 224): | ||
| self.file_paths = file_paths | ||
| self.transform = transforms.Compose( | ||
| [ | ||
| # transforms.Resize( | ||
| # img_size, interpolation=transforms.InterpolationMode.BICUBIC | ||
| # ), | ||
| transforms.CenterCrop(img_size), | ||
| transforms.ConvertImageDtype(torch.float32), # to [0,1] | ||
| transforms.Normalize( | ||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||
| ), | ||
| ] | ||
| ) | ||
| def __len__(self): | ||
| return len(self.file_paths) | ||
| def __getitem__(self, idx): | ||
| img = Image.open(self.file_paths[idx]).convert("RGB") | ||
| tensor = transforms.ToTensor()(img) # HWC→CHW float32 | ||
| tensor = self.transform(tensor) | ||
| return tensor | ||
| def collate_batch(batch): | ||
| return torch.stack(batch, dim=0) | ||
| def setup_siglip(): | ||
| pretrain_name = "google/siglip-base-patch16-224" | ||
| img_model = ( | ||
| SiglipVisionModel.from_pretrained( | ||
| pretrain_name, | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
| .to(device) | ||
| .eval() | ||
| ) | ||
| # return torch.compile(img_model) | ||
| return img_model | ||
| def setup_siglip2(): | ||
@@ -96,16 +54,2 @@ pretrain_name = "google/siglip2-base-patch16-224" | ||
| def setup_dinov2(): | ||
| pretrain_name = "facebook/dinov2-with-registers-base" | ||
| img_model = ( | ||
| Dinov2WithRegistersModel.from_pretrained( | ||
| pretrain_name, | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
| .to(device) | ||
| .eval() | ||
| ) | ||
| # return torch.compile(img_model) | ||
| return img_model | ||
| def fast_img_inference( | ||
@@ -112,0 +56,0 @@ result: list, |
+1
-1
| Metadata-Version: 2.4 | ||
| Name: item_matching | ||
| Version: 0.0.107 | ||
| Version: 0.0.108 | ||
| Summary: A name matching package | ||
@@ -5,0 +5,0 @@ Project-URL: Homepage, https://github.com/kevinkhang2909/item_matching |
+1
-1
@@ -7,3 +7,3 @@ [build-system] | ||
| name = "item_matching" | ||
| version = "0.0.107" | ||
| version = "0.0.108" | ||
| authors = [ | ||
@@ -10,0 +10,0 @@ { name="Kevin Khang", email="kevinkhang2909@gmail.com" }, |
@@ -22,4 +22,6 @@ from PIL import Image | ||
| def get_text_model(): | ||
| model_name = "BAAI/bge-m3" | ||
| print(model_name) | ||
| return BGEM3FlagModel( | ||
| "BAAI/bge-m3", use_fp16=True, device=device, normalize_embeddings=True | ||
| model_name, use_fp16=True, device=device, normalize_embeddings=True | ||
| ) | ||
@@ -82,2 +84,3 @@ | ||
| # return torch.compile(img_model) | ||
| print(f"Model Vision: {pretrain_name}") | ||
| return img_model | ||
@@ -97,2 +100,3 @@ | ||
| # return torch.compile(img_model) | ||
| print(f"Model Vision: {pretrain_name}") | ||
| return img_model | ||
@@ -99,0 +103,0 @@ |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
3252955
-0.04%1256
-3.38%