Elron commited on
Commit
5fdede1
·
verified ·
1 Parent(s): 077586b

Upload folder using huggingface_hub

Browse files
Files changed (9) hide show
  1. dataset.py +1 -0
  2. evaluate_cli.py +828 -0
  3. inference.py +233 -112
  4. llm_as_judge.py +4 -4
  5. metric.py +1 -0
  6. metrics.py +106 -95
  7. parsing_utils.py +2 -2
  8. processors.py +70 -16
  9. version.py +1 -1
dataset.py CHANGED
@@ -20,6 +20,7 @@ from .dialog_operators import __file__ as _
20
  from .dict_utils import __file__ as _
21
  from .error_utils import __file__ as _
22
  from .eval_utils import __file__ as _
 
23
  from .file_utils import __file__ as _
24
  from .formats import __file__ as _
25
  from .fusion import __file__ as _
 
20
  from .dict_utils import __file__ as _
21
  from .error_utils import __file__ as _
22
  from .eval_utils import __file__ as _
23
+ from .evaluate_cli import __file__ as _
24
  from .file_utils import __file__ as _
25
  from .formats import __file__ as _
26
  from .fusion import __file__ as _
evaluate_cli.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluate_cli.py
2
+ import argparse
3
+ import importlib.metadata
4
+ import json
5
+ import logging
6
+ import os
7
+ import platform
8
+ import subprocess
9
+ import sys
10
+ from datetime import datetime
11
+ from functools import partial
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ from datasets import Dataset as HFDataset
15
+
16
+ from . import evaluate, get_logger, load_dataset
17
+ from .artifact import UnitxtArtifactNotFoundError
18
+ from .benchmark import Benchmark
19
+
20
+ # Use HFAutoModelInferenceEngine for local models
21
+ from .inference import (
22
+ CrossProviderInferenceEngine,
23
+ HFAutoModelInferenceEngine,
24
+ InferenceEngine,
25
+ )
26
+ from .metric_utils import EvaluationResults
27
+ from .parsing_utils import parse_key_equals_value_string_to_dict
28
+ from .settings_utils import settings
29
+ from .standard import DatasetRecipe
30
+
31
+ # Define logger early so it can be used in initial error handling
32
+ # Basic config for initial messages, will be reconfigured in main()
33
+ logger = get_logger()
34
+
35
+
36
+ def try_parse_json(value: str) -> Union[str, dict, None]:
37
+ """Attempts to parse a string as JSON or key=value pairs.
38
+
39
+ Returns the original string if parsing fails
40
+ and the string doesn't look like JSON/kv pairs.
41
+ Raises ArgumentTypeError if it looks like JSON but is invalid.
42
+ """
43
+ if value is None:
44
+ return None
45
+ try:
46
+ # Handle simple key-value pairs like "key=value,key2=value2"
47
+ if "=" in value and "{" not in value:
48
+ parsed_dict = parse_key_equals_value_string_to_dict(value)
49
+ if parsed_dict:
50
+ return parsed_dict
51
+
52
+ # Attempt standard JSON parsing
53
+ return json.loads(value)
54
+
55
+ except json.JSONDecodeError as e:
56
+ if value.strip().startswith("{") or value.strip().startswith("["):
57
+ raise argparse.ArgumentTypeError(
58
+ f"Invalid JSON: '{value}'. Hint: Use double quotes for JSON strings and check syntax."
59
+ ) from e
60
+ return value # Return as string if not JSON-like
61
+ except Exception as e:
62
+ logger.error(f"Error parsing argument '{value}': {e}")
63
+ raise argparse.ArgumentTypeError(f"Could not parse argument: '{value}'") from e
64
+
65
+
66
+ def setup_parser() -> argparse.ArgumentParser:
67
+ """Sets up the argument parser."""
68
+ parser = argparse.ArgumentParser(
69
+ formatter_class=argparse.RawTextHelpFormatter,
70
+ description="CLI utility for running evaluations with unitxt.",
71
+ )
72
+
73
+ # --- Task/Dataset Arguments ---
74
+ parser.add_argument(
75
+ "--tasks", # Changed to plural to better reflect it holds a list
76
+ "-t",
77
+ dest="tasks", # Explicitly set the attribute name to 'tasks'
78
+ type=partial(str.split, sep="+"), # Use the custom function for type conversion
79
+ required=True,
80
+ help="Plus-separated (+) list of Unitxt task/dataset identifier strings.\n"
81
+ "Each task format: 'card=<card_ref>,template=<template_ref>,...'\n"
82
+ "Example: 'card=cards.mmlu,t=t.mmlu.all+card=cards.hellaswag,t=t.hellaswag.no'",
83
+ )
84
+
85
+ parser.add_argument(
86
+ "--split",
87
+ type=str,
88
+ default="test",
89
+ help="Dataset split to use (e.g., 'train', 'validation', 'test'). Default: 'test'.",
90
+ )
91
+ parser.add_argument(
92
+ "--num_fewshots",
93
+ type=int,
94
+ default=None,
95
+ help="number of fewshots to use",
96
+ )
97
+ parser.add_argument(
98
+ "--limit",
99
+ "-L",
100
+ type=int,
101
+ default=None,
102
+ metavar="N",
103
+ help="Limit the number of examples per task/dataset.",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--batch_size",
108
+ "-b",
109
+ type=int,
110
+ default=1,
111
+ help="Batch size for use in inference when selected model is hf. Default 1",
112
+ )
113
+
114
+ # --- Model Arguments (Explicit Types) ---
115
+ parser.add_argument(
116
+ "--model",
117
+ "-m",
118
+ type=str,
119
+ default="hf",
120
+ choices=["hf", "cross_provider"],
121
+ help="Specifies the model type/engine.\n"
122
+ "- 'hf': Local Hugging Face model via HFAutoModel (default). Requires 'pretrained=...' in --model_args.\n"
123
+ "- 'cross_provider': Remote model via CrossProviderInferenceEngine. Requires 'model_name=...' in --model_args.",
124
+ )
125
+ parser.add_argument(
126
+ "--model_args",
127
+ "-a",
128
+ type=try_parse_json,
129
+ default={},
130
+ help="Comma separated string or JSON formatted arguments for the model/inference engine.\n"
131
+ "Examples:\n"
132
+ "- For --model hf (default): 'pretrained=meta-llama/Llama-3.1-8B-Instruct,torch_dtype=bfloat16,device=cuda'\n"
133
+ " (Note: 'pretrained' key is REQUIRED. Other args like 'torch_dtype', 'device', generation params are passed too)\n"
134
+ "- For --model generic_remote: 'model_name=llama-3-3-70b-instruct,max_tokens=256,temperature=0.7'\n"
135
+ " (Note: 'model_name' key is REQUIRED)\n"
136
+ '- JSON format: \'{"pretrained": "my_model", "torch_dtype": "float32"}\' or \'{"model_name": "openai/gpt-4o"}\'',
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--gen_kwargs",
141
+ type=try_parse_json,
142
+ default=None,
143
+ help=(
144
+ "Comma delimited string for model generation on greedy_until tasks,"
145
+ """ e.g. temperature=0,top_p=0.1."""
146
+ ),
147
+ )
148
+
149
+ parser.add_argument(
150
+ "--chat_template_kwargs",
151
+ type=try_parse_json,
152
+ default=None,
153
+ help=(
154
+ "Comma delimited string for tokenizer kwargs"
155
+ "e.g. thinking=True (https://github.com/huggingface/transformers/blob/9a1c1fe7edaefdb25ab37116a979832df298d6ea/src/transformers/tokenization_utils_base.py#L1542)"
156
+ ),
157
+ )
158
+
159
+ # --- Output and Logging Arguments ---
160
+ parser.add_argument(
161
+ "--output_path",
162
+ "-o",
163
+ type=str,
164
+ default=".",
165
+ help="Directory to save evaluation results and logs. Default: current directory.",
166
+ )
167
+ parser.add_argument(
168
+ "--output_file_prefix",
169
+ type=str,
170
+ default="evaluation_results",
171
+ help="Prefix for the output JSON file names. Default: 'evaluation_results'.",
172
+ )
173
+ parser.add_argument(
174
+ "--log_samples",
175
+ "-s",
176
+ action="store_true",
177
+ default=False,
178
+ help="If True, save individual predictions and scores to a separate JSON file.",
179
+ )
180
+ parser.add_argument(
181
+ "--verbosity",
182
+ "-v",
183
+ type=str.upper,
184
+ default="INFO",
185
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
186
+ help="Controls logging verbosity level. Default: INFO.",
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--apply_chat_template",
191
+ action="store_true",
192
+ default=False,
193
+ )
194
+
195
+ # --- Unitxt Settings ---
196
+ parser.add_argument(
197
+ "--trust_remote_code",
198
+ action="store_true",
199
+ default=False,
200
+ help="Allow execution of unverified code from the HuggingFace Hub (used by datasets/unitxt).",
201
+ )
202
+ parser.add_argument(
203
+ "--disable_hf_cache",
204
+ action="store_true",
205
+ default=False,
206
+ help="Disable HuggingFace datasets caching.",
207
+ )
208
+ parser.add_argument(
209
+ "--cache_dir",
210
+ type=str,
211
+ default=None,
212
+ help="Directory for HuggingFace datasets cache (overrides default).",
213
+ )
214
+
215
+ return parser
216
+
217
+
218
+ def setup_logging(verbosity: str) -> None:
219
+ """Configures logging based on verbosity level."""
220
+ logging.basicConfig(
221
+ level=verbosity,
222
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
223
+ force=True, # Ensures reconfiguration works if basicConfig was called before
224
+ )
225
+ # Re-get the logger instance after basicConfig is set
226
+ global logger
227
+ logger = get_logger()
228
+ logger.setLevel(verbosity)
229
+
230
+
231
+ def prepare_output_paths(output_path: str, prefix: str) -> Tuple[str, str]:
232
+ """Creates output directory and defines file paths.
233
+
234
+ Args:
235
+ output_path (str): The directory where output files will be saved.
236
+ prefix (str): The prefix for the output file names.
237
+
238
+ Returns:
239
+ Tuple[str, str]: A tuple containing the path for the results summary file
240
+ and the path for the detailed samples file.
241
+ """
242
+ os.makedirs(output_path, exist_ok=True)
243
+ results_file_path = os.path.join(output_path, f"{prefix}.json")
244
+ samples_file_path = os.path.join(output_path, f"{prefix}_samples.json")
245
+ return results_file_path, samples_file_path
246
+
247
+
248
+ def configure_unitxt_settings(args: argparse.Namespace):
249
+ """Configures unitxt settings and returns a context manager.
250
+
251
+ Args:
252
+ args (argparse.Namespace): Parsed command-line arguments.
253
+
254
+ Returns:
255
+ ContextManager: A context manager for applying unitxt settings.
256
+ """
257
+ unitxt_settings_dict = {
258
+ "disable_hf_datasets_cache": args.disable_hf_cache,
259
+ "allow_unverified_code": args.trust_remote_code,
260
+ }
261
+ if args.cache_dir:
262
+ unitxt_settings_dict["hf_cache_dir"] = args.cache_dir
263
+ # Also set environment variable as some HF parts might read it directly
264
+ os.environ["HF_DATASETS_CACHE"] = args.cache_dir
265
+ os.environ["HF_HOME"] = args.cache_dir
266
+ logger.info(f"Set HF_DATASETS_CACHE to: {args.cache_dir}")
267
+
268
+ if args.disable_hf_cache:
269
+ os.environ["UNITXT_DISABLE_HF_DATASETS_CACHE"] = "True"
270
+
271
+ logger.info(f"Applying unitxt settings: {unitxt_settings_dict}")
272
+ return settings.context(**unitxt_settings_dict)
273
+
274
+
275
+ def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
276
+ """Loads the dataset based on command line arguments.
277
+
278
+ Args:
279
+ args (argparse.Namespace): Parsed command-line arguments.
280
+
281
+ Returns:
282
+ HFDataset: The loaded dataset.
283
+
284
+ Raises:
285
+ UnitxtArtifactNotFoundError: If the specified card or template artifact is not found.
286
+ FileNotFoundError: If a specified file (e.g., in a local card path) is not found.
287
+ AttributeError: If there's an issue accessing attributes during loading.
288
+ ValueError: If there's a value-related error during loading (e.g., parsing).
289
+ """
290
+ logger.info(
291
+ f"Loading task/dataset using identifier: '{args.tasks}' with split '{args.split}'"
292
+ )
293
+
294
+ benchmark_subsets = {}
295
+ for task_str in args.tasks:
296
+ dataset_args = task_str_to_dataset_args(task_str, args)
297
+
298
+ benchmark_subsets[task_str] = DatasetRecipe(**dataset_args)
299
+
300
+ benchmark = Benchmark(subsets=benchmark_subsets)
301
+
302
+ test_dataset = load_dataset(benchmark, split=args.split)
303
+ logger.info(
304
+ f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
305
+ )
306
+ return test_dataset
307
+
308
+
309
+ def task_str_to_dataset_args(task_str, args):
310
+ dataset_args = parse_key_equals_value_string_to_dict(task_str)
311
+
312
+ if args.limit is not None:
313
+ assert f"max_{args.split}_instances" not in dataset_args, (
314
+ "limit was inputted both as an arg and as a task parameter"
315
+ )
316
+ # Check if limit or loader_limit is already present
317
+ # dataset_args[f"max_{args.split}_instances"] = args.limit
318
+ dataset_args[f"max_{args.split}_instances"] = args.limit
319
+ # Use loader_limit for unitxt compatibility
320
+ logger.info(
321
+ f"Applying limit from --limit argument: max_{args.split}_instances={args.limit}"
322
+ )
323
+
324
+ if args.num_fewshots:
325
+ assert "num_demos" not in dataset_args, (
326
+ "num_demos was inputted both as an arg and as a task parameter"
327
+ )
328
+ dataset_args["num_demos"] = args.num_fewshots
329
+ dataset_args.update(
330
+ {
331
+ "demos_taken_from": "train",
332
+ "demos_pool_size": -1,
333
+ "demos_removed_from_data": True,
334
+ }
335
+ ) # Use loader_limit for unitxt compatibility
336
+ logger.info(
337
+ f"Applying limit from --limit argument: num_demos={args.num_fewshots}"
338
+ )
339
+
340
+ if args.apply_chat_template:
341
+ assert "format" not in dataset_args, (
342
+ "format was inputted as a task parameter, but chat_api was requested"
343
+ )
344
+ dataset_args["format"] = "formats.chat_api"
345
+ logger.info(
346
+ "Applying chat template from --apply_chat_template argument: format=formats.chat_api"
347
+ )
348
+
349
+ return dataset_args
350
+
351
+
352
+ def prepare_kwargs(kwargs: dict) -> Dict[str, Any]:
353
+ """Prepares the model arguments dictionary.
354
+
355
+ Args:
356
+ kwargs (dict): Parsed command-line arguments.
357
+
358
+ Returns:
359
+ Dict[str, Any]: The processed model arguments dictionary.
360
+ """
361
+ # Ensure model_args is a dictionary, handling potential string return from try_parse_json
362
+ kwargs_dict = kwargs if isinstance(kwargs, dict) else {}
363
+ if not isinstance(kwargs, dict) and kwargs is not None:
364
+ logger.warning(
365
+ f"Could not parse kwargs '{kwargs}' as JSON or key-value pairs. Treating as empty."
366
+ )
367
+
368
+ logger.info(f"Using kwargs: {kwargs_dict}")
369
+ return kwargs_dict
370
+
371
+
372
+ def initialize_inference_engine(
373
+ args: argparse.Namespace,
374
+ model_args_dict: Dict[str, Any],
375
+ chat_kwargs_dict: Dict[str, Any],
376
+ ) -> InferenceEngine:
377
+ """Initializes the appropriate inference engine based on arguments.
378
+
379
+ Args:
380
+ args (argparse.Namespace): Parsed command-line arguments.
381
+ model_args_dict (Dict[str, Any]): Processed model arguments.
382
+ chat_kwargs_dict (Dict[str, Any]): Processed chat arguments.
383
+
384
+ Returns:
385
+ InferenceEngine: The initialized inference engine instance.
386
+
387
+ Raises:
388
+ SystemExit: If required dependencies are missing for the selected model type.
389
+ ValueError: If required keys are missing in model_args for the selected model type.
390
+ """
391
+ inference_model = None
392
+ # --- Local Hugging Face Model (using HFAutoModelInferenceEngine) ---
393
+ if args.model.lower() == "hf":
394
+ if "pretrained" not in model_args_dict:
395
+ logger.error(
396
+ "Missing 'pretrained=<model_id_or_path>' in --model_args for '--model hf'."
397
+ )
398
+ raise ValueError(
399
+ "Argument 'pretrained' is required in --model_args when --model is 'hf'"
400
+ )
401
+
402
+ local_model_name = model_args_dict.pop("pretrained")
403
+ logger.info(
404
+ f"Initializing HFAutoModelInferenceEngine for model: {local_model_name}"
405
+ )
406
+
407
+ model_args_dict.update({"batch_size": args.batch_size})
408
+ logger.info(f"HFAutoModelInferenceEngine args: {model_args_dict}")
409
+
410
+ inference_model = HFAutoModelInferenceEngine(
411
+ model_name=local_model_name,
412
+ **model_args_dict,
413
+ chat_kwargs_dict=chat_kwargs_dict,
414
+ )
415
+
416
+ # --- Remote Model (CrossProviderInferenceEngine) ---
417
+ elif args.model.lower() == "cross_provider":
418
+ if "model_name" not in model_args_dict:
419
+ logger.error(
420
+ "Missing 'model_name=<provider/model_id>' in --model_args for '--model cross_provider'."
421
+ )
422
+ raise ValueError(
423
+ "Argument 'model_name' is required in --model_args when --model is 'cross_provider'"
424
+ )
425
+
426
+ remote_model_name = model_args_dict.pop("model_name")
427
+ logger.info(
428
+ f"Initializing CrossProviderInferenceEngine for model: {remote_model_name}"
429
+ )
430
+
431
+ if (
432
+ "max_tokens" not in model_args_dict
433
+ and "max_new_tokens" not in model_args_dict
434
+ ):
435
+ logger.warning(
436
+ f"'max_tokens' or 'max_new_tokens' not found in --model_args, {remote_model_name} might require it."
437
+ )
438
+
439
+ logger.info(f"CrossProviderInferenceEngine args: {model_args_dict}")
440
+
441
+ # Note: CrossProviderInferenceEngine expects 'model' parameter, not 'model_name'
442
+ inference_model = CrossProviderInferenceEngine(
443
+ model=remote_model_name,
444
+ **model_args_dict,
445
+ )
446
+ else:
447
+ # This case should not be reached due to argparse choices
448
+ logger.error(
449
+ f"Invalid --model type specified: {args.model}. Use 'hf' or 'cross_provider'."
450
+ )
451
+ sys.exit(1) # Exit here as it's an invalid configuration
452
+
453
+ return inference_model
454
+
455
+
456
+ def run_inference(engine: InferenceEngine, dataset: HFDataset) -> List[Any]:
457
+ """Runs inference using the initialized engine.
458
+
459
+ Args:
460
+ engine (InferenceEngine): The inference engine instance.
461
+ dataset (HFDataset): The dataset to run inference on.
462
+
463
+ Returns:
464
+ List[Any]: A list of predictions.
465
+
466
+ Raises:
467
+ Exception: If an error occurs during inference.
468
+ """
469
+ logger.info("Starting inference...")
470
+ try:
471
+ predictions = engine.infer(dataset)
472
+ logger.info("Inference completed.")
473
+ if not predictions:
474
+ logger.warning("Inference returned no predictions.")
475
+ return [] # Return empty list if no predictions
476
+ if len(predictions) != len(dataset):
477
+ logger.error(
478
+ f"Inference returned an unexpected number of predictions ({len(predictions)}). Expected {len(dataset)}."
479
+ )
480
+ # Don't exit, but log error. Evaluation might still work partially or fail later.
481
+ return predictions
482
+ except Exception:
483
+ logger.exception("An error occurred during inference") # Use logger.exception
484
+ raise # Re-raise after logging
485
+
486
+
487
+ def run_evaluation(predictions: List[Any], dataset: HFDataset) -> EvaluationResults:
488
+ """Runs evaluation on the predictions.
489
+
490
+ Args:
491
+ predictions (List[Any]): The list of predictions from the model.
492
+ dataset (HFDataset): The dataset containing references and other data.
493
+
494
+ Returns:
495
+ EvaluationResults: The evaluated dataset (list of instances with scores).
496
+
497
+ Raises:
498
+ RuntimeError: If evaluation returns no results or an unexpected type.
499
+ Exception: If any other error occurs during evaluation.
500
+ """
501
+ logger.info("Starting evaluation...")
502
+ if not predictions:
503
+ logger.warning("Skipping evaluation as there are no predictions.")
504
+ return [] # Return empty list if no predictions to evaluate
505
+
506
+ try:
507
+ evaluation_results = evaluate(predictions=predictions, data=dataset)
508
+ logger.info("Evaluation completed.")
509
+ if not evaluation_results:
510
+ logger.error("Evaluation returned no results (empty list/None).")
511
+ # Raise an error as this indicates a problem in the evaluation process
512
+ raise RuntimeError("Evaluation returned no results.")
513
+ if not isinstance(evaluation_results, EvaluationResults):
514
+ logger.error(
515
+ f"Evaluation returned unexpected type: {type(evaluation_results)}. Expected list."
516
+ )
517
+ raise RuntimeError(
518
+ f"Evaluation returned unexpected type: {type(evaluation_results)}"
519
+ )
520
+
521
+ return evaluation_results
522
+ except Exception:
523
+ logger.exception("An error occurred during evaluation") # Use logger.exception
524
+ raise # Re-raise after logging
525
+
526
+
527
+ def _get_unitxt_commit_hash() -> Optional[str]:
528
+ """Tries to get the git commit hash of the installed unitxt package."""
529
+ try:
530
+ # Find the directory of the unitxt package
531
+ # Use inspect to be more robust finding the package path
532
+
533
+ current_script_path = os.path.abspath(__file__)
534
+ package_dir = os.path.dirname(current_script_path)
535
+
536
+ # Check if it's a git repository and get the commit hash
537
+ # Use absolute path for git command
538
+ git_command = ["git", "-C", os.path.abspath(package_dir), "rev-parse", "HEAD"]
539
+ logger.debug(f"Running git command: {' '.join(git_command)}")
540
+ result = subprocess.run(
541
+ git_command,
542
+ capture_output=True,
543
+ text=True,
544
+ check=False, # Don't raise error if git command fails
545
+ encoding="utf-8",
546
+ errors="ignore", # Ignore potential decoding errors
547
+ )
548
+ if result.returncode == 0:
549
+ commit_hash = result.stdout.strip()
550
+ logger.info(f"Found unitxt git commit hash: {commit_hash}")
551
+ # Verify it looks like a hash (e.g., 40 hex chars)
552
+ if len(commit_hash) == 40 and all(
553
+ c in "0123456789abcdef" for c in commit_hash
554
+ ):
555
+ return commit_hash
556
+ logger.warning(
557
+ f"Git command output '{commit_hash}' doesn't look like a valid commit hash."
558
+ )
559
+ return None
560
+ stderr_msg = result.stderr.strip() if result.stderr else "No stderr"
561
+ logger.warning(
562
+ f"Could not get unitxt git commit hash (git command failed with code {result.returncode}): {stderr_msg}"
563
+ )
564
+ return None
565
+ except ImportError:
566
+ logger.warning("unitxt package not found, cannot determine commit hash.")
567
+ return None
568
+ except FileNotFoundError:
569
+ logger.warning(
570
+ "'git' command not found in PATH. Cannot determine unitxt commit hash."
571
+ )
572
+ return None
573
+ except Exception as e:
574
+ logger.warning(
575
+ f"Error getting unitxt commit hash: {e}", exc_info=True
576
+ ) # Log traceback
577
+ return None
578
+
579
+
580
+ def _get_installed_packages() -> Dict[str, str]:
581
+ """Gets a dictionary of installed packages and their versions."""
582
+ packages = {}
583
+ try:
584
+ for dist in importlib.metadata.distributions():
585
+ # Handle potential missing metadata gracefully
586
+ name = dist.metadata.get("Name")
587
+ version = dist.metadata.get("Version")
588
+ if name and version:
589
+ packages[name] = version
590
+ elif name:
591
+ packages[name] = "N/A" # Record package even if version is missing
592
+ logger.debug(f"Could not find version for package: {name}")
593
+
594
+ logger.info(f"Collected versions for {len(packages)} installed packages.")
595
+ except Exception as e:
596
+ logger.warning(f"Could not retrieve installed package list: {e}", exc_info=True)
597
+ return packages
598
+
599
+
600
+ def _get_unitxt_version() -> str:
601
+ """Gets the installed unitxt version using importlib.metadata."""
602
+ try:
603
+ version = importlib.metadata.version("unitxt")
604
+ logger.info(f"Found unitxt version using importlib.metadata: {version}")
605
+ return version
606
+ except importlib.metadata.PackageNotFoundError:
607
+ logger.warning(
608
+ "Could not find 'unitxt' package version using importlib.metadata. Is it installed correctly?"
609
+ )
610
+ return "N/A"
611
+ except Exception as e:
612
+ logger.warning(
613
+ f"Error getting unitxt version using importlib.metadata: {e}", exc_info=True
614
+ )
615
+ return "N/A"
616
+
617
+
618
+ def prepend_timestamp_to_path(original_path, timestamp):
619
+ """Takes a path string and a timestamp string, prepends the timestamp to the filename part of the path, and returns the new path string."""
620
+ directory, filename = os.path.split(original_path)
621
+ # Use an f-string to create the new filename with the timestamp prepended
622
+ new_filename = f"{timestamp}_{filename}"
623
+ # Join the directory and the new filename back together
624
+ return os.path.join(directory, new_filename)
625
+
626
+
627
+ def _save_results_to_disk(
628
+ args: argparse.Namespace,
629
+ global_scores: Dict[str, Any],
630
+ all_samples_data: Dict[str, List[Dict[str, Any]]],
631
+ results_path: str,
632
+ samples_path: str,
633
+ ) -> None:
634
+ """Saves the configuration, environment info, global scores, and samples to JSON files.
635
+
636
+ Args:
637
+ args (argparse.Namespace): Parsed command-line arguments.
638
+ global_scores (Dict[str, Any]): Dictionary of global scores.
639
+ all_samples_data (Dict[str, List[Dict[str, Any]]]): List of processed sample data.
640
+ results_path (str): Path to save the summary results JSON file.
641
+ samples_path (str): Path to save the detailed samples JSON file.
642
+ """
643
+ # --- Gather Configuration ---
644
+ config_to_save = {}
645
+ for k, v in vars(args).items():
646
+ # Ensure complex objects are represented as strings
647
+ if isinstance(v, (str, int, float, bool, list, dict, type(None))):
648
+ config_to_save[k] = v
649
+ else:
650
+ try:
651
+ # Try standard repr first
652
+ config_to_save[k] = repr(v)
653
+ except Exception:
654
+ # Fallback if repr fails
655
+ config_to_save[k] = (
656
+ f"<Object of type {type(v).__name__} could not be represented>"
657
+ )
658
+
659
+ # --- Gather Environment Info ---
660
+ unitxt_commit = _get_unitxt_commit_hash()
661
+ # Get version using the dedicated function
662
+ unitxt_pkg_version = _get_unitxt_version()
663
+
664
+ environment_info = {
665
+ "timestamp_utc": datetime.utcnow().isoformat() + "Z",
666
+ "command_line_invocation": sys.argv,
667
+ "parsed_arguments": config_to_save, # Include parsed args here as well
668
+ "unitxt_version": unitxt_pkg_version, # Use version from importlib.metadata
669
+ "unitxt_commit_hash": unitxt_commit if unitxt_commit else "N/A",
670
+ "python_version": platform.python_version(),
671
+ "system": platform.system(),
672
+ "system_version": platform.version(),
673
+ "installed_packages": _get_installed_packages(),
674
+ }
675
+
676
+ # --- Prepare Final Results Structure ---
677
+ results_summary = {
678
+ "environment_info": environment_info,
679
+ "results": global_scores,
680
+ }
681
+
682
+ # prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
683
+
684
+ timestamp = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
685
+
686
+ results_path = prepend_timestamp_to_path(results_path, timestamp)
687
+ samples_path = prepend_timestamp_to_path(samples_path, timestamp)
688
+
689
+ # --- Save Summary ---
690
+ logger.info(f"Saving global results summary to: {results_path}")
691
+ try:
692
+ with open(results_path, "w", encoding="utf-8") as f:
693
+ json.dump(results_summary, f, indent=4, ensure_ascii=False)
694
+ except OSError as e:
695
+ logger.error(f"Failed to write results summary file {results_path}: {e}")
696
+ except TypeError as e:
697
+ logger.error(
698
+ f"Failed to serialize results summary to JSON: {e}. Check data types."
699
+ )
700
+ # Log the problematic structure if possible (might be large)
701
+ # logger.debug(f"Problematic results_summary structure: {results_summary}")
702
+
703
+ # --- Save Samples (if requested) ---
704
+ if args.log_samples:
705
+ logger.info(f"Saving detailed samples to: {samples_path}")
706
+ # Structure samples file with environment info as well for self-containment
707
+ samples_output = {
708
+ "environment_info": environment_info, # Repeat env info here
709
+ "samples": all_samples_data,
710
+ }
711
+ try:
712
+ with open(samples_path, "w", encoding="utf-8") as f:
713
+ json.dump(samples_output, f, indent=4, ensure_ascii=False)
714
+ except OSError as e:
715
+ logger.error(f"Failed to write samples file {samples_path}: {e}")
716
+ except TypeError as e:
717
+ logger.error(f"Failed to serialize samples to JSON: {e}. Check data types.")
718
+
719
+
720
+ def process_and_save_results(
721
+ args: argparse.Namespace,
722
+ evaluation_results: EvaluationResults,
723
+ results_path: str,
724
+ samples_path: str,
725
+ ) -> None:
726
+ """Processes, prints, and saves the evaluation results.
727
+
728
+ Args:
729
+ args (argparse.Namespace): Parsed command-line arguments.
730
+ evaluation_results (EvaluationResults): The list of evaluated instances.
731
+ results_path (str): Path to save the summary results JSON file.
732
+ samples_path (str): Path to save the detailed samples JSON file.
733
+
734
+ Raises:
735
+ Exception: If an error occurs during result processing or saving (re-raised).
736
+ """
737
+ try:
738
+ # global_scores, all_samples_data = _extract_scores_and_samples(evaluated_dataset)
739
+
740
+ subsets_scores = evaluation_results.subsets_scores
741
+ instances_results = evaluation_results.instance_scores
742
+
743
+ subset_instances = {}
744
+ for instance in instances_results:
745
+ if instance["subset"][0] not in subset_instances:
746
+ subset_instances[instance["subset"][0]] = []
747
+ del instance["postprocessors"]
748
+ subset_instances[instance["subset"][0]].append(instance)
749
+
750
+ logger.info(f"\n{subsets_scores.summary}")
751
+
752
+ # --- Save Results ---
753
+ # Pass all necessary data to the saving function
754
+ _save_results_to_disk(
755
+ args, subsets_scores, subset_instances, results_path, samples_path
756
+ )
757
+
758
+ except Exception:
759
+ logger.exception(
760
+ "An error occurred during result processing or saving"
761
+ ) # Use logger.exception
762
+ raise # Re-raise after logging
763
+
764
+
765
+ def main():
766
+ """Main function to parse arguments and run evaluation."""
767
+ parser = setup_parser()
768
+ args = parser.parse_args()
769
+
770
+ # Setup logging ASAP
771
+ setup_logging(args.verbosity)
772
+
773
+ logger.info("Starting Unitxt Evaluation CLI")
774
+ # Log raw and parsed args at DEBUG level
775
+ logger.debug(f"Raw command line arguments: {sys.argv}")
776
+ logger.debug(f"Parsed arguments: {vars(args)}") # Log the vars(args) dict
777
+ logger.debug(
778
+ f"Parsed model_args type: {type(args.model_args)}, value: {args.model_args}"
779
+ )
780
+
781
+ try:
782
+ results_path, samples_path = prepare_output_paths(
783
+ args.output_path, args.output_file_prefix
784
+ )
785
+
786
+ # Apply unitxt settings within a context manager
787
+ with configure_unitxt_settings(args):
788
+ test_dataset = cli_load_dataset(args)
789
+ model_args_dict = prepare_kwargs(args.model_args)
790
+ gen_kwargs_dict = prepare_kwargs(args.gen_kwargs)
791
+ chat_kwargs_dict = prepare_kwargs(args.chat_template_kwargs)
792
+
793
+ model_args_dict.update(gen_kwargs_dict)
794
+ inference_model = initialize_inference_engine(
795
+ args, model_args_dict, chat_kwargs_dict
796
+ )
797
+ predictions = run_inference(inference_model, test_dataset)
798
+ evaluation_results = run_evaluation(predictions, test_dataset)
799
+ process_and_save_results(
800
+ args, evaluation_results, results_path, samples_path
801
+ )
802
+
803
+ # --- More Specific Error Handling ---
804
+ except (UnitxtArtifactNotFoundError, FileNotFoundError) as e:
805
+ logger.exception(f"Error loading artifact or file: {e}")
806
+ sys.exit(1)
807
+ except (AttributeError, ValueError) as e:
808
+ # Catch issues like missing keys in args, parsing errors, etc.
809
+ logger.exception(f"Configuration or value error: {e}")
810
+ sys.exit(1)
811
+ except ImportError as e:
812
+ # Catch missing optional dependencies
813
+ logger.exception(f"Missing dependency: {e}")
814
+ sys.exit(1)
815
+ except RuntimeError as e:
816
+ # Catch errors explicitly raised during execution (e.g., evaluation failure)
817
+ logger.exception(f"Runtime error during processing: {e}")
818
+ sys.exit(1)
819
+ except Exception as e:
820
+ # Catch any other unexpected errors
821
+ logger.exception(f"An unexpected error occurred: {e}")
822
+ sys.exit(1)
823
+
824
+ logger.info("Unitxt Evaluation CLI finished successfully.")
825
+
826
+
827
+ if __name__ == "__main__":
828
+ main()
inference.py CHANGED
@@ -61,6 +61,7 @@ def batched(lst, n):
61
  while batch := list(islice(it, n)):
