Didier commited on
Commit
2854813
·
verified ·
1 Parent(s): 89bc57d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +185 -0
  2. icij_utils.py +307 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """app.py
2
+
3
+ Smolagents agent given an SQL tool over a SQLite database built with data files
4
+ from the Internation Consortium of Investigative Journalism (ICIJ.org).
5
+
6
+ Agentic framework:
7
+ - smolagents
8
+
9
+ Database:
10
+ - SQLite
11
+
12
+ Generation:
13
+ - Mistral
14
+
15
+ :author: Didier Guillevic
16
+ :date: 2025-01-12
17
+ """
18
+
19
+ import gradio as gr
20
+ import icij_utils
21
+ import smolagents
22
+ import os
23
+
24
+ #
25
+ # Init a SQLite database with the data files from ICIJ.org
26
+ #
27
+ ICIJ_LEAKS_DB_NAME = 'icij_leaks.db'
28
+ ICIJ_LEAKS_DATA_DIR = './icij_data'
29
+
30
+ # Remove existing database (if present), since we will recreate it below.
31
+ icij_db_path = Path(ICIJ_LEAKS_DB_NAME)
32
+ icij_db_path.unlink(missing_ok=True)
33
+
34
+ # Load ICIJ data files into an SQLite database
35
+ loader = icij_utils.ICIJDataLoader(ICIJ_LEAKS_DB_NAME)
36
+ loader.load_all_files(ICIJ_LEAKS_DATA_DIR)
37
+
38
+ #
39
+ # Init an SQLAchemy instane (over the SQLite database)
40
+ #
41
+ db = icij_utils.ICIJDatabaseConnector(ICIJ_LEAKS_DB_NAME)
42
+ schema = db.get_full_database_schema()
43
+
44
+ #
45
+ # Build an SQL tool
46
+ #
47
+ schema = db.get_full_database_schema()
48
+ metadata = icij_utils.ICIJDatabaseMetadata()
49
+
50
+ tool_description = (
51
+ "Tool for querying the ICIJ offshore database containing financial data leaks. "
52
+ "This tool can execute SQL queries and return the results. "
53
+ "Beware that this tool's output is a string representation of the execution output.\n"
54
+ "It can use the following tables:"
55
+ )
56
+
57
+ # Add table documentation
58
+ for table, doc in metadata.TABLE_DOCS.items():
59
+ tool_description += f"\n\nTable: {table}\n"
60
+ tool_description += f"Description: {doc.strip()}\n"
61
+ tool_description += "Columns:\n"
62
+
63
+ # Add column documentation and types
64
+ if table in schema:
65
+ for col_name, col_type in schema[table].items():
66
+ col_doc = metadata.COLUMN_DOCS.get(table, {}).get(col_name, "No documentation available")
67
+ #tool_description += f" - {col_name}: {col_type}: {col_doc}\n"
68
+ tool_description += f" - {col_name}: {col_type}\n"
69
+
70
+ # Add source documentation
71
+ #tool_description += "\n\nSource IDs:\n"
72
+ #for source_id, descrip in metadata.SOURCE_IDS.items():
73
+ # tool_description += f"- {source_id}: {descrip}\n"
74
+
75
+ @smolagents.tool
76
+ def sql_tool(query: str) -> str:
77
+ """Description to be set beloiw...
78
+
79
+ Args:
80
+ query: The query to perform. This should be correct SQL.
81
+ """
82
+ output = ""
83
+ with db.get_engine().connect() as con:
84
+ rows = con.execute(sqlalchemy.text(query))
85
+ for row in rows:
86
+ output += "\n" + str(row)
87
+ return output
88
+
89
+ sql_tool.description = tool_description
90
+
91
+ #
92
+ # language models
93
+ #
94
+ default_model = smolagents.HfApiModel()
95
+
96
+ mistral_api_key = os.environ["MISTRAL_API_KEY"]
97
+ mistral_model_id = "mistral/codestral-latest"
98
+ mistral_model = smolagents.LiteLLMModel(
99
+ model_id=mistral_model_id, api_key=mistral_api_key)
100
+
101
+ #
102
+ # Define the agent
103
+ #
104
+ agent = smolagents.CodeAgent(
105
+ tools=[sql_engine],
106
+ model=mistral_model
107
+ )
108
+
109
+ def generate_response(query: str) -> str:
110
+ """Generate a response given query.
111
+
112
+ Args:
113
+
114
+ Returns:
115
+ - the response from the agent having access to a database over the ICIJ
116
+ data and a large language model.
117
+ """
118
+ agent_output = agent.run(query)
119
+ return agent_output
120
+
121
+
122
+ #
123
+ # User interface
124
+ #
125
+ with gr.Blocks() as demo:
126
+ gr.Markdown("""
127
+ # SQL agent
128
+ Database: ICIJ data on offshore financial data leaks.
129
+ """)
130
+
131
+ # Inputs: question
132
+ question = gr.Textbox(
133
+ label="Question to answer",
134
+ placeholder=""
135
+ )
136
+
137
+ # Response
138
+ response = gr.Textbox(
139
+ label="Response",
140
+ placeholder=""
141
+ )
142
+
143
+ # Button
144
+ with gr.Row():
145
+ response_button = gr.Button("Submit", variant='primary')
146
+ clear_button = gr.Button("Clear", variant='secondary')
147
+
148
+ # Example questions given default provided PDF file
149
+ with gr.Accordion("Sample questions", open=False):
150
+ gr.Examples(
151
+ [
152
+ ["",],
153
+ ["",],
154
+ ],
155
+ inputs=[question,],
156
+ outputs=[response,],
157
+ fn=generate_response,
158
+ cache_examples=False,
159
+ label="Sample questions"
160
+ )
161
+
162
+ # Documentation
163
+ with gr.Accordion("Documentation", open=False):
164
+ gr.Markdown("""
165
+ - Agentic framework: smolagents
166
+ - Data: icij.org
167
+ - Database: SQLite, SQLAlchemy
168
+ - Generation: Mistral
169
+ - Examples: Generated using Claude.ai
170
+ """)
171
+
172
+ # Click actions
173
+ response_button.click(
174
+ fn=generate_response,
175
+ inputs=[question,],
176
+ outputs=[response,]
177
+ )
178
+ clear_button.click(
179
+ fn=lambda: ('', ''),
180
+ inputs=[],
181
+ outputs=[question, response]
182
+ )
183
+
184
+
185
+ demo.launch(show_api=False)
icij_utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """icij_utils.py
2
+
3
+ Building an SQL agent over the ICIJ financial data leaks files.
4
+
5
+ :author: Didier Guillevic
6
+ :date: 2025-01-12
7
+ """
8
+
9
+ import logging
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ import pandas as pd
14
+ import sqlite3
15
+ import os
16
+ from pathlib import Path
17
+
18
+ from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, Float
19
+ from sqlalchemy.ext.declarative import declarative_base
20
+ from sqlalchemy.orm import sessionmaker
21
+
22
+
23
+ class ICIJDataLoader:
24
+ def __init__(self, db_path='icij_data.db'):
25
+ """Initialize the data loader with database path."""
26
+ self.db_path = db_path
27
+ self.table_mappings = {
28
+ 'nodes-addresses.csv': 'addresses',
29
+ 'nodes-entities.csv': 'entities',
30
+ 'nodes-intermediaries.csv': 'intermediaries',
31
+ 'nodes-officers.csv': 'officers',
32
+ 'nodes-others.csv': 'others',
33
+ 'relationships.csv': 'relationships'
34
+ }
35
+
36
+ def create_connection(self):
37
+ """Create a database connection."""
38
+ try:
39
+ conn = sqlite3.connect(self.db_path)
40
+ return conn
41
+ except sqlite3.Error as e:
42
+ print(f"Error connecting to database: {e}")
43
+ return None
44
+
45
+ def create_table_from_csv(self, csv_path, table_name, conn):
46
+ """Create a table based on CSV structure and load data."""
47
+ try:
48
+ # Read the first few rows to get column names and types
49
+ df = pd.read_csv(csv_path, nrows=5)
50
+
51
+ # Create table with appropriate columns
52
+ columns = []
53
+ for col in df.columns:
54
+ # Determine SQLite type based on pandas dtype
55
+ dtype = df[col].dtype
56
+ if 'int' in str(dtype):
57
+ sql_type = 'INTEGER'
58
+ elif 'float' in str(dtype):
59
+ sql_type = 'REAL'
60
+ else:
61
+ sql_type = 'TEXT'
62
+ columns.append(f'"{col}" {sql_type}')
63
+
64
+ # Create table
65
+ create_table_sql = f'''CREATE TABLE IF NOT EXISTS {table_name}
66
+ ({', '.join(columns)})'''
67
+ conn.execute(create_table_sql)
68
+
69
+ # Load data in chunks to handle large files
70
+ chunksize = 10000
71
+ for chunk in pd.read_csv(csv_path, chunksize=chunksize):
72
+ chunk.to_sql(table_name, conn, if_exists='append', index=False)
73
+
74
+ print(f"Successfully loaded {table_name}")
75
+ return True
76
+
77
+ except Exception as e:
78
+ print(f"Error processing {csv_path}: {e}")
79
+ return False
80
+
81
+ def load_all_files(self, data_directory):
82
+ """Load all recognized CSV files from the directory into SQLite."""
83
+ conn = self.create_connection()
84
+ if not conn:
85
+ return False
86
+
87
+ try:
88
+ data_path = Path(data_directory)
89
+ files_processed = 0
90
+
91
+ for csv_file, table_name in self.table_mappings.items():
92
+ file_path = data_path / csv_file
93
+ if file_path.exists():
94
+ print(f"Processing {csv_file}...")
95
+ if self.create_table_from_csv(file_path, table_name, conn):
96
+ files_processed += 1
97
+
98
+ # Create indexes for better query performance
99
+ self.create_indexes(conn)
100
+
101
+ conn.commit()
102
+ print(f"Successfully processed {files_processed} files")
103
+ return True
104
+
105
+ except Exception as e:
106
+ print(f"Error during data loading: {e}")
107
+ return False
108
+
109
+ finally:
110
+ conn.close()
111
+
112
+ def create_indexes(self, conn):
113
+ """Create indexes for better query performance."""
114
+ index_definitions = [
115
+ 'CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(name)',
116
+ 'CREATE INDEX IF NOT EXISTS idx_officers_name ON officers(name)',
117
+ 'CREATE INDEX IF NOT EXISTS idx_relationships_from ON relationships(node_id_start)',
118
+ 'CREATE INDEX IF NOT EXISTS idx_relationships_to ON relationships(node_id_end)'
119
+ ]
120
+
121
+ for index_sql in index_definitions:
122
+ try:
123
+ conn.execute(index_sql)
124
+ except sqlite3.Error as e:
125
+ print(f"Error creating index: {e}")
126
+
127
+
128
+ class ICIJDatabaseConnector:
129
+ def __init__(self, db_path='icij_leaks.db'):
130
+ # Create the SQLAlchemy engine
131
+ self.engine = create_engine(f'sqlite:///{db_path}', echo=False)
132
+
133
+ # Create declarative base
134
+ self.Base = declarative_base()
135
+
136
+ # Create session factory
137
+ self.Session = sessionmaker(bind=self.engine)
138
+
139
+ # Initialize metadata
140
+ self.metadata = MetaData()
141
+
142
+ # Reflect existing tables
143
+ self.metadata.reflect(bind=self.engine)
144
+
145
+ def get_engine(self):
146
+ """Return the SQLAlchemy engine."""
147
+ return self.engine
148
+
149
+ def get_session(self):
150
+ """Create and return a new session."""
151
+ return self.Session()
152
+
153
+ def get_table(self, table_name):
154
+ """Get a table by name from the metadata."""
155
+ return self.metadata.tables.get(table_name)
156
+
157
+ def list_tables(self):
158
+ """List all available tables in the database."""
159
+ return list(self.metadata.tables.keys())
160
+
161
+ def get_table_schema(self, table_name):
162
+ """Get column names and their types for a specific table."""
163
+ table = self.get_table(table_name)
164
+ if table is not None:
165
+ return {column.name: str(column.type) for column in table.columns}
166
+ return {}
167
+
168
+ def get_full_database_schema(self):
169
+ """Get the schema for all tables in the database."""
170
+ schema = {}
171
+ for table_name in self.list_tables():
172
+ schema[table_name] = self.get_table_schema(table_name)
173
+ return schema
174
+
175
+ def get_table_columns(self, table_name):
176
+ """Get column names for a specific table."""
177
+ table = self.get_table(table_name)
178
+ if table is not None:
179
+ return [column.name for column in table.columns]
180
+ return []
181
+
182
+ def query_table(self, table_name, limit=1):
183
+ """Execute a simple query on a table."""
184
+ table = self.get_table(table_name)
185
+ if table is not None:
186
+ stmt = select(table).limit(limit)
187
+ with self.engine.connect() as connection:
188
+ result = connection.execute(stmt)
189
+ return [dict(row) for row in result]
190
+ return []
191
+
192
+
193
+ class ICIJDatabaseMetadata:
194
+ """Holds detailed documentation about the ICIJ database structure."""
195
+
196
+ # Comprehensive table documentation
197
+ TABLE_DOCS = {
198
+ 'entities': (
199
+ "Contains information about companies, trusts, and other entities mentioned in the leaks. "
200
+ "These are typically offshore entities created in tax havens."
201
+ ),
202
+
203
+ 'officers': (
204
+ "Contains information about people or organizations connected to offshore entities. "
205
+ "Officers can be directors, shareholders, beneficiaries, or have other roles."
206
+ ),
207
+ 'intermediaries': (
208
+ "Contains information about professional firms that help create and manage offshore entities. "
209
+ "These are typically law firms, banks, or corporate service providers."
210
+ ),
211
+ 'addresses': (
212
+ "Contains physical address information connected to entities, officers, or intermediaries. "
213
+ "Addresses can be shared between multiple parties."
214
+ ),
215
+ 'others': (
216
+ "Contains information about miscellaneous parties that don't fit into other categories. "
217
+ "This includes vessel names, legal cases, events, and other related parties mentioned "
218
+ "in the leaks that aren't classified as entities, officers, or intermediaries."
219
+ ),
220
+ 'relationships': (
221
+ "Defines connections between different nodes (entities, officers, intermediaries) in the database. "
222
+ "Shows how different parties are connected to each other."
223
+ )
224
+ }
225
+
226
+ # Detailed column documentation for each table
227
+ COLUMN_DOCS = {
228
+ 'entities': {
229
+ 'name': "Legal name of the offshore entity",
230
+ 'original_name': "Name in original language/character set",
231
+ 'former_name': "Previous names of the entity",
232
+ 'jurisdiction': "Country/region where the entity is registered",
233
+ 'jurisdiction_description': "Detailed description of the jurisdiction",
234
+ 'company_type': "Legal structure of the entity (e.g., corporation, trust)",
235
+ 'address': "Primary registered address",
236
+ 'internal_id': "Unique identifier within the leak data",
237
+ 'incorporation_date': "Date when the entity was created",
238
+ 'inactivation_date': "Date when the entity became inactive",
239
+ 'struck_off_date': "Date when entity was struck from register",
240
+ 'dorm_date': "Date when entity became dormant",
241
+ 'status': "Current status of the entity",
242
+ 'service_provider': "Firm that provided offshore services",
243
+ 'source_id': "Identifier for the leak source"
244
+ },
245
+
246
+ 'others': {
247
+ 'name': "Name of the miscellaneous party or item",
248
+ 'type': "Type of the other party (e.g., vessel, legal case)",
249
+ 'incorporation_date': "Date of incorporation or creation if applicable",
250
+ 'jurisdiction': "Jurisdiction associated with the party",
251
+ 'countries': "Countries associated with the party",
252
+ 'status': "Current status",
253
+ 'internal_id': "Unique identifier within the leak data",
254
+ 'address': "Associated address if available",
255
+ 'source_id': "Identifier for the leak source",
256
+ 'valid_until': "Date until which the information is valid"
257
+ },
258
+
259
+ 'officers': {
260
+ 'name': "Name of the individual or organization",
261
+ 'country_codes': "Countries connected to the officer",
262
+ 'source_id': "Identifier for the leak source",
263
+ 'valid_until': "Date until which the information is valid",
264
+ 'status': "Current status of the officer",
265
+ 'internal_id': "Unique identifier within the leak data"
266
+ },
267
+
268
+ 'intermediaries': {
269
+ 'name': "Name of the professional firm",
270
+ 'internal_id': "Unique identifier within the leak data",
271
+ 'address': "Business address",
272
+ 'status': "Current status",
273
+ 'country_codes': "Countries where intermediary operates",
274
+ 'source_id': "Identifier for the leak source"
275
+ },
276
+
277
+ 'addresses': {
278
+ 'address': "Full address text",
279
+ 'name': "Name associated with address",
280
+ 'country_codes': "Country codes for the address",
281
+ 'countries': "Full country names",
282
+ 'source_id': "Identifier for the leak source",
283
+ 'valid_until': "Date until which address is valid",
284
+ 'internal_id': "Unique identifier within the leak data"
285
+ },
286
+
287
+ 'relationships': {
288
+ 'from_id': "Internal ID of the source node",
289
+ 'to_id': "Internal ID of the target node",
290
+ 'rel_type': "Type of relationship (e.g., shareholder, director)",
291
+ 'link': "Additional details about the relationship",
292
+ 'start_date': "When the relationship began",
293
+ 'end_date': "When the relationship ended",
294
+ 'source_id': "Identifier for the leak source",
295
+ 'status': "Current status of the relationship"
296
+ }
297
+ }
298
+
299
+ # Source documentation
300
+ SOURCE_IDS = {
301
+ "PANAMA_PAPERS": "Data from Panama Papers leak (2016)",
302
+ "PARADISE_PAPERS": "Data from Paradise Papers leak (2017)",
303
+ "BAHAMAS_LEAKS": "Data from Bahamas Leaks (2016)",
304
+ "OFFSHORE_LEAKS": "Data from Offshore Leaks (2013)",
305
+ "PANDORA_PAPERS": "Data from Pandora Papers leak (2021)"
306
+ }
307
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ smolagents
3
+ sqlite3
4
+ sqlalchemy