data_detective / tools /analyze_data.py
maiurilorenzo's picture
Update tools/analyze_data.py
e7e4046 verified
from smolagents import tool
import pandas as pd
@tool
def read_data(file_path: str) -> pd.DataFrame:
"""A tool that reads an Excel or CSV file from a given path and returns a pandas DataFrame.
Args:
file_path: The path to the Excel (.xlsx) or CSV (.csv) file.
Returns:
A pandas DataFrame containing the data from the file.
"""
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xls'):
df = pd.read_excel(file_path)
else:
raise f"Unsupported file extension: {file_path}"
return df
except Exception as e:
raise Exception(f"Error reading the file: {str(e)}")
@tool
def get_data_summary(df: pd.DataFrame) -> dict:
"""A tool that gives a summary of the data.
Args:
df: A pandas DataFrame.
Returns: A dictionary containing the number of rows and columns in the DataFrame, and a preview of the first few rows.
"""
try:
return {
"num_rows": df.shape[0],
"num_columns": df.shape[1],
"preview": df.head().to_dict()
}
except Exception as e:
raise Exception(f"Error in analyzing the dataset: {str(e)}")
import pandas as pd
@tool
def get_dataframe_statistics(data: dict) -> dict:
"""A tool that calculates statistical summaries of a pandas DataFrame.
Args:
data: A dictionary where keys are column names and values are lists of column values.
Returns:
A dictionary containing summary statistics such as mean, median, standard deviation,
and count for numerical columns.
"""
try:
# Convert input dictionary to DataFrame
df = pd.DataFrame(data)
# Generate summary statistics
stats = df.describe().to_dict()
# Convert NaN values to None for JSON compatibility
for col, col_stats in stats.items():
stats[col] = {key: (None if pd.isna(value) else value) for key, value in col_stats.items()}
return stats
except Exception as e:
raise Exception(f"error: {str(e)}")
@tool
def get_missing_values(data: dict) -> dict:
"""A tool that calculates the number and percentage of missing values in a pandas DataFrame.
Args:
data: A dictionary where keys are column names and values are lists of column values.
Returns:
A dictionary with column names as keys and missing value statistics (count and percentage).
"""
try:
df = pd.DataFrame(data)
missing_count = df.isnull().sum()
missing_percentage = (missing_count / len(df)) * 100
return {
col: {"missing_count": int(missing_count[col]), "missing_percentage": missing_percentage[col]}
for col in df.columns
}
except Exception as e:
return {"error": str(e)}
@tool
def get_duplicate_rows(data: dict) -> dict:
"""A tool that finds duplicate rows in a pandas DataFrame.
Args:
data: A dictionary where keys are column names and values are lists of column values.
Returns:
A dictionary with the number of duplicate rows and sample duplicate rows.
"""
try:
df = pd.DataFrame(data)
duplicates = df[df.duplicated(keep=False)]
return {
"duplicate_count": int(df.duplicated().sum()),
"duplicate_rows": duplicates.to_dict(orient="records"),
}
except Exception as e:
return {"error": str(e)}
@tool
def get_correlation_matrix(data: dict) -> dict:
"""A tool that calculates the correlation matrix for numerical columns in a pandas DataFrame.
Args:
data: A dictionary where keys are column names and values are lists of column values.
Returns:
A dictionary representing the correlation matrix.
"""
try:
df = pd.DataFrame(data)
correlation_matrix = df.corr().to_dict()
return correlation_matrix
except Exception as e:
return {"error": str(e)}