vxutils
Advanced tools
| { | ||
| "TEST_KEY": "test" | ||
| } |
| from typing import Type, Dict, Any | ||
| import polars as pl | ||
| from enum import Enum | ||
| from datetime import datetime, date, time, timedelta | ||
| from vxutils.datamodel import VXDataModel | ||
| __columns_mapping__: Dict[Any, pl.DataType] = { | ||
| int: pl.Int64, | ||
| float: pl.Float64, | ||
| bool: pl.Boolean, | ||
| bytes: pl.Binary, | ||
| str: pl.Utf8, | ||
| Enum: pl.Utf8, | ||
| datetime: pl.Datetime, | ||
| date: pl.Date, | ||
| time: pl.Time, | ||
| timedelta: pl.Float64, | ||
| } | ||
| class PolarsORM: | ||
| def __init__(self, model_cls: Type[VXDataModel], keys: list[str] = None): | ||
| self._model_cls = model_cls | ||
| self._keys = keys or [] | ||
| self._data: pl.DataFrame = pl.DataFrame( | ||
| data=[{"name": None for name in self._model_cls.model_fields.keys()}], | ||
| schema={ | ||
| name: __columns_mapping__.get(field.annotation, pl.Utf8) | ||
| for name, field in self._model_cls.model_fields.items() | ||
| }, | ||
| ).clear() | ||
| @property | ||
| def data(self) -> pl.DataFrame: | ||
| return self._data | ||
| def save(self, *data: VXDataModel) -> None: | ||
| if not all(isinstance(item, self._model_cls) for item in data): | ||
| raise ValueError(f"Invalid data type: {type(data)}") | ||
| if self._keys: | ||
| self._data = self._data.filter( | ||
| pl.any( | ||
| pl.all(pl.col(key) != item[key] for key in self._keys) | ||
| for item in data | ||
| ).not_() | ||
| ) | ||
| self._data = pl.concat( | ||
| [ | ||
| self._data, | ||
| pl.DataFrame([item.model_dump() for item in data]).select( | ||
| pl.col(self._data.columns) | ||
| ), | ||
| ] | ||
| ) | ||
| if __name__ == "__main__": | ||
| df = pl.DataFrame( | ||
| data=[ | ||
| {"id": 1, "name": "a"}, | ||
| {"id": 2, "name": "b"}, | ||
| ], | ||
| schema={ | ||
| "id": pl.Int64, | ||
| "name": pl.Utf8, | ||
| }, | ||
| ) | ||
| print(df) | ||
| class A(VXDataModel): | ||
| id: int | ||
| name: str | ||
| porm = PolarsORM(A, keys=["id"]) | ||
| porm.save(A(id=1, name="c")) | ||
| print(porm.data) |
| from pathlib import Path | ||
| import json | ||
| import getpass | ||
| from threading import Lock | ||
| from typing import Union | ||
| __all__ = ["APIKeyManager"] | ||
| class APIKeyManager: | ||
| def __init__(self, api_key_file: Union[str, Path] = "./api_keys.json") -> None: | ||
| self._api_key_file = Path(api_key_file) | ||
| self._lock = Lock() | ||
| def get_key(self, name: str) -> str: | ||
| """获取API Key | ||
| Args: | ||
| name (str): API Key的名称 | ||
| Returns: | ||
| str: API Key的值 | ||
| """ | ||
| api_keys: dict[str, str] = {} | ||
| if self._api_key_file.exists(): | ||
| with self._lock: | ||
| with open(self._api_key_file, "r") as f: | ||
| api_keys = json.load(f) | ||
| if name in api_keys: | ||
| return api_keys[name] | ||
| else: | ||
| value = getpass.getpass(f"请输入{name}的API Key:") | ||
| api_keys[name] = value | ||
| with self._lock: | ||
| with open(self._api_key_file, "w") as f: | ||
| json.dump(api_keys, f, indent=4) | ||
| return value |
Sorry, the diff of this file is not supported yet
| import sys | ||
| from pathlib import Path | ||
| SRC_PATH = Path(__file__).resolve().parents[1] / "src" | ||
| sys.path.insert(0, str(SRC_PATH)) | ||
| if "vxutils" in sys.modules: | ||
| del sys.modules["vxutils"] |
+2
-1
| Metadata-Version: 2.4 | ||
| Name: vxutils | ||
| Version: 20251211 | ||
| Version: 20251226 | ||
| Summary: A toolbox for vxquant | ||
| Author-email: libao <libao@vxquant.com> | ||
| Requires-Python: >=3.10 | ||
| Requires-Dist: polars>=1.36.1 | ||
| Requires-Dist: pydantic>=2.10.6 | ||
@@ -8,0 +9,0 @@ Requires-Dist: python-dateutil>=2.9.0.post0 |
+2
-1
| [project] | ||
| name = "vxutils" | ||
| version = "20251211" | ||
| version = "20251226" | ||
| description = "A toolbox for vxquant" | ||
@@ -11,2 +11,3 @@ readme = "README.md" | ||
| dependencies = [ | ||
| "polars>=1.36.1", | ||
| "pydantic>=2.10.6", | ||
@@ -13,0 +14,0 @@ "python-dateutil>=2.9.0.post0", |
@@ -1,2 +0,2 @@ | ||
| from .executor import VXThreadPoolExecutor | ||
| from .executor import VXThreadPoolExecutor, DynamicThreadPoolExecutor | ||
| from .logger import loggerConfig, VXColoredFormatter | ||
@@ -30,6 +30,11 @@ from .convertors import ( | ||
| DataAdapterError, | ||
| VXDBSession, | ||
| VXDataBase, | ||
| ) | ||
| from .tools import APIKeyManager | ||
| __all__ = [ | ||
| "VXThreadPoolExecutor", | ||
| "DynamicThreadPoolExecutor", | ||
| "loggerConfig", | ||
@@ -56,5 +61,8 @@ "VXColoredFormatter", | ||
| "VXColAdapter", | ||
| "VXDataBase", | ||
| "VXDBSession", | ||
| "TransCol", | ||
| "OriginCol", | ||
| "DataAdapterError", | ||
| "APIKeyManager", | ||
| ] |
| from .core import VXDataModel | ||
| from .adapter import VXDataAdapter, VXColAdapter, TransCol, OriginCol, DataAdapterError | ||
| from .dborm import VXDataBase, VXDBSession | ||
@@ -12,2 +13,4 @@ | ||
| "DataAdapterError", | ||
| "VXDataBase", | ||
| "VXDBSession", | ||
| ] |
@@ -46,4 +46,4 @@ """基础模型""" | ||
| def __setattr__(self, name: str, value: Any) -> None: | ||
| field_info = self.model_fields.get(name) | ||
| if field_info and field_info.annotation != type(value) and field_info.metadata: | ||
| field_info = self.__class__.model_fields.get(name) | ||
| if field_info and field_info.annotation != type(value) and field_info.metadata: # noqa: E721 | ||
| value = TypeAdapter(field_info.annotation).validate_python(value) | ||
@@ -50,0 +50,0 @@ |
@@ -20,3 +20,3 @@ """数据库ORM抽象""" | ||
| from contextlib import contextmanager | ||
| from multiprocessing import Lock | ||
| from threading import Lock | ||
| from sqlalchemy import ( # type: ignore[import-untyped] | ||
@@ -59,3 +59,3 @@ create_engine, | ||
| class _VXTable(Table): # type: ignore[misc] | ||
| class _VXTable(Table): | ||
| def __init__( | ||
@@ -152,3 +152,3 @@ self, | ||
| primary_key=(name in primary_keys), | ||
| nullable=(name in primary_keys), | ||
| nullable=(name not in primary_keys), | ||
| ) | ||
@@ -164,3 +164,3 @@ for name, field_info in vxdatacls.model_fields.items() | ||
| primary_key=(name in primary_keys), | ||
| nullable=(name in primary_keys), | ||
| nullable=(name not in primary_keys), | ||
| ) | ||
@@ -177,3 +177,3 @@ for name, field_info in vxdatacls.model_computed_fields.items() | ||
| with self._dbengine.begin() as conn: | ||
| with self._dbengine.begin(): | ||
| tbl.create(bind=self._dbengine, checkfirst=True) | ||
@@ -209,3 +209,3 @@ logging.debug("Create Table: [%s] ==> %s", table_name, vxdatacls) | ||
| with self._dbengine.begin() as conn: | ||
| sql = text(f"delete from '{table_name}';") | ||
| sql = text(f"delete from {table_name};") | ||
| conn.execute(sql) | ||
@@ -261,19 +261,16 @@ logging.warning("Table %s truncated", table_name) | ||
| tbl = self._metadata.tables[table_name] | ||
| insert_stmt = ( | ||
| sqlite_insert(tbl) | ||
| .values( | ||
| [ | ||
| {k: v for k, v in vxdataobj.model_dump().items()} | ||
| for vxdataobj in vxdataobjs | ||
| ] | ||
| ) | ||
| .execution_options() | ||
| ) | ||
| values = [ | ||
| {k: db_normalize(v) for k, v in vxdataobj.model_dump().items()} | ||
| for vxdataobj in vxdataobjs | ||
| ] | ||
| insert_stmt = sqlite_insert(tbl).values(values) | ||
| if tbl.primary_key: | ||
| pk_cols = list(tbl.primary_key.columns) | ||
| pk_names = {c.name for c in pk_cols} | ||
| insert_stmt = insert_stmt.on_conflict_do_update( | ||
| index_elements=tbl.primary_key, | ||
| index_elements=pk_cols, | ||
| set_={ | ||
| k: v | ||
| for k, v in vxdataobjs[0].model_dump().items() | ||
| if k not in tbl.primary_key | ||
| k: insert_stmt.excluded[k] | ||
| for k in values[0].keys() | ||
| if k not in pk_names | ||
| }, | ||
@@ -293,8 +290,9 @@ ) | ||
| tbl = self._metadata.tables[table_name] | ||
| delete_stmt = tbl.delete().where( | ||
| tbl.c[tbl.primary_key.columns.keys()[0]] | ||
| == vxdataobjs[0].model_dump()[tbl.primary_key.columns.keys()[0]] | ||
| ) | ||
| self._conn.execute(delete_stmt) | ||
| logging.debug("Table %s deleted, %s", table_name, delete_stmt) | ||
| pk_name = tbl.primary_key.columns.keys()[0] | ||
| for obj in vxdataobjs: | ||
| delete_stmt = tbl.delete().where( | ||
| tbl.c[pk_name] == obj.model_dump()[pk_name] | ||
| ) | ||
| self._conn.execute(delete_stmt) | ||
| logging.debug("Table %s deleted, %s", table_name, delete_stmt) | ||
| return self | ||
@@ -350,3 +348,3 @@ | ||
| for row in result: | ||
| row_data = dict(zip(row._fields, row)) | ||
| row_data = dict(row._mapping) | ||
| yield ( | ||
@@ -386,3 +384,3 @@ self._datamodel_factory[table_name](**row_data) | ||
| row_data = dict(zip(row._fields, row)) | ||
| row_data = dict(row._mapping) | ||
| return ( | ||
@@ -538,3 +536,3 @@ self._datamodel_factory[table_name](**row_data) | ||
| def __enter__(self) -> Any: | ||
| pass | ||
| return self | ||
@@ -541,0 +539,0 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any: |
@@ -8,3 +8,3 @@ import os | ||
| __all__ = ["VXThreadPoolExecutor"] | ||
| __all__ = ["VXThreadPoolExecutor", "DynamicThreadPoolExecutor"] | ||
@@ -199,1 +199,4 @@ | ||
| return v | ||
| # 兼容别名 | ||
| DynamicThreadPoolExecutor = VXThreadPoolExecutor |
Sorry, the diff of this file is too big to display
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
187288
12%28
21.74%2046
6.01%