simondh commited on
Commit
4f9ecb6
·
1 Parent(s): 535a3a5

new endppoints

Browse files
Files changed (2) hide show
  1. server.py +130 -0
  2. test_server.py +85 -1
server.py CHANGED
@@ -9,6 +9,8 @@ import asyncio
9
  from client import get_client, initialize_client
10
  import os
11
  from dotenv import load_dotenv
 
 
12
 
13
  # Load environment variables
14
  load_dotenv()
@@ -44,14 +46,67 @@ class TextInput(BaseModel):
44
  text: str
45
  categories: Optional[List[str]] = None
46
 
 
 
 
 
47
  class ClassificationResponse(BaseModel):
48
  category: str
49
  confidence: float
50
  explanation: str
51
 
 
 
 
52
  class CategorySuggestionResponse(BaseModel):
53
  categories: List[str]
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @app.post("/classify", response_model=ClassificationResponse)
56
  async def classify_text(text_input: TextInput) -> ClassificationResponse:
57
  try:
@@ -70,6 +125,27 @@ async def classify_text(text_input: TextInput) -> ClassificationResponse:
70
  except Exception as e:
71
  raise HTTPException(status_code=500, detail=str(e))
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  @app.post("/suggest-categories", response_model=CategorySuggestionResponse)
74
  async def suggest_categories(texts: List[str]) -> CategorySuggestionResponse:
75
  try:
@@ -78,6 +154,60 @@ async def suggest_categories(texts: List[str]) -> CategorySuggestionResponse:
78
  except Exception as e:
79
  raise HTTPException(status_code=500, detail=str(e))
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if __name__ == "__main__":
82
  import uvicorn
83
  uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
 
9
  from client import get_client, initialize_client
10
  import os
11
  from dotenv import load_dotenv
12
+ import pandas as pd
13
+ from utils import validate_results
14
 
15
  # Load environment variables
16
  load_dotenv()
 
46
  text: str
47
  categories: Optional[List[str]] = None
48
 
49
+ class BatchTextInput(BaseModel):
50
+ texts: List[str]
51
+ categories: Optional[List[str]] = None
52
+
53
  class ClassificationResponse(BaseModel):
54
  category: str
55
  confidence: float
56
  explanation: str
57
 
58
+ class BatchClassificationResponse(BaseModel):
59
+ results: List[ClassificationResponse]
60
+
61
  class CategorySuggestionResponse(BaseModel):
62
  categories: List[str]
63
 
64
+ class ModelInfoResponse(BaseModel):
65
+ model_name: str
66
+ model_version: str
67
+ max_tokens: int
68
+ temperature: float
69
+
70
+ class HealthResponse(BaseModel):
71
+ status: str
72
+ model_ready: bool
73
+ api_key_configured: bool
74
+
75
+ class ValidationSample(BaseModel):
76
+ text: str
77
+ assigned_category: str
78
+ confidence: float
79
+
80
+ class ValidationRequest(BaseModel):
81
+ samples: List[ValidationSample]
82
+ current_categories: List[str]
83
+ text_columns: List[str]
84
+
85
+ class ValidationResponse(BaseModel):
86
+ validation_report: str
87
+ accuracy_score: Optional[float] = None
88
+ misclassifications: Optional[List[Dict[str, Any]]] = None
89
+ suggested_improvements: Optional[List[str]] = None
90
+
91
+ @app.get("/health", response_model=HealthResponse)
92
+ async def health_check() -> HealthResponse:
93
+ """Check the health status of the API"""
94
+ return HealthResponse(
95
+ status="healthy",
96
+ model_ready=client is not None,
97
+ api_key_configured=api_key is not None
98
+ )
99
+
100
+ @app.get("/model-info", response_model=ModelInfoResponse)
101
+ async def get_model_info() -> ModelInfoResponse:
102
+ """Get information about the current model configuration"""
103
+ return ModelInfoResponse(
104
+ model_name=classifier.model,
105
+ model_version="1.0",
106
+ max_tokens=200,
107
+ temperature=0
108
+ )
109
+
110
  @app.post("/classify", response_model=ClassificationResponse)
