File size: 2,431 Bytes
6a1e686
 
c4ad33b
6a1e686
 
5d27647
6a1e686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d27647
6a1e686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4ad33b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a1e686
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import click
from .predict import predict_single
from .evaluate import evaluate_batch
import warnings
from transformers import logging as hf_logging
from .config import HF_REPO

def configure_logging(debug):
    """Configure warning and logging levels based on debug flag"""
    if not debug:
        warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
        hf_logging.set_verbosity_error()
    else:
        hf_logging.set_verbosity_info()
        warnings.simplefilter("default")

@click.group()
@click.option('--debug', is_flag=True, help="Enable debug output including warnings")
@click.pass_context
def cli(ctx, debug):
    """Qwen Multi-label Classifier CLI"""
    ctx.ensure_object(dict)
    ctx.obj['DEBUG'] = debug
    configure_logging(debug)

@cli.command()
@click.argument('text')
@click.option('--hf-token', envvar="HF_TOKEN", help="HF API token (or set HF_TOKEN env variable)")
@click.option('--hf-repo', default=HF_REPO, help="Hugging Face model repo")
@click.option('--backend', 
              type=click.Choice(['local', 'hf'], case_sensitive=False),
              default='local',
              help="Inference backend: 'local' (your machine) or 'hf' (Hugging Face API)")
@click.pass_context
def predict(ctx, text, hf_repo, backend, hf_token):
    """Make prediction on a single text"""
    if ctx.obj['DEBUG']:
        click.echo("Debug mode enabled - showing all warnings")
    
    results = predict_single(
        text,
        hf_repo, 
        backend=backend, 
        hf_token=hf_token
    )
    click.echo(f"Prediction results: {results}")

@cli.command()
@click.argument('file_path')
@click.option('--hf-token', envvar="HF_TOKEN", help="HF API token (or set HF_TOKEN env variable)")
@click.option('--hf-repo', default=HF_REPO, help="Hugging Face model repo")
@click.option('--backend', 
              type=click.Choice(['local', 'hf'], case_sensitive=False),
              default='local',
              help="Inference backend: 'local' (your machine) or 'hf' (Hugging Face API)")
@click.pass_context
def evaluate(ctx, file_path, hf_repo, backend, hf_token):
    """Make prediction on a single text"""
    if ctx.obj['DEBUG']:
        click.echo("Debug mode enabled - showing all warnings")
    
    results = evaluate_batch(
        file_path,
        hf_repo, 
        backend=backend, 
        hf_token=hf_token
    )
    click.echo(f"Prediction results: {results}")