62
  yield batch
63
 
 
64
  class StandardAPIParamsMixin(Artifact):
65
  model: str
66
  frequency_penalty: Optional[float] = None
@@ -157,6 +158,7 @@ class ListWithMetadata(List[T]):
157
 
158
  class InferenceEngine(Artifact):
159
  """Abstract base class for inference."""
 
160
  cache_batch_size: int = 100
161
  use_cache: bool = True
162
 
@@ -206,9 +208,9 @@ class InferenceEngine(Artifact):
206
  instance_str = json.dumps(record, sort_keys=True)
207
  return hashlib.md5(instance_str.encode()).hexdigest()
208
 
209
- def verify_infer_inputs(self,
210
- dataset: Union[List[Dict[str, Any]], Dataset],
211
- return_meta_data: bool):
212
  if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
213
  raise Exception(
214
  "Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
@@ -238,33 +240,49 @@ class InferenceEngine(Artifact):
238
  if self.use_cache:
239
  number_of_batches = len(dataset) // self.cache_batch_size + 1
240
  result = []
241
- for batch_index, batch in enumerate(batched(dataset, self.cache_batch_size)):
 
 
242
  cached_results = []
243
  missing_examples = []
244
  for i, item in enumerate(batch):
245
  cache_key = self._get_cache_key(item)
246
  cached_value = self._cache.get(cache_key)
247
  if cached_value is not None:
248
- cached_results.append((i, cached_value)) # each element is index in batch, and value
 
 
249
  else:
250
- missing_examples.append((i, item)) # each element is index in batch and example
 
 
251
  # infare on missing examples only, without indices
252
 
253
- logger.info(f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})")
254
- if (len(missing_examples) > 0):
255
- inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
 
 
 
 
256
  # recombined to index and value
