folktexts
Advanced tools
| Metadata-Version: 2.1 | ||
| Name: folktexts | ||
| Version: 0.0.20 | ||
| Version: 0.0.21 | ||
| Summary: Use LLMs to get classification risk scores on tabular tasks. | ||
@@ -111,4 +111,4 @@ Author: Andre Cruz, Ricardo Dominguez-Olmedo, Celestine Mendler-Dunner, Moritz Hardt | ||
| - [Example usage](#example-usage) | ||
| - [Benchmark features and options](#benchmark-features-and-options) | ||
| - [Evaluating feature importance](#evaluating-feature-importance) | ||
| - [Benchmark options](#benchmark-options) | ||
| - [FAQ](#faq) | ||
@@ -168,2 +168,3 @@ - [Citation](#citation) | ||
| ```py | ||
| # Load transformers model | ||
| from folktexts.llm_utils import load_model_tokenizer | ||
@@ -176,4 +177,4 @@ model, tokenizer = load_model_tokenizer("gpt2") # using tiny model as an example | ||
| # Create an object that classifies data using an LLM | ||
| from folktexts import LLMClassifier | ||
| clf = LLMClassifier( | ||
| from folktexts import TransformersLLMClassifier | ||
| clf = TransformersLLMClassifier( | ||
| model=model, | ||
@@ -183,2 +184,3 @@ tokenizer=tokenizer, | ||
| ) | ||
| # NOTE: You can also use a web-hosted model like GPT4 using the `WebAPILLMClassifier` class | ||
@@ -188,6 +190,2 @@ # Use a dataset or feed in your own data | ||
| # And simply run the benchmark to get a variety of metrics and plots | ||
| from folktexts.benchmark import Benchmark | ||
| benchmark_results = Benchmark(clf, dataset).run(results_root_dir="results") | ||
| # You can compute risk score predictions using an sklearn-style interface | ||
@@ -202,28 +200,36 @@ X_test, y_test = dataset.get_test() | ||
| test_preds = clf.predict(X_test) | ||
| # If you only care about the overall metrics and not individual predictions, | ||
| # you can simply run the following code: | ||
| from folktexts.benchmark import Benchmark, BenchmarkConfig | ||
| bench = Benchmark.make_benchmark( | ||
| task=acs_task_name, dataset=dataset, | ||
| model=model, tokenizer=tokenizer, | ||
| config=BenchmarkConfig(numeric_risk_prompting=True), # Add other configs here | ||
| ) | ||
| bench_results = bench.run(results_root_dir="results") | ||
| ``` | ||
| <!-- TODO: add code to show-case example functionalities, including the | ||
| LLMClassifier (maybe the above code is fine for this), the benchmark, and | ||
| creating a custom ACS prediction task --> | ||
| ## Benchmark features and options | ||
| ## Evaluating feature importance | ||
| Here's a summary list of the most important benchmark options/flags used in | ||
| conjunction with the `run_acs_benchmark` command line script, or with the | ||
| `Benchmark` class. | ||
| By evaluating LLMs on tabular classification tasks, we can use standard feature importance methods to assess which features the model uses to compute risk scores. | ||
| | Option | Description | Examples | | ||
| |:---|:---|:---:| | ||
| | `--model` | Name of the model on huggingface transformers, or local path to folder with pretrained model and tokenizer. Can also use web-hosted models with `"[provider]/[model-name]"`. | `meta-llama/Meta-Llama-3-8B`, `openai/gpt-4o-mini` | | ||
| | `--task` | Name of the ACS task to run benchmark on. | `ACSIncome`, `ACSEmployment` | | ||
| | `--results-dir` | Path to directory under which benchmark results will be saved. | `results` | | ||
| | `--data-dir` | Root folder to find datasets in (or download ACS data to). | `~/data` | | ||
| | `--numeric-risk-prompting` | Whether to use verbalized numeric risk prompting, i.e., directly query model for a probability estimate. **By default** will use standard multiple-choice Q&A, and extract risk scores from internal token probabilities. | Boolean flag (`True` if present, `False` otherwise) | | ||
| | `--use-web-api-model` | Whether the given `--model` name corresponds to a web-hosted model or not. **By default** this is False (assumes a huggingface transformers model). If this flag is provided, `--model` must contain a [litellm](https://docs.litellm.ai) model identifier ([examples here](https://docs.litellm.ai/docs/providers/openai#openai-chat-completion-models)). | Boolean flag (`True` if present, `False` otherwise) | | ||
| | `--subsampling` | Which fraction of the dataset to use for the benchmark. **By default** will use the whole test set. | `0.01` | | ||
| | `--fit-threshold` | Whether to use the given number of samples to fit the binarization threshold. **By default** will use a fixed $t=0.5$ threshold instead of fitting on data. | `100` | | ||
| | `--batch-size` | The number of samples to process in each inference batch. Choose according to your available VRAM. | `10`, `32` | | ||
| You can do so yourself by calling `folktexts.cli.eval_feature_importance` (add `--help` for a full list of options). | ||
| Here's an example for the Llama3-70B-Instruct model on the ACSIncome task (*warning: takes 24h on an Nvidia H100*): | ||
| ``` | ||
| python -m folktexts.cli.eval_feature_importance --model 'meta-llama/Meta-Llama-3-70B-Instruct' --task ACSIncome --subsampling 0.1 | ||
| ``` | ||
| <div style="text-align: center;"> | ||
| <img src="docs/_static/feat-imp_meta-llama--Meta-Llama-3-70B-Instruct.png" alt="feature importance on llama3 70b it" width="50%"> | ||
| </div> | ||
| Full list of options: | ||
| This script uses sklearn's [`permutation_importance`](https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html#sklearn.inspection.permutation_importance) to assess which features contribute the most for the ROC AUC metric (other metrics can be assessed using the `--scorer [scorer]` parameter). | ||
| ## Benchmark options | ||
| ``` | ||
@@ -268,2 +274,20 @@ usage: run_acs_benchmark [-h] --model MODEL --results-dir RESULTS_DIR --data-dir DATA_DIR [--task TASK] [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD] [--subsampling SUBSAMPLING] [--seed SEED] [--use-web-api-model] [--dont-correct-order-bias] [--numeric-risk-prompting] [--reuse-few-shot-examples] [--use-feature-subset USE_FEATURE_SUBSET] | ||
| ## Evaluating feature importance | ||
| By evaluating LLMs on tabular classification tasks, we can use standard feature importance methods to assess which features the model uses to compute risk scores. | ||
| You can do so yourself by calling `folktexts.cli.eval_feature_importance` (add `--help` for a full list of options). | ||
| Here's an example for the Llama3-70B-Instruct model on the ACSIncome task (*warning: takes 24h on an Nvidia H100*): | ||
| ``` | ||
| python -m folktexts.cli.eval_feature_importance --model 'meta-llama/Meta-Llama-3-70B-Instruct' --task ACSIncome --subsampling 0.1 | ||
| ``` | ||
| <div style="text-align: center;"> | ||
| <img src="docs/_static/feat-imp_meta-llama--Meta-Llama-3-70B-Instruct.png" alt="feature importance on llama3 70b it" width="50%"> | ||
| </div> | ||
| This script uses sklearn's [`permutation_importance`](https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html#sklearn.inspection.permutation_importance) to assess which features contribute the most for the ROC AUC metric (other metrics can be assessed using the `--scorer [scorer]` parameter). | ||
| ## FAQ | ||
@@ -287,2 +311,3 @@ | ||
| **A:** **Yes!** We provide compatibility with local LLMs via [🤗 transformers](https://github.com/huggingface/transformers) and compatibility with web-hosted LLMs via [litellm](https://github.com/BerriAI/litellm). For example, you can use `--model='gpt-4o' --use-web-api-model` to use GPT-4o when calling the `run_acs_benchmark` script. [Here's a complete list](https://docs.litellm.ai/docs/providers/openai#openai-chat-completion-models) of compatible OpenAI models. Note that some models are not compatible as they don't enable access to log-probabilities. | ||
| Using models through a web API requires installing extra optional dependencies with `pip install 'folktexts[apis]'`. | ||
@@ -289,0 +314,0 @@ |
| from ._version import __version__, __version_info__ | ||
| from .acs import ACSDataset, ACSTaskMetadata | ||
| from .benchmark import BenchmarkConfig, Benchmark | ||
| from .classifier import LLMClassifier, TransformersLLMClassifier | ||
| from .benchmark import Benchmark, BenchmarkConfig | ||
| from .classifier import LLMClassifier, TransformersLLMClassifier, WebAPILLMClassifier | ||
| from .task import TaskMetadata |
@@ -16,2 +16,3 @@ """Module to hold ACS column mappings from values to natural text. | ||
| acs_travel_time_threshold, | ||
| acs_poverty_ratio_threshold, | ||
| ) | ||
@@ -204,4 +205,2 @@ | ||
| ), | ||
| answer_probability=True, | ||
| num_forward_passes=2, | ||
| ) | ||
@@ -226,3 +225,3 @@ | ||
| acs_pubcov_og_target_col = ColumnToText( | ||
| acs_pubcov = ColumnToText( | ||
| "PUBCOV", | ||
@@ -247,2 +246,9 @@ short_description="public health coverage status", | ||
| acs_pubcov_numeric_qa = DirectNumericQA( | ||
| column=acs_public_coverage_threshold.apply_to_column_name("PUBCOV"), | ||
| text=( | ||
| "What is the probability that this person is covered by public health insurance?" | ||
| ), # NOTE: value=1 for yes, 0 for no | ||
| ) | ||
| acs_pubcov_target_col = ColumnToText( | ||
@@ -308,10 +314,17 @@ name=acs_public_coverage_threshold.apply_to_column_name("PUBCOV"), | ||
| acs_mobility_qa = MultipleChoiceQA( | ||
| column=acs_mobility_threshold.apply_to_column_name("MIG"), | ||
| column=acs_mobility_threshold.apply_to_column_name("MIG"), # NOTE: Thresholded by MIG!=1 | ||
| text="Has this person moved in the last year?", | ||
| choices=( | ||
| Choice("No, person has lived in the same house for the last year", 1), | ||
| Choice("Yes, person has moved in the last year", 0), | ||
| Choice("No, person has lived in the same house for the last year", 0), | ||
| Choice("Yes, person has moved in the last year", 1), | ||
| ), | ||
| ) | ||
| acs_mobility_numeric_qa = DirectNumericQA( | ||
| column=acs_mobility_threshold.apply_to_column_name("MIG"), # NOTE: Thresholded by MIG!=1 | ||
| text=( | ||
| "What is the probability that this person has moved in the last year?" | ||
| ), # NOTE: Question should relate to probability of MIG!=1 | ||
| ) | ||
| acs_mobility_target_col = ColumnToText( | ||
@@ -416,2 +429,10 @@ name=acs_mobility_threshold.apply_to_column_name("MIG"), | ||
| acs_employment_numeric_qa = DirectNumericQA( | ||
| column=acs_employment_threshold.apply_to_column_name("ESR"), | ||
| text=( | ||
| "What is the probability that this person is an employed civilian?" | ||
| # "What is the probability that this person is currently employed?" | ||
| ), # NOTE: Question should relate to probability of ESR==1 | ||
| ) | ||
| acs_employment_target_col = ColumnToText( | ||
@@ -465,2 +486,9 @@ name=acs_employment_threshold.apply_to_column_name("ESR"), | ||
| acs_commute_time_numeric_qa = DirectNumericQA( | ||
| column=acs_travel_time_threshold.apply_to_column_name("JWMNP"), | ||
| text=( | ||
| "What is the probability that this person's commute time is longer than 20 minutes?" | ||
| ), # NOTE: Question should relate to probability of JWMNP>20 | ||
| ) | ||
| acs_travel_time_target_col = ColumnToText( | ||
@@ -500,2 +528,22 @@ name=acs_travel_time_threshold.apply_to_column_name("JWMNP"), | ||
| # POVPIP: Income-to-Poverty Ratio (Thresholded) | ||
| acs_poverty_ratio_qa = MultipleChoiceQA( | ||
| column=acs_poverty_ratio_threshold.apply_to_column_name("POVPIP"), | ||
| text=( | ||
| "Is this person's income-to-poverty ratio below 2.5 ? " | ||
| "That is, is this person's annual income below 2.5 times the poverty line income?", | ||
| ), | ||
| choices=( | ||
| Choice("Yes, this person earns below 2.5 times the poverty line income", 1), | ||
| Choice("No, this person earns above 2.5 times the poverty line income.", 0), | ||
| ), | ||
| ) | ||
| acs_poverty_ratio_target_col = ColumnToText( | ||
| name=acs_poverty_ratio_threshold.apply_to_column_name("POVPIP"), | ||
| short_description="income-to-poverty ratio is below 2.5", | ||
| question=acs_poverty_ratio_qa, | ||
| use_value_map_only=True, | ||
| ) | ||
| # GCL: Grandparent Living with Grandchildren | ||
@@ -556,2 +604,10 @@ acs_gcl_col = ColumnToText( | ||
| acs_health_ins_2_numeric_qa = DirectNumericQA( | ||
| column=acs_health_insurance_threshold.apply_to_column_name("HINS2"), | ||
| text=( | ||
| "What is the probability that this person has purchased health " | ||
| "insurance directly through a private company?" | ||
| ), # NOTE: Question should relate to probability of HINS2==1 | ||
| ) | ||
| acs_health_ins_2_target_col = ColumnToText( | ||
@@ -558,0 +614,0 @@ name=acs_health_insurance_threshold.apply_to_column_name("HINS2"), |
@@ -9,4 +9,10 @@ """A collection of instantiated ACS column objects and ACS tasks.""" | ||
| from . import acs_columns | ||
| from .acs_tasks import acs_columns_map | ||
| # Map of ACS column names to ColumnToText objects | ||
| acs_columns_map: dict[str, object] = { | ||
| col_mapper.name: col_mapper | ||
| for col_mapper in acs_columns.__dict__.values() | ||
| if isinstance(col_mapper, ColumnToText) | ||
| } | ||
| # Map of numeric ACS questions | ||
@@ -13,0 +19,0 @@ acs_numeric_qa_map: dict[str, object] = { |
@@ -12,9 +12,11 @@ """A collection of ACS prediction tasks based on the folktables package. | ||
| from ..col_to_text import ColumnToText as _ColumnToText | ||
| from ..qa_interface import DirectNumericQA, MultipleChoiceQA | ||
| from ..task import TaskMetadata | ||
| from ..threshold import Threshold | ||
| from . import acs_columns | ||
| from . import acs_questions | ||
| from .acs_thresholds import ( | ||
| acs_employment_threshold, | ||
| acs_health_insurance_threshold, | ||
| acs_income_poverty_ratio_threshold, | ||
| acs_poverty_ratio_threshold, | ||
| acs_income_threshold, | ||
@@ -48,5 +50,8 @@ acs_mobility_threshold, | ||
| target: str, | ||
| sensitive_attribute: str = None, | ||
| target_threshold: Threshold = None, | ||
| sensitive_attribute: str = None, | ||
| **kwargs, | ||
| population_description: str = None, | ||
| folktables_obj: BasicProblem = None, | ||
| multiple_choice_qa: MultipleChoiceQA = None, | ||
| direct_numeric_qa: DirectNumericQA = None, | ||
| ) -> ACSTaskMetadata: | ||
@@ -57,2 +62,13 @@ # Validate columns mappings exist | ||
| # Resolve target column name | ||
| target_col_name = ( | ||
| target_threshold.apply_to_column_name(target) | ||
| if target_threshold is not None else target) | ||
| # Get default Q&A interfaces for this task's target column | ||
| if multiple_choice_qa is None: | ||
| multiple_choice_qa = acs_questions.acs_multiple_choice_qa_map.get(target_col_name) | ||
| if direct_numeric_qa is None: | ||
| direct_numeric_qa = acs_questions.acs_numeric_qa_map.get(target_col_name) | ||
| return cls( | ||
@@ -64,6 +80,8 @@ name=name, | ||
| cols_to_text=acs_columns_map, | ||
| sensitive_attribute=sensitive_attribute, | ||
| target_threshold=target_threshold, | ||
| sensitive_attribute=sensitive_attribute, | ||
| folktables_obj=None, | ||
| **kwargs, | ||
| population_description=population_description, | ||
| multiple_choice_qa=multiple_choice_qa, | ||
| direct_numeric_qa=direct_numeric_qa, | ||
| folktables_obj=folktables_obj, | ||
| ) | ||
@@ -77,2 +95,3 @@ | ||
| target_threshold: Threshold = None, | ||
| population_description: str = None, | ||
| ) -> ACSTaskMetadata: | ||
@@ -86,3 +105,3 @@ | ||
| acs_task = ACSTaskMetadata( | ||
| acs_task = cls.make_task( | ||
| name=name, | ||
@@ -92,5 +111,5 @@ description=description, | ||
| target=folktables_task.target, | ||
| cols_to_text=acs_columns_map, | ||
| sensitive_attribute=folktables_task.group, | ||
| target_threshold=target_threshold, | ||
| population_description=population_description, | ||
| folktables_obj=folktables_task, | ||
@@ -142,3 +161,3 @@ ) | ||
| description="predict whether an individual's income-to-poverty ratio is below 2.5", | ||
| target_threshold=acs_income_poverty_ratio_threshold, | ||
| target_threshold=acs_poverty_ratio_threshold, | ||
| ) | ||
@@ -145,0 +164,0 @@ |
@@ -12,3 +12,3 @@ """Threshold instances for ACS / folktables tasks. | ||
| # ACSMobility task | ||
| acs_mobility_threshold = Threshold(1, "==") | ||
| acs_mobility_threshold = Threshold(1, "!=") | ||
@@ -22,5 +22,5 @@ # ACSEmployment task | ||
| # ACSIncomePovertyRatio task | ||
| acs_income_poverty_ratio_threshold = Threshold(250, "<") | ||
| acs_poverty_ratio_threshold = Threshold(250, "<") | ||
| # ACSHealthInsurance task | ||
| acs_health_insurance_threshold = Threshold(1, "==") |
+71
-30
@@ -17,3 +17,2 @@ """A benchmark class for measuring and evaluating LLM calibration. | ||
| from .acs.acs_dataset import ACSDataset | ||
| from .acs.acs_questions import acs_multiple_choice_qa_map, acs_numeric_qa_map | ||
| from .acs.acs_tasks import ACSTaskMetadata | ||
@@ -79,2 +78,18 @@ from .classifier import LLMClassifier, TransformersLLMClassifier, WebAPILLMClassifier | ||
| def update(self, **changes) -> BenchmarkConfig: | ||
| """Update the configuration with new values.""" | ||
| possible_keys = dataclasses.asdict(self).keys() | ||
| valid_changes = {k: v for k, v in changes.items() if k in possible_keys} | ||
| # Log config changes | ||
| if valid_changes: | ||
| logging.info(f"Updating benchmark configuration with: {valid_changes}") | ||
| # Log unused kwargs | ||
| if len(valid_changes) < len(changes): | ||
| unused_kwargs = {k: v for k, v in changes.items() if k not in possible_keys} | ||
| logging.warning(f"Unused config arguments: {unused_kwargs}") | ||
| return dataclasses.replace(self, **valid_changes) | ||
| @classmethod | ||
@@ -131,2 +146,17 @@ def load_from_disk(cls, path: str | Path): | ||
| ): | ||
| """A benchmark object to measure and evaluate risk scores produced by an LLM. | ||
| Parameters | ||
| ---------- | ||
| llm_clf : LLMClassifier | ||
| A language model classifier object (can be local or web-hosted). | ||
| dataset : Dataset | ||
| The dataset object to use for the benchmark.÷ | ||
| config : BenchmarkConfig, optional | ||
| The configuration object used to create the benchmark parameters. | ||
| **NOTE**: This is used to uniquely identify the benchmark object for | ||
| reproducibility; it **will not be used to change the benchmark | ||
| behavior**. To configure the benchmark, pass a configuration object | ||
| to the Benchmark.make_benchmark method. | ||
| """ | ||
| self.llm_clf = llm_clf | ||
@@ -375,2 +405,3 @@ self.dataset = dataset | ||
| data_dir: str | Path = None, | ||
| max_api_rpm: int = None, | ||
| config: BenchmarkConfig = BenchmarkConfig.default_config(), | ||
@@ -393,2 +424,4 @@ **kwargs, | ||
| Path to the directory to load data from and save data in. | ||
| max_api_rpm : int, optional | ||
| The maximum number of API requests per minute for webAPI models. | ||
| config : BenchmarkConfig, optional | ||
@@ -398,4 +431,5 @@ Extra benchmark configurations, by default will use | ||
| **kwargs | ||
| Additional arguments passed to `ACSDataset`. By default will use a | ||
| set of standardized dataset configurations for reproducibility. | ||
| Additional arguments passed to `ACSDataset` and `BenchmarkConfig`. | ||
| By default will use a set of standardized configurations for | ||
| reproducibility. | ||
@@ -407,3 +441,2 @@ Returns | ||
| """ | ||
| # Handle non-standard ACS arguments | ||
@@ -419,8 +452,10 @@ acs_dataset_configs = cls.ACS_DATASET_CONFIGS.copy() | ||
| # Log unused kwargs | ||
| if kwargs: | ||
| logging.warning(f"Unused key-word arguments: {kwargs}") | ||
| # Update config with any additional kwargs | ||
| config = config.update(**kwargs) | ||
| # Fetch ACS task and dataset | ||
| acs_task = ACSTaskMetadata.get_task(task_name) | ||
| acs_task = ACSTaskMetadata.get_task( | ||
| name=task_name, | ||
| use_numeric_qa=config.numeric_risk_prompting) | ||
| acs_dataset = ACSDataset.make_from_task( | ||
@@ -436,2 +471,3 @@ task=acs_task, | ||
| tokenizer=tokenizer, | ||
| max_api_rpm=max_api_rpm, | ||
| config=config, | ||
@@ -448,3 +484,5 @@ ) | ||
| tokenizer: AutoTokenizer = None, # WebAPI models have no local tokenizer | ||
| max_api_rpm: int = None, | ||
| config: BenchmarkConfig = BenchmarkConfig.default_config(), | ||
| **kwargs, | ||
| ) -> Benchmark: | ||
@@ -465,5 +503,10 @@ """Create a calibration benchmark from a given configuration. | ||
| model). Not required for webAPI models. | ||
| max_api_rpm : int, optional | ||
| The maximum number of API requests per minute for webAPI models. | ||
| config : BenchmarkConfig, optional | ||
| Extra benchmark configurations, by default will use | ||
| `BenchmarkConfig.default_config()`. | ||
| **kwargs | ||
| Additional arguments for easier configuration of the benchmark. | ||
| Will simply use these values to update the `config` object. | ||
@@ -475,14 +518,19 @@ Returns | ||
| """ | ||
| # Update config with any additional kwargs | ||
| config = config.update(**kwargs) | ||
| # Handle TaskMetadata object | ||
| task_obj = TaskMetadata.get_task(task) if isinstance(task, str) else task | ||
| task = TaskMetadata.get_task(task) if isinstance(task, str) else task | ||
| if config.numeric_risk_prompting: | ||
| task.use_numeric_qa = True | ||
| if config.feature_subset is not None and len(config.feature_subset) > 0: | ||
| task_obj = task_obj.create_task_with_feature_subset(config.feature_subset) | ||
| dataset.task = task_obj | ||
| task = task.create_task_with_feature_subset(config.feature_subset) | ||
| dataset.task = task | ||
| # Check dataset is compatible with task | ||
| if dataset.task is not task_obj and dataset.task.name != task_obj.name: | ||
| if dataset.task is not task and dataset.task.name != task.name: | ||
| raise ValueError( | ||
| f"Dataset task '{dataset.task.name}' does not match the " | ||
| f"provided task '{task_obj.name}'.") | ||
| f"provided task '{task.name}'.") | ||
@@ -493,8 +541,7 @@ if config.population_filter is not None: | ||
| # Get prompting function | ||
| encode_row_function = partial(encode_row_prompt, task=task_obj) | ||
| if config.few_shot: | ||
| print(f"Using few-shot prompting (n={config.few_shot})!") | ||
| encode_row_function = partial( | ||
| encode_row_prompt_few_shot, | ||
| task=task_obj, | ||
| task=task, | ||
| n_shots=config.few_shot, | ||
@@ -505,13 +552,7 @@ dataset=dataset, | ||
| # Load the QA interface to be used for risk-score prompting | ||
| if config.numeric_risk_prompting: | ||
| logging.warning(f"Untested feature: numeric_risk_prompting={config.numeric_risk_prompting}") | ||
| question = acs_numeric_qa_map[task_obj.get_target()] | ||
| else: | ||
| question = acs_multiple_choice_qa_map[task_obj.get_target()] | ||
| print("Using zero-shot prompting.") | ||
| encode_row_function = partial(encode_row_prompt, task=task) | ||
| # Set the task's target question | ||
| task_obj.question = question | ||
| # Construct the LLMClassifier object | ||
| # Parse LLMClassifier parameters | ||
| llm_inference_kwargs = {"correct_order_bias": config.correct_order_bias} | ||
@@ -522,13 +563,14 @@ if config.batch_size is not None: | ||
| llm_inference_kwargs["context_size"] = config.context_size | ||
| if max_api_rpm is not None and isinstance(model, str): | ||
| llm_inference_kwargs["max_api_rpm"] = max_api_rpm | ||
| # Create LLMClassifier object | ||
| if isinstance(model, str): | ||
| logging.info(f"Using webAPI model: {model}") | ||
| llm_clf = WebAPILLMClassifier( | ||
| model_name=model, | ||
| task=task_obj, | ||
| task=task, | ||
| encode_row=encode_row_function, | ||
| **llm_inference_kwargs, | ||
| ) | ||
| logging.info(f"Using webAPI model: {model}") | ||
@@ -539,7 +581,6 @@ else: | ||
| tokenizer=tokenizer, | ||
| task=task_obj, | ||
| task=task, | ||
| encode_row=encode_row_function, | ||
| **llm_inference_kwargs, | ||
| ) | ||
| logging.info(f"Using local transformers model: {llm_clf.model_name}") | ||
@@ -546,0 +587,0 @@ |
@@ -10,3 +10,3 @@ #!/usr/bin/env python | ||
| from folktexts._io import save_json, save_pickle | ||
| from folktexts.classifier import LLMClassifier, TransformersLLMClassifier | ||
| from folktexts.classifier import LLMClassifier | ||
| from folktexts.dataset import Dataset | ||
@@ -25,3 +25,3 @@ from folktexts.llm_utils import get_model_folder_path, load_model_tokenizer | ||
| DEFAULT_CONTEXT_SIZE = 500 | ||
| DEFAULT_CONTEXT_SIZE = 600 | ||
| DEFAULT_BATCH_SIZE = 30 | ||
@@ -41,13 +41,64 @@ DEFAULT_SEED = 42 | ||
| cli_args = [ | ||
| ("--model", str, "[str] Model name or path to model saved on disk"), | ||
| ("--task", str, "[str] Name of the ACS task to run the experiment on", False, DEFAULT_TASK_NAME), | ||
| ("--results-dir", str, "[str] Directory under which this experiment's results will be saved", False, DEFAULT_RESULTS_DIR), | ||
| ("--data-dir", str, "[str] Root folder to find datasets on", False, DEFAULT_DATA_DIR), | ||
| ("--models-dir", str, "[str] Root folder to find huggingface models on", False, DEFAULT_MODELS_DIR), | ||
| ("--scorer", str, "[str] Name of the scorer to use for evaluation", False, "roc_auc"), | ||
| ("--batch-size", int, "[int] The batch size to use for inference", False, DEFAULT_BATCH_SIZE), | ||
| ("--context-size", int, "[int] The maximum context size when prompting the LLM", False, DEFAULT_CONTEXT_SIZE), | ||
| ("--subsampling", float, "[float] Which fraction of the dataset to use (if omitted will use all data)", DEFAULT_SUBSAMPLING), | ||
| ("--fit-threshold", int, "[int] Whether to fit the prediction threshold, and on how many samples", False), | ||
| ("--seed", int, "[int] Random seed -- to set for reproducibility", False, DEFAULT_SEED), | ||
| ("--model", | ||
| str, | ||
| "[str] Model name or path to model saved on disk"), | ||
| ("--task", | ||
| str, | ||
| "[str] Name of the ACS task to run the experiment on", | ||
| False, | ||
| DEFAULT_TASK_NAME, | ||
| ), | ||
| ("--results-dir", | ||
| str, | ||
| "[str] Directory under which this experiment's results will be saved", | ||
| False, | ||
| DEFAULT_RESULTS_DIR, | ||
| ), | ||
| ("--data-dir", | ||
| str, | ||
| "[str] Root folder to find datasets on", | ||
| False, | ||
| DEFAULT_DATA_DIR, | ||
| ), | ||
| ("--models-dir", | ||
| str, | ||
| "[str] Root folder to find huggingface models on", | ||
| False, | ||
| DEFAULT_MODELS_DIR, | ||
| ), | ||
| ("--scorer", | ||
| str, | ||
| "[str] Name of the scorer to use for evaluation", | ||
| False, | ||
| "roc_auc", | ||
| ), | ||
| ("--batch-size", | ||
| int, | ||
| "[int] The batch size to use for inference", | ||
| False, | ||
| DEFAULT_BATCH_SIZE, | ||
| ), | ||
| ("--context-size", | ||
| int, | ||
| "[int] The maximum context size when prompting the LLM", | ||
| False, | ||
| DEFAULT_CONTEXT_SIZE, | ||
| ), | ||
| ("--subsampling", | ||
| float, | ||
| "[float] Which fraction of the dataset to use (if omitted will use all data)", | ||
| False, | ||
| DEFAULT_SUBSAMPLING, | ||
| ), | ||
| ("--fit-threshold", | ||
| int, | ||
| "[int] Whether to fit the prediction threshold, and on how many samples", | ||
| False, | ||
| ), | ||
| ("--seed", | ||
| int, | ||
| "[int] Random seed -- to set for reproducibility", | ||
| False, | ||
| DEFAULT_SEED, | ||
| ), | ||
| ] | ||
@@ -54,0 +105,0 @@ |
@@ -27,2 +27,3 @@ """General constants and helper classes to run the main experiments on htcondor. | ||
| executable_path: str | ||
| env_vars: str = "" | ||
| kwargs: dict = field(default_factory=dict) | ||
@@ -101,2 +102,5 @@ | ||
| # Environment variables | ||
| "environment": exp.env_vars or "", | ||
| # GPU requirements | ||
@@ -103,0 +107,0 @@ "requirements": ( |
@@ -5,3 +5,2 @@ #!/usr/bin/env python3 | ||
| import argparse | ||
| import logging | ||
| import math | ||
@@ -95,2 +94,3 @@ from pathlib import Path | ||
| results_dir: str, | ||
| env_vars_str: str = "", | ||
| **kwargs, | ||
@@ -105,3 +105,4 @@ ) -> Experiment: | ||
| model_path = get_model_folder_path(model_name, root_dir=MODELS_DIR) | ||
| assert Path(model_path).exists(), f"Model path '{model_path}' does not exist." | ||
| if not Path(model_path).exists() and "use_web_api_model" not in kwargs: | ||
| raise FileNotFoundError(f"Model folder not found at '{model_path}'.") | ||
@@ -129,4 +130,5 @@ # Split experiment and job kwargs | ||
| executable_path=executable_path, | ||
| env_vars=env_vars_str, | ||
| kwargs=dict( | ||
| model=model_path, | ||
| model=model_path if "use_web_api_model" not in kwargs else model_name, | ||
| task=task, | ||
@@ -197,2 +199,12 @@ results_dir=results_dir, | ||
| parser.add_argument( | ||
| "--environment", | ||
| type=str, | ||
| help=( | ||
| "[string] String defining environment variables to be passed to " | ||
| "launched jobs, in the form 'VAR1=val1;VAR2=val2;...'." | ||
| ), | ||
| required=False, | ||
| ) | ||
| return parser | ||
@@ -234,2 +246,3 @@ | ||
| results_dir=args.results_dir, | ||
| env_vars_str=args.environment, | ||
| **extra_kwargs, | ||
@@ -236,0 +249,0 @@ ) |
@@ -13,3 +13,3 @@ #!/usr/bin/env python3 | ||
| DEFAULT_BATCH_SIZE = 30 | ||
| DEFAULT_CONTEXT_SIZE = 500 | ||
| DEFAULT_CONTEXT_SIZE = 600 | ||
| DEFAULT_SEED = 42 | ||
@@ -98,2 +98,9 @@ | ||
| parser.add_argument( | ||
| "--max-api-rpm", | ||
| type=int, | ||
| help="[int] Maximum number of API requests per minute (if using a web-hosted model)", | ||
| required=False, | ||
| ) | ||
| parser.add_argument( | ||
| "--logger-level", | ||
@@ -104,2 +111,3 @@ type=str, | ||
| required=False, | ||
| default="WARNING", | ||
| ) | ||
@@ -117,3 +125,3 @@ | ||
| logging.getLogger().setLevel(level=args.logger_level or "INFO") | ||
| logging.getLogger().setLevel(level=args.logger_level) | ||
| pretty_args_str = json.dumps(vars(args), indent=4, sort_keys=True) | ||
@@ -163,2 +171,3 @@ logging.info(f"Current python executable: '{sys.executable}'") | ||
| subsampling=args.subsampling, | ||
| max_api_rpm=args.max_api_rpm, | ||
| ) | ||
@@ -165,0 +174,0 @@ |
@@ -12,4 +12,3 @@ """Module to map risk-estimates to a variety of evaluation metrics. | ||
| import statistics | ||
| from functools import partial | ||
| from typing import Optional, Callable | ||
| from typing import Callable, Optional | ||
@@ -40,2 +39,3 @@ import numpy as np | ||
| tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=(0, 1)).ravel() | ||
| tn, fp, fn, tp = map(int, (tn, fp, fn, tp)) | ||
@@ -42,0 +42,0 @@ total = tn + fp + fn + tp |
@@ -8,7 +8,6 @@ """Common functions to use with transformer LLMs.""" | ||
| import numpy as np | ||
| import torch | ||
| import numpy as np | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| # Will warn if the sum of digit probabilities is below this threshold | ||
@@ -231,4 +230,6 @@ PROB_WARN_THR = 0.5 | ||
| return int(regex.group("size")) * int(regex.group("times") or 1) | ||
| else: | ||
| logging.error(f"Could not infer model size from name '{model_name}'") | ||
| logging.warning( | ||
| f"Could not infer model size from name '{model_name}'; " | ||
| f"Using default size of {default}B.") | ||
| return default |
@@ -241,3 +241,3 @@ """Module to plot evaluation results. | ||
| if len(group_indices) / len(sensitive_attribute) < group_size_threshold: | ||
| logging.warning(f"Skipping group {group_value_map(s_value)} plot as it's too small.") | ||
| logging.info(f"Skipping group {group_value_map(s_value)} plot as it's too small.") | ||
| continue | ||
@@ -261,27 +261,2 @@ | ||
| # ### | ||
| # Plot scores distribution per group | ||
| # ### | ||
| # TODO: make a decent score-distribution plot... # TODO: try score CDFs! | ||
| # hist_bin_edges = np.histogram_bin_edges(y_pred_scores, bins=10) | ||
| # for idx, s_value in enumerate(np.unique(sensitive_attribute)): | ||
| # group_indices = np.argwhere(sensitive_attribute == s_value).flatten() | ||
| # group_y_pred_scores = y_pred_scores[group_indices] | ||
| # is_first_group = (idx == 0) | ||
| # if is_first_group: | ||
| # fig, ax = plt.subplots() | ||
| # sns.histplot( | ||
| # group_y_pred_scores, | ||
| # bins=hist_bin_edges, | ||
| # stat="density", | ||
| # kde=False, | ||
| # color=group_colors[idx], | ||
| # label=group_value_map(s_value), | ||
| # ax=ax, | ||
| # ) | ||
| # plt.legend() | ||
| # plt.title("Score distribution per sub-group" + model_str) | ||
| # results["score_distribution_per_subgroup_path"] = save_fig(fig, "score_distribution_per_subgroup", imgs_dir) | ||
| return results |
@@ -7,2 +7,6 @@ """Module to map risk-estimation questions to different prompting techniques. | ||
| """ | ||
| from __future__ import annotations | ||
| import logging | ||
| import pandas as pd | ||
@@ -40,4 +44,5 @@ from transformers import AutoTokenizer | ||
| task: TaskMetadata, | ||
| question: QAInterface = None, | ||
| custom_prompt_prefix: str = None, | ||
| add_task_description: bool = True, | ||
| question: QAInterface = None, | ||
| ) -> str: | ||
@@ -49,2 +54,3 @@ """Encode a question regarding a given row.""" | ||
| (ACS_TASK_DESCRIPTION + "\n" if add_task_description else "") | ||
| + (f"\n{custom_prompt_prefix}\n" if custom_prompt_prefix else "") | ||
| + f"""\ | ||
@@ -61,5 +67,6 @@ Information: | ||
| dataset: Dataset, | ||
| n_shots: int = 10, | ||
| n_shots: int, | ||
| question: QAInterface = None, | ||
| reuse_examples: bool = False, | ||
| question: QAInterface = None, | ||
| custom_prompt_prefix: str = None, | ||
| ) -> str: | ||
@@ -98,3 +105,8 @@ """Encode a question regarding a given row using few-shot prompting. | ||
| prompt += ( | ||
| encode_row_prompt(X_examples.iloc[i], task=task, add_task_description=False) | ||
| encode_row_prompt( | ||
| X_examples.iloc[i], | ||
| task=task, | ||
| add_task_description=False, | ||
| custom_prompt_prefix=custom_prompt_prefix, | ||
| ) | ||
| + f" {question.get_answer_key_from_value(y_examples.iloc[i])}" | ||
@@ -109,2 +121,3 @@ + "\n\n" | ||
| add_task_description=False, | ||
| custom_prompt_prefix=custom_prompt_prefix, | ||
| question=question, | ||
@@ -125,2 +138,4 @@ ) | ||
| # - and another for regular models compatible with system prompts | ||
| logging.warning("NOTE :: Untested feature!!") | ||
| return apply_chat_template( | ||
@@ -127,0 +142,0 @@ tokenizer, |
@@ -9,6 +9,6 @@ """Interface for question-answering with LLMs. | ||
| import dataclasses | ||
| import itertools | ||
| import logging | ||
| import re | ||
| import logging | ||
| import itertools | ||
| import dataclasses | ||
| from abc import ABC | ||
@@ -150,5 +150,5 @@ from dataclasses import dataclass | ||
| if len(last_token_probs) < self.num_forward_passes: | ||
| logging.warning( | ||
| f"Expected {self.num_forward_passes} forward passes, got {len(last_token_probs)}." | ||
| ) | ||
| logging.info( | ||
| f"Expected {self.num_forward_passes} forward passes, got " | ||
| f"{len(last_token_probs)}.") | ||
@@ -170,4 +170,5 @@ answer_text = "" | ||
| # Filter out any non-numeric characters | ||
| numeric_answer_text = re.match(r"[-+]?\d*\.\d+|\d+", answer_text).group() | ||
| assert numeric_answer_text, f"Could not find numeric answer in '{answer_text}'." | ||
| match_ = re.match(r"[-+]?\d*\.\d+|\d+", answer_text) | ||
| assert match_, f"Could not find numeric answer in '{answer_text}'." | ||
| numeric_answer_text = match_.group() | ||
@@ -260,3 +261,3 @@ if self.answer_probability and "." not in numeric_answer_text: | ||
| @property | ||
| def answer_keys(self) -> tuple[str]: | ||
| def answer_keys(self) -> tuple[str, ...]: | ||
| return self._answer_keys_source[:len(self.choices)] | ||
@@ -263,0 +264,0 @@ |
+107
-35
@@ -14,4 +14,4 @@ """Definition of a generic TaskMetadata class. | ||
| from .col_to_text import ColumnToText | ||
| from .qa_interface import QAInterface, MultipleChoiceQA, DirectNumericQA | ||
| from .threshold import Threshold | ||
| from .qa_interface import QAInterface | ||
@@ -21,31 +21,37 @@ | ||
| class TaskMetadata: | ||
| """A base class to hold information on a prediction task. | ||
| """A base class to hold information on a prediction task.""" | ||
| Attributes | ||
| ---------- | ||
| name : str | ||
| The name of the task. | ||
| description : str | ||
| A description of the task, including the population to which the task | ||
| pertains to. | ||
| features : list[str] | ||
| The names of the features used in the task. | ||
| target : str | ||
| The name of the target column. | ||
| cols_to_text : dict[str, ColumnToText] | ||
| A mapping between column names and their textual descriptions. | ||
| sensitive_attribute : str, optional | ||
| The name of the column used as the sensitive attribute data (if provided). | ||
| target_threshold : Threshold, optional | ||
| The threshold used to binarize the target column (if provided). | ||
| """ | ||
| name: str | ||
| """The name of the task.""" | ||
| description: str | ||
| """A description of the task, including the population to which the task pertains to.""" | ||
| features: list[str] | ||
| """The names of the features used in the task.""" | ||
| target: str | ||
| """The name of the target column.""" | ||
| cols_to_text: dict[str, ColumnToText] | ||
| """A mapping between column names and their textual descriptions.""" | ||
| sensitive_attribute: str = None | ||
| """The name of the column used as the sensitive attribute data (if provided).""" | ||
| target_threshold: Threshold = None | ||
| qa_interface: QAInterface = None # This will override whichever QA interface is set in the ColumnToText object | ||
| """The threshold used to binarize the target column (if provided).""" | ||
| population_description: str = None | ||
| """A description of the population to which the task pertains to.""" | ||
| multiple_choice_qa: MultipleChoiceQA = None | ||
| """The multiple-choice question and answer interface for this task.""" | ||
| direct_numeric_qa: DirectNumericQA = None | ||
| """The direct numeric question and answer interface for this task.""" | ||
| _use_numeric_qa: bool = False | ||
| """Whether to use numeric Q&A instead of multiple-choice Q&A prompts. Default is False.""" | ||
| # Class-level task storage | ||
@@ -62,2 +68,19 @@ _tasks: ClassVar[dict[str, "TaskMetadata"]] = field(default={}, init=False, repr=False) | ||
| # If no question is explicitly provided, use the question from the target column | ||
| if self.multiple_choice_qa is None and self.direct_numeric_qa is None: | ||
| logging.warning( | ||
| f"No question was explicitly provided for task '{self.name}'. " | ||
| f"Inferring from target column's default question ({self.get_target()}).") | ||
| if self.cols_to_text[self.get_target()]._question is not None: | ||
| question = self.cols_to_text[self.get_target()]._question | ||
| self.set_question(question) | ||
| # Make sure Q&A related attributes are consistent | ||
| if ( | ||
| self._use_numeric_qa is True and self.direct_numeric_qa is None | ||
| or self._use_numeric_qa is False and self.multiple_choice_qa is None | ||
| ): | ||
| raise ValueError("Inconsistent Q&A attributes provided.") | ||
| def __hash__(self) -> int: | ||
@@ -76,25 +99,74 @@ hashable_params = dataclasses.asdict(self) | ||
| def set_question(self, question: QAInterface): | ||
| """Sets the Q&A interface for this task.""" | ||
| logging.info(f"Setting question for task '{self.name}' to '{question.text}'.") | ||
| if isinstance(question, MultipleChoiceQA): | ||
| self.multiple_choice_qa = question | ||
| self._use_numeric_qa = False | ||
| elif isinstance(question, DirectNumericQA): | ||
| self.direct_numeric_qa = question | ||
| self._use_numeric_qa = True | ||
| else: | ||
| raise ValueError(f"Invalid question type: {type(question).__name__}") | ||
| @property | ||
| def use_numeric_qa(self) -> bool: | ||
| """Getter for whether to use numeric Q&A instead of multiple-choice Q&A prompts.""" | ||
| return self._use_numeric_qa | ||
| @use_numeric_qa.setter | ||
| def use_numeric_qa(self, use_numeric_qa: bool): | ||
| """Setter for whether to use numeric Q&A instead of multiple-choice Q&A prompts.""" | ||
| logging.info( | ||
| f"Changing Q&A mode for task '{self.name}' to " | ||
| f"{'numeric' if use_numeric_qa else 'multiple-choice'}.") | ||
| self._use_numeric_qa = use_numeric_qa | ||
| @classmethod | ||
| def get_task(cls, name: str) -> "TaskMetadata": | ||
| def get_task(cls, name: str, use_numeric_qa: bool = False) -> TaskMetadata: | ||
| """Fetches a previously created task by its name. | ||
| Parameters | ||
| ---------- | ||
| name : str | ||
| The name of the task to fetch. | ||
| use_numeric_qa : bool, optional | ||
| Whether to set the retrieved task to use verbalized numeric Q&A | ||
| instead of the default multiple-choice Q&A prompts. Default is False. | ||
| Returns | ||
| ------- | ||
| task : TaskMetadata | ||
| The task object with the given name. | ||
| Raises | ||
| ------ | ||
| ValueError | ||
| Raised if the task with the given name has not been created yet. | ||
| """ | ||
| if name not in cls._tasks: | ||
| raise ValueError(f"Task '{name}' has not been created yet.") | ||
| return cls._tasks[name] | ||
| # Retrieve the task object | ||
| task = cls._tasks[name] | ||
| # Set Q&A interface type | ||
| task.use_numeric_qa = use_numeric_qa | ||
| return task | ||
| @property | ||
| def question(self) -> QAInterface: | ||
| if self.qa_interface is not None: | ||
| return self.qa_interface | ||
| """Getter for the Q&A interface for this task.""" | ||
| elif self.cols_to_text[self.get_target()]._question is not None: | ||
| return self.cols_to_text[self.get_target()].question | ||
| # Resolve direct numeric Q&A vs multiple-choice Q&A | ||
| if self._use_numeric_qa: | ||
| q = self.direct_numeric_qa | ||
| else: | ||
| raise ValueError(f"No question provided for the target column '{self.get_target()}'.") | ||
| q = self.multiple_choice_qa | ||
| @question.setter | ||
| def question(self, new_qa: QAInterface): | ||
| if new_qa is not None and new_qa.column == self.get_target(): | ||
| self.qa_interface = new_qa | ||
| else: | ||
| logging.error("Mismatch between task target column and provided question.") | ||
| if q is None: | ||
| raise ValueError(f"Invalid Q&A interface configured for task {self.name}.") | ||
| return q | ||
@@ -101,0 +173,0 @@ def get_row_description(self, row: pd.Series) -> str: |
+51
-26
| Metadata-Version: 2.1 | ||
| Name: folktexts | ||
| Version: 0.0.20 | ||
| Version: 0.0.21 | ||
| Summary: Use LLMs to get classification risk scores on tabular tasks. | ||
@@ -111,4 +111,4 @@ Author: Andre Cruz, Ricardo Dominguez-Olmedo, Celestine Mendler-Dunner, Moritz Hardt | ||
| - [Example usage](#example-usage) | ||
| - [Benchmark features and options](#benchmark-features-and-options) | ||
| - [Evaluating feature importance](#evaluating-feature-importance) | ||
| - [Benchmark options](#benchmark-options) | ||
| - [FAQ](#faq) | ||
@@ -168,2 +168,3 @@ - [Citation](#citation) | ||
| ```py | ||
| # Load transformers model | ||
| from folktexts.llm_utils import load_model_tokenizer | ||
@@ -176,4 +177,4 @@ model, tokenizer = load_model_tokenizer("gpt2") # using tiny model as an example | ||
| # Create an object that classifies data using an LLM | ||
| from folktexts import LLMClassifier | ||
| clf = LLMClassifier( | ||
| from folktexts import TransformersLLMClassifier | ||
| clf = TransformersLLMClassifier( | ||
| model=model, | ||
@@ -183,2 +184,3 @@ tokenizer=tokenizer, | ||
| ) | ||
| # NOTE: You can also use a web-hosted model like GPT4 using the `WebAPILLMClassifier` class | ||
@@ -188,6 +190,2 @@ # Use a dataset or feed in your own data | ||
| # And simply run the benchmark to get a variety of metrics and plots | ||
| from folktexts.benchmark import Benchmark | ||
| benchmark_results = Benchmark(clf, dataset).run(results_root_dir="results") | ||
| # You can compute risk score predictions using an sklearn-style interface | ||
@@ -202,28 +200,36 @@ X_test, y_test = dataset.get_test() | ||
| test_preds = clf.predict(X_test) | ||
| # If you only care about the overall metrics and not individual predictions, | ||
| # you can simply run the following code: | ||
| from folktexts.benchmark import Benchmark, BenchmarkConfig | ||
| bench = Benchmark.make_benchmark( | ||
| task=acs_task_name, dataset=dataset, | ||
| model=model, tokenizer=tokenizer, | ||
| config=BenchmarkConfig(numeric_risk_prompting=True), # Add other configs here | ||
| ) | ||
| bench_results = bench.run(results_root_dir="results") | ||
| ``` | ||
| <!-- TODO: add code to show-case example functionalities, including the | ||
| LLMClassifier (maybe the above code is fine for this), the benchmark, and | ||
| creating a custom ACS prediction task --> | ||
| ## Benchmark features and options | ||
| ## Evaluating feature importance | ||
| Here's a summary list of the most important benchmark options/flags used in | ||
| conjunction with the `run_acs_benchmark` command line script, or with the | ||
| `Benchmark` class. | ||
| By evaluating LLMs on tabular classification tasks, we can use standard feature importance methods to assess which features the model uses to compute risk scores. | ||
| | Option | Description | Examples | | ||
| |:---|:---|:---:| | ||
| | `--model` | Name of the model on huggingface transformers, or local path to folder with pretrained model and tokenizer. Can also use web-hosted models with `"[provider]/[model-name]"`. | `meta-llama/Meta-Llama-3-8B`, `openai/gpt-4o-mini` | | ||
| | `--task` | Name of the ACS task to run benchmark on. | `ACSIncome`, `ACSEmployment` | | ||
| | `--results-dir` | Path to directory under which benchmark results will be saved. | `results` | | ||
| | `--data-dir` | Root folder to find datasets in (or download ACS data to). | `~/data` | | ||
| | `--numeric-risk-prompting` | Whether to use verbalized numeric risk prompting, i.e., directly query model for a probability estimate. **By default** will use standard multiple-choice Q&A, and extract risk scores from internal token probabilities. | Boolean flag (`True` if present, `False` otherwise) | | ||
| | `--use-web-api-model` | Whether the given `--model` name corresponds to a web-hosted model or not. **By default** this is False (assumes a huggingface transformers model). If this flag is provided, `--model` must contain a [litellm](https://docs.litellm.ai) model identifier ([examples here](https://docs.litellm.ai/docs/providers/openai#openai-chat-completion-models)). | Boolean flag (`True` if present, `False` otherwise) | | ||
| | `--subsampling` | Which fraction of the dataset to use for the benchmark. **By default** will use the whole test set. | `0.01` | | ||
| | `--fit-threshold` | Whether to use the given number of samples to fit the binarization threshold. **By default** will use a fixed $t=0.5$ threshold instead of fitting on data. | `100` | | ||
| | `--batch-size` | The number of samples to process in each inference batch. Choose according to your available VRAM. | `10`, `32` | | ||
| You can do so yourself by calling `folktexts.cli.eval_feature_importance` (add `--help` for a full list of options). | ||
| Here's an example for the Llama3-70B-Instruct model on the ACSIncome task (*warning: takes 24h on an Nvidia H100*): | ||
| ``` | ||
| python -m folktexts.cli.eval_feature_importance --model 'meta-llama/Meta-Llama-3-70B-Instruct' --task ACSIncome --subsampling 0.1 | ||
| ``` | ||
| <div style="text-align: center;"> | ||
| <img src="docs/_static/feat-imp_meta-llama--Meta-Llama-3-70B-Instruct.png" alt="feature importance on llama3 70b it" width="50%"> | ||
| </div> | ||
| Full list of options: | ||
| This script uses sklearn's [`permutation_importance`](https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html#sklearn.inspection.permutation_importance) to assess which features contribute the most for the ROC AUC metric (other metrics can be assessed using the `--scorer [scorer]` parameter). | ||
| ## Benchmark options | ||
| ``` | ||
@@ -268,2 +274,20 @@ usage: run_acs_benchmark [-h] --model MODEL --results-dir RESULTS_DIR --data-dir DATA_DIR [--task TASK] [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD] [--subsampling SUBSAMPLING] [--seed SEED] [--use-web-api-model] [--dont-correct-order-bias] [--numeric-risk-prompting] [--reuse-few-shot-examples] [--use-feature-subset USE_FEATURE_SUBSET] | ||
| ## Evaluating feature importance | ||
| By evaluating LLMs on tabular classification tasks, we can use standard feature importance methods to assess which features the model uses to compute risk scores. | ||
| You can do so yourself by calling `folktexts.cli.eval_feature_importance` (add `--help` for a full list of options). | ||
| Here's an example for the Llama3-70B-Instruct model on the ACSIncome task (*warning: takes 24h on an Nvidia H100*): | ||
| ``` | ||
| python -m folktexts.cli.eval_feature_importance --model 'meta-llama/Meta-Llama-3-70B-Instruct' --task ACSIncome --subsampling 0.1 | ||
| ``` | ||
| <div style="text-align: center;"> | ||
| <img src="docs/_static/feat-imp_meta-llama--Meta-Llama-3-70B-Instruct.png" alt="feature importance on llama3 70b it" width="50%"> | ||
| </div> | ||
| This script uses sklearn's [`permutation_importance`](https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html#sklearn.inspection.permutation_importance) to assess which features contribute the most for the ROC AUC metric (other metrics can be assessed using the `--scorer [scorer]` parameter). | ||
| ## FAQ | ||
@@ -287,2 +311,3 @@ | ||
| **A:** **Yes!** We provide compatibility with local LLMs via [🤗 transformers](https://github.com/huggingface/transformers) and compatibility with web-hosted LLMs via [litellm](https://github.com/BerriAI/litellm). For example, you can use `--model='gpt-4o' --use-web-api-model` to use GPT-4o when calling the `run_acs_benchmark` script. [Here's a complete list](https://docs.litellm.ai/docs/providers/openai#openai-chat-completion-models) of compatible OpenAI models. Note that some models are not compatible as they don't enable access to log-probabilities. | ||
| Using models through a web API requires installing extra optional dependencies with `pip install 'folktexts[apis]'`. | ||
@@ -289,0 +314,0 @@ |
+1
-1
@@ -35,3 +35,3 @@ [build-system] | ||
| version = "0.0.20" | ||
| version = "0.0.21" | ||
| requires-python = ">=3.8" | ||
@@ -38,0 +38,0 @@ dynamic = [ |
+50
-25
@@ -26,4 +26,4 @@ # :book: folktexts <!-- omit in toc --> | ||
| - [Example usage](#example-usage) | ||
| - [Benchmark features and options](#benchmark-features-and-options) | ||
| - [Evaluating feature importance](#evaluating-feature-importance) | ||
| - [Benchmark options](#benchmark-options) | ||
| - [FAQ](#faq) | ||
@@ -83,2 +83,3 @@ - [Citation](#citation) | ||
| ```py | ||
| # Load transformers model | ||
| from folktexts.llm_utils import load_model_tokenizer | ||
@@ -91,4 +92,4 @@ model, tokenizer = load_model_tokenizer("gpt2") # using tiny model as an example | ||
| # Create an object that classifies data using an LLM | ||
| from folktexts import LLMClassifier | ||
| clf = LLMClassifier( | ||
| from folktexts import TransformersLLMClassifier | ||
| clf = TransformersLLMClassifier( | ||
| model=model, | ||
@@ -98,2 +99,3 @@ tokenizer=tokenizer, | ||
| ) | ||
| # NOTE: You can also use a web-hosted model like GPT4 using the `WebAPILLMClassifier` class | ||
@@ -103,6 +105,2 @@ # Use a dataset or feed in your own data | ||
| # And simply run the benchmark to get a variety of metrics and plots | ||
| from folktexts.benchmark import Benchmark | ||
| benchmark_results = Benchmark(clf, dataset).run(results_root_dir="results") | ||
| # You can compute risk score predictions using an sklearn-style interface | ||
@@ -117,28 +115,36 @@ X_test, y_test = dataset.get_test() | ||
| test_preds = clf.predict(X_test) | ||
| # If you only care about the overall metrics and not individual predictions, | ||
| # you can simply run the following code: | ||
| from folktexts.benchmark import Benchmark, BenchmarkConfig | ||
| bench = Benchmark.make_benchmark( | ||
| task=acs_task_name, dataset=dataset, | ||
| model=model, tokenizer=tokenizer, | ||
| config=BenchmarkConfig(numeric_risk_prompting=True), # Add other configs here | ||
| ) | ||
| bench_results = bench.run(results_root_dir="results") | ||
| ``` | ||
| <!-- TODO: add code to show-case example functionalities, including the | ||
| LLMClassifier (maybe the above code is fine for this), the benchmark, and | ||
| creating a custom ACS prediction task --> | ||
| ## Benchmark features and options | ||
| ## Evaluating feature importance | ||
| Here's a summary list of the most important benchmark options/flags used in | ||
| conjunction with the `run_acs_benchmark` command line script, or with the | ||
| `Benchmark` class. | ||
| By evaluating LLMs on tabular classification tasks, we can use standard feature importance methods to assess which features the model uses to compute risk scores. | ||
| | Option | Description | Examples | | ||
| |:---|:---|:---:| | ||
| | `--model` | Name of the model on huggingface transformers, or local path to folder with pretrained model and tokenizer. Can also use web-hosted models with `"[provider]/[model-name]"`. | `meta-llama/Meta-Llama-3-8B`, `openai/gpt-4o-mini` | | ||
| | `--task` | Name of the ACS task to run benchmark on. | `ACSIncome`, `ACSEmployment` | | ||
| | `--results-dir` | Path to directory under which benchmark results will be saved. | `results` | | ||
| | `--data-dir` | Root folder to find datasets in (or download ACS data to). | `~/data` | | ||
| | `--numeric-risk-prompting` | Whether to use verbalized numeric risk prompting, i.e., directly query model for a probability estimate. **By default** will use standard multiple-choice Q&A, and extract risk scores from internal token probabilities. | Boolean flag (`True` if present, `False` otherwise) | | ||
| | `--use-web-api-model` | Whether the given `--model` name corresponds to a web-hosted model or not. **By default** this is False (assumes a huggingface transformers model). If this flag is provided, `--model` must contain a [litellm](https://docs.litellm.ai) model identifier ([examples here](https://docs.litellm.ai/docs/providers/openai#openai-chat-completion-models)). | Boolean flag (`True` if present, `False` otherwise) | | ||
| | `--subsampling` | Which fraction of the dataset to use for the benchmark. **By default** will use the whole test set. | `0.01` | | ||
| | `--fit-threshold` | Whether to use the given number of samples to fit the binarization threshold. **By default** will use a fixed $t=0.5$ threshold instead of fitting on data. | `100` | | ||
| | `--batch-size` | The number of samples to process in each inference batch. Choose according to your available VRAM. | `10`, `32` | | ||
| You can do so yourself by calling `folktexts.cli.eval_feature_importance` (add `--help` for a full list of options). | ||
| Here's an example for the Llama3-70B-Instruct model on the ACSIncome task (*warning: takes 24h on an Nvidia H100*): | ||
| ``` | ||
| python -m folktexts.cli.eval_feature_importance --model 'meta-llama/Meta-Llama-3-70B-Instruct' --task ACSIncome --subsampling 0.1 | ||
| ``` | ||
| <div style="text-align: center;"> | ||
| <img src="docs/_static/feat-imp_meta-llama--Meta-Llama-3-70B-Instruct.png" alt="feature importance on llama3 70b it" width="50%"> | ||
| </div> | ||
| Full list of options: | ||
| This script uses sklearn's [`permutation_importance`](https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html#sklearn.inspection.permutation_importance) to assess which features contribute the most for the ROC AUC metric (other metrics can be assessed using the `--scorer [scorer]` parameter). | ||
| ## Benchmark options | ||
| ``` | ||
@@ -183,2 +189,20 @@ usage: run_acs_benchmark [-h] --model MODEL --results-dir RESULTS_DIR --data-dir DATA_DIR [--task TASK] [--few-shot FEW_SHOT] [--batch-size BATCH_SIZE] [--context-size CONTEXT_SIZE] [--fit-threshold FIT_THRESHOLD] [--subsampling SUBSAMPLING] [--seed SEED] [--use-web-api-model] [--dont-correct-order-bias] [--numeric-risk-prompting] [--reuse-few-shot-examples] [--use-feature-subset USE_FEATURE_SUBSET] | ||
| ## Evaluating feature importance | ||
| By evaluating LLMs on tabular classification tasks, we can use standard feature importance methods to assess which features the model uses to compute risk scores. | ||
| You can do so yourself by calling `folktexts.cli.eval_feature_importance` (add `--help` for a full list of options). | ||
| Here's an example for the Llama3-70B-Instruct model on the ACSIncome task (*warning: takes 24h on an Nvidia H100*): | ||
| ``` | ||
| python -m folktexts.cli.eval_feature_importance --model 'meta-llama/Meta-Llama-3-70B-Instruct' --task ACSIncome --subsampling 0.1 | ||
| ``` | ||
| <div style="text-align: center;"> | ||
| <img src="docs/_static/feat-imp_meta-llama--Meta-Llama-3-70B-Instruct.png" alt="feature importance on llama3 70b it" width="50%"> | ||
| </div> | ||
| This script uses sklearn's [`permutation_importance`](https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html#sklearn.inspection.permutation_importance) to assess which features contribute the most for the ROC AUC metric (other metrics can be assessed using the `--scorer [scorer]` parameter). | ||
| ## FAQ | ||
@@ -202,2 +226,3 @@ | ||
| **A:** **Yes!** We provide compatibility with local LLMs via [🤗 transformers](https://github.com/huggingface/transformers) and compatibility with web-hosted LLMs via [litellm](https://github.com/BerriAI/litellm). For example, you can use `--model='gpt-4o' --use-web-api-model` to use GPT-4o when calling the `run_acs_benchmark` script. [Here's a complete list](https://docs.litellm.ai/docs/providers/openai#openai-chat-completion-models) of compatible OpenAI models. Note that some models are not compatible as they don't enable access to log-probabilities. | ||
| Using models through a web API requires installing extra optional dependencies with `pip install 'folktexts[apis]'`. | ||
@@ -204,0 +229,0 @@ |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
250372
6.69%4254
5.58%