llt-prediction / app.py
ahsanr's picture
Create app.py
bb0a1f7 verified
raw
history blame contribute delete
62.3 kB
# Install TA-Lib (see instructions above) then: pip install TA-Lib
import ccxt
import numpy as np
import pandas as pd
import time
from sklearn.neighbors import KNeighborsClassifier
from scipy.linalg import svd
import gradio as gr
import concurrent.futures
import traceback
from datetime import datetime, timezone, timedelta
import logging
import sys
import talib # Import TA-Lib
import threading
# --- Setup Logging ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(threadName)s:%(funcName)s] - %(message)s',
stream=sys.stdout
)
logging.getLogger().handlers[0].flush = sys.stdout.flush
# --- Parameters ---
L = 10
LAG = 11
MINUTES_PER_HOUR = 60
PREDICTION_WINDOW_HOURS = 2
TRAINING_WINDOW_HOURS = 12
TOTAL_WINDOW_HOURS = TRAINING_WINDOW_HOURS + PREDICTION_WINDOW_HOURS
K = TRAINING_WINDOW_HOURS * MINUTES_PER_HOUR # 720
WINDOW = TOTAL_WINDOW_HOURS * MINUTES_PER_HOUR # 840
FEATURES = ['open', 'high', 'low', 'close', 'volume']
D = 5
OVERLAP_STEP = 60
MIN_TRAINING_EXAMPLES = 20
MAX_COINS_TO_DISPLAY = 10
USE_SYNTHETIC_DATA_FOR_LOW_VOLUME = False
NUM_WORKERS_TRAINING = 4
NUM_WORKERS_PREDICTION = 10
# --- TA & Risk Parameters ---
TA_DATA_POINTS = 200 # Candles needed for TA calculation
RSI_PERIOD = 14
MACD_FAST = 12
MACD_SLOW = 26
MACD_SIGNAL = 9
ATR_PERIOD = 14
CONFIDENCE_THRESHOLD = 0.65 # Min confidence for Rise signal
TP1_ATR_MULTIPLIER = 1.5
TP2_ATR_MULTIPLIER = 3.0
SL_ATR_MULTIPLIER = 1.0
# --- CCXT Initialization ---
try:
exchange = ccxt.bitget({
'enableRateLimit': True,
'rateLimit': 1100,
'timeout': 45000,
'options': {'adjustForTimeDifference': True}
})
logging.info(f"Initialized {exchange.id} exchange.")
except Exception as e:
logging.exception("FATAL: Could not initialize CCXT exchange.")
sys.exit()
# --- Global Caches and Variables ---
markets_cache = None
last_markets_update = None
data_cache = {}
trained_models = {}
last_update_time = datetime.now(timezone.utc)
# --- Functions ---
def format_datetime(dt, default="N/A"):
# (Keep this function as is)
if pd.isna(dt) or dt is None:
return default
try:
if isinstance(dt, (datetime, pd.Timestamp)):
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.strftime('%Y-%m-%d %H:%M:%S %Z')
else:
return str(dt)
except Exception:
return default
def get_all_usdt_pairs():
# (Keep this function as is - no changes needed)
global markets_cache, last_markets_update
current_time = time.time()
cache_duration = 3600 # 1 hour
if markets_cache is not None and last_markets_update is not None:
if current_time - last_markets_update < cache_duration:
logging.info("Using cached markets list.")
if isinstance(markets_cache, list) and markets_cache:
return markets_cache
else:
logging.warning("Cached market list was invalid, fetching fresh.")
logging.info("Fetching markets from Bitget...")
try:
exchange.load_markets(reload=True)
all_symbols = list(exchange.markets.keys())
usdt_pairs = [
symbol for symbol in all_symbols
if isinstance(symbol, str)
and symbol.endswith('/USDT')
and exchange.markets.get(symbol, {}).get('active', False)
and exchange.markets.get(symbol, {}).get('spot', False)
and 'LEVERAGED' not in exchange.markets.get(symbol, {}).get('type', 'spot').upper()
and not exchange.markets.get(symbol, {}).get('inverse', False)
]
logging.info(f"Found {len(usdt_pairs)} active USDT spot pairs initially.")
if not usdt_pairs:
logging.warning("No active USDT spot pairs found.")
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT']
logging.info(f"Fetching tickers for {len(usdt_pairs)} pairs for volume sorting...")
volumes = {}
symbols_to_fetch = usdt_pairs[:]
fetched_tickers = {}
try:
if exchange.has['fetchTickers']:
batch_size_tickers = 100
for i in range(0, len(symbols_to_fetch), batch_size_tickers):
batch_symbols = symbols_to_fetch[i:i+batch_size_tickers]
logging.info(f"Fetching ticker batch {i//batch_size_tickers + 1}/{ (len(symbols_to_fetch) + batch_size_tickers -1)//batch_size_tickers }...")
retries = 2
for attempt in range(retries):
try:
batch_tickers = exchange.fetch_tickers(symbols=batch_symbols)
fetched_tickers.update(batch_tickers)
time.sleep(exchange.rateLimit / 1000 * 1.5) # Add delay
break
except (ccxt.RequestTimeout, ccxt.NetworkError) as e_timeout:
logging.warning(f"Ticker fetch timeout/network error on attempt {attempt+1}/{retries}: {e_timeout}, retrying after delay...")
time.sleep(3 * (attempt + 1))
except ccxt.RateLimitExceeded:
logging.warning(f"Rate limit exceeded fetching tickers, sleeping...")
time.sleep(10 * (attempt+1)) # Longer sleep for rate limit
except Exception as e_ticker:
logging.error(f"Error fetching ticker batch (attempt {attempt+1}): {e_ticker}")
if attempt == retries - 1: raise # Rethrow last error
time.sleep(2 * (attempt + 1))
logging.info(f"Fetched {len(fetched_tickers)} tickers using fetchTickers.")
else:
raise ccxt.NotSupported("fetchTickers not supported/enabled. Volume sorting requires it.")
except Exception as e:
logging.exception(f"Could not fetch tickers for volume sorting: {e}. Volume sorting unavailable.")
markets_cache = usdt_pairs[:MAX_COINS_TO_DISPLAY]
last_markets_update = current_time
logging.warning(f"Returning top {len(markets_cache)} unsorted pairs due to ticker error.")
return markets_cache
for symbol, ticker in fetched_tickers.items():
try:
quote_volume = ticker.get('info', {}).get('quoteVolume') # Prefer quoteVolume if available
last_price = ticker.get('last')
base_volume = ticker.get('baseVolume')
# Ensure values are convertible to float before calculation
valid_last = last_price is not None
valid_base = base_volume is not None
valid_quote = quote_volume is not None
if valid_quote:
volumes[symbol] = float(quote_volume)
elif valid_base and valid_last:
volumes[symbol] = float(base_volume) * float(last_price)
else:
volumes[symbol] = 0
except (TypeError, ValueError, KeyError, AttributeError) as e:
logging.warning(f"Could not parse volume/price for {symbol} from ticker: {ticker}. Error: {e}")
volumes[symbol] = 0
valid_volume_pairs = {k: v for k, v in volumes.items() if v > 0}
logging.info(f"Found {len(valid_volume_pairs)} pairs with non-zero volume.")
if not valid_volume_pairs:
logging.warning("No pairs with valid volume found. Returning default list.")
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT']
sorted_pairs = sorted(valid_volume_pairs.items(), key=lambda item: item[1], reverse=True)
num_pairs_to_take = min(MAX_COINS_TO_DISPLAY, len(sorted_pairs))
top_pairs = [pair[0] for pair in sorted_pairs[:num_pairs_to_take]]
logging.info(f"Selected Top {len(top_pairs)} pairs by volume. Top 5: {[p[0] for p in sorted_pairs[:5]]}")
markets_cache = top_pairs
last_markets_update = current_time
return top_pairs
except ccxt.NetworkError as e:
logging.error(f"Network error getting USDT pairs: {e}")
except ccxt.ExchangeError as e:
logging.error(f"Exchange error getting USDT pairs: {e}")
except Exception as e:
logging.exception("General error getting USDT pairs.")
logging.warning("Error fetching markets, returning default fallback list.")
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'BNB/USDT', 'XRP/USDT']
def clean_and_process_ohlcv(ohlcv_list, symbol, expected_candles):
# (Keep this function as is - no changes needed)
if not ohlcv_list:
return None
try:
df = pd.DataFrame(ohlcv_list, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
initial_len = len(df)
if initial_len == 0: return None
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df = df.drop_duplicates(subset=['timestamp'])
df = df.sort_values('timestamp')
len_after_dupes = len(df)
numeric_cols = ['open', 'high', 'low', 'close', 'volume']
for col in numeric_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
# Drop rows with NaN in essential price/volume features needed for TA-Lib
df = df.dropna(subset=numeric_cols)
len_after_na = len(df)
df.reset_index(drop=True, inplace=True)
logging.debug(f"Data cleaning for {symbol}: Initial Fetched={initial_len}, AfterDupes={len_after_dupes}, AfterNA={len_after_na}")
if len(df) >= expected_candles:
final_df = df.iloc[-expected_candles:].copy() # Take the most recent ones
return final_df
else:
return None
except Exception as e:
logging.exception(f"Error processing DataFrame for {symbol}")
return None
def fetch_historical_data(symbol, timeframe='1m', total_candles=WINDOW):
# (Keep this function as is - no changes needed)
cache_key = f"{symbol}_{timeframe}_{total_candles}"
current_time = time.time()
cache_validity_seconds = 300 # 5 minutes
if cache_key in data_cache:
cache_time, cached_data = data_cache[cache_key]
if current_time - cache_time < cache_validity_seconds:
if isinstance(cached_data, pd.DataFrame) and len(cached_data) == total_candles:
logging.debug(f"Using valid cached data for {symbol} ({len(cached_data)} candles)")
return cached_data.copy()
else:
logging.warning(f"Cache for {symbol} invalid or wrong size ({len(cached_data) if isinstance(cached_data, pd.DataFrame) else 'N/A'} vs {total_candles}), fetching fresh.")
if cache_key in data_cache: del data_cache[cache_key]
if not exchange.has['fetchOHLCV']:
logging.error(f"Exchange {exchange.id} does not support fetchOHLCV.")
return None
logging.debug(f"Fetching {total_candles} candles for {symbol} (timeframe: {timeframe})")
final_df = None
fetch_start_time = time.time()
duration_ms = exchange.parse_timeframe(timeframe) * 1000
now_ms = exchange.milliseconds()
# --- Strategy 1: Try Single Large Fetch ---
single_fetch_limit = total_candles + 200 # Buffer
single_fetch_since = now_ms - single_fetch_limit * duration_ms
try:
ohlcv_list = exchange.fetch_ohlcv(symbol, timeframe, limit=single_fetch_limit, since=single_fetch_since)
if ohlcv_list:
processed_df = clean_and_process_ohlcv(ohlcv_list, symbol, total_candles)
if processed_df is not None and len(processed_df) == total_candles:
final_df = processed_df
except ccxt.RateLimitExceeded as e:
logging.warning(f"Rate limit hit during single fetch for {symbol}, falling back: {e}")
time.sleep(5)
except (ccxt.RequestTimeout, ccxt.NetworkError) as e:
logging.warning(f"Timeout/Network error during single fetch for {symbol}, falling back: {e}")
time.sleep(2)
except ccxt.ExchangeNotAvailable as e:
logging.error(f"Exchange not available during fetch for {symbol}: {e}")
return None
except ccxt.AuthenticationError as e:
logging.error(f"Authentication error fetching {symbol}: {e}")
return None
except ccxt.ExchangeError as e:
logging.warning(f"Exchange error during single fetch for {symbol}, falling back: {e}")
except Exception as e:
logging.exception(f"Unexpected error during single fetch for {symbol}, falling back.")
# --- Strategy 2: Fallback to Iterative Chunking ---
if final_df is None:
logging.debug(f"Falling back to iterative chunk fetching for {symbol}.")
limit_per_call = exchange.safe_integer(exchange.limits.get('fetchOHLCV', {}), 'max', 1000)
limit_per_call = min(limit_per_call, 1000)
all_ohlcv_chunks = []
required_start_time_ms = now_ms - (total_candles + 5) * duration_ms
current_chunk_end_time_ms = now_ms
max_chunk_attempts = 15
attempts = 0
while attempts < max_chunk_attempts:
attempts += 1
oldest_ts_in_hand = all_ohlcv_chunks[0][0] if all_ohlcv_chunks else current_chunk_end_time_ms
if oldest_ts_in_hand <= required_start_time_ms:
logging.debug(f"Chunking: Collected enough historical range for {symbol}.")
break
fetch_limit = limit_per_call
chunk_fetch_since = oldest_ts_in_hand - fetch_limit * duration_ms
params = {}
try:
ohlcv_chunk = exchange.fetch_ohlcv(symbol, timeframe, since=chunk_fetch_since, limit=fetch_limit, params=params)
if not ohlcv_chunk:
logging.debug(f"Chunking: No more data received for {symbol} from API.")
break
new_chunk = [c for c in ohlcv_chunk if c[0] < oldest_ts_in_hand]
if not new_chunk:
break
new_chunk.sort(key=lambda x: x[0])
all_ohlcv_chunks = new_chunk + all_ohlcv_chunks
if len(new_chunk) < limit_per_call // 20 and attempts > 5:
logging.warning(f"Chunking: Received very few new candles ({len(new_chunk)}) repeatedly for {symbol}.")
break
time.sleep(exchange.rateLimit / 1000 * 1.1)
except ccxt.RateLimitExceeded as e:
logging.warning(f"Rate limit hit during chunking for {symbol}, sleeping 10s: {e}")
time.sleep(10 * (attempts/3 + 1))
except (ccxt.NetworkError, ccxt.RequestTimeout) as e:
logging.error(f"Network/Timeout error during chunking for {symbol}: {e}. Stopping.")
break
except ccxt.ExchangeError as e:
logging.error(f"Exchange error during chunking for {symbol}: {e}. Stopping.")
break
except Exception as e:
logging.exception(f"Generic error during chunking fetch for {symbol}")
break
if attempts >= max_chunk_attempts:
logging.warning(f"Max chunk fetch attempts reached for {symbol}.")
if all_ohlcv_chunks:
processed_df = clean_and_process_ohlcv(all_ohlcv_chunks, symbol, total_candles)
if processed_df is not None and len(processed_df) == total_candles:
final_df = processed_df
else:
logging.error(f"No data obtained from chunk fetching for {symbol}.")
# --- Final Check and Caching ---
if final_df is not None and len(final_df) == total_candles:
expected_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
if all(col in final_df.columns for col in expected_cols):
data_cache[cache_key] = (current_time, final_df.copy())
return final_df
else:
logging.error(f"Final DataFrame for {symbol} missing expected columns. Won't cache.")
return None
else:
logging.error(f"Failed to fetch exactly {total_candles} candles for {symbol}. Found: {len(final_df) if final_df is not None else 0}")
return None
# --- Embedding, LLT, Normalize, Training Prep (Largely unchanged) ---
# Keep create_embedding, llt_transform, normalize_data, prepare_training_data, train_model
# as they don't depend on the TA library choice.
def create_embedding(data, l=L, lag=LAG):
# (Keep this function as is)
n = len(data)
rows = n - (l - 1) * lag
if rows <= 0:
logging.debug(f"Cannot create embedding: data length {n} too short for L={l}, Lag={lag}")
return np.array([])
A = np.zeros((rows, l))
try:
for t in range(rows):
indices = t + np.arange(l) * lag
A[t] = data[indices]
return A
except IndexError as e:
logging.error(f"IndexError during embedding: n={n}, l={l}, lag={lag}. Error: {e}")
return np.array([])
except Exception as e:
logging.exception("Error in create_embedding")
return np.array([])
def llt_transform(X_train, y_train, X_test):
# (Keep this function as is)
if not isinstance(X_train, np.ndarray) or X_train.ndim != 3 or \
not isinstance(y_train, np.ndarray) or y_train.ndim != 1 or \
not isinstance(X_test, np.ndarray) or (X_test.size > 0 and X_test.ndim != 3):
logging.error(f"LLT input type/shape error.")
return np.array([]), np.array([])
if X_train.shape[0] != y_train.shape[0]:
logging.error(f"LLT input mismatch: len(X_train) != len(y_train)")
return np.array([]), np.array([])
if X_train.size == 0 or y_train.size == 0:
logging.error("LLT requires non-empty training data.")
return np.array([]), np.array([])
if X_test.size > 0 and X_test.shape[1:] != X_train.shape[1:]:
logging.error(f"LLT train/test shape mismatch")
return np.array([]), np.array([])
try:
num_features = X_train.shape[2]
if num_features != len(FEATURES):
logging.error(f"LLT: Feature count mismatch.")
return np.array([]), np.array([])
V = {j: {'0': [], '1': []} for j in range(num_features)}
laws_computed_count = {j: {'0': 0, '1': 0} for j in range(num_features)}
for i in range(len(X_train)):
label = str(int(y_train[i]))
if label not in ['0', '1']: continue
for j in range(num_features):
feature_data = X_train[i, :, j]
A = create_embedding(feature_data, l=L, lag=LAG)
if A.shape[0] < L: continue
if np.isnan(A).any() or np.isinf(A).any(): continue
try:
S = A.T @ A
if np.isnan(S).any() or np.isinf(S).any(): continue
U, s, Vt = svd(S, full_matrices=False)
if Vt.shape[0] < L or Vt.shape[1] != L: continue
if s[-1] < 1e-9: continue
v = Vt[-1]
norm = np.linalg.norm(v)
if norm < 1e-9: continue
V[j][label].append(v / norm)
laws_computed_count[j][label] += 1
except np.linalg.LinAlgError: pass
except Exception: pass
valid_laws_exist = False
for j in V:
for c in ['0', '1']:
if laws_computed_count[j][c] > 0:
valid_vecs = [vec for vec in V[j][c] if isinstance(vec, np.ndarray) and vec.shape == (L,)]
if not valid_vecs:
V[j][c] = np.zeros((L, 0))
continue
try:
V[j][c] = np.array(valid_vecs).T
if V[j][c].shape[0] != L:
V[j][c] = np.zeros((L, 0))
else:
valid_laws_exist = True
except Exception: V[j][c] = np.zeros((L, 0))
else: V[j][c] = np.zeros((L, 0))
if not valid_laws_exist:
logging.error("LLT ERROR: No valid laws computed.")
return np.array([]), np.array([])
def transform_instance(X_instance):
transformed_features = []
if X_instance.ndim != 2 or X_instance.shape[0] != K or X_instance.shape[1] != num_features:
return np.zeros(num_features * 2 * D)
for j in range(num_features):
feature_data = X_instance[:, j]
A = create_embedding(feature_data, l=L, lag=LAG)
if A.shape[0] < L:
transformed_features.extend([0.0] * (2 * D))
continue
if np.isnan(A).any() or np.isinf(A).any():
transformed_features.extend([0.0] * (2 * D))
continue
try:
S = A.T @ A
if np.isnan(S).any() or np.isinf(S).any():
transformed_features.extend([0.0] * (2 * D))
continue
for c in ['0', '1']:
if V[j][c].shape[1] == 0:
transformed_features.extend([0.0] * D)
continue
S_V = S @ V[j][c]
if S_V.size == 0 or np.isnan(S_V).any() or np.isinf(S_V).any():
transformed_features.extend([0.0] * D)
continue
variances = np.var(S_V, axis=0)
if variances.size == 0:
transformed_features.extend([0.0] * D)
continue
variances = np.nan_to_num(variances, nan=np.finfo(variances.dtype).max, posinf=np.finfo(variances.dtype).max, neginf=np.finfo(variances.dtype).max)
num_vars_available = variances.size
num_vars_to_select = min(D, num_vars_available)
smallest_indices = np.argpartition(variances, num_vars_to_select -1)[:num_vars_to_select]
smallest_vars = np.sort(variances[smallest_indices])
padded_vars = np.pad(smallest_vars, (0, D - num_vars_to_select), 'constant', constant_values=0.0)
if np.isnan(padded_vars).any() or np.isinf(padded_vars).any():
padded_vars = np.nan_to_num(padded_vars, nan=0.0, posinf=0.0, neginf=0.0)
transformed_features.extend(padded_vars)
except Exception:
current_len = len(transformed_features)
expected_len_after_feature = (j + 1) * 2 * D
num_missing = expected_len_after_feature - current_len
if num_missing > 0: transformed_features.extend([0.0] * num_missing)
transformed_features = transformed_features[:expected_len_after_feature]
correct_len = num_features * 2 * D
if len(transformed_features) != correct_len:
if len(transformed_features) < correct_len: transformed_features.extend([0.0] * (correct_len - len(transformed_features)))
else: transformed_features = transformed_features[:correct_len]
return np.array(transformed_features)
X_train_t = np.array([transform_instance(X) for X in X_train])
X_test_t = np.array([])
if X_test.size > 0: X_test_t = np.array([transform_instance(X) for X in X_test])
expected_dim = num_features * 2 * D
if X_train_t.shape[0] != len(X_train) or (X_train_t.size > 0 and X_train_t.shape[1] != expected_dim):
logging.error(f"LLT Train transform resulted in unexpected shape.")
return np.array([]), np.array([])
if X_test.size > 0 and (X_test_t.shape[0] != len(X_test) or (X_test_t.size > 0 and X_test_t.shape[1] != expected_dim)):
logging.error(f"LLT Test transform resulted in unexpected shape.")
return X_train_t, np.array([])
return X_train_t, X_test_t
except Exception as e:
logging.exception("Error in llt_transform function")
return np.array([]), np.array([])
def normalize_data(df):
# (Keep this function as is)
normalized_df = df.copy()
if not isinstance(df, pd.DataFrame):
logging.error("Normalize_data received non-DataFrame input.")
return None
for feature in FEATURES:
if feature == 'timestamp': continue
if feature not in df.columns:
normalized_df[feature] = 0.0
continue
if pd.api.types.is_numeric_dtype(df[feature]):
mean = df[feature].mean()
std = df[feature].std()
if std is not None and not pd.isna(std) and std > 1e-9:
normalized_df[feature] = (df[feature] - mean) / std
else:
normalized_df[feature] = 0.0
if normalized_df[feature].isnull().any():
normalized_df[feature] = normalized_df[feature].fillna(0.0)
else:
normalized_df[feature] = 0.0
return normalized_df
def generate_synthetic_data(symbol, total_candles=WINDOW):
# (Keep this function as is)
logging.info(f"Generating synthetic data for {symbol} ({total_candles} candles)")
np.random.seed(int(time.time() * 1000) % (2**32 - 1))
end_time = pd.Timestamp.now(tz='UTC')
timestamps = pd.date_range(end=end_time, periods=total_candles, freq='T')
volatility = np.random.uniform(0.005, 0.03)
base_price = np.random.uniform(1, 5000)
prices = [base_price]
for _ in range(1, total_candles):
change = np.random.normal(0, volatility / np.sqrt(1440))
prices.append(prices[-1] * (1 + change))
prices = np.maximum(0.01, prices)
close_prices = np.array(prices)
open_prices = close_prices * (1 + np.random.normal(0, volatility / np.sqrt(1440) / 2, total_candles))
high_prices = np.maximum(close_prices, open_prices) * (1 + np.random.uniform(0, volatility / np.sqrt(1440), total_candles))
low_prices = np.minimum(close_prices, open_prices) * (1 - np.random.uniform(0, volatility / np.sqrt(1440), total_candles))
high_prices = np.maximum.reduce([high_prices, open_prices, close_prices])
low_prices = np.minimum.reduce([low_prices, open_prices, close_prices])
volumes = np.random.poisson(base_price * np.random.uniform(1, 10)) * (1 + np.abs(np.diff(close_prices, prepend=close_prices[0])) / close_prices * 5)
volumes = np.maximum(1, volumes)
df = pd.DataFrame({
'timestamp': timestamps, 'open': open_prices, 'high': high_prices,
'low': low_prices, 'close': close_prices, 'volume': volumes
})
for col in FEATURES: df[col] = pd.to_numeric(df[col])
df.reset_index(drop=True, inplace=True)
return df
def prepare_training_data(symbol, total_candles_to_fetch=WINDOW + OVERLAP_STEP * 20):
# (Keep this function as is)
logging.info(f"Preparing training data for {symbol}...")
try:
required_base_candles = WINDOW
estimated_candles_needed = required_base_candles + (MIN_TRAINING_EXAMPLES * 2) * OVERLAP_STEP + 500
fetch_candle_count = max(WINDOW + 500, estimated_candles_needed)
logging.info(f"Fetching {fetch_candle_count} candles for {symbol} training prep...")
df = fetch_historical_data(symbol, timeframe='1m', total_candles=fetch_candle_count)
if df is None or len(df) < WINDOW:
logging.error(f"Insufficient data fetched for {symbol} ({len(df) if df is not None else 0} < {WINDOW}).")
if USE_SYNTHETIC_DATA_FOR_LOW_VOLUME:
logging.warning(f"Attempting synthetic data generation for {symbol}.")
df = generate_synthetic_data(symbol, total_candles=WINDOW + OVERLAP_STEP * 10)
if df is None or len(df) < WINDOW:
logging.error(f"Synthetic data generation failed or insufficient for {symbol}.")
return None, None
else: logging.info(f"Using synthetic data ({len(df)} points) for {symbol}.")
else: return None, None
df_normalized = normalize_data(df)
if df_normalized is None:
logging.error(f"Normalization failed for {symbol}.")
return None, None
if df_normalized[FEATURES].isnull().any().any():
logging.warning(f"NaN values found after normalization for {symbol}. Filling with 0.")
df_normalized = df_normalized.fillna(0.0)
X, y = [], []
end_index = len(df)
start_index = WINDOW
num_windows_created = 0
for i in range(end_index, start_index - 1, -OVERLAP_STEP):
window_end_idx = i
window_start_idx = i - WINDOW
if window_start_idx < 0: continue
window_orig = df.iloc[window_start_idx:window_end_idx]
window_norm = df_normalized.iloc[window_start_idx:window_end_idx]
if len(window_orig) != WINDOW or len(window_norm) != WINDOW: continue
input_data_norm = window_norm.iloc[:K][FEATURES].values
if input_data_norm.shape[0] != K or input_data_norm.shape[1] != len(FEATURES): continue
if np.isnan(input_data_norm).any(): continue
start_price_iloc_idx = K - 1
end_price_iloc_idx = WINDOW - 1
start_price = window_orig['close'].iloc[start_price_iloc_idx]
end_price = window_orig['close'].iloc[end_price_iloc_idx]
if pd.isna(start_price) or pd.isna(end_price) or start_price <= 0: continue
X.append(input_data_norm)
y.append(1 if end_price > start_price else 0)
num_windows_created += 1
if not X:
logging.error(f"No valid windows created for {symbol}.")
return None, None
X = np.array(X)
y = np.array(y)
unique_classes, class_counts = np.unique(y, return_counts=True)
class_dist_str = ", ".join([f"Class {cls}: {count}" for cls, count in zip(unique_classes, class_counts)])
logging.info(f"Class distribution BEFORE balancing for {symbol}: {class_dist_str}")
if len(unique_classes) < 2:
logging.error(f"ONLY ONE CLASS ({unique_classes[0]}) present for {symbol}.")
return None, None
min_class_count = min(class_counts)
if min_class_count * 2 < MIN_TRAINING_EXAMPLES:
logging.error(f"Minority class count ({min_class_count}) too low for {symbol}.")
return None, None
samples_per_class = min_class_count
balanced_indices = []
for class_val in unique_classes:
class_indices = np.where(y == class_val)[0]
num_to_choose = min(samples_per_class, len(class_indices))
chosen_indices = np.random.choice(class_indices, size=num_to_choose, replace=False)
balanced_indices.extend(chosen_indices)
np.random.shuffle(balanced_indices)
X_balanced = X[balanced_indices]
y_balanced = y[balanced_indices]
final_unique, final_counts = np.unique(y_balanced, return_counts=True)
logging.info(f"Balanced dataset for {symbol}: {len(X_balanced)} instances. Final counts: {dict(zip(final_unique, final_counts))}")
if len(X_balanced) < MIN_TRAINING_EXAMPLES:
logging.error(f"Insufficient data ({len(X_balanced)}) for {symbol} AFTER balancing.")
return None, None
if X_balanced.ndim != 3 or X_balanced.shape[0] == 0 or X_balanced.shape[1] != K or X_balanced.shape[2] != len(FEATURES):
logging.error(f"Final balanced data has unexpected shape {X_balanced.shape} for {symbol}.")
return None, None
return X_balanced, y_balanced
except Exception as e:
logging.exception(f"Error preparing training data for {symbol}")
return None, None
def train_model(symbol):
# (Keep this function as is)
logging.info(f"--- Attempting to train model for {symbol} ---")
np.random.seed(int(time.time()) % (2**32 - 1))
X, y = prepare_training_data(symbol)
if X is None or y is None:
logging.error(f"Failed to prepare training data for {symbol}. Training aborted.")
return None, None, None
try:
accuracy = -1.0
if len(X) < MIN_TRAINING_EXAMPLES + 2:
logging.warning(f"Dataset for {symbol} too small ({len(X)}). Training on all data.")
X_train, y_train = X, y
X_val, y_val = np.array([]), np.array([])
else:
indices = np.random.permutation(len(X))
val_size = max(1, int(len(X) * 0.2))
split_idx = len(X) - val_size
train_indices, val_indices = indices[:split_idx], indices[split_idx:]
if len(train_indices) == 0 or len(val_indices) == 0:
logging.error(f"Train/Val split resulted in zero samples. Training on all data.")
X_train, y_train = X, y
X_val, y_val = np.array([]), np.array([])
else:
X_train, X_val = X[train_indices], X[val_indices]
y_train, y_val = y[train_indices], y[val_indices]
if len(np.unique(y_train)) < 2:
logging.error(f"Only one class in TRAINING set after split for {symbol}. Aborting.")
return None, None, None
if len(np.unique(y_val)) < 2:
logging.warning(f"Only one class in VALIDATION set after split for {symbol}.")
if X_val.size == 0: X_val_shaped = np.empty((0, K, len(FEATURES)))
else: X_val_shaped = X_val
X_train_t, X_val_t = llt_transform(X_train, y_train, X_val_shaped)
if X_train_t.size == 0:
logging.error(f"LLT training transformation failed for {symbol}. Training aborted.")
return None, None, None
if X_val.size > 0 and X_val_t.size == 0:
logging.warning(f"LLT validation transformation failed for {symbol}.")
accuracy = -1.0
if np.isnan(X_train_t).any() or np.isinf(X_train_t).any():
logging.error(f"NaN/Inf in LLT transformed TRAINING data for {symbol}. Training aborted.")
return None, None, None
if X_val_t.size > 0 and (np.isnan(X_val_t).any() or np.isinf(X_val_t).any()):
logging.warning(f"NaN/Inf in LLT transformed VALIDATION data for {symbol}.")
accuracy = -1.0
n_neighbors = min(5, len(y_train) - 1) if len(y_train) > 1 else 1
n_neighbors = max(1, n_neighbors)
if n_neighbors > 1 and n_neighbors % 2 == 0: n_neighbors -= 1
model = KNeighborsClassifier(n_neighbors=n_neighbors, weights='distance')
model.fit(X_train_t, y_train)
if accuracy != -1.0 and X_val_t.size > 0:
try:
accuracy = model.score(X_val_t, y_val)
logging.info(f"Model for {symbol} trained. Validation Accuracy: {accuracy:.3f}")
except Exception as eval_e:
logging.exception(f"Error during KNN validation scoring for {symbol}: {eval_e}")
accuracy = -1.0
elif accuracy == -1.0:
logging.info(f"Model for {symbol} trained. Validation skipped or failed.")
else:
logging.info(f"Model for {symbol} trained. No validation data.")
accuracy = -1.0
return model, X_train, y_train
except Exception as e:
logging.exception(f"Error during model training pipeline for {symbol}")
return None, None, None
def predict_real_time(symbol, model_data):
# (Keep this function as is)
if model_data is None: return "Model N/A", 0.0
model, X_train_orig_for_llt, y_train_orig_for_llt = model_data
if model is None or X_train_orig_for_llt is None or y_train_orig_for_llt is None:
logging.error(f"Invalid model data tuple for prediction on {symbol}")
return "Model Error", 0.0
if X_train_orig_for_llt.size == 0 or y_train_orig_for_llt.size == 0:
logging.error(f"Training data for LLT laws is empty for {symbol}")
return "LLT Data Error", 0.0
try:
df = fetch_historical_data(symbol, timeframe='1m', total_candles=K + 60)
if df is None or len(df) < K:
return "Data Error", 0.0
df_recent = df.iloc[-K:]
if len(df_recent) != K:
return "Data Error", 0.0
df_recent_normalized = normalize_data(df_recent)
if df_recent_normalized is None: return "Norm Error", 0.0
if df_recent_normalized[FEATURES].isnull().any().any():
df_recent_normalized = df_recent_normalized.fillna(0.0)
X_predict_input = np.array([df_recent_normalized[FEATURES].values])
_, X_predict_transformed = llt_transform(X_train_orig_for_llt, y_train_orig_for_llt, X_predict_input)
if X_predict_transformed.size == 0 or X_predict_transformed.shape[0] != 1:
return "Transform Error", 0.0
if np.isnan(X_predict_transformed).any() or np.isinf(X_predict_transformed).any():
X_predict_transformed = np.nan_to_num(X_predict_transformed, nan=0.0, posinf=0.0, neginf=0.0)
try:
probabilities = model.predict_proba(X_predict_transformed)
if probabilities.shape[0] != 1 or probabilities.shape[1] != 2:
return "Predict Error", 0.0
prob_class_1 = probabilities[0, 1]
prediction_label = "Rise" if prob_class_1 >= 0.5 else "Fall"
confidence = prob_class_1 if prediction_label == "Rise" else probabilities[0, 0]
return prediction_label, confidence
except Exception as knn_e:
logging.exception(f"Error during KNN prediction probability for {symbol}")
return "Predict Error", 0.0
except Exception as e:
logging.exception(f"Error in predict_real_time for {symbol}")
return "Error", 0.0
# --- TA Calculation Function (Using TA-Lib) ---
def calculate_ta_indicators(df_ta):
"""
Calculates TA indicators (RSI, MACD, VWAP, ATR) using TA-Lib.
Requires df_ta to have 'open', 'high', 'low', 'close', 'volume' columns.
"""
indicators = {'RSI': np.nan, 'MACD': np.nan, 'MACD_Signal': np.nan, 'MACD_Hist': np.nan, 'VWAP': np.nan, 'ATR': np.nan}
required_cols = ['open', 'high', 'low', 'close', 'volume']
min_len_needed = max(RSI_PERIOD, MACD_SLOW, ATR_PERIOD) + 1 # TA-Lib often needs P+1
if df_ta is None or len(df_ta) < min_len_needed:
logging.warning(f"Insufficient data ({len(df_ta) if df_ta is not None else 0} < {min_len_needed}) for TA-Lib calculations.")
return indicators
# Ensure columns exist
if not all(col in df_ta.columns for col in required_cols):
logging.error(f"Missing required columns for TA-Lib: Have {df_ta.columns}, Need {required_cols}")
return indicators
# --- Prepare data for TA-Lib (NumPy arrays, handle NaNs) ---
df_ta = df_ta[required_cols].copy() # Work on a copy with only needed columns
# Check for NaNs BEFORE converting to numpy, TA-Lib generally dislikes them
if df_ta.isnull().values.any():
nan_count = df_ta.isnull().sum().sum()
logging.warning(f"Found {nan_count} NaN(s) in TA input data. Applying ffill()...")
df_ta.ffill(inplace=True) # Forward fill NaNs
# Check again after ffill - if NaNs remain (e.g., at the start), need more robust handling
if df_ta.isnull().values.any():
logging.error(f"NaNs still present after ffill. Cannot proceed with TA-Lib.")
return indicators # Return NaNs
try:
# Convert to NumPy arrays of type float
open_p = df_ta['open'].values.astype(float)
high_p = df_ta['high'].values.astype(float)
low_p = df_ta['low'].values.astype(float)
close_p = df_ta['close'].values.astype(float)
volume_p = df_ta['volume'].values.astype(float)
# --- Calculate Indicators using TA-Lib ---
# RSI
rsi_values = talib.RSI(close_p, timeperiod=RSI_PERIOD)
indicators['RSI'] = rsi_values[-1] if len(rsi_values) > 0 else np.nan
# MACD
macd_line, signal_line, hist = talib.MACD(close_p, fastperiod=MACD_FAST, slowperiod=MACD_SLOW, signalperiod=MACD_SIGNAL)
indicators['MACD'] = macd_line[-1] if len(macd_line) > 0 else np.nan
indicators['MACD_Signal'] = signal_line[-1] if len(signal_line) > 0 else np.nan
indicators['MACD_Hist'] = hist[-1] if len(hist) > 0 else np.nan
# ATR
atr_values = talib.ATR(high_p, low_p, close_p, timeperiod=ATR_PERIOD)
indicators['ATR'] = atr_values[-1] if len(atr_values) > 0 else np.nan
# VWAP (Manual Calculation - TA-Lib doesn't have it built-in)
typical_price = (high_p + low_p + close_p) / 3.0
tp_vol = typical_price * volume_p
cumulative_volume = np.cumsum(volume_p)
# Avoid division by zero if volume is zero for initial periods
if cumulative_volume[-1] > 1e-12: # Check if there's significant volume
vwap_values = np.cumsum(tp_vol) / np.maximum(cumulative_volume, 1e-12) # Avoid div by zero strictly
indicators['VWAP'] = vwap_values[-1]
else:
indicators['VWAP'] = np.nan # VWAP undefined if no volume
# Final check for NaNs in results (TA-Lib might return NaN for initial periods)
for key, value in indicators.items():
if pd.isna(value):
indicators[key] = np.nan # Ensure consistent NaN representation
# logging.debug(f"TA-Lib Indicators calculated: {indicators}")
return indicators
except Exception as ta_e:
logging.exception(f"Error calculating TA indicators using TA-Lib: {ta_e}")
return {k: np.nan for k in indicators} # Return NaNs on error
# --- Trade Level Calculation (Unchanged) ---
def calculate_trade_levels(prediction, confidence, current_price, atr):
# (Keep this function as is - no changes needed)
levels = {'Entry': np.nan, 'TP1': np.nan, 'TP2': np.nan, 'SL': np.nan}
if pd.isna(current_price) or current_price <= 0 or pd.isna(atr) or atr <= 0:
return levels
if prediction == "Rise" and confidence >= CONFIDENCE_THRESHOLD:
entry_price = current_price
levels['Entry'] = entry_price
levels['TP1'] = entry_price + TP1_ATR_MULTIPLIER * atr
levels['TP2'] = entry_price + TP2_ATR_MULTIPLIER * atr
levels['SL'] = entry_price - SL_ATR_MULTIPLIER * atr
levels['SL'] = max(0.01, levels['SL'])
# Add Fall logic here if needed
return levels
# --- Concurrency Wrappers (Unchanged) ---
def train_model_task(coin):
# (Keep this function as is)
try:
result = train_model(coin)
if result != (None, None, None):
model, X_train_orig, y_train_orig = result
return coin, (model, X_train_orig, y_train_orig)
else:
return coin, None
except Exception as e:
logging.exception(f"Unhandled exception in train_model_task for {coin}")
return coin, None
def train_all_models(coin_list=None, num_workers=NUM_WORKERS_TRAINING):
# (Keep this function as is)
global trained_models
start_time = time.time()
if coin_list is None or not coin_list:
logging.info("No coin list provided, fetching top coins by volume...")
try:
coin_list = get_all_usdt_pairs()
if not coin_list:
msg = "Failed to fetch coin list even with fallback. Training aborted."
logging.error(msg)
return msg
except Exception as e:
msg = f"Error fetching coin list: {e}. Training aborted."
logging.exception(msg)
return msg
logging.info(f"Starting training for {len(coin_list)} coins using {num_workers} workers...")
results_log = []
successful_trains = 0
failed_trains = 0
new_models = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers, thread_name_prefix='TrainWorker') as executor:
future_to_coin = {executor.submit(train_model_task, coin): coin for coin in coin_list}
processed_count = 0
total_coins = len(coin_list)
for future in concurrent.futures.as_completed(future_to_coin):
processed_count += 1
coin = future_to_coin[future]
try:
returned_coin, model_data = future.result()
if returned_coin == coin and model_data is not None:
new_models[returned_coin] = model_data
results_log.append(f"✅ {returned_coin}: Model trained successfully.")
successful_trains += 1
else:
results_log.append(f"❌ {coin}: Model training failed (check logs).")
failed_trains += 1
except Exception as e:
results_log.append(f"❌ {coin}: Training task generated exception: {e}")
failed_trains += 1
logging.exception(f"Exception from training future for {coin}")
if processed_count % 10 == 0 or processed_count == total_coins:
logging.info(f"Training progress: {processed_count}/{total_coins} coins processed.")
logging.getLogger().handlers[0].flush()
trained_models.update(new_models)
logging.info(f"Updated global models dictionary. Total models now: {len(trained_models)}")
end_time = time.time()
duration = end_time - start_time
completion_message = (
f"Training run completed in {duration:.2f} seconds.\n"
f"Successfully trained: {successful_trains}\n"
f"Failed to train: {failed_trains}\n"
f"Total models available now: {len(trained_models)}"
)
logging.info(completion_message)
return completion_message + "\n\n" + "\n".join(results_log[-20:])
# --- Update Predictions Table (Mostly Unchanged, uses new TA function) ---
def update_predictions_table():
# (This function structure remains the same, it just calls the new calculate_ta_indicators)
global last_update_time
logging.info("--- Updating Predictions Table ---")
start_time = time.time()
predictions_data = {}
current_models = trained_models.copy()
if not current_models:
msg = "No models available. Please train first."
logging.warning(msg)
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
return pd.DataFrame([], columns=cols), msg
symbols_with_models = list(current_models.keys())
logging.info(f"Step 1: Generating predictions for {len(symbols_with_models)} models...")
# --- Stage 1: Get Predictions Concurrently ---
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='PredictWorker') as executor:
future_to_coin_pred = {executor.submit(predict_real_time, coin, model_data): coin for coin, model_data in current_models.items()}
pred_success = 0
pred_fail = 0
for future in concurrent.futures.as_completed(future_to_coin_pred):
coin = future_to_coin_pred[future]
try:
pred, conf = future.result()
if pred not in ["Model N/A", "Model Error", "Data Error", "Norm Error", "LLT Data Error", "Transform Error", "Predict Error", "Error"]:
predictions_data[coin] = {'prediction': pred, 'confidence': float(conf)}
pred_success += 1
else:
predictions_data[coin] = {'prediction': pred, 'confidence': 0.0}
pred_fail += 1
except Exception as e:
logging.exception(f"Error getting prediction result for {coin}")
predictions_data[coin] = {'prediction': "Future Error", 'confidence': 0.0}
pred_fail +=1
logging.info(f"Step 1 Complete: Predictions generated ({pred_success} success, {pred_fail} fail).")
# --- Stage 2: Fetch Current Tickers & TA Data Concurrently ---
symbols_to_fetch_data = list(predictions_data.keys())
if not symbols_to_fetch_data:
logging.warning("No symbols with predictions to fetch data for.")
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
return pd.DataFrame([], columns=cols), "No symbols processed."
logging.info(f"Step 2: Fetching Tickers and {TA_DATA_POINTS} OHLCV candles for {len(symbols_to_fetch_data)} symbols...")
tickers_data = {}
ohlcv_data = {}
try: # Fetch Tickers
batch_size_tickers = 100
fetched_tickers_batch = {}
for i in range(0, len(symbols_to_fetch_data), batch_size_tickers):
batch_symbols = symbols_to_fetch_data[i:i+batch_size_tickers]
try:
batch_tickers = exchange.fetch_tickers(symbols=batch_symbols)
fetched_tickers_batch.update(batch_tickers)
time.sleep(exchange.rateLimit / 1000 * 0.5)
except Exception as e:
logging.error(f"Failed to fetch ticker batch starting with {batch_symbols[0]}: {e}")
tickers_data = fetched_tickers_batch
logging.info(f"Fetched {len(tickers_data)} tickers.")
except Exception as e:
logging.exception(f"Error fetching tickers in prediction update: {e}")
# Fetch OHLCV for TA
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='TADataWorker') as executor:
future_to_coin_ohlcv = {executor.submit(fetch_historical_data, coin, '1m', TA_DATA_POINTS): coin for coin in symbols_to_fetch_data}
for future in concurrent.futures.as_completed(future_to_coin_ohlcv):
coin = future_to_coin_ohlcv[future]
try:
df_ta = future.result()
if df_ta is not None and len(df_ta) == TA_DATA_POINTS:
# Ensure standard column names expected by calculate_ta_indicators
df_ta.columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
ohlcv_data[coin] = df_ta
except Exception as e:
logging.exception(f"Error fetching TA OHLCV data for {coin}")
logging.info(f"Step 2 Complete: Fetched TA data for {len(ohlcv_data)} symbols.")
# --- Stage 3: Calculate TA & Trade Levels ---
logging.info(f"Step 3: Calculating TA (using TA-Lib) and Trade Levels...")
final_results = []
processing_time = datetime.now(timezone.utc)
for symbol in symbols_to_fetch_data:
pred_info = predictions_data.get(symbol, {'prediction': 'Missing Pred', 'confidence': 0.0})
ticker = tickers_data.get(symbol)
df_ta = ohlcv_data.get(symbol) # This df should have standard columns now
current_price, quote_volume = np.nan, np.nan
ta_indicators = {k: np.nan for k in ['RSI', 'MACD', 'MACD_Signal', 'MACD_Hist', 'VWAP', 'ATR']}
trade_levels = {k: np.nan for k in ['Entry', 'TP1', 'TP2', 'SL']}
entry_time, exit_time = pd.NaT, pd.NaT
if ticker and isinstance(ticker, dict):
current_price = ticker.get('last', np.nan)
quote_volume = ticker.get('info', {}).get('quoteVolume')
if quote_volume is None:
base_volume = ticker.get('baseVolume')
if base_volume is not None and current_price is not None:
try: quote_volume = float(base_volume) * float(current_price)
except (ValueError, TypeError): quote_volume = np.nan
try: current_price = float(current_price) if current_price is not None else np.nan
except (ValueError, TypeError): current_price = np.nan
try: quote_volume = float(quote_volume) if quote_volume is not None else np.nan
except (ValueError, TypeError): quote_volume = np.nan
# Calculate TA using the new function
if df_ta is not None:
ta_indicators = calculate_ta_indicators(df_ta) # Calls the TA-Lib version
if pred_info['prediction'] in ["Rise", "Fall"] and not pd.isna(current_price) and not pd.isna(ta_indicators['ATR']):
trade_levels = calculate_trade_levels(pred_info['prediction'], pred_info['confidence'], current_price, ta_indicators['ATR'])
if not pd.isna(trade_levels['Entry']):
entry_time = processing_time
exit_time = processing_time + timedelta(hours=PREDICTION_WINDOW_HOURS)
final_results.append({
'coin': symbol.split('/')[0], 'full_symbol': symbol,
'prediction': pred_info['prediction'], 'confidence': pred_info['confidence'],
'price': current_price, 'volume': quote_volume,
'entry': trade_levels['Entry'], 'entry_time': entry_time, 'exit_time': exit_time,
'tp1': trade_levels['TP1'], 'tp2': trade_levels['TP2'], 'sl': trade_levels['SL'],
'rsi': ta_indicators['RSI'], 'macd_hist': ta_indicators['MACD_Hist'],
'vwap': ta_indicators['VWAP'], 'atr': ta_indicators['ATR']
})
logging.info("Step 3 Complete: TA and Trade Levels calculated.")
# --- Stage 4: Sort and Format (Unchanged) ---
def sort_key(item):
pred, conf = item['prediction'], item['confidence']
if pred == "Rise" and conf >= CONFIDENCE_THRESHOLD and not pd.isna(item['entry']): return (0, -conf)
elif pred == "Rise": return (1, -conf)
elif pred == "Fall": return (2, -conf)
else: return (3, 0)
final_results.sort(key=sort_key)
formatted_output = []
for i, p in enumerate(final_results[:MAX_COINS_TO_DISPLAY]):
formatted_output.append([
i + 1, p['coin'], p['prediction'], f"{p['confidence']:.3f}",
f"{p['price']:.4f}" if not pd.isna(p['price']) else "N/A",
f"{p['volume']:,.0f}" if not pd.isna(p['volume']) else "N/A",
f"{p['entry']:.4f}" if not pd.isna(p['entry']) else "N/A",
format_datetime(p['entry_time'], "N/A"), format_datetime(p['exit_time'], "N/A"),
f"{p['tp1']:.4f}" if not pd.isna(p['tp1']) else "N/A",
f"{p['tp2']:.4f}" if not pd.isna(p['tp2']) else "N/A",
f"{p['sl']:.4f}" if not pd.isna(p['sl']) else "N/A",
f"{p['rsi']:.2f}" if not pd.isna(p['rsi']) else "N/A",
f"{p['macd_hist']:.4f}" if not pd.isna(p['macd_hist']) else "N/A",
f"{p['vwap']:.4f}" if not pd.isna(p['vwap']) else "N/A",
f"{p['atr']:.4f}" if not pd.isna(p['atr']) else "N/A",
])
output_columns = [
'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)',
'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL',
'RSI', 'MACD Hist', 'VWAP', 'ATR'
]
output_df = pd.DataFrame(formatted_output, columns=output_columns)
end_time = time.time()
duration = end_time - start_time
last_update_time = processing_time
status_message = f"Predictions updated ({len(final_results)} symbols processed) in {duration:.2f}s. Last update: {format_datetime(last_update_time)}"
logging.info(status_message)
return output_df, status_message
# --- Gradio UI Handlers (Unchanged) ---
def handle_train_click(coin_input, num_workers):
# (Keep this function as is)
logging.info(f"Train button clicked. Workers: {num_workers}")
coins = None
num_workers = int(num_workers)
if coin_input and coin_input.strip():
raw_coins = coin_input.replace(',', ' ').split()
coins = []
valid = True
for c in raw_coins:
coin_upper = c.strip().upper()
if '/' not in coin_upper: coin_upper += '/USDT'
if coin_upper.endswith('/USDT'): coins.append(coin_upper)
else:
valid = False
logging.error(f"Invalid coin format: {c}. Must be SYMBOL or SYMBOL/USDT.")
break
if not valid: return "Error: Custom coins must be valid SYMBOL or SYMBOL/USDT pairs."
logging.info(f"Training requested for custom coin list: {coins}")
else:
logging.info("Training requested for top coins by volume.")
train_status = train_all_models(coin_list=coins, num_workers=num_workers)
return f"--- Training Run ---:\n{train_status}\n\n---> Press 'Refresh Predictions' <---"
def handle_refresh_click():
# (Keep this function as is)
logging.info("Refresh button clicked.")
try:
df, status = update_predictions_table()
return df, status
except Exception as e:
logging.exception("Error during handle_refresh_click")
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR']
return pd.DataFrame([], columns=cols), f"Error updating predictions: {e}"
# --- Gradio Interface Definition (Unchanged) ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Cryptocurrency Prediction & TA Signal Explorer (LLT-KNN + TA-Lib)") # Updated title slightly
gr.Markdown(f"""
Predicts **{PREDICTION_WINDOW_HOURS}-hour** price direction (Rise/Fall) using LLT-KNN.
Displays current price, volume, TA indicators (RSI, MACD, VWAP, ATR calculated using **TA-Lib**), and potential trade levels for **Rise** signals meeting confidence >= **{CONFIDENCE_THRESHOLD}**.
TP/SL levels based on **{TP1_ATR_MULTIPLIER}x / {TP2_ATR_MULTIPLIER}x / {SL_ATR_MULTIPLIER}x ATR({ATR_PERIOD})**.
**Warning:** Educational. High risk. Not financial advice. Ensure TA-Lib is correctly installed.
""")
with gr.Row():
with gr.Column(scale=4):
prediction_df = gr.Dataframe(
headers=[
'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)',
'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL',
'RSI', 'MACD Hist', 'VWAP', 'ATR'
],
datatype=[
'number', 'str', 'str', 'str', 'str', 'str',
'str', 'str', 'str', 'str', 'str', 'str',
'str', 'str', 'str', 'str'
],
row_count=15, col_count=(16, "fixed"), label="Predictions & TA Signals", wrap=True,
)
with gr.Column(scale=1):
with gr.Accordion("Train Models", open=True):
coin_input = gr.Textbox(label="Train Specific Coins (e.g., BTC, ETH/USDT)", placeholder="Leave empty for top coins by volume")
max_workers_slider = gr.Slider(minimum=1, maximum=10, value=NUM_WORKERS_TRAINING, step=1, label="Parallel Training Workers")
train_button = gr.Button("Start Training", variant="primary")
refresh_button = gr.Button("Refresh Predictions", variant="secondary")
status_text = gr.Textbox(label="Status Log", lines=15, interactive=False, max_lines=30)
gr.Markdown(
"""
## Notes
- **TA-Lib**: This version uses the TA-Lib library for indicators. Ensure it's installed correctly (can be tricky).
- **Data**: Fetches OHLCV data (Bitget, 1-min). Uses cache. Handles rate limits.
- **Training**: Uses past ~14h data (12h train, 2h predict). Normalizes, balances classes, applies LLT, trains KNN.
- **Prediction**: Uses latest 12h data for KNN input.
- **Trade Levels**: Only shown for 'Rise' predictions above confidence threshold. Based on current price and ATR volatility. **Highly speculative.**
- **Sorting**: Table sorted by (Potential Rise Signals > Other Rise > Fall > Errors), then by confidence descending.
- **Refresh**: Fetches latest prices/TA and re-evaluates signals.
"""
)
train_button.click(fn=handle_train_click, inputs=[coin_input, max_workers_slider], outputs=status_text)
refresh_button.click(fn=handle_refresh_click, inputs=None, outputs=[prediction_df, status_text])
# --- Startup Initialization (Unchanged) ---
def initialize_models_on_startup():
# (Keep this function as is)
logging.info("----- Initializing Models (Startup Thread) -----")
default_coins = ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'XRP/USDT', 'DOGE/USDT']
try:
initial_status = train_all_models(default_coins, num_workers=2)
logging.info("----- Initial Model Training Complete -----")
logging.info(initial_status)
except Exception as e:
logging.exception("Error during startup initialization.")
# --- Main Execution (Unchanged) ---
if __name__ == "__main__":
logging.info("Starting application...")
# Check if TA-Lib import worked (basic check)
try:
# Try accessing a TA-Lib function
_ = talib.RSI(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]))
logging.info("TA-Lib library seems accessible.")
except NameError:
logging.error("FATAL: TA-Lib library not found or import failed. Please install it correctly.")
sys.exit(1)
except Exception as ta_init_e:
logging.error(f"FATAL: Error testing TA-Lib library: {ta_init_e}. Please check installation.")
sys.exit(1)
init_thread = threading.Thread(target=initialize_models_on_startup, name="StartupTrainThread", daemon=True)
init_thread.start()
logging.info("Launching Gradio Interface...")
try:
demo.launch(server_name="0.0.0.0")
except Exception as e:
logging.exception("Failed to launch Gradio interface.")
finally:
logging.info("Gradio Interface stopped.")