257
- inferred_results = list(zip([e[0] for e in missing_examples], inferred_results))
 
 
258
  # Add missing examples to cache
259
- for (_, item), (_, prediction) in zip(missing_examples, inferred_results):
 
 
260
  if prediction is None:
261
  continue
262
  cache_key = self._get_cache_key(item)
263
  self._cache[cache_key] = prediction
264
  else:
265
- inferred_results=[]
266
  # Combine cached and inferred results in original order
267
- batch_predictions = [p[1] for p in sorted(cached_results + inferred_results)]
 
 
268
  result.extend(batch_predictions)
269
  else:
270
  result = self._infer(dataset, return_meta_data)
@@ -414,6 +432,8 @@ class HFInferenceEngineBase(
414
  low_cpu_mem_usage: bool = True
415
  torch_dtype: str = "torch.float16"
416
 
 
 
417
  model: Any = InternalField(default=None, name="Inference object")
418
  processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
419
 
@@ -618,16 +638,52 @@ class HFInferenceEngineBase(
618
  class HFAutoModelInferenceEngine(HFInferenceEngineBase):
619
  label: str = "hf_auto_model"
620
 
 
 
 
 
 
 
 
 
 
 
 
621
  def _init_processor(self):
622
  from transformers import AutoTokenizer
623
 
624
  self.processor = AutoTokenizer.from_pretrained(
625
  pretrained_model_name_or_path=self.model_name,
626
  use_fast=self.use_fast_tokenizer,
627
- padding=True,
628
- truncation=True,
629
  )
630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  def _init_model(self):
632
  from transformers import (
633
  AutoConfig,
@@ -641,11 +697,12 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
641
  else AutoModelForCausalLM
642
  )
643
 
 
 
644
  self.model = model_class.from_pretrained(
645
  pretrained_model_name_or_path=self.model_name,
646
  trust_remote_code=True,
647
- device_map=self.device_map,
648
- torch_dtype=self._get_torch_dtype(),
649
  )
650
  if self.device_map is None:
651
  self.model.to(self.device)
@@ -653,13 +710,21 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
653
  def prepare_inputs(self, data: Iterable) -> Mapping:
654
  if isinstance(data[0], list):
655
  data = self.processor.apply_chat_template(
656
- data, tokenize=False, add_generation_prompt=True
 
 
 
657
  )
 
 
 
 
658
  return self.processor(
659
  data,
660
- padding=True,
661
- truncation=True,
662
  return_tensors="pt",
 
 
 
663
  ).to(self.device or self.device_map)
664
 
665
  def _infer_fn(
@@ -668,40 +733,81 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
668
  return_meta_data: bool,
669
  return_logprobs: bool,
670
  ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
671
- tokenized_inputs = self.prepare_inputs(
672
- [instance["source"] for instance in dataset]
673
- )
674
- input_length = (
675
- 1
676
- if self.model.config.is_encoder_decoder
677
- else tokenized_inputs.input_ids.shape[1]
678
- )
679
 
680
- predictions = self.make_predictions(tokenized_inputs)
681
- sequences = predictions.sequences
 
 
 
682
 
683
- string_tokens = [
684
- self.decode_tokens(sequence, input_length) for sequence in sequences
685
- ]
 
 
686
 
687
- final_outputs = (
688
- self.get_logprobs(predictions, string_tokens)
689
- if return_logprobs
690
- else [self.create_string_from_tokens(strings) for strings in string_tokens]
691
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
- return [
694
- self.get_return_object(
695
- output=final_outputs[i],
696
- output_tokens=len(string_tokens[i]),
697
- inp=dataset[i]["source"],
698
- inp_tokens=len(tokenized_inputs.encodings[i].tokens)
699
- if tokenized_inputs.encodings is not None
700
- else None,
701
- return_meta_data=return_meta_data,
 
 
 
 
 
 
 
 
702
  )
703
- for i in range(len(sequences))
704
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
 
706
  def _infer(
707
  self,
@@ -885,10 +991,10 @@ class HFPeftInferenceEngine(HFAutoModelInferenceEngine):
885
 
886
  model_class = (
887
  AutoPeftModelForSeq2SeqLM
888
- if AutoConfig.from_pretrained(self.model_name).is_encoder_decoder
889
  else AutoPeftModelForCausalLM
890
  )
891
- path = self.peft_config.base_model_name_or_path
892
  if settings.hf_offline_models_path is not None:
893
  path = os.path.join(settings.hf_offline_models_path, path)
894
 
@@ -899,6 +1005,7 @@ class HFPeftInferenceEngine(HFAutoModelInferenceEngine):
899
  low_cpu_mem_usage=self.low_cpu_mem_usage,
900
  torch_dtype=self._get_torch_dtype(),
901
  )
 
902
  if self.device_map is None:
903
  self.model.to(self.device)
904
 
@@ -949,19 +1056,27 @@ class HFPipelineBasedInferenceEngine(
949
  except Exception:
950
  try:
951
  from peft import PeftConfig
 
952
  # If full model loading fails, try loading as a PEFT adapter
953
  peft_config = PeftConfig.from_pretrained(path)
954
 
955
  if not peft_config.base_model_name_or_path:
956
- raise ValueError(f"Base model name not found in PEFT config for {path}")
 
 
957
 
958
  # Load the base model's config
959
- config = AutoConfig.from_pretrained(peft_config.base_model_name_or_path, trust_remote_code=True)
 
 
960
  except Exception as err2:
961
- raise ValueError(f"Could not determine model type for: {path}") from err2
962
-
 
963
 
964
- self.task = "text2text-generation" if config.is_encoder_decoder else "text-generation"
 
 
965
 
966
  def _get_model_args(self) -> Dict[str, Any]:
967
  import torch
@@ -1306,9 +1421,9 @@ class OptionSelectingByLogProbsInferenceEngine:
1306
  for option in instance["task_data"]["options"]
1307
  ]
1308
 
1309
- dataset_with_options_logprobs: List[
1310
- List[Dict[str, Union[float, str]]]
1311
- ] = self.get_options_log_probs(dataset_with_options)
1312
 
1313
  dataset_iterator = iter(dataset_with_options_logprobs)
1314
 
@@ -1381,7 +1496,7 @@ class IbmGenAiInferenceEngine(
1381
  def _get_credentials():
1382
  from genai import Credentials
1383
 
1384
- api_key_env_var_name = "GENAI_KEY" # pragma: allowlist secret
1385
  api_key = os.environ.get(api_key_env_var_name)
1386
 
1387
  assert api_key is not None, (
@@ -1467,9 +1582,9 @@ class IbmGenAiInferenceEngine(
1467
  predict_results = []
1468
  for prediction in predictions:
1469
  result: TextGenerationResult = prediction.results[0]
1470
- assert isinstance(
1471
- result.generated_tokens, list
1472
- ), "result.generated_tokens should be a list"
1473
 
1474
  predict_result = []
1475
  for base_token in result.generated_tokens:
@@ -1714,6 +1829,7 @@ class OpenAiInferenceEngine(
1714
  @run_with_imap
1715
  def _get_chat_completion(self, instance, return_meta_data):
1716
  import openai
 
1717
  messages = self.to_messages(instance)
1718
  try:
1719
  response = self.client.chat.completions.create(
@@ -1725,13 +1841,17 @@ class OpenAiInferenceEngine(
1725
  return self.get_return_object(prediction, response, return_meta_data)
1726
  # catch in case of content_filtering failure
1727
  except openai.BadRequestError as e:
1728
- logging.error(f"Error predicting instance {messages}:{e}. Returning empty prediction")
1729
- return TextGenerationInferenceOutput(prediction = "-", input_tokens=0, output_tokens=0)
1730
-
 
 
 
1731
 
1732
  @run_with_imap
1733
  def _get_logprobs(self, instance, return_meta_data):
1734
  import openai
 
1735
  messages = self.to_messages(instance)
1736
  try:
1737
  response = self.client.chat.completions.create(
@@ -1752,13 +1872,13 @@ class OpenAiInferenceEngine(
1752
  return self.get_return_object(pred_output, response, return_meta_data)
1753
  # catch in case of content_filtering failure
1754
  except openai.BadRequestError as e:
1755
- logging.error(f"Error predicting instance {messages}:{e}. Returning empty prediction")
1756
- prediction = [{"top_tokens": [
1757
- {"text": "-", "logprob": 0}
1758
- ]
1759
- }]
1760
- return TextGenerationInferenceOutput(prediction=prediction, input_tokens=0, output_tokens=0)
1761
-
1762
 
1763
  def get_return_object(self, predict_result, response, return_meta_data):
1764
  if return_meta_data:
@@ -1792,9 +1912,9 @@ class AzureOpenAIInferenceEngine(OpenAiInferenceEngine):
1792
  api_version = self.credentials.get(
1793
  "api_version", os.environ.get("OPENAI_API_VERSION", None)
1794
  )
1795
- assert (
1796
- api_version and azure_openapi_host
1797
- ), "Error while trying to run AzureOpenAIInferenceEngine: Missing environment variable param AZURE_OPENAI_HOST or OPENAI_API_VERSION"
1798
  api_url = f"{azure_openapi_host}/openai/deployments/{self.model_name}/chat/completions?api-version={api_version}"
1799
 
1800
  return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
@@ -1821,9 +1941,7 @@ class RITSInferenceEngine(
1821
  label: str = "rits"
1822
  data_classification_policy = ["public", "proprietary"]
1823
 
1824
- model_names_dict = {
1825
- "microsoft/phi-4": "microsoft-phi-4"
1826
- }
1827
 
1828
  def get_default_headers(self):
1829
  return {"RITS_API_KEY": self.credentials["api_key"]}
@@ -1891,7 +2009,7 @@ class TogetherAiInferenceEngine(
1891
  from together import Together
1892
  from together.types.models import ModelType
1893
 
1894
- api_key_env_var_name = "TOGETHER_API_KEY" # pragma: allowlist secret
1895
  api_key = os.environ.get(api_key_env_var_name)
1896
  assert api_key is not None, (
1897
  f"Error while trying to run TogetherAiInferenceEngine."
@@ -1906,9 +2024,9 @@ class TogetherAiInferenceEngine(
1906
  together_model.id: together_model.type for together_model in together_models
1907
  }
1908
  model_type = together_model_id_to_type.get(self.model_name)
1909
- assert (
1910
- model_type is not None
1911
- ), f"Could not find model {self.model_name} in Together AI model list"
1912
  assert model_type in [ModelType.CHAT, ModelType.LANGUAGE, ModelType.CODE], (
1913
  f"Together AI model type {model_type} is not supported; "
1914
  "supported types are 'chat', 'language' and 'code'."
@@ -2087,11 +2205,11 @@ class WMLInferenceEngineBase(
2087
  def verify(self):
2088
  super().verify()
2089
 
2090
- assert (
2091
- self.model_name
2092
- or self.deployment_id
2093
- and not (self.model_name and self.deployment_id)
2094
- ), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
2095
 
2096
  # def process_data_before_dump(self, data):
2097
  # if "credentials" in data:
@@ -2110,11 +2228,11 @@ class WMLInferenceEngineBase(
2110
  self._verify_wml_credentials(self.credentials)
2111
  return APIClient(
2112
  credentials=Credentials(
2113
- api_key=self.credentials["api_key"],
2114
- url=self.credentials["url"]
2115
  ),
2116
  project_id=self.credentials.get("project_id", None),
2117
- space_id=self.credentials.get("space_id", None))
 
2118
 
2119
  @staticmethod
2120
  def _read_wml_credentials_from_env() -> CredentialsWML:
@@ -2182,9 +2300,9 @@ class WMLInferenceEngineBase(
2182
  "['url', 'api_key', 'username', 'password']."
2183
  )
2184
 
2185
- assert credentials.get(
2186
- "url"
2187
- ), "'url' is a mandatory key for WML credentials dict."
2188
  assert "space_id" in credentials or "project_id" in credentials, (
2189
  "Either 'space_id' or 'project_id' must be provided "
2190
  "as keys for WML credentials dict."
@@ -2585,7 +2703,9 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2585
  return True
2586
 
2587
  def to_messages(self, instance: Union[Dict, List]) -> List[List[Dict[str, Any]]]:
2588
- if isinstance(instance["source"], str) and self.check_instance_contains_image(instance):
 
 
2589
  return self._create_messages_from_instance(instance)
2590
 
2591
  messages = super().to_messages(instance)
@@ -2909,7 +3029,7 @@ class VLLMParamsMixin(Artifact):
2909
 
2910
 
2911
  class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin):
2912
- label="vllm"
2913
 
2914
  def get_engine_id(self):
2915
  return get_model_and_label_id(self.model, self.label)
@@ -3011,7 +3131,6 @@ class LiteLLMInferenceEngine(
3011
  self.inference_type = "litellm"
3012
  from litellm import acompletion
3013
 
3014
-
3015
  self._completion = acompletion
3016
  # Initialize a semaphore to limit concurrency
3017
  self._semaphore = asyncio.Semaphore(round(self.max_requests_per_second))
@@ -3032,7 +3151,6 @@ class LiteLLMInferenceEngine(
3032
  response = await self._completion(
3033
  messages=messages,
3034
  max_retries=self.max_retries,
3035
- caching=True,
3036
  drop_params=False,
3037
  **self.credentials,
3038
  **kwargs,
@@ -3123,10 +3241,10 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3123
 
3124
  label: str = "cross_provider"
3125
  provider: Optional[_supported_apis] = None
3126
- provider_specific_args: Optional[Dict[str, Dict[str,str]]] = None
3127
 
3128
  provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
3129
- "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
3130
  "granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
3131
  "granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
3132
  "granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
@@ -3153,7 +3271,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3153
  "llama-3-1-70b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
3154
  "llama-3-1-405b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
3155
  "llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
3156
- "llama-3-3-70b-instruct": "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo"
3157
  },
3158
  "aws": {
3159
  "llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
@@ -3167,7 +3285,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3167
  "llama-3-1-405b-instruct": "llama3.1:405b",
3168
  "llama-3-2-1b-instruct": "llama3.2:1b",
3169
  "llama-3-2-3b-instruct": "llama3.2:3b",
3170
- "llama-3-3-70b-instruct": "llama3.3"
3171
  },
3172
  "bam": {
3173
  "granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
@@ -3242,12 +3360,14 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3242
  "llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas",
3243
  },
3244
  "replicate": {
3245
- "granite-20b-code-instruct-8k": "replicate/ibm-granite/granite-20b-code-instruct-8k",
3246
- "granite-3-2b-instruct": "replicate/ibm-granite/granite-3.0-2b-instruct",
3247
- "granite-3-8b-instruct": "replicate/ibm-granite/granite-3.0-8b-instruct",
3248
- "granite-3-1-2b-instruct": "replicate/ibm-granite/granite-3.1-2b-instruct",
3249
  "granite-3-1-8b-instruct": "replicate/ibm-granite/granite-3.1-8b-instruct",
 
 
 
3250
  "granite-8b-code-instruct-128k": "replicate/ibm-granite/granite-8b-code-instruct-128k",
 
3251
  "llama-2-13b": "replicate/meta/llama-2-13b",
3252
  "llama-2-13b-chat": "replicate/meta/llama-2-13b-chat",
3253
  "llama-2-70b": "replicate/meta/llama-2-70b",
@@ -3264,7 +3384,9 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3264
  "mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
3265
  },
3266
  }
3267
- provider_model_map["watsonx"] = {k: f"watsonx/{v}" for k,v in provider_model_map["watsonx-sdk"].items()}
 
 
3268
 
3269
  _provider_to_base_class = {
3270
  "watsonx": LiteLLMInferenceEngine,
@@ -3307,7 +3429,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3307
  args["model"] = self.provider_model_map[provider].get(self.model, self.model)
3308
 
3309
  if self.provider_specific_args is not None:
3310
- provider_args = self.provider_specific_args.get(provider)
3311
  if provider_args is not None:
3312
  args.update(provider_args)
3313
 
@@ -3342,6 +3464,7 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
3342
 
3343
  This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
3344
  """
 
3345
  label = "hf_option_selection"
3346
  model_name: str
3347
  batch_size: int
@@ -3368,10 +3491,8 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
3368
  path,
3369
  )
3370
  self.model = AutoModelForCausalLM.from_pretrained(
3371
- path,
3372
- ).to(
3373
- self.device
3374
- )
3375
  # Set pad_token if it doesn't exist
3376
  if self.tokenizer.pad_token is None:
3377
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
61
  while batch := list(islice(it, n)):
62
  yield batch
63
 
64
+
65
  class StandardAPIParamsMixin(Artifact):
66
  model: str
67
  frequency_penalty: Optional[float] = None
 
158
 
159
  class InferenceEngine(Artifact):
160
  """Abstract base class for inference."""
161
+
162
  cache_batch_size: int = 100
163
  use_cache: bool = True
164
 
 
208
  instance_str = json.dumps(record, sort_keys=True)
209
  return hashlib.md5(instance_str.encode()).hexdigest()
210
 
211
+ def verify_infer_inputs(
212
+ self, dataset: Union[List[Dict[str, Any]], Dataset], return_meta_data: bool
213
+ ):
214
  if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
215
  raise Exception(
216
  "Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
 
240
  if self.use_cache:
241
  number_of_batches = len(dataset) // self.cache_batch_size + 1
242
  result = []
243
+ for batch_index, batch in enumerate(
244
+ batched(dataset, self.cache_batch_size)
245
+ ):
246
  cached_results = []
247
  missing_examples = []
248
  for i, item in enumerate(batch):
249
  cache_key = self._get_cache_key(item)
250
  cached_value = self._cache.get(cache_key)
251
  if cached_value is not None:
252
+ cached_results.append(
253
+ (i, cached_value)
254
+ ) # each element is index in batch, and value
255
  else:
256
+ missing_examples.append(
257
+ (i, item)
258
+ ) # each element is index in batch and example
259
  # infare on missing examples only, without indices
260
 
261
+ logger.info(
262
+ f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})"
263
+ )
264
+ if len(missing_examples) > 0:
265
+ inferred_results = self._infer(
266
+ [e[1] for e in missing_examples], return_meta_data
267
+ )
268
  # recombined to index and value
269
+ inferred_results = list(
270
+ zip([e[0] for e in missing_examples], inferred_results)
271
+ )
272
  # Add missing examples to cache
273
+ for (_, item), (_, prediction) in zip(
274
+ missing_examples, inferred_results
275
+ ):
276
  if prediction is None:
277
  continue
278
  cache_key = self._get_cache_key(item)
279
  self._cache[cache_key] = prediction
280
  else:
281
+ inferred_results = []
282
  # Combine cached and inferred results in original order
283
+ batch_predictions = [
284
+ p[1] for p in sorted(cached_results + inferred_results)
285
+ ]
286
  result.extend(batch_predictions)
287
  else:
288
  result = self._infer(dataset, return_meta_data)
 
432
  low_cpu_mem_usage: bool = True
433
  torch_dtype: str = "torch.float16"
434
 
435
+ batch_size: int = 1
436
+
437
  model: Any = InternalField(default=None, name="Inference object")
438
  processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
439
 
 
638
  class HFAutoModelInferenceEngine(HFInferenceEngineBase):
639
  label: str = "hf_auto_model"
640
 
641
+ use_fp16: bool = True
642
+ load_in_8bit: bool = False
643
+
644
+ device_map: Any = None
645
+
646
+ padding: bool = True
647
+ truncation: bool = True
648
+ padding_side: str = "left" # for decoder only models
649
+
650
+ chat_kwargs_dict: dict = {}
651
+
652
  def _init_processor(self):
653
  from transformers import AutoTokenizer
654
 
655
  self.processor = AutoTokenizer.from_pretrained(
656
  pretrained_model_name_or_path=self.model_name,
657
  use_fast=self.use_fast_tokenizer,
 
 
658
  )
659
 
660
+ def _get_model_args(self) -> Dict[str, Any]:
661
+ import torch
662
+ from transformers import BitsAndBytesConfig
663
+
664
+ args = {}
665
+
666
+ if self.load_in_8bit:
667
+ quantization_config = BitsAndBytesConfig(load_in_8bit=self.load_in_8bit)
668
+ args["quantization_config"] = quantization_config
669
+ elif self.use_fp16:
670
+ if self.device == torch.device("mps"):
671
+ args["torch_dtype"] = torch.float16
672
+ else:
673
+ args["torch_dtype"] = torch.bfloat16
674
+
675
+ # We do this, because in some cases, using device:auto will offload some weights to the cpu
676
+ # (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will
677
+ # cause an error because the data is always on the gpu
678
+ # if torch.cuda.device_count() > 1:
679
+ # assert self.device == torch.device(0)
680
+ args["device_map"] = "auto"
681
+ # else:
682
+ # if not self.load_in_8bit:
683
+ # args["device"] = self.device
684
+
685
+ return args
686
+
687
  def _init_model(self):
688
  from transformers import (
689
  AutoConfig,
 
697
  else AutoModelForCausalLM
698
  )
699
 
700
+ model_args = self._get_model_args()
701
+
702
  self.model = model_class.from_pretrained(
703
  pretrained_model_name_or_path=self.model_name,
704
  trust_remote_code=True,
705
+ **model_args,
 
706
  )
707
  if self.device_map is None:
708
  self.model.to(self.device)
 
710
  def prepare_inputs(self, data: Iterable) -> Mapping:
711
  if isinstance(data[0], list):
712
  data = self.processor.apply_chat_template(
713
+ data,
714
+ tokenize=False,
715
+ add_generation_prompt=True,
716
+ **self.chat_kwargs_dict,
717
  )
718
+
719
+ if self.processor.pad_token is None:
720
+ self.processor.pad_token_id = self.model.config.eos_token_id[0]
721
+
722
  return self.processor(
723
  data,
 
 
724
  return_tensors="pt",
725
+ padding=self.padding,
726
+ truncation=self.truncation,
727
+ padding_side=self.padding_side,
728
  ).to(self.device or self.device_map)
729
 
730
  def _infer_fn(
 
733
  return_meta_data: bool,
734
  return_logprobs: bool,
735
  ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
736
+ """Performs inference on the dataset in batches.
 
 
 
 
 
 
 
737
 
738
+ Args:
739
+ dataset: A list of dictionaries or a Dataset object containing the input data.
740
+ Each item should have a "source" key.
741
+ return_meta_data: Whether to include metadata in the output.
742
+ return_logprobs: Whether to return log probabilities along with the output.
743
 
744
+ Returns:
745
+ A list of outputs, which can be strings, dictionaries (if metadata is returned),
746
+ or TextGenerationInferenceOutput objects (if logprobs are returned).
747
+ """
748
+ all_final_outputs = [] # List to store results from all batches
749
 
750
+ for i in tqdm(
751
+ range(0, len(dataset), self.batch_size),
752
+ desc=f"Running inference in batches of {self.batch_size}",
753
+ ):
754
+ # Get the current batch
755
+ batch_data = dataset[i : i + self.batch_size]
756
+ batch_sources = [instance["source"] for instance in batch_data]
757
+
758
+ # --- Process the current batch ---
759
+ # 1. Tokenize inputs for the batch
760
+ tokenized_inputs = self.prepare_inputs(batch_sources)
761
+
762
+ # 2. Determine input length (handle encoder-decoder models)
763
+ input_length = (
764
+ 1
765
+ if self.model.config.is_encoder_decoder
766
+ else tokenized_inputs.input_ids.shape[1]
767
+ )
768
 
769
+ # 3. Make predictions for the batch
770
+ predictions = self.make_predictions(tokenized_inputs)
771
+ sequences = predictions.sequences # Sequences for the current batch
772
+
773
+ # 4. Decode tokens for the batch
774
+ string_tokens_batch = [
775
+ self.decode_tokens(sequence, input_length) for sequence in sequences
776
+ ]
777
+
778
+ # 5. Calculate logprobs or create strings for the batch
779
+ final_outputs_batch = (
780
+ self.get_logprobs(predictions, string_tokens_batch)
781
+ if return_logprobs
782
+ else [
783
+ self.create_string_from_tokens(strings)
784
+ for strings in string_tokens_batch
785
+ ]
786
  )
787
+
788
+ # 6. Create return objects for the batch
789
+ batch_results = [
790
+ self.get_return_object(
791
+ output=final_outputs_batch[
792
+ j
793
+ ], # Output for the j-th item in the batch
794
+ output_tokens=len(string_tokens_batch[j]),
795
+ inp=batch_data[j]["source"], # Original input for the j-th item
796
+ inp_tokens=len(tokenized_inputs.encodings[j].tokens)
797
+ if tokenized_inputs.encodings is not None
798
+ else None,
799
+ return_meta_data=return_meta_data,
800
+ )
801
+ for j in range(
802
+ len(sequences)
803
+ ) # Iterate through items in the current batch
804
+ ]
805
+
806
+ # Add results from this batch to the overall list
807
+ all_final_outputs.extend(batch_results)
808
+ # --- End of batch processing ---
809
+
810
+ return all_final_outputs
811
 
812
  def _infer(
813
  self,
 
991
 
992
  model_class = (
993
  AutoPeftModelForSeq2SeqLM
994
+ if AutoConfig.from_pretrained(self.peft_config.base_model_name_or_path).is_encoder_decoder
995
  else AutoPeftModelForCausalLM
996
  )
997
+ path = self.model_name
998
  if settings.hf_offline_models_path is not None:
999
  path = os.path.join(settings.hf_offline_models_path, path)
1000
 
 
1005
  low_cpu_mem_usage=self.low_cpu_mem_usage,
1006
  torch_dtype=self._get_torch_dtype(),
1007
  )
1008
+ self.model = self.model.to(dtype=self._get_torch_dtype()) # Make sure that base model and adapter use same dtype
1009
  if self.device_map is None:
1010
  self.model.to(self.device)
1011
 
 
1056
  except Exception:
1057
  try:
1058
  from peft import PeftConfig
1059
+
1060
  # If full model loading fails, try loading as a PEFT adapter
1061
  peft_config = PeftConfig.from_pretrained(path)
1062
 
1063
  if not peft_config.base_model_name_or_path:
1064
+ raise ValueError(
1065
+ f"Base model name not found in PEFT config for {path}"
1066
+ )
1067
 
1068
  # Load the base model's config
1069
+ config = AutoConfig.from_pretrained(
1070
+ peft_config.base_model_name_or_path, trust_remote_code=True
1071
+ )
1072
  except Exception as err2:
1073
+ raise ValueError(
1074
+ f"Could not determine model type for: {path}"
1075
+ ) from err2
1076
 
1077
+ self.task = (
1078
+ "text2text-generation" if config.is_encoder_decoder else "text-generation"
1079
+ )
1080
 
1081
  def _get_model_args(self) -> Dict[str, Any]:
1082
  import torch
 
1421
  for option in instance["task_data"]["options"]
1422
  ]
1423
 
1424
+ dataset_with_options_logprobs: List[List[Dict[str, Union[float, str]]]] = (
1425
+ self.get_options_log_probs(dataset_with_options)
1426
+ )
1427
 
1428
  dataset_iterator = iter(dataset_with_options_logprobs)
1429
 
 
1496
  def _get_credentials():
1497
  from genai import Credentials
1498
 
1499
+ api_key_env_var_name = "GENAI_KEY" # pragma: allowlist secret
1500
  api_key = os.environ.get(api_key_env_var_name)
1501
 
1502
  assert api_key is not None, (
 
1582
  predict_results = []
1583
  for prediction in predictions:
1584
  result: TextGenerationResult = prediction.results[0]
1585
+ assert isinstance(result.generated_tokens, list), (
1586
+ "result.generated_tokens should be a list"
1587
+ )
1588
 
1589
  predict_result = []
1590
  for base_token in result.generated_tokens:
 
1829
  @run_with_imap
1830
  def _get_chat_completion(self, instance, return_meta_data):
1831
  import openai
1832
+
1833
  messages = self.to_messages(instance)
1834
  try:
1835
  response = self.client.chat.completions.create(
 
1841
  return self.get_return_object(prediction, response, return_meta_data)
1842
  # catch in case of content_filtering failure
1843
  except openai.BadRequestError as e:
1844
+ logging.error(
1845
+ f"Error predicting instance {messages}:{e}. Returning empty prediction"
1846
+ )
1847
+ return TextGenerationInferenceOutput(
1848
+ prediction="-", input_tokens=0, output_tokens=0
1849
+ )
1850
 
1851
  @run_with_imap
1852
  def _get_logprobs(self, instance, return_meta_data):
1853
  import openai
1854
+
1855
  messages = self.to_messages(instance)
1856
  try:
1857
  response = self.client.chat.completions.create(
 
1872
  return self.get_return_object(pred_output, response, return_meta_data)
1873
  # catch in case of content_filtering failure
1874
  except openai.BadRequestError as e:
1875
+ logging.error(
1876
+ f"Error predicting instance {messages}:{e}. Returning empty prediction"
1877
+ )
1878
+ prediction = [{"top_tokens": [{"text": "-", "logprob": 0}]}]
1879
+ return TextGenerationInferenceOutput(
1880
+ prediction=prediction, input_tokens=0, output_tokens=0
1881
+ )
1882
 
1883
  def get_return_object(self, predict_result, response, return_meta_data):
1884
  if return_meta_data:
 
1912
  api_version = self.credentials.get(
1913
  "api_version", os.environ.get("OPENAI_API_VERSION", None)
1914
  )
1915
+ assert api_version and azure_openapi_host, (
1916
+ "Error while trying to run AzureOpenAIInferenceEngine: Missing environment variable param AZURE_OPENAI_HOST or OPENAI_API_VERSION"
1917
+ )
1918
  api_url = f"{azure_openapi_host}/openai/deployments/{self.model_name}/chat/completions?api-version={api_version}"
1919
 
1920
  return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
 
1941
  label: str = "rits"
1942
  data_classification_policy = ["public", "proprietary"]
1943
 
1944
+ model_names_dict = {"microsoft/phi-4": "microsoft-phi-4"}
 
 
1945
 
1946
  def get_default_headers(self):
1947
  return {"RITS_API_KEY": self.credentials["api_key"]}
 
2009
  from together import Together
2010
  from together.types.models import ModelType
2011
 
2012
+ api_key_env_var_name = "TOGETHER_API_KEY" # pragma: allowlist secret
2013
  api_key = os.environ.get(api_key_env_var_name)
2014
  assert api_key is not None, (
2015
  f"Error while trying to run TogetherAiInferenceEngine."
 
2024
  together_model.id: together_model.type for together_model in together_models
2025
  }
2026
  model_type = together_model_id_to_type.get(self.model_name)
2027
+ assert model_type is not None, (
2028
+ f"Could not find model {self.model_name} in Together AI model list"
2029
+ )
2030
  assert model_type in [ModelType.CHAT, ModelType.LANGUAGE, ModelType.CODE], (
2031
  f"Together AI model type {model_type} is not supported; "
2032
  "supported types are 'chat', 'language' and 'code'."
 
2205
  def verify(self):
2206
  super().verify()
2207
 
2208
+ assert self.model_name or (
2209
+ self.deployment_id and not (self.model_name and self.deployment_id)
2210
+ ), (
2211
+ "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
2212
+ )
2213
 
2214
  # def process_data_before_dump(self, data):
2215
  # if "credentials" in data:
 
2228
  self._verify_wml_credentials(self.credentials)
2229
  return APIClient(
2230
  credentials=Credentials(
2231
+ api_key=self.credentials["api_key"], url=self.credentials["url"]
 
2232
  ),
2233
  project_id=self.credentials.get("project_id", None),
2234
+ space_id=self.credentials.get("space_id", None),
2235
+ )
2236
 
2237
  @staticmethod
2238
  def _read_wml_credentials_from_env() -> CredentialsWML:
 
2300
  "['url', 'api_key', 'username', 'password']."
2301
  )
2302
 
2303
+ assert credentials.get("url"), (
2304
+ "'url' is a mandatory key for WML credentials dict."
2305
+ )
2306
  assert "space_id" in credentials or "project_id" in credentials, (
2307
  "Either 'space_id' or 'project_id' must be provided "
2308
  "as keys for WML credentials dict."
 
2703
  return True
2704
 
2705
  def to_messages(self, instance: Union[Dict, List]) -> List[List[Dict[str, Any]]]:
2706
+ if isinstance(instance["source"], str) and self.check_instance_contains_image(
2707
+ instance
2708
+ ):
2709
  return self._create_messages_from_instance(instance)
2710
 
2711
  messages = super().to_messages(instance)
 
3029
 
3030
 
3031
  class VLLMInferenceEngine(InferenceEngine, PackageRequirementsMixin, VLLMParamsMixin):
3032
+ label = "vllm"
3033
 
3034
  def get_engine_id(self):
3035
  return get_model_and_label_id(self.model, self.label)
 
3131
  self.inference_type = "litellm"
3132
  from litellm import acompletion
3133
 
 
3134
  self._completion = acompletion
3135
  # Initialize a semaphore to limit concurrency
3136
  self._semaphore = asyncio.Semaphore(round(self.max_requests_per_second))
 
3151
  response = await self._completion(
3152
  messages=messages,
3153
  max_retries=self.max_retries,
 
3154
  drop_params=False,
3155
  **self.credentials,
3156
  **kwargs,
 
3241
 
3242
  label: str = "cross_provider"
3243
  provider: Optional[_supported_apis] = None
3244
+ provider_specific_args: Optional[Dict[str, Dict[str, str]]] = None
3245
 
3246
  provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
3247
+ "watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
3248
  "granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
3249
  "granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
3250
  "granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
 
3271
  "llama-3-1-70b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
3272
  "llama-3-1-405b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
3273
  "llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
3274
+ "llama-3-3-70b-instruct": "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
3275
  },
3276
  "aws": {
3277
  "llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
 
3285
  "llama-3-1-405b-instruct": "llama3.1:405b",
3286
  "llama-3-2-1b-instruct": "llama3.2:1b",
3287
  "llama-3-2-3b-instruct": "llama3.2:3b",
3288
+ "llama-3-3-70b-instruct": "llama3.3",
3289
  },
3290
  "bam": {
3291
  "granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
 
3360
  "llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas",
3361
  },
3362
  "replicate": {
3363
+ "granite-3-2-8b-instruct": "replicate/ibm-granite/granite-3.2-8b-instruct",
3364
+ "granite-vision-3-2-2b": "replicate/ibm-granite/granite-vision-3.2-2b",
 
 
3365
  "granite-3-1-8b-instruct": "replicate/ibm-granite/granite-3.1-8b-instruct",
3366
+ "granite-3-1-2b-instruct": "replicate/ibm-granite/granite-3.1-2b-instruct",
3367
+ "granite-3-8b-instruct": "replicate/ibm-granite/granite-3.0-8b-instruct",
3368
+ "granite-3-2b-instruct": "replicate/ibm-granite/granite-3.0-2b-instruct",
3369
  "granite-8b-code-instruct-128k": "replicate/ibm-granite/granite-8b-code-instruct-128k",
3370
+ "granite-20b-code-instruct-8k": "replicate/ibm-granite/granite-20b-code-instruct-8k",
3371
  "llama-2-13b": "replicate/meta/llama-2-13b",
3372
  "llama-2-13b-chat": "replicate/meta/llama-2-13b-chat",
3373
  "llama-2-70b": "replicate/meta/llama-2-70b",
 
3384
  "mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
3385
  },
3386
  }
3387
+ provider_model_map["watsonx"] = {
3388
+ k: f"watsonx/{v}" for k, v in provider_model_map["watsonx-sdk"].items()
3389
+ }
3390
 
3391
  _provider_to_base_class = {
3392
  "watsonx": LiteLLMInferenceEngine,
 
3429
  args["model"] = self.provider_model_map[provider].get(self.model, self.model)
3430
 
3431
  if self.provider_specific_args is not None:
3432
+ provider_args = self.provider_specific_args.get(provider)
3433
  if provider_args is not None:
3434
  args.update(provider_args)
3435
 
 
3464
 
3465
  This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
3466
  """
3467
+
3468
  label = "hf_option_selection"
3469
  model_name: str
3470
  batch_size: int
 
3491
  path,
3492
  )
3493
  self.model = AutoModelForCausalLM.from_pretrained(
3494
+ path,
3495
+ ).to(self.device)
 
 
3496
  # Set pad_token if it doesn't exist
3497
  if self.tokenizer.pad_token is None:
3498
  self.tokenizer.pad_token = self.tokenizer.eos_token
llm_as_judge.py CHANGED
@@ -240,7 +240,7 @@ class LLMJudgeDirect(LLMJudge):
240
  main_score = "llm_as_judge"
241
  """The primary score name used in the results. By default, it will take the value of the criteria name (if only one criteria is being used for evaluation) or "llm_as_judge" otherwise."""
242
  reduction_map = {"mean": ["llm_as_judge"]}
243
- """A mapping used for score aggregation. By default, it will take the value of `{'mean': [<default_main_score_name>]}`."""
244
 
245
  def prepare(self):
246
  super().prepare()
@@ -420,7 +420,7 @@ class LLMJudgeDirect(LLMJudge):
420
  This method evaluates the quality of of the predictions by calculating scores for each instance based on a criterion.
421
 
422
  Returns:
423
- -------
424
  List[Dict]
425
  A list of dictionaries containing the evaluation results for each instance. The results include the computed scores for each prediction. Each result will have the `score_name` as a prefix, which may be the criterion name if only one used, or "llm_as_judge" if several criteria were used.
426
 
@@ -647,7 +647,7 @@ class LLMJudgePairwise(LLMJudge):
647
  main_score = "1_winrate"
648
  """The main score metric for pairwise evaluation. By default, its value is `1_winrate`, and will take the value of the winrate of the first system."""
649
  reduction_map = {"mean": ["score"]}
650
- """A mapping specifying how scores should be reduced. By default, it will be `{'main': ['score']}`"""
651
 
652
  def prepare(self):
653
  """Prepares the pairwise comparison by initializing the necessary templates and tasks. These tasks will be used to assess, summarize, and select options from candidate responses."""
@@ -937,7 +937,7 @@ class LLMJudgePairwise(LLMJudge):
937
  task_data (List[Dict[str, str]]): Task data to be used for evaluation.
938
 
939
  Returns:
940
- -------
941
  List[Dict[str,Dict]]
942
  The results of the evaluation, including winrate, ranking, and other metrics.
943
 
 
240
  main_score = "llm_as_judge"
241
  """The primary score name used in the results. By default, it will take the value of the criteria name (if only one criteria is being used for evaluation) or "llm_as_judge" otherwise."""
242
  reduction_map = {"mean": ["llm_as_judge"]}
243
+ """A mapping used for score aggregation. By default, it will take the value of ``{'mean': [<default_main_score_name>]}`` ."""
244
 
245
  def prepare(self):
246
  super().prepare()
 
420
  This method evaluates the quality of of the predictions by calculating scores for each instance based on a criterion.
421
 
422
  Returns:
423
+ --------
424
  List[Dict]
425
  A list of dictionaries containing the evaluation results for each instance. The results include the computed scores for each prediction. Each result will have the `score_name` as a prefix, which may be the criterion name if only one used, or "llm_as_judge" if several criteria were used.
426
 
 
647
  main_score = "1_winrate"
648
  """The main score metric for pairwise evaluation. By default, its value is `1_winrate`, and will take the value of the winrate of the first system."""
649
  reduction_map = {"mean": ["score"]}
650
+ """A mapping specifying how scores should be reduced. By default, it will be ``{'main': ['score']}`` ."""
651
 
652
  def prepare(self):
653
  """Prepares the pairwise comparison by initializing the necessary templates and tasks. These tasks will be used to assess, summarize, and select options from candidate responses."""
 
937
  task_data (List[Dict[str, str]]): Task data to be used for evaluation.
938
 
939
  Returns:
940
+ --------
941
  List[Dict[str,Dict]]
942
  The results of the evaluation, including winrate, ranking, and other metrics.
943
 
metric.py CHANGED
@@ -18,6 +18,7 @@ from .dialog_operators import __file__ as _
18
  from .dict_utils import __file__ as _
19
  from .error_utils import __file__ as _
20
  from .eval_utils import __file__ as _
 
21
  from .file_utils import __file__ as _
22
  from .formats import __file__ as _
23
  from .fusion import __file__ as _
 
18
  from .dict_utils import __file__ as _
19
  from .error_utils import __file__ as _
20
  from .eval_utils import __file__ as _
21
+ from .evaluate_cli import __file__ as _
22
  from .file_utils import __file__ as _
23
  from .formats import __file__ as _
24
  from .fusion import __file__ as _
metrics.py CHANGED
@@ -71,6 +71,7 @@ settings = get_settings()
71
 
72
  warnings.filterwarnings("ignore", category=DegenerateDataWarning)
73
 
 
74
  @retry_connection_with_exponential_backoff(backoff_factor=2)
75
  def hf_evaluate_load(path: str, *args, **kwargs):
76
  if settings.hf_offline_metrics_path is not None:
@@ -792,6 +793,7 @@ class MetricWithConfidenceInterval(Metric):
792
  n_resamples: int = None
793
  confidence_level: float = 0.95
794
  ci_scores: List[str] = None
 
795
 
796
  @staticmethod
797
  def new_random_generator():
@@ -907,6 +909,7 @@ class MetricWithConfidenceInterval(Metric):
907
  n_resamples=self.n_resamples,
908
  confidence_level=self.confidence_level,
909
  random_state=self.new_random_generator(),
 
910
  ).confidence_interval
911
  full_score_name = ci_score_prefix + score_name
912
  result[f"{full_score_name}_ci_low"] = ci.low
@@ -1007,6 +1010,7 @@ class MetricWithConfidenceInterval(Metric):
1007
  n_resamples=self.n_resamples,
1008
  confidence_level=self.confidence_level,
1009
  random_state=random_gen,
 
1010
  ).confidence_interval
1011
  result["score_ci_low"] = float(ci.low)
1012
  result["score_ci_high"] = float(ci.high)
@@ -1193,9 +1197,9 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1193
  )
1194
 
1195
  for reduction, fields in self.reduction_map.items():
1196
- assert (
1197
- reduction in self.implemented_reductions
1198
- ), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
1199
 
1200
  if reduction == "mean":
1201
  for field_name in fields:
@@ -1464,12 +1468,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1464
  def _validate_group_mean_task_data(self, instance):
1465
  # instances need to all have task_data field with field group_id
1466
  assert "task_data" in instance, "each instance must have an task_data field"
1467
- assert isinstance(
1468
- instance["task_data"], dict
1469
- ), "each instance must have an task_data field that is a dict"
1470
- assert (
1471
- "group_id" in instance["task_data"]
1472
- ), "each instance task_data dict must have a key group_id"
1473
 
1474
  def _validate_group_mean_reduction(self):
1475
  """Ensure that group_mean reduction_map is properly formatted.
@@ -1522,30 +1526,30 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1522
  2 'Why are ants eating my food?' 'original'
1523
  """
1524
  # validate the reduction_map
1525
- assert (
1526
- "group_mean" in self.reduction_map
1527
- ), "reduction_map must have a 'group_mean' key"
1528
  fields = self.reduction_map["group_mean"]
1529
  # for group_mean, expects a dict
1530
  assert isinstance(fields, dict)
1531
- assert (
1532
- "agg_func" in fields
1533
- ), "fields should have a key 'agg_func' whose value is a 3-element list of a function name, function definition, and a boolean indicator"
1534
- assert isinstance(
1535
- fields["agg_func"], list
1536
- ), "fields['agg_func'] should be a list"
1537
- assert (
1538
- len(fields["agg_func"]) == 3
1539
- ), "fields['agg_func'] should be a 3-element list"
1540
- assert isinstance(
1541
- fields["agg_func"][0], str
1542
- ), "first item in fields['agg_func'] should be a string name of a function"
1543
- assert callable(
1544
- fields["agg_func"][1]
1545
- ), "second item in fields['agg_func'] should be a callable function"
1546
- assert isinstance(
1547
- fields["agg_func"][2], bool
1548
- ), "third item in fields['agg_func'] should be a boolean value"
1549
  if "score_fields" in fields:
1550
  assert isinstance(fields["score_fields"], list)
1551
 
@@ -1553,9 +1557,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1553
  instance_scores = self.compute_instance_scores(stream)
1554
  global_score = {"num_of_instances": len(instance_scores)}
1555
  for reduction_type, reduction_params in self.reduction_map.items():
1556
- assert (
1557
- reduction_type in self.implemented_reductions
1558
- ), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
1559
 
1560
  field_name_full_prefix = ""
1561
  # used for passing to the bootstrapping, depends on whether the groups are fixed or not
@@ -1653,7 +1657,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1653
  assert (
1654
  "task_data" in instance
1655
  and self.subgroup_column in instance["task_data"]
1656
- ), f"each instance task_data dict must have a key {self.subgroup_column}"
 
 
1657
 
1658
  task_data = instance["task_data"] if "task_data" in instance else {}
1659
 
@@ -2249,15 +2255,15 @@ class MetricPipeline(MultiStreamOperator, Metric):
2249
 
2250
  def verify(self):
2251
  super().verify()
2252
- assert (
2253
- self.metric is not None
2254
- ), f"'metric' is not set in {self.get_metric_name()}"
2255
- assert (
2256
- self.main_score is not None
2257
- ), f"'main_score' is not set in {self.get_metric_name()}"
2258
- assert isinstance(
2259
- self.metric, Metric
2260
- ), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
2261
  if self.postpreprocess_steps is not None:
2262
  depr_message = "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
2263
  warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
@@ -2278,9 +2284,9 @@ class MetricPipeline(MultiStreamOperator, Metric):
2278
  and isinstance(self.postprocess_steps, list)
2279
  and len(self.postprocess_steps) > 0
2280
  )
2281
- assert not (
2282
- has_postpreprocess and has_postprocess
2283
- ), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
2284
  if has_postpreprocess:
2285
  self.postprocess_steps = self.postpreprocess_steps
2286
  self.prepare_score = SequentialOperator(
@@ -2357,10 +2363,14 @@ class HuggingfaceMetric(GlobalMetric):
2357
 
2358
  assert self.hf_additional_input_fields is None or isoftype(
2359
  self.hf_additional_input_fields, List[str]
2360
- ), f"Argument hf_additional_input_fields should be either None or List[str]. It is now: {self.hf_additional_input_fields}."
 
 
2361
  assert self.hf_additional_input_fields_pass_one_value is None or isoftype(
2362
  self.hf_additional_input_fields_pass_one_value, List[str]
2363
- ), f"Argument hf_additional_input_fields_pass_one_value should be either None or List[str]. It is now: {self.hf_additional_input_fields_pass_one_value}."
 
 
2364
 
2365
  return super().verify()
2366
 
@@ -2377,25 +2387,25 @@ class HuggingfaceMetric(GlobalMetric):
2377
  ) -> dict:
2378
  passed_task_data = {}
2379
  for additional_input_field in self.hf_additional_input_fields:
2380
- assert (
2381
- additional_input_field in task_data[0]
2382
- ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
2383
  passed_task_data[additional_input_field] = [
2384
  additional_input[additional_input_field]
2385
  for additional_input in task_data
2386
  ]
2387
  for additional_input_field in self.hf_additional_input_fields_pass_one_value:
2388
- assert (
2389
- additional_input_field in task_data[0]
2390
- ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
2391
 
2392
  values = {
2393
  additional_input[additional_input_field]
2394
  for additional_input in task_data
2395
  }
2396
- assert (
2397
- len(values) == 1
2398
- ), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
2399
 
2400
  passed_task_data[additional_input_field] = next(iter(values))
2401
 
@@ -2410,22 +2420,22 @@ class HuggingfaceMetric(GlobalMetric):
2410
  result[self.main_score] = float(result[self.hf_main_score])
2411
  del result[self.hf_main_score]
2412
  if self.scale != 1.0:
2413
- assert (
2414
- self.scaled_fields is not None
2415
- ), f"Scaling factor was set to {self.scale}, but no fields specified"
2416
  for key in self.scaled_fields:
2417
- assert (
2418
- key in result
2419
- ), f"Trying to scale field '{key}' which is not in results of metrics: {result}"
2420
  if isinstance(result[key], list):
2421
- assert all(
2422
- isinstance(v, float) for v in result[key]
2423
- ), "Not all scaled field '{key}' values are floats: {result[key]}"
2424
  result[key] = [v / self.scale for v in result[key]]
2425
  else:
2426
- assert isinstance(
2427
- result[key], float
2428
- ), "Scaled field '{key}' is not float: {result[key]}"
2429
  result[key] /= self.scale
2430
  if self.main_score in result:
2431
  result[self.main_score] = float(result[self.main_score])
@@ -2452,9 +2462,9 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
2452
  ) -> List[Dict[str, Any]]:
2453
  passed_task_data = {}
2454
  for additional_input_field in self.hf_additional_input_fields:
2455
- assert (
2456
- additional_input_field in task_data[0]
2457
- ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
2458
  passed_task_data[additional_input_field] = [
2459
  additional_input[additional_input_field]
2460
  for additional_input in task_data
@@ -2791,9 +2801,9 @@ class FinQAEval(InstanceMetric):
2791
  response = requests.get(url)
2792
  response.raise_for_status()
2793
  content = response.content
2794
- assert (
2795
- hashlib.md5(content).hexdigest() == hash_of_script
2796
- ), f'URL ("{url}") is different than expected. Make sure you added the right one.'
2797
 
2798
  with open(local_path, "wb") as file:
2799
  file.write(content)
@@ -2925,9 +2935,9 @@ class F1MultiLabel(GlobalMetric, PackageRequirementsMixin):
2925
  labels=labels_param,
2926
  )
2927
  if isinstance(result[self.metric], numpy.ndarray):
2928
- assert len(result[self.metric]) == len(
2929
- labels
2930
- ), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
2931
  final_result = {self.main_score: nan_mean(result[self.metric])}
2932
  for i, label in enumerate(labels):
2933
  final_result[self.metric + "_" + label] = result[self.metric][i]
@@ -3414,7 +3424,6 @@ class CustomF1(GlobalMetric):
3414
 
3415
 
3416
  class KeyValueExtraction(GlobalMetric):
3417
-
3418
  prediction_type = Dict[str, str]
3419
  metric: Metric
3420
  single_reference_per_prediction = True
@@ -3978,9 +3987,9 @@ class LlamaIndexLLMMetric(InstanceMetric):
3978
  prediction_type = str
3979
  reduction_map: Dict[str, List[str]] = None
3980
  openai_models: List[str] = ["gpt-3.5-turbo"]
3981
- anthropic_models: List[str] = (
3982
- []
3983
- ) # this is here for the sake of documentation for future models
3984
  mock_models: List[str] = ["mock"]
3985
  external_api_models = openai_models + anthropic_models
3986
  data_classification_policy = ["public"]
@@ -4819,12 +4828,12 @@ def validate_subgroup_types(
4819
  for subgroup_name, score_list in subgroup_scores_dict.items()
4820
  }
4821
  )
4822
- assert isinstance(
4823
- control_subgroup_types, list
4824
- ), "control_subgroup_types must be a list"
4825
- assert isinstance(
4826
- comparison_subgroup_types, list
4827
- ), "comparison_subgroup_types must be a list"
4828
  # make sure each list is unique, so that labels aren't double-counted
4829
  control_subgroup_types = list(set(control_subgroup_types))
4830
  comparison_subgroup_types = list(set(comparison_subgroup_types))
@@ -4979,9 +4988,9 @@ def normalized_cohens_h(
4979
 
4980
  # requires scores to be in [0,1]
4981
  for subgroup_name, score_list in subgroup_scores_dict.items():
4982
- assert all(
4983
- 0 <= score <= 1 for score in score_list
4984
- ), f"all {subgroup_name} scores must be in [0,1]"
4985
 
4986
  # combine all scores from each label (if there are more than 1 in each group) into a list
4987
  group_scores_list = [
@@ -5967,9 +5976,9 @@ class RandomForestMetricsEnsemble(MetricsEnsemble):
5967
  return json.load(file)
5968
 
5969
  def ensemble(self, instance):
5970
- assert (
5971
- self.weights is not None
5972
- ), "RandomForestMetricsEnsemble must set self.weights before it can be used"
5973
  ensemble_model = self.decode_forest(self.weights)
5974
 
5975
  prediction_lst = []
@@ -6378,7 +6387,7 @@ class SQLExecutionAccuracy(InstanceMetric):
6378
  ]
6379
 
6380
  prediction_type = "Any" # string representation is compared
6381
- sql_timeout = 100.0
6382
 
6383
  _requirements_list = ["sqlglot", "func_timeout"]
6384
 
@@ -6445,6 +6454,7 @@ class SQLExecutionAccuracy(InstanceMetric):
6445
 
6446
  Comparison is column order independent, and could optionally be row order independent.
6447
  We interpret "subset" as follows:
 
6448
  - For each row in df1, there must be a matching (or superset) row in df2, i.e. the set of values
6449
  in the df1 row is a subset of the set of values in that df2 row. Then do the same check in reverse.
6450
  - If either condition (df1 is subset of df2 OR df2 is subset of df1) is satisfied, return True.
@@ -6458,6 +6468,7 @@ class SQLExecutionAccuracy(InstanceMetric):
6458
 
6459
  Returns:
6460
  bool: True if df1 is a subset of df2 or vice versa, based on the specified row-order condition.
 
6461
  """
6462
  df1_array = df1.values.astype(str)
6463
  df2_array = df2.values.astype(str)
 
71
 
72
  warnings.filterwarnings("ignore", category=DegenerateDataWarning)
73
 
74
+
75
  @retry_connection_with_exponential_backoff(backoff_factor=2)
76
  def hf_evaluate_load(path: str, *args, **kwargs):
77
  if settings.hf_offline_metrics_path is not None:
 
793
  n_resamples: int = None
794
  confidence_level: float = 0.95
795
  ci_scores: List[str] = None
796
+ ci_method: str = "BCa"
797
 
798
  @staticmethod
799
  def new_random_generator():
 
909
  n_resamples=self.n_resamples,
910
  confidence_level=self.confidence_level,
911
  random_state=self.new_random_generator(),
912
+ method=self.ci_method
913
  ).confidence_interval
914
  full_score_name = ci_score_prefix + score_name
915
  result[f"{full_score_name}_ci_low"] = ci.low
 
1010
  n_resamples=self.n_resamples,
1011
  confidence_level=self.confidence_level,
1012
  random_state=random_gen,
1013
+ method=self.ci_method
1014
  ).confidence_interval
1015
  result["score_ci_low"] = float(ci.low)
1016
  result["score_ci_high"] = float(ci.high)
 
1197
  )
1198
 
1199
  for reduction, fields in self.reduction_map.items():
1200
+ assert reduction in self.implemented_reductions, (
1201
+ f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
1202
+ )
1203
 
1204
  if reduction == "mean":
1205
  for field_name in fields:
 
1468
  def _validate_group_mean_task_data(self, instance):
1469
  # instances need to all have task_data field with field group_id
1470
  assert "task_data" in instance, "each instance must have an task_data field"
1471
+ assert isinstance(instance["task_data"], dict), (
1472
+ "each instance must have an task_data field that is a dict"
1473
+ )
1474
+ assert "group_id" in instance["task_data"], (
1475
+ "each instance task_data dict must have a key group_id"
1476
+ )
1477
 
1478
  def _validate_group_mean_reduction(self):
1479
  """Ensure that group_mean reduction_map is properly formatted.
 
1526
  2 'Why are ants eating my food?' 'original'
1527
  """
1528
  # validate the reduction_map
1529
+ assert "group_mean" in self.reduction_map, (
1530
+ "reduction_map must have a 'group_mean' key"
1531
+ )
1532
  fields = self.reduction_map["group_mean"]
1533
  # for group_mean, expects a dict
1534
  assert isinstance(fields, dict)
1535
+ assert "agg_func" in fields, (
1536
+ "fields should have a key 'agg_func' whose value is a 3-element list of a function name, function definition, and a boolean indicator"
1537
+ )
1538
+ assert isinstance(fields["agg_func"], list), (
1539
+ "fields['agg_func'] should be a list"
1540
+ )
1541
+ assert len(fields["agg_func"]) == 3, (
1542
+ "fields['agg_func'] should be a 3-element list"
1543
+ )
1544
+ assert isinstance(fields["agg_func"][0], str), (
1545
+ "first item in fields['agg_func'] should be a string name of a function"
1546
+ )
1547
+ assert callable(fields["agg_func"][1]), (
1548
+ "second item in fields['agg_func'] should be a callable function"
1549
+ )
1550
+ assert isinstance(fields["agg_func"][2], bool), (
1551
+ "third item in fields['agg_func'] should be a boolean value"
1552
+ )
1553
  if "score_fields" in fields:
1554
  assert isinstance(fields["score_fields"], list)
1555
 
 
1557
  instance_scores = self.compute_instance_scores(stream)
1558
  global_score = {"num_of_instances": len(instance_scores)}
1559
  for reduction_type, reduction_params in self.reduction_map.items():
1560
+ assert reduction_type in self.implemented_reductions, (
1561
+ f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
1562
+ )
1563
 
1564
  field_name_full_prefix = ""
1565
  # used for passing to the bootstrapping, depends on whether the groups are fixed or not
 
1657
  assert (
1658
  "task_data" in instance
1659
  and self.subgroup_column in instance["task_data"]
1660
+ ), (
1661
+ f"each instance task_data dict must have a key {self.subgroup_column}"
1662
+ )
1663
 
1664
  task_data = instance["task_data"] if "task_data" in instance else {}
1665
 
 
2255
 
2256
  def verify(self):
2257
  super().verify()
2258
+ assert self.metric is not None, (
2259
+ f"'metric' is not set in {self.get_metric_name()}"
2260
+ )
2261
+ assert self.main_score is not None, (
2262
+ f"'main_score' is not set in {self.get_metric_name()}"
2263
+ )
2264
+ assert isinstance(self.metric, Metric), (
2265
+ f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
2266
+ )
2267
  if self.postpreprocess_steps is not None:
2268
  depr_message = "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
2269
  warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
 
2284
  and isinstance(self.postprocess_steps, list)
2285
  and len(self.postprocess_steps) > 0
2286
  )
2287
+ assert not (has_postpreprocess and has_postprocess), (
2288
+ "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
2289
+ )
2290
  if has_postpreprocess:
2291
  self.postprocess_steps = self.postpreprocess_steps
2292
  self.prepare_score = SequentialOperator(
 
2363
 
2364
  assert self.hf_additional_input_fields is None or isoftype(
2365
  self.hf_additional_input_fields, List[str]
2366
+ ), (
2367
+ f"Argument hf_additional_input_fields should be either None or List[str]. It is now: {self.hf_additional_input_fields}."
2368
+ )
2369
  assert self.hf_additional_input_fields_pass_one_value is None or isoftype(
2370
  self.hf_additional_input_fields_pass_one_value, List[str]
2371
+ ), (
2372
+ f"Argument hf_additional_input_fields_pass_one_value should be either None or List[str]. It is now: {self.hf_additional_input_fields_pass_one_value}."
2373
+ )
2374
 
2375
  return super().verify()
2376
 
 
2387
  ) -> dict:
2388
  passed_task_data = {}
2389
  for additional_input_field in self.hf_additional_input_fields:
2390
+ assert additional_input_field in task_data[0], (
2391
+ f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
2392
+ )
2393
  passed_task_data[additional_input_field] = [
2394
  additional_input[additional_input_field]
2395
  for additional_input in task_data
2396
  ]
2397
  for additional_input_field in self.hf_additional_input_fields_pass_one_value:
2398
+ assert additional_input_field in task_data[0], (
2399
+ f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
2400
+ )
2401
 
2402
  values = {
2403
  additional_input[additional_input_field]
2404
  for additional_input in task_data
2405
  }
2406
+ assert len(values) == 1, (
2407
+ f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
2408
+ )
2409
 
2410
  passed_task_data[additional_input_field] = next(iter(values))
2411
 
 
2420
  result[self.main_score] = float(result[self.hf_main_score])
2421
  del result[self.hf_main_score]
2422
  if self.scale != 1.0:
2423
+ assert self.scaled_fields is not None, (
2424
+ f"Scaling factor was set to {self.scale}, but no fields specified"
2425
+ )
2426
  for key in self.scaled_fields:
2427
+ assert key in result, (
2428
+ f"Trying to scale field '{key}' which is not in results of metrics: {result}"
2429
+ )
2430
  if isinstance(result[key], list):
2431
+ assert all(isinstance(v, float) for v in result[key]), (
2432
+ "Not all scaled field '{key}' values are floats: {result[key]}"
2433
+ )
2434
  result[key] = [v / self.scale for v in result[key]]
2435
  else:
2436
+ assert isinstance(result[key], float), (
2437
+ "Scaled field '{key}' is not float: {result[key]}"
2438
+ )
2439
  result[key] /= self.scale
2440
  if self.main_score in result:
2441
  result[self.main_score] = float(result[self.main_score])
 
2462
  ) -> List[Dict[str, Any]]:
2463
  passed_task_data = {}
2464
  for additional_input_field in self.hf_additional_input_fields:
2465
+ assert additional_input_field in task_data[0], (
2466
+ f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
2467
+ )
2468
  passed_task_data[additional_input_field] = [
2469
  additional_input[additional_input_field]
2470
  for additional_input in task_data
 
2801
  response = requests.get(url)
2802
  response.raise_for_status()
2803
  content = response.content
2804
+ assert hashlib.md5(content).hexdigest() == hash_of_script, (
2805
+ f'URL ("{url}") is different than expected. Make sure you added the right one.'
2806
+ )
2807
 
2808
  with open(local_path, "wb") as file:
2809
  file.write(content)
 
2935
  labels=labels_param,
2936
  )
2937
  if isinstance(result[self.metric], numpy.ndarray):
2938
+ assert len(result[self.metric]) == len(labels), (
2939
+ f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
2940
+ )
2941
  final_result = {self.main_score: nan_mean(result[self.metric])}
2942
  for i, label in enumerate(labels):
2943
  final_result[self.metric + "_" + label] = result[self.metric][i]
 
3424
 
3425
 
3426
  class KeyValueExtraction(GlobalMetric):
 
3427
  prediction_type = Dict[str, str]
3428
  metric: Metric
3429
  single_reference_per_prediction = True
 
3987
  prediction_type = str
3988
  reduction_map: Dict[str, List[str]] = None
3989
  openai_models: List[str] = ["gpt-3.5-turbo"]
3990
+ anthropic_models: List[
3991
+ str
3992
+ ] = [] # this is here for the sake of documentation for future models
3993
  mock_models: List[str] = ["mock"]
3994
  external_api_models = openai_models + anthropic_models
3995
  data_classification_policy = ["public"]
 
4828
  for subgroup_name, score_list in subgroup_scores_dict.items()
4829
  }
4830
  )
4831
+ assert isinstance(control_subgroup_types, list), (
4832
+ "control_subgroup_types must be a list"
4833
+ )
4834
+ assert isinstance(comparison_subgroup_types, list), (
4835
+ "comparison_subgroup_types must be a list"
4836
+ )
4837
  # make sure each list is unique, so that labels aren't double-counted
4838
  control_subgroup_types = list(set(control_subgroup_types))
4839
  comparison_subgroup_types = list(set(comparison_subgroup_types))
 
4988
 
4989
  # requires scores to be in [0,1]
4990
  for subgroup_name, score_list in subgroup_scores_dict.items():
4991
+ assert all(0 <= score <= 1 for score in score_list), (
4992
+ f"all {subgroup_name} scores must be in [0,1]"
4993
+ )
4994
 
4995
  # combine all scores from each label (if there are more than 1 in each group) into a list
4996
  group_scores_list = [
 
5976
  return json.load(file)
5977
 
5978
  def ensemble(self, instance):
5979
+ assert self.weights is not None, (
5980
+ "RandomForestMetricsEnsemble must set self.weights before it can be used"
5981
+ )
5982
  ensemble_model = self.decode_forest(self.weights)
5983
 
5984
  prediction_lst = []
 
6387
  ]
6388
 
6389
  prediction_type = "Any" # string representation is compared
6390
+ sql_timeout = 30.0
6391
 
6392
  _requirements_list = ["sqlglot", "func_timeout"]
6393
 
 
6454
 
6455
  Comparison is column order independent, and could optionally be row order independent.
6456
  We interpret "subset" as follows:
6457
+
6458
  - For each row in df1, there must be a matching (or superset) row in df2, i.e. the set of values
6459
  in the df1 row is a subset of the set of values in that df2 row. Then do the same check in reverse.
6460
  - If either condition (df1 is subset of df2 OR df2 is subset of df1) is satisfied, return True.
 
6468
 
6469
  Returns:
6470
  bool: True if df1 is a subset of df2 or vice versa, based on the specified row-order condition.
6471
+
6472
  """
6473
  df1_array = df1.values.astype(str)
6474
  df2_array = df2.values.astype(str)
parsing_utils.py CHANGED
@@ -51,9 +51,9 @@ def consume_name_val(instring: str) -> Tuple[Any, str]:
51
  instring = instring[len(name_val) :].strip()
52
  name_val = name_val.strip()
53
 
54
- if name_val == "True":
55
  return (True, instring)
56
- if name_val == "False":
57
  return (False, instring)
58
  if name_val == "None":
59
  return (None, instring)
 
51
  instring = instring[len(name_val) :].strip()
52
  name_val = name_val.strip()
53
 
54
+ if name_val.lower() == "true":
55
  return (True, instring)
56
+ if name_val.lower() == "false":
57
  return (False, instring)
58
  if name_val == "None":
59
  return (None, instring)
processors.py CHANGED
@@ -430,32 +430,86 @@ class AddPrefix(FieldOperator):
430
 
431
 
432
  class GetSQL(FieldOperator):
 
 
 
 
 
 
 
 
433
  def process_value(self, text: str) -> str:
434
- """Extracts the first SQL query from a given text.
435
 
436
  Args:
437
- text: The input string containing the SQL query.
 
438
 
439
  Returns:
440
- The first SQL query found in the text, or None if no query is found.
 
441
  """
442
- match = re.search(
443
- r"(?:```)?.*?(SELECT.*?(?:FROM|WITH|;|$).*?)(?:```|;|$)",
444
- text,
445
- re.IGNORECASE | re.DOTALL,
446
- )
447
 
448
- if match:
449
- out = (
450
- text[match.start() : match.end()]
451
- .replace("```", "")
452
- .replace(";", "")
453
- .strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  else:
456
- out = "No query found in generation"
457
 
458
- return out
 
 
 
459
 
460
 
461
  class ScaleNumberToZeroOneReturnZeroIfFails(FieldOperator):
 
430
 
431
 
432
  class GetSQL(FieldOperator):
433
+ """Operator to extract the most likely SQL query from text, often generated by language models.
434
+
435
+ It prioritizes SQL within markdown code blocks (```sql or ```)
436
+ and defaults to finding the last SELECT statement in the text
437
+ if no code blocks are found. It attempts to remove trailing text
438
+ after the first semicolon in the identified query.
439
+ """
440
+
441
  def process_value(self, text: str) -> str:
442
+ """Extracts the most plausible SQL query from the given text.
443
 
444
  Args:
445
+ text: The input string potentially containing an SQL query
446
+ and other text (e.g., explanations, markdown).
447
 
448
  Returns:
449
+ The extracted SQL query string, or a message indicating
450
+ no query was found.
451
  """
452
+ if not isinstance(text, str):
453
+ return "Input must be a string" # Basic type check
 
 
 
454
 
455
+ sql_query_candidate = None # Renamed to indicate it might need cleanup
456
+
457
+ # 1. Try to find ```sql ... ``` code blocks
458
+ sql_blocks = re.findall(
459
+ r"```sql\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE
460
+ )
461
+ if sql_blocks:
462
+ # Use the content of the last ```sql block
463
+ sql_query_candidate = sql_blocks[-1].strip()
464
+ else:
465
+ # 2. If no ```sql blocks, try to find generic ``` ... ``` blocks
466
+ generic_blocks = re.findall(r"```\s*(.*?)\s*```", text, re.DOTALL)
467
+ if generic_blocks:
468
+ # Check if the last block looks like SQL (starts with SELECT, INSERT, etc.)
469
+ last_block_content = generic_blocks[-1].strip()
470
+ # Allow common SQL starting keywords
471
+ sql_keywords = (
472
+ r"^(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|WITH|DROP|TRUNCATE)\b"
473
+ )
474
+ if re.match(sql_keywords, last_block_content, re.IGNORECASE):
475
+ sql_query_candidate = last_block_content
476
+
477
+ # 3. If no suitable code blocks found, search the entire text for the last relevant SQL keyword
478
+ if sql_query_candidate is None:
479
+ # Find the start index of the *last* common SQL keyword (case-insensitive)
480
+ last_match = None
481
+ # Expand search beyond just SELECT for better fallback
482
+ sql_keywords_search = (
483
+ r"\b(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|WITH|DROP|TRUNCATE)\b"
484
  )
485
+ for match in re.finditer(sql_keywords_search, text, re.IGNORECASE):
486
+ last_match = match
487
+
488
+ if last_match:
489
+ # Extract from the last keyword to the end of the string
490
+ sql_query_candidate = text[last_match.start() :].strip()
491
+
492
+ # 4. Cleanup: Truncate at first semicolon and strip whitespace
493
+ if sql_query_candidate:
494
+ # Find the first semicolon in the candidate string
495
+ first_semicolon_index = sql_query_candidate.find(";")
496
+ if first_semicolon_index != -1:
497
+ # If found, take everything before it
498
+ sql_query = sql_query_candidate[:first_semicolon_index].strip()
499
+ else:
500
+ # If no semicolon, use the candidate as is (after stripping)
501
+ sql_query = sql_query_candidate.strip()
502
+
503
+ # clean the ```sql\n from the start and the \n``` in case it is there
504
+ sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
505
+
506
  else:
507
+ sql_query = None # Ensure sql_query is None if no candidate was found
508
 
509
+ # 5. Return result or 'not found' message
510
+ return (
511
+ sql_query if sql_query is not None else "No query found in generation"
512
+ ) # Check for None explicitly
513
 
514
 
515
  class ScaleNumberToZeroOneReturnZeroIfFails(FieldOperator):
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.22.1"
 
1
+ version = "1.22.2"