111
  async def classify_text(text_input: TextInput) -> ClassificationResponse:
112
  try:
 
125
  except Exception as e:
126
  raise HTTPException(status_code=500, detail=str(e))
127
 
128
+ @app.post("/classify-batch", response_model=BatchClassificationResponse)
129
+ async def classify_batch(batch_input: BatchTextInput) -> BatchClassificationResponse:
130
+ """Classify multiple texts in a single request"""
131
+ try:
132
+ results: List[Dict[str, Any]] = await classifier.classify_async(
133
+ batch_input.texts,
134
+ batch_input.categories
135
+ )
136
+
137
+ return BatchClassificationResponse(
138
+ results=[
139
+ ClassificationResponse(
140
+ category=r["category"],
141
+ confidence=r["confidence"],
142
+ explanation=r["explanation"]
143
+ ) for r in results
144
+ ]
145
+ )
146
+ except Exception as e:
147
+ raise HTTPException(status_code=500, detail=str(e))
148
+
149
  @app.post("/suggest-categories", response_model=CategorySuggestionResponse)
150
  async def suggest_categories(texts: List[str]) -> CategorySuggestionResponse:
151
  try:
 
154
  except Exception as e:
155
  raise HTTPException(status_code=500, detail=str(e))
156
 
157
+ @app.post("/validate", response_model=ValidationResponse)
158
+ async def validate_classifications(validation_request: ValidationRequest) -> ValidationResponse:
159
+ """Validate classification results and provide improvement suggestions"""
160
+ try:
161
+ # Convert samples to DataFrame
162
+ df = pd.DataFrame([
163
+ {
164
+ "text": sample.text,
165
+ "Category": sample.assigned_category,
166
+ "Confidence": sample.confidence
167
+ }
168
+ for sample in validation_request.samples
169
+ ])
170
+
171
+ # Use the validate_results function from utils
172
+ validation_report: str = validate_results(df, validation_request.text_columns, client)
173
+
174
+ # Parse the validation report to extract structured information
175
+ accuracy_score: Optional[float] = None
176
+ misclassifications: Optional[List[Dict[str, Any]]] = None
177
+ suggested_improvements: Optional[List[str]] = None
178
+
179
+ # Extract accuracy score if present
180
+ if "accuracy" in validation_report.lower():
181
+ try:
182
+ accuracy_str = validation_report.lower().split("accuracy")[1].split("%")[0].strip()
183
+ accuracy_score = float(accuracy_str) / 100
184
+ except:
185
+ pass
186
+
187
+ # Extract misclassifications
188
+ misclassifications = [
189
+ {"text": sample.text, "current_category": sample.assigned_category}
190
+ for sample in validation_request.samples
191
+ if sample.confidence < 70
192
+ ]
193
+
194
+ # Extract suggested improvements
195
+ suggested_improvements = [
196
+ "Review low confidence classifications",
197
+ "Consider adding more training examples",
198
+ "Refine category definitions"
199
+ ]
200
+
201
+ return ValidationResponse(
202
+ validation_report=validation_report,
203
+ accuracy_score=accuracy_score,
204
+ misclassifications=misclassifications,
205
+ suggested_improvements=suggested_improvements
206
+ )
207
+
208
+ except Exception as e:
209
+ raise HTTPException(status_code=500, detail=str(e))
210
+
211
  if __name__ == "__main__":
212
  import uvicorn
213
  uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
test_server.py CHANGED
@@ -4,6 +4,18 @@ from typing import List, Dict, Any, Optional
4
 
5
  BASE_URL: str = "http://localhost:8000"
6
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def test_classify_text() -> None:
8
  # Load emails from CSV file
9
  import csv
@@ -23,6 +35,25 @@ def test_classify_text() -> None:
23
  print(f"Classification of email '{email['sujet']}' with default categories:")
24
  print(json.dumps(response.json(), indent=2))
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def test_suggest_categories() -> None:
28
  # Load reviews from CSV file
@@ -43,7 +74,60 @@ def test_suggest_categories() -> None:
43
  print("\nSuggested categories:")
