Spaces:
Running
Running
File size: 6,778 Bytes
28be125 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import logging
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from huggingface_hub.inference._generated.types import ChatCompletionOutput
from huggingface_hub.utils import HfHubHTTPError
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Load environment variables from .env file
# load_dotenv() # Removed: This should be loaded only at the main entry point (app.py)
load_dotenv() # Restored: Ensure env vars are loaded when this module is imported/used
HF_TOKEN = os.getenv("HF_TOKEN")
HF_INFERENCE_ENDPOINT_URL = os.getenv("HF_INFERENCE_ENDPOINT_URL")
# Default parameters for the LLM call
DEFAULT_MAX_TOKENS = 2048
DEFAULT_TEMPERATURE = 0.1 # Lower temperature for more deterministic analysis
# Special dictionary to indicate a 503 error
ERROR_503_DICT = {"error_type": "503", "message": "Service Unavailable"}
def query_qwen_endpoint(
formatted_prompt: list[dict[str, str]], max_tokens: int = DEFAULT_MAX_TOKENS
) -> ChatCompletionOutput | dict | None:
"""
Queries the specified Qwen Inference Endpoint with the formatted prompt.
Args:
formatted_prompt: A list of message dictionaries for the chat completion API.
max_tokens: The maximum number of tokens to generate.
Returns:
The ChatCompletionOutput object from the inference client,
a specific dictionary (ERROR_503_DICT) if a 503 error occurs,
or None if another error occurs.
"""
if not HF_INFERENCE_ENDPOINT_URL:
logging.error("HF_INFERENCE_ENDPOINT_URL environment variable not set.")
return None
if not HF_TOKEN:
logging.warning(
"HF_TOKEN environment variable not set. Requests might fail if the endpoint requires authentication."
)
# Depending on endpoint config, it might still work without token
logging.info(f"Querying Inference Endpoint: {HF_INFERENCE_ENDPOINT_URL}")
client = InferenceClient(model=HF_INFERENCE_ENDPOINT_URL, token=HF_TOKEN)
try:
response = client.chat_completion(
messages=formatted_prompt,
max_tokens=max_tokens,
temperature=DEFAULT_TEMPERATURE,
# Qwen models often benefit from setting stop sequences if known,
# but we'll rely on max_tokens and model's natural stopping for now.
# stop=["<|im_end|>"] # Example stop token if needed for specific Qwen finetunes
)
logging.info("Successfully received response from Inference Endpoint.")
return response
except HfHubHTTPError as e:
# Check specifically for 503 Service Unavailable
if e.response is not None and e.response.status_code == 503:
logging.warning(
f"Encountered 503 Service Unavailable from endpoint: {HF_INFERENCE_ENDPOINT_URL}"
)
return ERROR_503_DICT # Return special dict for 503
else:
# Handle other HTTP errors
logging.error(f"HTTP error querying Inference Endpoint: {e}")
if e.response is not None:
logging.error(f"Response details: {e.response.text}")
return None # Return None for other HTTP errors
except Exception as e:
logging.error(f"An unexpected error occurred querying Inference Endpoint: {e}")
return None
def parse_qwen_response(response: ChatCompletionOutput | dict | None) -> str:
"""
Parses the response from the Qwen model to extract the generated text.
Handles potential None or error dict inputs.
Args:
response: The ChatCompletionOutput object, ERROR_503_DICT, or None.
Returns:
The extracted response text as a string, or an error message string.
"""
if response is None:
return "Error: Failed to get response from the language model."
# Check if it's our specific 503 error signal before trying to parse as ChatCompletionOutput
if isinstance(response, dict) and response.get("error_type") == "503":
return f"Error: {response['error_type']} {response['message']}"
# Check if it's likely the expected ChatCompletionOutput structure
if not hasattr(response, "choices"):
logging.error(
f"Unexpected response type received by parse_qwen_response: {type(response)}. Content: {response}"
)
return "Error: Received an unexpected response format from the language model endpoint."
try:
# Access the generated content according to the ChatCompletionOutput structure
if response.choices and len(response.choices) > 0:
content = response.choices[0].message.content
if content:
logging.info("Successfully parsed response content.")
return content.strip()
else:
logging.warning("Response received, but content is empty.")
return "Error: Received an empty response from the language model."
else:
logging.warning("Response received, but no choices found.")
return "Error: No response choices found in the language model output."
except AttributeError as e:
# This might catch cases where response looks like the object but lacks expected attributes
logging.error(
f"Attribute error parsing response: {e}. Response structure might be unexpected."
)
logging.error(f"Raw response object: {response}")
return "Error: Could not parse the structure of the language model response."
except Exception as e:
logging.error(f"An unexpected error occurred parsing the response: {e}")
return "Error: An unexpected error occurred while parsing the language model response."
# Example Usage (for testing - requires .env setup and potentially prompts.py)
# if __name__ == '__main__':
# # This example assumes you have a prompts.py that can generate a test prompt
# try:
# from prompts import format_code_for_analysis
# # Create a dummy prompt for testing
# test_files = {"app.py": "print('hello')"}
# test_prompt = format_code_for_analysis("test/minimal", test_files)
# print("--- Sending Test Prompt ---")
# print(test_prompt)
# api_response = query_qwen_endpoint(test_prompt)
# print("\n--- Raw API Response ---")
# print(api_response)
# print("\n--- Parsed Response ---")
# parsed_text = parse_qwen_response(api_response)
# print(parsed_text)
# except ImportError:
# print("Could not import prompts.py for testing. Run this test from the project root.")
# except Exception as e:
# print(f"An error occurred during testing: {e}")
|