Spaces:
Build error
Build error
# Install TA-Lib (see instructions above) then: pip install TA-Lib | |
import ccxt | |
import numpy as np | |
import pandas as pd | |
import time | |
from sklearn.neighbors import KNeighborsClassifier | |
from scipy.linalg import svd | |
import gradio as gr | |
import concurrent.futures | |
import traceback | |
from datetime import datetime, timezone, timedelta | |
import logging | |
import sys | |
import talib # Import TA-Lib | |
import threading | |
# --- Setup Logging --- | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - [%(threadName)s:%(funcName)s] - %(message)s', | |
stream=sys.stdout | |
) | |
logging.getLogger().handlers[0].flush = sys.stdout.flush | |
# --- Parameters --- | |
L = 10 | |
LAG = 11 | |
MINUTES_PER_HOUR = 60 | |
PREDICTION_WINDOW_HOURS = 2 | |
TRAINING_WINDOW_HOURS = 12 | |
TOTAL_WINDOW_HOURS = TRAINING_WINDOW_HOURS + PREDICTION_WINDOW_HOURS | |
K = TRAINING_WINDOW_HOURS * MINUTES_PER_HOUR # 720 | |
WINDOW = TOTAL_WINDOW_HOURS * MINUTES_PER_HOUR # 840 | |
FEATURES = ['open', 'high', 'low', 'close', 'volume'] | |
D = 5 | |
OVERLAP_STEP = 60 | |
MIN_TRAINING_EXAMPLES = 20 | |
MAX_COINS_TO_DISPLAY = 10 | |
USE_SYNTHETIC_DATA_FOR_LOW_VOLUME = False | |
NUM_WORKERS_TRAINING = 4 | |
NUM_WORKERS_PREDICTION = 10 | |
# --- TA & Risk Parameters --- | |
TA_DATA_POINTS = 200 # Candles needed for TA calculation | |
RSI_PERIOD = 14 | |
MACD_FAST = 12 | |
MACD_SLOW = 26 | |
MACD_SIGNAL = 9 | |
ATR_PERIOD = 14 | |
CONFIDENCE_THRESHOLD = 0.65 # Min confidence for Rise signal | |
TP1_ATR_MULTIPLIER = 1.5 | |
TP2_ATR_MULTIPLIER = 3.0 | |
SL_ATR_MULTIPLIER = 1.0 | |
# --- CCXT Initialization --- | |
try: | |
exchange = ccxt.bitget({ | |
'enableRateLimit': True, | |
'rateLimit': 1100, | |
'timeout': 45000, | |
'options': {'adjustForTimeDifference': True} | |
}) | |
logging.info(f"Initialized {exchange.id} exchange.") | |
except Exception as e: | |
logging.exception("FATAL: Could not initialize CCXT exchange.") | |
sys.exit() | |
# --- Global Caches and Variables --- | |
markets_cache = None | |
last_markets_update = None | |
data_cache = {} | |
trained_models = {} | |
last_update_time = datetime.now(timezone.utc) | |
# --- Functions --- | |
def format_datetime(dt, default="N/A"): | |
# (Keep this function as is) | |
if pd.isna(dt) or dt is None: | |
return default | |
try: | |
if isinstance(dt, (datetime, pd.Timestamp)): | |
if dt.tzinfo is None: | |
dt = dt.replace(tzinfo=timezone.utc) | |
return dt.strftime('%Y-%m-%d %H:%M:%S %Z') | |
else: | |
return str(dt) | |
except Exception: | |
return default | |
def get_all_usdt_pairs(): | |
# (Keep this function as is - no changes needed) | |
global markets_cache, last_markets_update | |
current_time = time.time() | |
cache_duration = 3600 # 1 hour | |
if markets_cache is not None and last_markets_update is not None: | |
if current_time - last_markets_update < cache_duration: | |
logging.info("Using cached markets list.") | |
if isinstance(markets_cache, list) and markets_cache: | |
return markets_cache | |
else: | |
logging.warning("Cached market list was invalid, fetching fresh.") | |
logging.info("Fetching markets from Bitget...") | |
try: | |
exchange.load_markets(reload=True) | |
all_symbols = list(exchange.markets.keys()) | |
usdt_pairs = [ | |
symbol for symbol in all_symbols | |
if isinstance(symbol, str) | |
and symbol.endswith('/USDT') | |
and exchange.markets.get(symbol, {}).get('active', False) | |
and exchange.markets.get(symbol, {}).get('spot', False) | |
and 'LEVERAGED' not in exchange.markets.get(symbol, {}).get('type', 'spot').upper() | |
and not exchange.markets.get(symbol, {}).get('inverse', False) | |
] | |
logging.info(f"Found {len(usdt_pairs)} active USDT spot pairs initially.") | |
if not usdt_pairs: | |
logging.warning("No active USDT spot pairs found.") | |
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT'] | |
logging.info(f"Fetching tickers for {len(usdt_pairs)} pairs for volume sorting...") | |
volumes = {} | |
symbols_to_fetch = usdt_pairs[:] | |
fetched_tickers = {} | |
try: | |
if exchange.has['fetchTickers']: | |
batch_size_tickers = 100 | |
for i in range(0, len(symbols_to_fetch), batch_size_tickers): | |
batch_symbols = symbols_to_fetch[i:i+batch_size_tickers] | |
logging.info(f"Fetching ticker batch {i//batch_size_tickers + 1}/{ (len(symbols_to_fetch) + batch_size_tickers -1)//batch_size_tickers }...") | |
retries = 2 | |
for attempt in range(retries): | |
try: | |
batch_tickers = exchange.fetch_tickers(symbols=batch_symbols) | |
fetched_tickers.update(batch_tickers) | |
time.sleep(exchange.rateLimit / 1000 * 1.5) # Add delay | |
break | |
except (ccxt.RequestTimeout, ccxt.NetworkError) as e_timeout: | |
logging.warning(f"Ticker fetch timeout/network error on attempt {attempt+1}/{retries}: {e_timeout}, retrying after delay...") | |
time.sleep(3 * (attempt + 1)) | |
except ccxt.RateLimitExceeded: | |
logging.warning(f"Rate limit exceeded fetching tickers, sleeping...") | |
time.sleep(10 * (attempt+1)) # Longer sleep for rate limit | |
except Exception as e_ticker: | |
logging.error(f"Error fetching ticker batch (attempt {attempt+1}): {e_ticker}") | |
if attempt == retries - 1: raise # Rethrow last error | |
time.sleep(2 * (attempt + 1)) | |
logging.info(f"Fetched {len(fetched_tickers)} tickers using fetchTickers.") | |
else: | |
raise ccxt.NotSupported("fetchTickers not supported/enabled. Volume sorting requires it.") | |
except Exception as e: | |
logging.exception(f"Could not fetch tickers for volume sorting: {e}. Volume sorting unavailable.") | |
markets_cache = usdt_pairs[:MAX_COINS_TO_DISPLAY] | |
last_markets_update = current_time | |
logging.warning(f"Returning top {len(markets_cache)} unsorted pairs due to ticker error.") | |
return markets_cache | |
for symbol, ticker in fetched_tickers.items(): | |
try: | |
quote_volume = ticker.get('info', {}).get('quoteVolume') # Prefer quoteVolume if available | |
last_price = ticker.get('last') | |
base_volume = ticker.get('baseVolume') | |
# Ensure values are convertible to float before calculation | |
valid_last = last_price is not None | |
valid_base = base_volume is not None | |
valid_quote = quote_volume is not None | |
if valid_quote: | |
volumes[symbol] = float(quote_volume) | |
elif valid_base and valid_last: | |
volumes[symbol] = float(base_volume) * float(last_price) | |
else: | |
volumes[symbol] = 0 | |
except (TypeError, ValueError, KeyError, AttributeError) as e: | |
logging.warning(f"Could not parse volume/price for {symbol} from ticker: {ticker}. Error: {e}") | |
volumes[symbol] = 0 | |
valid_volume_pairs = {k: v for k, v in volumes.items() if v > 0} | |
logging.info(f"Found {len(valid_volume_pairs)} pairs with non-zero volume.") | |
if not valid_volume_pairs: | |
logging.warning("No pairs with valid volume found. Returning default list.") | |
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT'] | |
sorted_pairs = sorted(valid_volume_pairs.items(), key=lambda item: item[1], reverse=True) | |
num_pairs_to_take = min(MAX_COINS_TO_DISPLAY, len(sorted_pairs)) | |
top_pairs = [pair[0] for pair in sorted_pairs[:num_pairs_to_take]] | |
logging.info(f"Selected Top {len(top_pairs)} pairs by volume. Top 5: {[p[0] for p in sorted_pairs[:5]]}") | |
markets_cache = top_pairs | |
last_markets_update = current_time | |
return top_pairs | |
except ccxt.NetworkError as e: | |
logging.error(f"Network error getting USDT pairs: {e}") | |
except ccxt.ExchangeError as e: | |
logging.error(f"Exchange error getting USDT pairs: {e}") | |
except Exception as e: | |
logging.exception("General error getting USDT pairs.") | |
logging.warning("Error fetching markets, returning default fallback list.") | |
return ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'BNB/USDT', 'XRP/USDT'] | |
def clean_and_process_ohlcv(ohlcv_list, symbol, expected_candles): | |
# (Keep this function as is - no changes needed) | |
if not ohlcv_list: | |
return None | |
try: | |
df = pd.DataFrame(ohlcv_list, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) | |
initial_len = len(df) | |
if initial_len == 0: return None | |
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) | |
df = df.drop_duplicates(subset=['timestamp']) | |
df = df.sort_values('timestamp') | |
len_after_dupes = len(df) | |
numeric_cols = ['open', 'high', 'low', 'close', 'volume'] | |
for col in numeric_cols: | |
df[col] = pd.to_numeric(df[col], errors='coerce') | |
# Drop rows with NaN in essential price/volume features needed for TA-Lib | |
df = df.dropna(subset=numeric_cols) | |
len_after_na = len(df) | |
df.reset_index(drop=True, inplace=True) | |
logging.debug(f"Data cleaning for {symbol}: Initial Fetched={initial_len}, AfterDupes={len_after_dupes}, AfterNA={len_after_na}") | |
if len(df) >= expected_candles: | |
final_df = df.iloc[-expected_candles:].copy() # Take the most recent ones | |
return final_df | |
else: | |
return None | |
except Exception as e: | |
logging.exception(f"Error processing DataFrame for {symbol}") | |
return None | |
def fetch_historical_data(symbol, timeframe='1m', total_candles=WINDOW): | |
# (Keep this function as is - no changes needed) | |
cache_key = f"{symbol}_{timeframe}_{total_candles}" | |
current_time = time.time() | |
cache_validity_seconds = 300 # 5 minutes | |
if cache_key in data_cache: | |
cache_time, cached_data = data_cache[cache_key] | |
if current_time - cache_time < cache_validity_seconds: | |
if isinstance(cached_data, pd.DataFrame) and len(cached_data) == total_candles: | |
logging.debug(f"Using valid cached data for {symbol} ({len(cached_data)} candles)") | |
return cached_data.copy() | |
else: | |
logging.warning(f"Cache for {symbol} invalid or wrong size ({len(cached_data) if isinstance(cached_data, pd.DataFrame) else 'N/A'} vs {total_candles}), fetching fresh.") | |
if cache_key in data_cache: del data_cache[cache_key] | |
if not exchange.has['fetchOHLCV']: | |
logging.error(f"Exchange {exchange.id} does not support fetchOHLCV.") | |
return None | |
logging.debug(f"Fetching {total_candles} candles for {symbol} (timeframe: {timeframe})") | |
final_df = None | |
fetch_start_time = time.time() | |
duration_ms = exchange.parse_timeframe(timeframe) * 1000 | |
now_ms = exchange.milliseconds() | |
# --- Strategy 1: Try Single Large Fetch --- | |
single_fetch_limit = total_candles + 200 # Buffer | |
single_fetch_since = now_ms - single_fetch_limit * duration_ms | |
try: | |
ohlcv_list = exchange.fetch_ohlcv(symbol, timeframe, limit=single_fetch_limit, since=single_fetch_since) | |
if ohlcv_list: | |
processed_df = clean_and_process_ohlcv(ohlcv_list, symbol, total_candles) | |
if processed_df is not None and len(processed_df) == total_candles: | |
final_df = processed_df | |
except ccxt.RateLimitExceeded as e: | |
logging.warning(f"Rate limit hit during single fetch for {symbol}, falling back: {e}") | |
time.sleep(5) | |
except (ccxt.RequestTimeout, ccxt.NetworkError) as e: | |
logging.warning(f"Timeout/Network error during single fetch for {symbol}, falling back: {e}") | |
time.sleep(2) | |
except ccxt.ExchangeNotAvailable as e: | |
logging.error(f"Exchange not available during fetch for {symbol}: {e}") | |
return None | |
except ccxt.AuthenticationError as e: | |
logging.error(f"Authentication error fetching {symbol}: {e}") | |
return None | |
except ccxt.ExchangeError as e: | |
logging.warning(f"Exchange error during single fetch for {symbol}, falling back: {e}") | |
except Exception as e: | |
logging.exception(f"Unexpected error during single fetch for {symbol}, falling back.") | |
# --- Strategy 2: Fallback to Iterative Chunking --- | |
if final_df is None: | |
logging.debug(f"Falling back to iterative chunk fetching for {symbol}.") | |
limit_per_call = exchange.safe_integer(exchange.limits.get('fetchOHLCV', {}), 'max', 1000) | |
limit_per_call = min(limit_per_call, 1000) | |
all_ohlcv_chunks = [] | |
required_start_time_ms = now_ms - (total_candles + 5) * duration_ms | |
current_chunk_end_time_ms = now_ms | |
max_chunk_attempts = 15 | |
attempts = 0 | |
while attempts < max_chunk_attempts: | |
attempts += 1 | |
oldest_ts_in_hand = all_ohlcv_chunks[0][0] if all_ohlcv_chunks else current_chunk_end_time_ms | |
if oldest_ts_in_hand <= required_start_time_ms: | |
logging.debug(f"Chunking: Collected enough historical range for {symbol}.") | |
break | |
fetch_limit = limit_per_call | |
chunk_fetch_since = oldest_ts_in_hand - fetch_limit * duration_ms | |
params = {} | |
try: | |
ohlcv_chunk = exchange.fetch_ohlcv(symbol, timeframe, since=chunk_fetch_since, limit=fetch_limit, params=params) | |
if not ohlcv_chunk: | |
logging.debug(f"Chunking: No more data received for {symbol} from API.") | |
break | |
new_chunk = [c for c in ohlcv_chunk if c[0] < oldest_ts_in_hand] | |
if not new_chunk: | |
break | |
new_chunk.sort(key=lambda x: x[0]) | |
all_ohlcv_chunks = new_chunk + all_ohlcv_chunks | |
if len(new_chunk) < limit_per_call // 20 and attempts > 5: | |
logging.warning(f"Chunking: Received very few new candles ({len(new_chunk)}) repeatedly for {symbol}.") | |
break | |
time.sleep(exchange.rateLimit / 1000 * 1.1) | |
except ccxt.RateLimitExceeded as e: | |
logging.warning(f"Rate limit hit during chunking for {symbol}, sleeping 10s: {e}") | |
time.sleep(10 * (attempts/3 + 1)) | |
except (ccxt.NetworkError, ccxt.RequestTimeout) as e: | |
logging.error(f"Network/Timeout error during chunking for {symbol}: {e}. Stopping.") | |
break | |
except ccxt.ExchangeError as e: | |
logging.error(f"Exchange error during chunking for {symbol}: {e}. Stopping.") | |
break | |
except Exception as e: | |
logging.exception(f"Generic error during chunking fetch for {symbol}") | |
break | |
if attempts >= max_chunk_attempts: | |
logging.warning(f"Max chunk fetch attempts reached for {symbol}.") | |
if all_ohlcv_chunks: | |
processed_df = clean_and_process_ohlcv(all_ohlcv_chunks, symbol, total_candles) | |
if processed_df is not None and len(processed_df) == total_candles: | |
final_df = processed_df | |
else: | |
logging.error(f"No data obtained from chunk fetching for {symbol}.") | |
# --- Final Check and Caching --- | |
if final_df is not None and len(final_df) == total_candles: | |
expected_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] | |
if all(col in final_df.columns for col in expected_cols): | |
data_cache[cache_key] = (current_time, final_df.copy()) | |
return final_df | |
else: | |
logging.error(f"Final DataFrame for {symbol} missing expected columns. Won't cache.") | |
return None | |
else: | |
logging.error(f"Failed to fetch exactly {total_candles} candles for {symbol}. Found: {len(final_df) if final_df is not None else 0}") | |
return None | |
# --- Embedding, LLT, Normalize, Training Prep (Largely unchanged) --- | |
# Keep create_embedding, llt_transform, normalize_data, prepare_training_data, train_model | |
# as they don't depend on the TA library choice. | |
def create_embedding(data, l=L, lag=LAG): | |
# (Keep this function as is) | |
n = len(data) | |
rows = n - (l - 1) * lag | |
if rows <= 0: | |
logging.debug(f"Cannot create embedding: data length {n} too short for L={l}, Lag={lag}") | |
return np.array([]) | |
A = np.zeros((rows, l)) | |
try: | |
for t in range(rows): | |
indices = t + np.arange(l) * lag | |
A[t] = data[indices] | |
return A | |
except IndexError as e: | |
logging.error(f"IndexError during embedding: n={n}, l={l}, lag={lag}. Error: {e}") | |
return np.array([]) | |
except Exception as e: | |
logging.exception("Error in create_embedding") | |
return np.array([]) | |
def llt_transform(X_train, y_train, X_test): | |
# (Keep this function as is) | |
if not isinstance(X_train, np.ndarray) or X_train.ndim != 3 or \ | |
not isinstance(y_train, np.ndarray) or y_train.ndim != 1 or \ | |
not isinstance(X_test, np.ndarray) or (X_test.size > 0 and X_test.ndim != 3): | |
logging.error(f"LLT input type/shape error.") | |
return np.array([]), np.array([]) | |
if X_train.shape[0] != y_train.shape[0]: | |
logging.error(f"LLT input mismatch: len(X_train) != len(y_train)") | |
return np.array([]), np.array([]) | |
if X_train.size == 0 or y_train.size == 0: | |
logging.error("LLT requires non-empty training data.") | |
return np.array([]), np.array([]) | |
if X_test.size > 0 and X_test.shape[1:] != X_train.shape[1:]: | |
logging.error(f"LLT train/test shape mismatch") | |
return np.array([]), np.array([]) | |
try: | |
num_features = X_train.shape[2] | |
if num_features != len(FEATURES): | |
logging.error(f"LLT: Feature count mismatch.") | |
return np.array([]), np.array([]) | |
V = {j: {'0': [], '1': []} for j in range(num_features)} | |
laws_computed_count = {j: {'0': 0, '1': 0} for j in range(num_features)} | |
for i in range(len(X_train)): | |
label = str(int(y_train[i])) | |
if label not in ['0', '1']: continue | |
for j in range(num_features): | |
feature_data = X_train[i, :, j] | |
A = create_embedding(feature_data, l=L, lag=LAG) | |
if A.shape[0] < L: continue | |
if np.isnan(A).any() or np.isinf(A).any(): continue | |
try: | |
S = A.T @ A | |
if np.isnan(S).any() or np.isinf(S).any(): continue | |
U, s, Vt = svd(S, full_matrices=False) | |
if Vt.shape[0] < L or Vt.shape[1] != L: continue | |
if s[-1] < 1e-9: continue | |
v = Vt[-1] | |
norm = np.linalg.norm(v) | |
if norm < 1e-9: continue | |
V[j][label].append(v / norm) | |
laws_computed_count[j][label] += 1 | |
except np.linalg.LinAlgError: pass | |
except Exception: pass | |
valid_laws_exist = False | |
for j in V: | |
for c in ['0', '1']: | |
if laws_computed_count[j][c] > 0: | |
valid_vecs = [vec for vec in V[j][c] if isinstance(vec, np.ndarray) and vec.shape == (L,)] | |
if not valid_vecs: | |
V[j][c] = np.zeros((L, 0)) | |
continue | |
try: | |
V[j][c] = np.array(valid_vecs).T | |
if V[j][c].shape[0] != L: | |
V[j][c] = np.zeros((L, 0)) | |
else: | |
valid_laws_exist = True | |
except Exception: V[j][c] = np.zeros((L, 0)) | |
else: V[j][c] = np.zeros((L, 0)) | |
if not valid_laws_exist: | |
logging.error("LLT ERROR: No valid laws computed.") | |
return np.array([]), np.array([]) | |
def transform_instance(X_instance): | |
transformed_features = [] | |
if X_instance.ndim != 2 or X_instance.shape[0] != K or X_instance.shape[1] != num_features: | |
return np.zeros(num_features * 2 * D) | |
for j in range(num_features): | |
feature_data = X_instance[:, j] | |
A = create_embedding(feature_data, l=L, lag=LAG) | |
if A.shape[0] < L: | |
transformed_features.extend([0.0] * (2 * D)) | |
continue | |
if np.isnan(A).any() or np.isinf(A).any(): | |
transformed_features.extend([0.0] * (2 * D)) | |
continue | |
try: | |
S = A.T @ A | |
if np.isnan(S).any() or np.isinf(S).any(): | |
transformed_features.extend([0.0] * (2 * D)) | |
continue | |
for c in ['0', '1']: | |
if V[j][c].shape[1] == 0: | |
transformed_features.extend([0.0] * D) | |
continue | |
S_V = S @ V[j][c] | |
if S_V.size == 0 or np.isnan(S_V).any() or np.isinf(S_V).any(): | |
transformed_features.extend([0.0] * D) | |
continue | |
variances = np.var(S_V, axis=0) | |
if variances.size == 0: | |
transformed_features.extend([0.0] * D) | |
continue | |
variances = np.nan_to_num(variances, nan=np.finfo(variances.dtype).max, posinf=np.finfo(variances.dtype).max, neginf=np.finfo(variances.dtype).max) | |
num_vars_available = variances.size | |
num_vars_to_select = min(D, num_vars_available) | |
smallest_indices = np.argpartition(variances, num_vars_to_select -1)[:num_vars_to_select] | |
smallest_vars = np.sort(variances[smallest_indices]) | |
padded_vars = np.pad(smallest_vars, (0, D - num_vars_to_select), 'constant', constant_values=0.0) | |
if np.isnan(padded_vars).any() or np.isinf(padded_vars).any(): | |
padded_vars = np.nan_to_num(padded_vars, nan=0.0, posinf=0.0, neginf=0.0) | |
transformed_features.extend(padded_vars) | |
except Exception: | |
current_len = len(transformed_features) | |
expected_len_after_feature = (j + 1) * 2 * D | |
num_missing = expected_len_after_feature - current_len | |
if num_missing > 0: transformed_features.extend([0.0] * num_missing) | |
transformed_features = transformed_features[:expected_len_after_feature] | |
correct_len = num_features * 2 * D | |
if len(transformed_features) != correct_len: | |
if len(transformed_features) < correct_len: transformed_features.extend([0.0] * (correct_len - len(transformed_features))) | |
else: transformed_features = transformed_features[:correct_len] | |
return np.array(transformed_features) | |
X_train_t = np.array([transform_instance(X) for X in X_train]) | |
X_test_t = np.array([]) | |
if X_test.size > 0: X_test_t = np.array([transform_instance(X) for X in X_test]) | |
expected_dim = num_features * 2 * D | |
if X_train_t.shape[0] != len(X_train) or (X_train_t.size > 0 and X_train_t.shape[1] != expected_dim): | |
logging.error(f"LLT Train transform resulted in unexpected shape.") | |
return np.array([]), np.array([]) | |
if X_test.size > 0 and (X_test_t.shape[0] != len(X_test) or (X_test_t.size > 0 and X_test_t.shape[1] != expected_dim)): | |
logging.error(f"LLT Test transform resulted in unexpected shape.") | |
return X_train_t, np.array([]) | |
return X_train_t, X_test_t | |
except Exception as e: | |
logging.exception("Error in llt_transform function") | |
return np.array([]), np.array([]) | |
def normalize_data(df): | |
# (Keep this function as is) | |
normalized_df = df.copy() | |
if not isinstance(df, pd.DataFrame): | |
logging.error("Normalize_data received non-DataFrame input.") | |
return None | |
for feature in FEATURES: | |
if feature == 'timestamp': continue | |
if feature not in df.columns: | |
normalized_df[feature] = 0.0 | |
continue | |
if pd.api.types.is_numeric_dtype(df[feature]): | |
mean = df[feature].mean() | |
std = df[feature].std() | |
if std is not None and not pd.isna(std) and std > 1e-9: | |
normalized_df[feature] = (df[feature] - mean) / std | |
else: | |
normalized_df[feature] = 0.0 | |
if normalized_df[feature].isnull().any(): | |
normalized_df[feature] = normalized_df[feature].fillna(0.0) | |
else: | |
normalized_df[feature] = 0.0 | |
return normalized_df | |
def generate_synthetic_data(symbol, total_candles=WINDOW): | |
# (Keep this function as is) | |
logging.info(f"Generating synthetic data for {symbol} ({total_candles} candles)") | |
np.random.seed(int(time.time() * 1000) % (2**32 - 1)) | |
end_time = pd.Timestamp.now(tz='UTC') | |
timestamps = pd.date_range(end=end_time, periods=total_candles, freq='T') | |
volatility = np.random.uniform(0.005, 0.03) | |
base_price = np.random.uniform(1, 5000) | |
prices = [base_price] | |
for _ in range(1, total_candles): | |
change = np.random.normal(0, volatility / np.sqrt(1440)) | |
prices.append(prices[-1] * (1 + change)) | |
prices = np.maximum(0.01, prices) | |
close_prices = np.array(prices) | |
open_prices = close_prices * (1 + np.random.normal(0, volatility / np.sqrt(1440) / 2, total_candles)) | |
high_prices = np.maximum(close_prices, open_prices) * (1 + np.random.uniform(0, volatility / np.sqrt(1440), total_candles)) | |
low_prices = np.minimum(close_prices, open_prices) * (1 - np.random.uniform(0, volatility / np.sqrt(1440), total_candles)) | |
high_prices = np.maximum.reduce([high_prices, open_prices, close_prices]) | |
low_prices = np.minimum.reduce([low_prices, open_prices, close_prices]) | |
volumes = np.random.poisson(base_price * np.random.uniform(1, 10)) * (1 + np.abs(np.diff(close_prices, prepend=close_prices[0])) / close_prices * 5) | |
volumes = np.maximum(1, volumes) | |
df = pd.DataFrame({ | |
'timestamp': timestamps, 'open': open_prices, 'high': high_prices, | |
'low': low_prices, 'close': close_prices, 'volume': volumes | |
}) | |
for col in FEATURES: df[col] = pd.to_numeric(df[col]) | |
df.reset_index(drop=True, inplace=True) | |
return df | |
def prepare_training_data(symbol, total_candles_to_fetch=WINDOW + OVERLAP_STEP * 20): | |
# (Keep this function as is) | |
logging.info(f"Preparing training data for {symbol}...") | |
try: | |
required_base_candles = WINDOW | |
estimated_candles_needed = required_base_candles + (MIN_TRAINING_EXAMPLES * 2) * OVERLAP_STEP + 500 | |
fetch_candle_count = max(WINDOW + 500, estimated_candles_needed) | |
logging.info(f"Fetching {fetch_candle_count} candles for {symbol} training prep...") | |
df = fetch_historical_data(symbol, timeframe='1m', total_candles=fetch_candle_count) | |
if df is None or len(df) < WINDOW: | |
logging.error(f"Insufficient data fetched for {symbol} ({len(df) if df is not None else 0} < {WINDOW}).") | |
if USE_SYNTHETIC_DATA_FOR_LOW_VOLUME: | |
logging.warning(f"Attempting synthetic data generation for {symbol}.") | |
df = generate_synthetic_data(symbol, total_candles=WINDOW + OVERLAP_STEP * 10) | |
if df is None or len(df) < WINDOW: | |
logging.error(f"Synthetic data generation failed or insufficient for {symbol}.") | |
return None, None | |
else: logging.info(f"Using synthetic data ({len(df)} points) for {symbol}.") | |
else: return None, None | |
df_normalized = normalize_data(df) | |
if df_normalized is None: | |
logging.error(f"Normalization failed for {symbol}.") | |
return None, None | |
if df_normalized[FEATURES].isnull().any().any(): | |
logging.warning(f"NaN values found after normalization for {symbol}. Filling with 0.") | |
df_normalized = df_normalized.fillna(0.0) | |
X, y = [], [] | |
end_index = len(df) | |
start_index = WINDOW | |
num_windows_created = 0 | |
for i in range(end_index, start_index - 1, -OVERLAP_STEP): | |
window_end_idx = i | |
window_start_idx = i - WINDOW | |
if window_start_idx < 0: continue | |
window_orig = df.iloc[window_start_idx:window_end_idx] | |
window_norm = df_normalized.iloc[window_start_idx:window_end_idx] | |
if len(window_orig) != WINDOW or len(window_norm) != WINDOW: continue | |
input_data_norm = window_norm.iloc[:K][FEATURES].values | |
if input_data_norm.shape[0] != K or input_data_norm.shape[1] != len(FEATURES): continue | |
if np.isnan(input_data_norm).any(): continue | |
start_price_iloc_idx = K - 1 | |
end_price_iloc_idx = WINDOW - 1 | |
start_price = window_orig['close'].iloc[start_price_iloc_idx] | |
end_price = window_orig['close'].iloc[end_price_iloc_idx] | |
if pd.isna(start_price) or pd.isna(end_price) or start_price <= 0: continue | |
X.append(input_data_norm) | |
y.append(1 if end_price > start_price else 0) | |
num_windows_created += 1 | |
if not X: | |
logging.error(f"No valid windows created for {symbol}.") | |
return None, None | |
X = np.array(X) | |
y = np.array(y) | |
unique_classes, class_counts = np.unique(y, return_counts=True) | |
class_dist_str = ", ".join([f"Class {cls}: {count}" for cls, count in zip(unique_classes, class_counts)]) | |
logging.info(f"Class distribution BEFORE balancing for {symbol}: {class_dist_str}") | |
if len(unique_classes) < 2: | |
logging.error(f"ONLY ONE CLASS ({unique_classes[0]}) present for {symbol}.") | |
return None, None | |
min_class_count = min(class_counts) | |
if min_class_count * 2 < MIN_TRAINING_EXAMPLES: | |
logging.error(f"Minority class count ({min_class_count}) too low for {symbol}.") | |
return None, None | |
samples_per_class = min_class_count | |
balanced_indices = [] | |
for class_val in unique_classes: | |
class_indices = np.where(y == class_val)[0] | |
num_to_choose = min(samples_per_class, len(class_indices)) | |
chosen_indices = np.random.choice(class_indices, size=num_to_choose, replace=False) | |
balanced_indices.extend(chosen_indices) | |
np.random.shuffle(balanced_indices) | |
X_balanced = X[balanced_indices] | |
y_balanced = y[balanced_indices] | |
final_unique, final_counts = np.unique(y_balanced, return_counts=True) | |
logging.info(f"Balanced dataset for {symbol}: {len(X_balanced)} instances. Final counts: {dict(zip(final_unique, final_counts))}") | |
if len(X_balanced) < MIN_TRAINING_EXAMPLES: | |
logging.error(f"Insufficient data ({len(X_balanced)}) for {symbol} AFTER balancing.") | |
return None, None | |
if X_balanced.ndim != 3 or X_balanced.shape[0] == 0 or X_balanced.shape[1] != K or X_balanced.shape[2] != len(FEATURES): | |
logging.error(f"Final balanced data has unexpected shape {X_balanced.shape} for {symbol}.") | |
return None, None | |
return X_balanced, y_balanced | |
except Exception as e: | |
logging.exception(f"Error preparing training data for {symbol}") | |
return None, None | |
def train_model(symbol): | |
# (Keep this function as is) | |
logging.info(f"--- Attempting to train model for {symbol} ---") | |
np.random.seed(int(time.time()) % (2**32 - 1)) | |
X, y = prepare_training_data(symbol) | |
if X is None or y is None: | |
logging.error(f"Failed to prepare training data for {symbol}. Training aborted.") | |
return None, None, None | |
try: | |
accuracy = -1.0 | |
if len(X) < MIN_TRAINING_EXAMPLES + 2: | |
logging.warning(f"Dataset for {symbol} too small ({len(X)}). Training on all data.") | |
X_train, y_train = X, y | |
X_val, y_val = np.array([]), np.array([]) | |
else: | |
indices = np.random.permutation(len(X)) | |
val_size = max(1, int(len(X) * 0.2)) | |
split_idx = len(X) - val_size | |
train_indices, val_indices = indices[:split_idx], indices[split_idx:] | |
if len(train_indices) == 0 or len(val_indices) == 0: | |
logging.error(f"Train/Val split resulted in zero samples. Training on all data.") | |
X_train, y_train = X, y | |
X_val, y_val = np.array([]), np.array([]) | |
else: | |
X_train, X_val = X[train_indices], X[val_indices] | |
y_train, y_val = y[train_indices], y[val_indices] | |
if len(np.unique(y_train)) < 2: | |
logging.error(f"Only one class in TRAINING set after split for {symbol}. Aborting.") | |
return None, None, None | |
if len(np.unique(y_val)) < 2: | |
logging.warning(f"Only one class in VALIDATION set after split for {symbol}.") | |
if X_val.size == 0: X_val_shaped = np.empty((0, K, len(FEATURES))) | |
else: X_val_shaped = X_val | |
X_train_t, X_val_t = llt_transform(X_train, y_train, X_val_shaped) | |
if X_train_t.size == 0: | |
logging.error(f"LLT training transformation failed for {symbol}. Training aborted.") | |
return None, None, None | |
if X_val.size > 0 and X_val_t.size == 0: | |
logging.warning(f"LLT validation transformation failed for {symbol}.") | |
accuracy = -1.0 | |
if np.isnan(X_train_t).any() or np.isinf(X_train_t).any(): | |
logging.error(f"NaN/Inf in LLT transformed TRAINING data for {symbol}. Training aborted.") | |
return None, None, None | |
if X_val_t.size > 0 and (np.isnan(X_val_t).any() or np.isinf(X_val_t).any()): | |
logging.warning(f"NaN/Inf in LLT transformed VALIDATION data for {symbol}.") | |
accuracy = -1.0 | |
n_neighbors = min(5, len(y_train) - 1) if len(y_train) > 1 else 1 | |
n_neighbors = max(1, n_neighbors) | |
if n_neighbors > 1 and n_neighbors % 2 == 0: n_neighbors -= 1 | |
model = KNeighborsClassifier(n_neighbors=n_neighbors, weights='distance') | |
model.fit(X_train_t, y_train) | |
if accuracy != -1.0 and X_val_t.size > 0: | |
try: | |
accuracy = model.score(X_val_t, y_val) | |
logging.info(f"Model for {symbol} trained. Validation Accuracy: {accuracy:.3f}") | |
except Exception as eval_e: | |
logging.exception(f"Error during KNN validation scoring for {symbol}: {eval_e}") | |
accuracy = -1.0 | |
elif accuracy == -1.0: | |
logging.info(f"Model for {symbol} trained. Validation skipped or failed.") | |
else: | |
logging.info(f"Model for {symbol} trained. No validation data.") | |
accuracy = -1.0 | |
return model, X_train, y_train | |
except Exception as e: | |
logging.exception(f"Error during model training pipeline for {symbol}") | |
return None, None, None | |
def predict_real_time(symbol, model_data): | |
# (Keep this function as is) | |
if model_data is None: return "Model N/A", 0.0 | |
model, X_train_orig_for_llt, y_train_orig_for_llt = model_data | |
if model is None or X_train_orig_for_llt is None or y_train_orig_for_llt is None: | |
logging.error(f"Invalid model data tuple for prediction on {symbol}") | |
return "Model Error", 0.0 | |
if X_train_orig_for_llt.size == 0 or y_train_orig_for_llt.size == 0: | |
logging.error(f"Training data for LLT laws is empty for {symbol}") | |
return "LLT Data Error", 0.0 | |
try: | |
df = fetch_historical_data(symbol, timeframe='1m', total_candles=K + 60) | |
if df is None or len(df) < K: | |
return "Data Error", 0.0 | |
df_recent = df.iloc[-K:] | |
if len(df_recent) != K: | |
return "Data Error", 0.0 | |
df_recent_normalized = normalize_data(df_recent) | |
if df_recent_normalized is None: return "Norm Error", 0.0 | |
if df_recent_normalized[FEATURES].isnull().any().any(): | |
df_recent_normalized = df_recent_normalized.fillna(0.0) | |
X_predict_input = np.array([df_recent_normalized[FEATURES].values]) | |
_, X_predict_transformed = llt_transform(X_train_orig_for_llt, y_train_orig_for_llt, X_predict_input) | |
if X_predict_transformed.size == 0 or X_predict_transformed.shape[0] != 1: | |
return "Transform Error", 0.0 | |
if np.isnan(X_predict_transformed).any() or np.isinf(X_predict_transformed).any(): | |
X_predict_transformed = np.nan_to_num(X_predict_transformed, nan=0.0, posinf=0.0, neginf=0.0) | |
try: | |
probabilities = model.predict_proba(X_predict_transformed) | |
if probabilities.shape[0] != 1 or probabilities.shape[1] != 2: | |
return "Predict Error", 0.0 | |
prob_class_1 = probabilities[0, 1] | |
prediction_label = "Rise" if prob_class_1 >= 0.5 else "Fall" | |
confidence = prob_class_1 if prediction_label == "Rise" else probabilities[0, 0] | |
return prediction_label, confidence | |
except Exception as knn_e: | |
logging.exception(f"Error during KNN prediction probability for {symbol}") | |
return "Predict Error", 0.0 | |
except Exception as e: | |
logging.exception(f"Error in predict_real_time for {symbol}") | |
return "Error", 0.0 | |
# --- TA Calculation Function (Using TA-Lib) --- | |
def calculate_ta_indicators(df_ta): | |
""" | |
Calculates TA indicators (RSI, MACD, VWAP, ATR) using TA-Lib. | |
Requires df_ta to have 'open', 'high', 'low', 'close', 'volume' columns. | |
""" | |
indicators = {'RSI': np.nan, 'MACD': np.nan, 'MACD_Signal': np.nan, 'MACD_Hist': np.nan, 'VWAP': np.nan, 'ATR': np.nan} | |
required_cols = ['open', 'high', 'low', 'close', 'volume'] | |
min_len_needed = max(RSI_PERIOD, MACD_SLOW, ATR_PERIOD) + 1 # TA-Lib often needs P+1 | |
if df_ta is None or len(df_ta) < min_len_needed: | |
logging.warning(f"Insufficient data ({len(df_ta) if df_ta is not None else 0} < {min_len_needed}) for TA-Lib calculations.") | |
return indicators | |
# Ensure columns exist | |
if not all(col in df_ta.columns for col in required_cols): | |
logging.error(f"Missing required columns for TA-Lib: Have {df_ta.columns}, Need {required_cols}") | |
return indicators | |
# --- Prepare data for TA-Lib (NumPy arrays, handle NaNs) --- | |
df_ta = df_ta[required_cols].copy() # Work on a copy with only needed columns | |
# Check for NaNs BEFORE converting to numpy, TA-Lib generally dislikes them | |
if df_ta.isnull().values.any(): | |
nan_count = df_ta.isnull().sum().sum() | |
logging.warning(f"Found {nan_count} NaN(s) in TA input data. Applying ffill()...") | |
df_ta.ffill(inplace=True) # Forward fill NaNs | |
# Check again after ffill - if NaNs remain (e.g., at the start), need more robust handling | |
if df_ta.isnull().values.any(): | |
logging.error(f"NaNs still present after ffill. Cannot proceed with TA-Lib.") | |
return indicators # Return NaNs | |
try: | |
# Convert to NumPy arrays of type float | |
open_p = df_ta['open'].values.astype(float) | |
high_p = df_ta['high'].values.astype(float) | |
low_p = df_ta['low'].values.astype(float) | |
close_p = df_ta['close'].values.astype(float) | |
volume_p = df_ta['volume'].values.astype(float) | |
# --- Calculate Indicators using TA-Lib --- | |
# RSI | |
rsi_values = talib.RSI(close_p, timeperiod=RSI_PERIOD) | |
indicators['RSI'] = rsi_values[-1] if len(rsi_values) > 0 else np.nan | |
# MACD | |
macd_line, signal_line, hist = talib.MACD(close_p, fastperiod=MACD_FAST, slowperiod=MACD_SLOW, signalperiod=MACD_SIGNAL) | |
indicators['MACD'] = macd_line[-1] if len(macd_line) > 0 else np.nan | |
indicators['MACD_Signal'] = signal_line[-1] if len(signal_line) > 0 else np.nan | |
indicators['MACD_Hist'] = hist[-1] if len(hist) > 0 else np.nan | |
# ATR | |
atr_values = talib.ATR(high_p, low_p, close_p, timeperiod=ATR_PERIOD) | |
indicators['ATR'] = atr_values[-1] if len(atr_values) > 0 else np.nan | |
# VWAP (Manual Calculation - TA-Lib doesn't have it built-in) | |
typical_price = (high_p + low_p + close_p) / 3.0 | |
tp_vol = typical_price * volume_p | |
cumulative_volume = np.cumsum(volume_p) | |
# Avoid division by zero if volume is zero for initial periods | |
if cumulative_volume[-1] > 1e-12: # Check if there's significant volume | |
vwap_values = np.cumsum(tp_vol) / np.maximum(cumulative_volume, 1e-12) # Avoid div by zero strictly | |
indicators['VWAP'] = vwap_values[-1] | |
else: | |
indicators['VWAP'] = np.nan # VWAP undefined if no volume | |
# Final check for NaNs in results (TA-Lib might return NaN for initial periods) | |
for key, value in indicators.items(): | |
if pd.isna(value): | |
indicators[key] = np.nan # Ensure consistent NaN representation | |
# logging.debug(f"TA-Lib Indicators calculated: {indicators}") | |
return indicators | |
except Exception as ta_e: | |
logging.exception(f"Error calculating TA indicators using TA-Lib: {ta_e}") | |
return {k: np.nan for k in indicators} # Return NaNs on error | |
# --- Trade Level Calculation (Unchanged) --- | |
def calculate_trade_levels(prediction, confidence, current_price, atr): | |
# (Keep this function as is - no changes needed) | |
levels = {'Entry': np.nan, 'TP1': np.nan, 'TP2': np.nan, 'SL': np.nan} | |
if pd.isna(current_price) or current_price <= 0 or pd.isna(atr) or atr <= 0: | |
return levels | |
if prediction == "Rise" and confidence >= CONFIDENCE_THRESHOLD: | |
entry_price = current_price | |
levels['Entry'] = entry_price | |
levels['TP1'] = entry_price + TP1_ATR_MULTIPLIER * atr | |
levels['TP2'] = entry_price + TP2_ATR_MULTIPLIER * atr | |
levels['SL'] = entry_price - SL_ATR_MULTIPLIER * atr | |
levels['SL'] = max(0.01, levels['SL']) | |
# Add Fall logic here if needed | |
return levels | |
# --- Concurrency Wrappers (Unchanged) --- | |
def train_model_task(coin): | |
# (Keep this function as is) | |
try: | |
result = train_model(coin) | |
if result != (None, None, None): | |
model, X_train_orig, y_train_orig = result | |
return coin, (model, X_train_orig, y_train_orig) | |
else: | |
return coin, None | |
except Exception as e: | |
logging.exception(f"Unhandled exception in train_model_task for {coin}") | |
return coin, None | |
def train_all_models(coin_list=None, num_workers=NUM_WORKERS_TRAINING): | |
# (Keep this function as is) | |
global trained_models | |
start_time = time.time() | |
if coin_list is None or not coin_list: | |
logging.info("No coin list provided, fetching top coins by volume...") | |
try: | |
coin_list = get_all_usdt_pairs() | |
if not coin_list: | |
msg = "Failed to fetch coin list even with fallback. Training aborted." | |
logging.error(msg) | |
return msg | |
except Exception as e: | |
msg = f"Error fetching coin list: {e}. Training aborted." | |
logging.exception(msg) | |
return msg | |
logging.info(f"Starting training for {len(coin_list)} coins using {num_workers} workers...") | |
results_log = [] | |
successful_trains = 0 | |
failed_trains = 0 | |
new_models = {} | |
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers, thread_name_prefix='TrainWorker') as executor: | |
future_to_coin = {executor.submit(train_model_task, coin): coin for coin in coin_list} | |
processed_count = 0 | |
total_coins = len(coin_list) | |
for future in concurrent.futures.as_completed(future_to_coin): | |
processed_count += 1 | |
coin = future_to_coin[future] | |
try: | |
returned_coin, model_data = future.result() | |
if returned_coin == coin and model_data is not None: | |
new_models[returned_coin] = model_data | |
results_log.append(f"✅ {returned_coin}: Model trained successfully.") | |
successful_trains += 1 | |
else: | |
results_log.append(f"❌ {coin}: Model training failed (check logs).") | |
failed_trains += 1 | |
except Exception as e: | |
results_log.append(f"❌ {coin}: Training task generated exception: {e}") | |
failed_trains += 1 | |
logging.exception(f"Exception from training future for {coin}") | |
if processed_count % 10 == 0 or processed_count == total_coins: | |
logging.info(f"Training progress: {processed_count}/{total_coins} coins processed.") | |
logging.getLogger().handlers[0].flush() | |
trained_models.update(new_models) | |
logging.info(f"Updated global models dictionary. Total models now: {len(trained_models)}") | |
end_time = time.time() | |
duration = end_time - start_time | |
completion_message = ( | |
f"Training run completed in {duration:.2f} seconds.\n" | |
f"Successfully trained: {successful_trains}\n" | |
f"Failed to train: {failed_trains}\n" | |
f"Total models available now: {len(trained_models)}" | |
) | |
logging.info(completion_message) | |
return completion_message + "\n\n" + "\n".join(results_log[-20:]) | |
# --- Update Predictions Table (Mostly Unchanged, uses new TA function) --- | |
def update_predictions_table(): | |
# (This function structure remains the same, it just calls the new calculate_ta_indicators) | |
global last_update_time | |
logging.info("--- Updating Predictions Table ---") | |
start_time = time.time() | |
predictions_data = {} | |
current_models = trained_models.copy() | |
if not current_models: | |
msg = "No models available. Please train first." | |
logging.warning(msg) | |
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR'] | |
return pd.DataFrame([], columns=cols), msg | |
symbols_with_models = list(current_models.keys()) | |
logging.info(f"Step 1: Generating predictions for {len(symbols_with_models)} models...") | |
# --- Stage 1: Get Predictions Concurrently --- | |
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='PredictWorker') as executor: | |
future_to_coin_pred = {executor.submit(predict_real_time, coin, model_data): coin for coin, model_data in current_models.items()} | |
pred_success = 0 | |
pred_fail = 0 | |
for future in concurrent.futures.as_completed(future_to_coin_pred): | |
coin = future_to_coin_pred[future] | |
try: | |
pred, conf = future.result() | |
if pred not in ["Model N/A", "Model Error", "Data Error", "Norm Error", "LLT Data Error", "Transform Error", "Predict Error", "Error"]: | |
predictions_data[coin] = {'prediction': pred, 'confidence': float(conf)} | |
pred_success += 1 | |
else: | |
predictions_data[coin] = {'prediction': pred, 'confidence': 0.0} | |
pred_fail += 1 | |
except Exception as e: | |
logging.exception(f"Error getting prediction result for {coin}") | |
predictions_data[coin] = {'prediction': "Future Error", 'confidence': 0.0} | |
pred_fail +=1 | |
logging.info(f"Step 1 Complete: Predictions generated ({pred_success} success, {pred_fail} fail).") | |
# --- Stage 2: Fetch Current Tickers & TA Data Concurrently --- | |
symbols_to_fetch_data = list(predictions_data.keys()) | |
if not symbols_to_fetch_data: | |
logging.warning("No symbols with predictions to fetch data for.") | |
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR'] | |
return pd.DataFrame([], columns=cols), "No symbols processed." | |
logging.info(f"Step 2: Fetching Tickers and {TA_DATA_POINTS} OHLCV candles for {len(symbols_to_fetch_data)} symbols...") | |
tickers_data = {} | |
ohlcv_data = {} | |
try: # Fetch Tickers | |
batch_size_tickers = 100 | |
fetched_tickers_batch = {} | |
for i in range(0, len(symbols_to_fetch_data), batch_size_tickers): | |
batch_symbols = symbols_to_fetch_data[i:i+batch_size_tickers] | |
try: | |
batch_tickers = exchange.fetch_tickers(symbols=batch_symbols) | |
fetched_tickers_batch.update(batch_tickers) | |
time.sleep(exchange.rateLimit / 1000 * 0.5) | |
except Exception as e: | |
logging.error(f"Failed to fetch ticker batch starting with {batch_symbols[0]}: {e}") | |
tickers_data = fetched_tickers_batch | |
logging.info(f"Fetched {len(tickers_data)} tickers.") | |
except Exception as e: | |
logging.exception(f"Error fetching tickers in prediction update: {e}") | |
# Fetch OHLCV for TA | |
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS_PREDICTION, thread_name_prefix='TADataWorker') as executor: | |
future_to_coin_ohlcv = {executor.submit(fetch_historical_data, coin, '1m', TA_DATA_POINTS): coin for coin in symbols_to_fetch_data} | |
for future in concurrent.futures.as_completed(future_to_coin_ohlcv): | |
coin = future_to_coin_ohlcv[future] | |
try: | |
df_ta = future.result() | |
if df_ta is not None and len(df_ta) == TA_DATA_POINTS: | |
# Ensure standard column names expected by calculate_ta_indicators | |
df_ta.columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] | |
ohlcv_data[coin] = df_ta | |
except Exception as e: | |
logging.exception(f"Error fetching TA OHLCV data for {coin}") | |
logging.info(f"Step 2 Complete: Fetched TA data for {len(ohlcv_data)} symbols.") | |
# --- Stage 3: Calculate TA & Trade Levels --- | |
logging.info(f"Step 3: Calculating TA (using TA-Lib) and Trade Levels...") | |
final_results = [] | |
processing_time = datetime.now(timezone.utc) | |
for symbol in symbols_to_fetch_data: | |
pred_info = predictions_data.get(symbol, {'prediction': 'Missing Pred', 'confidence': 0.0}) | |
ticker = tickers_data.get(symbol) | |
df_ta = ohlcv_data.get(symbol) # This df should have standard columns now | |
current_price, quote_volume = np.nan, np.nan | |
ta_indicators = {k: np.nan for k in ['RSI', 'MACD', 'MACD_Signal', 'MACD_Hist', 'VWAP', 'ATR']} | |
trade_levels = {k: np.nan for k in ['Entry', 'TP1', 'TP2', 'SL']} | |
entry_time, exit_time = pd.NaT, pd.NaT | |
if ticker and isinstance(ticker, dict): | |
current_price = ticker.get('last', np.nan) | |
quote_volume = ticker.get('info', {}).get('quoteVolume') | |
if quote_volume is None: | |
base_volume = ticker.get('baseVolume') | |
if base_volume is not None and current_price is not None: | |
try: quote_volume = float(base_volume) * float(current_price) | |
except (ValueError, TypeError): quote_volume = np.nan | |
try: current_price = float(current_price) if current_price is not None else np.nan | |
except (ValueError, TypeError): current_price = np.nan | |
try: quote_volume = float(quote_volume) if quote_volume is not None else np.nan | |
except (ValueError, TypeError): quote_volume = np.nan | |
# Calculate TA using the new function | |
if df_ta is not None: | |
ta_indicators = calculate_ta_indicators(df_ta) # Calls the TA-Lib version | |
if pred_info['prediction'] in ["Rise", "Fall"] and not pd.isna(current_price) and not pd.isna(ta_indicators['ATR']): | |
trade_levels = calculate_trade_levels(pred_info['prediction'], pred_info['confidence'], current_price, ta_indicators['ATR']) | |
if not pd.isna(trade_levels['Entry']): | |
entry_time = processing_time | |
exit_time = processing_time + timedelta(hours=PREDICTION_WINDOW_HOURS) | |
final_results.append({ | |
'coin': symbol.split('/')[0], 'full_symbol': symbol, | |
'prediction': pred_info['prediction'], 'confidence': pred_info['confidence'], | |
'price': current_price, 'volume': quote_volume, | |
'entry': trade_levels['Entry'], 'entry_time': entry_time, 'exit_time': exit_time, | |
'tp1': trade_levels['TP1'], 'tp2': trade_levels['TP2'], 'sl': trade_levels['SL'], | |
'rsi': ta_indicators['RSI'], 'macd_hist': ta_indicators['MACD_Hist'], | |
'vwap': ta_indicators['VWAP'], 'atr': ta_indicators['ATR'] | |
}) | |
logging.info("Step 3 Complete: TA and Trade Levels calculated.") | |
# --- Stage 4: Sort and Format (Unchanged) --- | |
def sort_key(item): | |
pred, conf = item['prediction'], item['confidence'] | |
if pred == "Rise" and conf >= CONFIDENCE_THRESHOLD and not pd.isna(item['entry']): return (0, -conf) | |
elif pred == "Rise": return (1, -conf) | |
elif pred == "Fall": return (2, -conf) | |
else: return (3, 0) | |
final_results.sort(key=sort_key) | |
formatted_output = [] | |
for i, p in enumerate(final_results[:MAX_COINS_TO_DISPLAY]): | |
formatted_output.append([ | |
i + 1, p['coin'], p['prediction'], f"{p['confidence']:.3f}", | |
f"{p['price']:.4f}" if not pd.isna(p['price']) else "N/A", | |
f"{p['volume']:,.0f}" if not pd.isna(p['volume']) else "N/A", | |
f"{p['entry']:.4f}" if not pd.isna(p['entry']) else "N/A", | |
format_datetime(p['entry_time'], "N/A"), format_datetime(p['exit_time'], "N/A"), | |
f"{p['tp1']:.4f}" if not pd.isna(p['tp1']) else "N/A", | |
f"{p['tp2']:.4f}" if not pd.isna(p['tp2']) else "N/A", | |
f"{p['sl']:.4f}" if not pd.isna(p['sl']) else "N/A", | |
f"{p['rsi']:.2f}" if not pd.isna(p['rsi']) else "N/A", | |
f"{p['macd_hist']:.4f}" if not pd.isna(p['macd_hist']) else "N/A", | |
f"{p['vwap']:.4f}" if not pd.isna(p['vwap']) else "N/A", | |
f"{p['atr']:.4f}" if not pd.isna(p['atr']) else "N/A", | |
]) | |
output_columns = [ | |
'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', | |
'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', | |
'RSI', 'MACD Hist', 'VWAP', 'ATR' | |
] | |
output_df = pd.DataFrame(formatted_output, columns=output_columns) | |
end_time = time.time() | |
duration = end_time - start_time | |
last_update_time = processing_time | |
status_message = f"Predictions updated ({len(final_results)} symbols processed) in {duration:.2f}s. Last update: {format_datetime(last_update_time)}" | |
logging.info(status_message) | |
return output_df, status_message | |
# --- Gradio UI Handlers (Unchanged) --- | |
def handle_train_click(coin_input, num_workers): | |
# (Keep this function as is) | |
logging.info(f"Train button clicked. Workers: {num_workers}") | |
coins = None | |
num_workers = int(num_workers) | |
if coin_input and coin_input.strip(): | |
raw_coins = coin_input.replace(',', ' ').split() | |
coins = [] | |
valid = True | |
for c in raw_coins: | |
coin_upper = c.strip().upper() | |
if '/' not in coin_upper: coin_upper += '/USDT' | |
if coin_upper.endswith('/USDT'): coins.append(coin_upper) | |
else: | |
valid = False | |
logging.error(f"Invalid coin format: {c}. Must be SYMBOL or SYMBOL/USDT.") | |
break | |
if not valid: return "Error: Custom coins must be valid SYMBOL or SYMBOL/USDT pairs." | |
logging.info(f"Training requested for custom coin list: {coins}") | |
else: | |
logging.info("Training requested for top coins by volume.") | |
train_status = train_all_models(coin_list=coins, num_workers=num_workers) | |
return f"--- Training Run ---:\n{train_status}\n\n---> Press 'Refresh Predictions' <---" | |
def handle_refresh_click(): | |
# (Keep this function as is) | |
logging.info("Refresh button clicked.") | |
try: | |
df, status = update_predictions_table() | |
return df, status | |
except Exception as e: | |
logging.exception("Error during handle_refresh_click") | |
cols = ['Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', 'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', 'RSI', 'MACD Hist', 'VWAP', 'ATR'] | |
return pd.DataFrame([], columns=cols), f"Error updating predictions: {e}" | |
# --- Gradio Interface Definition (Unchanged) --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Cryptocurrency Prediction & TA Signal Explorer (LLT-KNN + TA-Lib)") # Updated title slightly | |
gr.Markdown(f""" | |
Predicts **{PREDICTION_WINDOW_HOURS}-hour** price direction (Rise/Fall) using LLT-KNN. | |
Displays current price, volume, TA indicators (RSI, MACD, VWAP, ATR calculated using **TA-Lib**), and potential trade levels for **Rise** signals meeting confidence >= **{CONFIDENCE_THRESHOLD}**. | |
TP/SL levels based on **{TP1_ATR_MULTIPLIER}x / {TP2_ATR_MULTIPLIER}x / {SL_ATR_MULTIPLIER}x ATR({ATR_PERIOD})**. | |
**Warning:** Educational. High risk. Not financial advice. Ensure TA-Lib is correctly installed. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
prediction_df = gr.Dataframe( | |
headers=[ | |
'Rank', 'Coin', 'Prediction', 'Confidence', 'Price', 'Volume (Quote)', | |
'Entry', 'Entry Time', 'Exit Time', 'TP1', 'TP2', 'SL', | |
'RSI', 'MACD Hist', 'VWAP', 'ATR' | |
], | |
datatype=[ | |
'number', 'str', 'str', 'str', 'str', 'str', | |
'str', 'str', 'str', 'str', 'str', 'str', | |
'str', 'str', 'str', 'str' | |
], | |
row_count=15, col_count=(16, "fixed"), label="Predictions & TA Signals", wrap=True, | |
) | |
with gr.Column(scale=1): | |
with gr.Accordion("Train Models", open=True): | |
coin_input = gr.Textbox(label="Train Specific Coins (e.g., BTC, ETH/USDT)", placeholder="Leave empty for top coins by volume") | |
max_workers_slider = gr.Slider(minimum=1, maximum=10, value=NUM_WORKERS_TRAINING, step=1, label="Parallel Training Workers") | |
train_button = gr.Button("Start Training", variant="primary") | |
refresh_button = gr.Button("Refresh Predictions", variant="secondary") | |
status_text = gr.Textbox(label="Status Log", lines=15, interactive=False, max_lines=30) | |
gr.Markdown( | |
""" | |
## Notes | |
- **TA-Lib**: This version uses the TA-Lib library for indicators. Ensure it's installed correctly (can be tricky). | |
- **Data**: Fetches OHLCV data (Bitget, 1-min). Uses cache. Handles rate limits. | |
- **Training**: Uses past ~14h data (12h train, 2h predict). Normalizes, balances classes, applies LLT, trains KNN. | |
- **Prediction**: Uses latest 12h data for KNN input. | |
- **Trade Levels**: Only shown for 'Rise' predictions above confidence threshold. Based on current price and ATR volatility. **Highly speculative.** | |
- **Sorting**: Table sorted by (Potential Rise Signals > Other Rise > Fall > Errors), then by confidence descending. | |
- **Refresh**: Fetches latest prices/TA and re-evaluates signals. | |
""" | |
) | |
train_button.click(fn=handle_train_click, inputs=[coin_input, max_workers_slider], outputs=status_text) | |
refresh_button.click(fn=handle_refresh_click, inputs=None, outputs=[prediction_df, status_text]) | |
# --- Startup Initialization (Unchanged) --- | |
def initialize_models_on_startup(): | |
# (Keep this function as is) | |
logging.info("----- Initializing Models (Startup Thread) -----") | |
default_coins = ['BTC/USDT', 'ETH/USDT', 'SOL/USDT', 'XRP/USDT', 'DOGE/USDT'] | |
try: | |
initial_status = train_all_models(default_coins, num_workers=2) | |
logging.info("----- Initial Model Training Complete -----") | |
logging.info(initial_status) | |
except Exception as e: | |
logging.exception("Error during startup initialization.") | |
# --- Main Execution (Unchanged) --- | |
if __name__ == "__main__": | |
logging.info("Starting application...") | |
# Check if TA-Lib import worked (basic check) | |
try: | |
# Try accessing a TA-Lib function | |
_ = talib.RSI(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) | |
logging.info("TA-Lib library seems accessible.") | |
except NameError: | |
logging.error("FATAL: TA-Lib library not found or import failed. Please install it correctly.") | |
sys.exit(1) | |
except Exception as ta_init_e: | |
logging.error(f"FATAL: Error testing TA-Lib library: {ta_init_e}. Please check installation.") | |
sys.exit(1) | |
init_thread = threading.Thread(target=initialize_models_on_startup, name="StartupTrainThread", daemon=True) | |
init_thread.start() | |
logging.info("Launching Gradio Interface...") | |
try: | |
demo.launch(server_name="0.0.0.0") | |
except Exception as e: | |
logging.exception("Failed to launch Gradio interface.") | |
finally: | |
logging.info("Gradio Interface stopped.") |