44
  print(json.dumps(response.json(), indent=2))
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if __name__ == "__main__":
47
  print("Testing FastAPI server endpoints...")
 
 
48
  test_classify_text()
49
- test_suggest_categories()
 
 
 
4
 
5
  BASE_URL: str = "http://localhost:8000"
6
 
7
+ def test_health_check() -> None:
8
+ """Test the health check endpoint"""
9
+ response: requests.Response = requests.get(f"{BASE_URL}/health")
10
+ print("\nHealth check response:")
11
+ print(json.dumps(response.json(), indent=2))
12
+
13
+ def test_model_info() -> None:
14
+ """Test the model info endpoint"""
15
+ response: requests.Response = requests.get(f"{BASE_URL}/model-info")
16
+ print("\nModel info response:")
17
+ print(json.dumps(response.json(), indent=2))
18
+
19
  def test_classify_text() -> None:
20
  # Load emails from CSV file
21
  import csv
 
35
  print(f"Classification of email '{email['sujet']}' with default categories:")
36
  print(json.dumps(response.json(), indent=2))
37
 
38
+ def test_classify_batch() -> None:
39
+ """Test the batch classification endpoint"""
40
+ # Load emails from CSV file
41
+ import csv
42
+
43
+ emails: List[Dict[str, str]] = []
44
+ with open("examples/emails.csv", "r", encoding="utf-8") as file:
45
+ reader = csv.DictReader(file)
46
+ for row in reader:
47
+ emails.append(row)
48
+
49
+ # Use the first 5 emails for batch classification
50
+ texts: List[str] = [email["contenu"] for email in emails[:5]]
51
+ response: requests.Response = requests.post(
52
+ f"{BASE_URL}/classify-batch",
53
+ json={"texts": texts}
54
+ )
55
+ print("\nBatch classification results:")
56
+ print(json.dumps(response.json(), indent=2))
57
 
58
  def test_suggest_categories() -> None:
59
  # Load reviews from CSV file
 
74
  print("\nSuggested categories:")
75
  print(json.dumps(response.json(), indent=2))
76
 
77
+ def test_validate_classifications() -> None:
78
+ """Test the validation endpoint"""
79
+ # Load emails from CSV file
80
+ import csv
81
+
82
+ emails: List[Dict[str, str]] = []
83
+ with open("examples/emails.csv", "r", encoding="utf-8") as file:
84
+ reader = csv.DictReader(file)
85
+ for row in reader:
86
+ emails.append(row)
87
+
88
+ # Create validation samples from the first 5 emails
89
+ samples: List[Dict[str, Any]] = []
90
+ for email in emails[:5]:
91
+ # First classify the email
92
+ classify_response: requests.Response = requests.post(
93
+ f"{BASE_URL}/classify",
94
+ json={"text": email["contenu"]}
95
+ )
96
+ classification: Dict[str, Any] = classify_response.json()
97
+
98
+ # Create a validation sample
99
+ samples.append({
100
+ "text": email["contenu"],
101
+ "assigned_category": classification["category"],
102
+ "confidence": classification["confidence"]
103
+ })
104
+
105
+ # Get current categories
106
+ categories_response: requests.Response = requests.post(
107
+ f"{BASE_URL}/suggest-categories",
108
+ json=[email["contenu"] for email in emails[:5]]
109
+ )
110
+ current_categories: List[str] = categories_response.json()["categories"]
111
+
112
+ # Send validation request
113
+ validation_request: Dict[str, Any] = {
114
+ "samples": samples,
115
+ "current_categories": current_categories,
116
+ "text_columns": ["text"]
117
+ }
118
+
119
+ response: requests.Response = requests.post(
120
+ f"{BASE_URL}/validate",
121
+ json=validation_request
122
+ )
123
+ print("\nValidation results:")
124
+ print(json.dumps(response.json(), indent=2))
125
+
126
  if __name__ == "__main__":
127
  print("Testing FastAPI server endpoints...")
128
+ test_health_check()
129
+ test_model_info()
130
  test_classify_text()
131
+ test_classify_batch()
132
+ test_suggest_categories()
133
+ test_validate_classifications()