item-matching
Advanced tools
+1
-1
| Metadata-Version: 2.4 | ||
| Name: item_matching | ||
| Version: 0.0.101 | ||
| Version: 0.0.102 | ||
| Summary: A name matching package | ||
@@ -5,0 +5,0 @@ Project-URL: Homepage, https://github.com/kevinkhang2909/item_matching |
+1
-1
@@ -7,3 +7,3 @@ [build-system] | ||
| name = "item_matching" | ||
| version = "0.0.101" | ||
| version = "0.0.102" | ||
| authors = [ | ||
@@ -10,0 +10,0 @@ { name="Kevin Khang", email="kevinkhang2909@gmail.com" }, |
@@ -14,3 +14,2 @@ from pathlib import Path | ||
| mode: str = "", | ||
| ): | ||
@@ -27,7 +26,12 @@ # path | ||
| # sorted files | ||
| files_sorted = sorted([*self.folder_image.glob("*/*.jpg")], key=lambda x: int(x.stem.split("_")[0])) | ||
| files_sorted = sorted( | ||
| [*self.folder_image.glob("*/*.jpg")], | ||
| key=lambda x: int(x.stem.split("_")[0]), | ||
| ) | ||
| # file to df | ||
| lst = [(i.stem, str(i), i.exists()) for i in files_sorted] | ||
| df = pl.DataFrame(lst, orient="row", schema=["path_idx", "file_path", "file_exists"]) | ||
| df = pl.DataFrame( | ||
| lst, orient="row", schema=["path_idx", "file_path", "file_exists"] | ||
| ) | ||
| if df.shape[0] == 0: | ||
@@ -50,3 +54,6 @@ print(f"-> Images Errors {self.mode}: {df.shape}") | ||
| data = data.with_row_index("img_index") | ||
| run = [(i["img_index"], i[col_image_url]) for i in data[["img_index", col_image_url]].to_dicts()] | ||
| run = [ | ||
| (i["img_index"], i[col_image_url]) | ||
| for i in data[["img_index", col_image_url]].to_dicts() | ||
| ] | ||
| print(f"-> Base data {self.mode}: {data.shape}") | ||
@@ -53,0 +60,0 @@ |
@@ -62,3 +62,12 @@ from pathlib import Path | ||
| print(f"[Clean Cache]") | ||
| folder_list = ["index", "result", "db_array", "db_ds", "q_array", "q_ds", "array", "ds"] | ||
| folder_list = [ | ||
| "index", | ||
| "result", | ||
| "db_array", | ||
| "db_ds", | ||
| "q_array", | ||
| "q_ds", | ||
| "array", | ||
| "ds", | ||
| ] | ||
| for name in folder_list: | ||
@@ -89,5 +98,3 @@ rm_all_folder(self.ROOT_PATH / name) | ||
| chunk_q = self._load_data(cat=cat, mode="q", file=self.PATH_Q) | ||
| print( | ||
| f"-> Database shape {chunk_db.shape}, Query shape {chunk_q.shape}" | ||
| ) | ||
| print(f"-> Database shape {chunk_db.shape}, Query shape {chunk_q.shape}") | ||
@@ -103,3 +110,3 @@ if chunk_q.shape[0] < 2 or chunk_db.shape[0] < 2: | ||
| MATCH_BY=self.MATCH_BY, | ||
| SHARD_SIZE=self.SHARD_SIZE | ||
| SHARD_SIZE=self.SHARD_SIZE, | ||
| ).load(data=chunk_db) | ||
@@ -111,3 +118,3 @@ | ||
| MATCH_BY=self.MATCH_BY, | ||
| SHARD_SIZE=self.SHARD_SIZE | ||
| SHARD_SIZE=self.SHARD_SIZE, | ||
| ).load(data=chunk_q) | ||
@@ -114,0 +121,0 @@ |
@@ -31,4 +31,4 @@ from pathlib import Path | ||
| self.path_index = _create_folder(path, "index", one=True) | ||
| self.file_index = self.path_index / f"ip.index" | ||
| self.file_index_json = str(self.path_index / f"index.json") | ||
| self.file_index = self.path_index / "ip.index" | ||
| self.file_index_json = str(self.path_index / "index.json") | ||
@@ -49,7 +49,9 @@ # array | ||
| def _create_folder_result(self): | ||
| self.path_result_query_score = self.path / f"result" | ||
| self.path_result_query_score = self.path / "result" | ||
| self.path_result_final = self.path / f"result_match_{self.MATCH_BY}" | ||
| make_dir(self.path_result_query_score) | ||
| make_dir(self.path_result_final) | ||
| self.file_export_final = self.path_result_final / f"{self.file_export_name}.parquet" | ||
| self.file_export_final = ( | ||
| self.path_result_final / f"{self.file_export_name}.parquet" | ||
| ) | ||
@@ -60,3 +62,3 @@ def build(self): | ||
| if not self.file_index.exists(): | ||
| print(f"[BuildIndex] Start") | ||
| print("[BuildIndex] Start") | ||
| try: | ||
@@ -76,3 +78,3 @@ build_index( | ||
| else: | ||
| print(f"[BuildIndex] Index is existed") | ||
| print("[BuildIndex] Index is existed") | ||
@@ -82,3 +84,5 @@ def load_dataset(self): | ||
| for i in ["db", "q"]: | ||
| files = sorted(self.dataset_dict[f"{i}_ds_path"].glob("*"), key=self.sort_key_ds) | ||
| files = sorted( | ||
| self.dataset_dict[f"{i}_ds_path"].glob("*"), key=self.sort_key_ds | ||
| ) | ||
| df = pl.concat([pl.read_parquet(f) for f in files]) | ||
@@ -127,3 +131,5 @@ dataset[i] = Dataset.from_polars(df) | ||
| dict_ = {f"score_{self.col_embedding}": [_round_score(arr) for arr in score]} | ||
| dict_ = { | ||
| f"score_{self.col_embedding}": [_round_score(arr) for arr in score] | ||
| } | ||
| df_score = pl.DataFrame(dict_) | ||
@@ -140,15 +146,25 @@ df_score.write_parquet(file_name_score) | ||
| # Concat all files | ||
| dataset_q = dataset_q.remove_columns(self.col_embedding) # prevent polars issues | ||
| dataset_q = dataset_q.remove_columns( | ||
| self.col_embedding | ||
| ) # prevent polars issues | ||
| del dataset_db | ||
| # score | ||
| files_score = sorted(self.path_result_query_score.glob("score*.parquet"), key=self.sort_key_result) | ||
| files_score = sorted( | ||
| self.path_result_query_score.glob("score*.parquet"), | ||
| key=self.sort_key_result, | ||
| ) | ||
| df_score = pl.concat([pl.read_parquet(f) for f in files_score]) | ||
| # result | ||
| files_result = sorted(self.path_result_query_score.glob("result*.parquet"), key=self.sort_key_result) | ||
| files_result = sorted( | ||
| self.path_result_query_score.glob("result*.parquet"), | ||
| key=self.sort_key_result, | ||
| ) | ||
| df_result = pl.concat([pl.read_parquet(f) for f in files_result]) | ||
| # combine to data | ||
| df_match = pl.concat([dataset_q.to_polars(), df_result, df_score], how="horizontal") | ||
| df_match = pl.concat( | ||
| [dataset_q.to_polars(), df_result, df_score], how="horizontal" | ||
| ) | ||
@@ -155,0 +171,0 @@ # explode result |
@@ -15,3 +15,3 @@ from PIL import Image | ||
| from FlagEmbedding import BGEM3FlagModel | ||
| from transformers import Dinov2WithRegistersModel | ||
| from transformers import Dinov2WithRegistersModel, AutoModel | ||
| from .func import _create_folder | ||
@@ -22,8 +22,6 @@ | ||
| def get_text_model(): | ||
| return BGEM3FlagModel( | ||
| "BAAI/bge-m3", | ||
| use_fp16=True, | ||
| device=device, | ||
| normalize_embeddings=True | ||
| "BAAI/bge-m3", use_fp16=True, device=device, normalize_embeddings=True | ||
| ) | ||
@@ -75,5 +73,5 @@ | ||
| def get_img_model(): | ||
| pretrain_name = "facebook/dinov2-with-registers-base" | ||
| pretrain_name = "google/siglip-base-patch16-224" | ||
| img_model = ( | ||
| Dinov2WithRegistersModel.from_pretrained( | ||
| AutoModel.from_pretrained( | ||
| pretrain_name, | ||
@@ -85,2 +83,12 @@ torch_dtype=torch.bfloat16, | ||
| ) | ||
| # pretrain_name = "facebook/dinov2-with-registers-base" | ||
| # img_model = ( | ||
| # Dinov2WithRegistersModel.from_pretrained( | ||
| # pretrain_name, | ||
| # torch_dtype=torch.bfloat16, | ||
| # ) | ||
| # .to(device) | ||
| # .eval() | ||
| # ) | ||
| return torch.compile(img_model) | ||
@@ -133,3 +141,5 @@ | ||
| mmap.flush() # ensure all data is on disk | ||
| embeddings = np.memmap(save_file_path, dtype=np.float32, mode='r', shape=(total, dim)) | ||
| embeddings = np.memmap( | ||
| save_file_path, dtype=np.float32, mode="r", shape=(total, dim) | ||
| ) | ||
| return embeddings | ||
@@ -140,7 +150,7 @@ | ||
| def __init__( | ||
| self, | ||
| path: Path, | ||
| MODE: str, | ||
| MATCH_BY: str = "text", | ||
| SHARD_SIZE: int = 1_500_000, | ||
| self, | ||
| path: Path, | ||
| MODE: str, | ||
| MATCH_BY: str = "text", | ||
| SHARD_SIZE: int = 1_500_000, | ||
| ): | ||
@@ -200,3 +210,3 @@ # Config | ||
| save_file_path=array_name, | ||
| iterable_list=dataset_chunk[self.col_input].to_list() | ||
| iterable_list=dataset_chunk[self.col_input].to_list(), | ||
| ) | ||
@@ -207,3 +217,3 @@ else: | ||
| save_file_path=array_name, | ||
| iterable_list=dataset_chunk[self.col_input].to_list() | ||
| iterable_list=dataset_chunk[self.col_input].to_list(), | ||
| ) | ||
@@ -210,0 +220,0 @@ |
@@ -11,7 +11,7 @@ import polars as pl | ||
| def __init__( | ||
| self, | ||
| path: Path, | ||
| db_col_idx: str = "db_item_id", | ||
| q_col_idx: str = "q_item_id", | ||
| col_text: str = "item_name" | ||
| self, | ||
| path: Path, | ||
| db_col_idx: str = "db_item_id", | ||
| q_col_idx: str = "q_item_id", | ||
| col_text: str = "item_name", | ||
| ): | ||
@@ -30,3 +30,5 @@ # path | ||
| # all category | ||
| self.all_category = set([i.stem for i in self.file_text] + [i.stem for i in self.file_image]) | ||
| self.all_category = set( | ||
| [i.stem for i in self.file_text] + [i.stem for i in self.file_image] | ||
| ) | ||
@@ -95,5 +97,7 @@ # col | ||
| df_dict = self._data_check(cat) | ||
| df = self.rerank_score(data_text=df_dict["text"], data_image=df_dict["image"]) | ||
| df = self.rerank_score( | ||
| data_text=df_dict["text"], data_image=df_dict["image"] | ||
| ) | ||
| if not df.is_empty(): | ||
| df.write_parquet(file_name) | ||
| print(f"[RERANK]: Done {total_cat} categories") |
Sorry, the diff of this file is not supported yet
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
1282
3.47%3598288
-0.32%32
-5.88%