"""
Cerebrum Forex - Stacking Ensemble Model
Meta-learner that combines XGBoost, LightGBM, RandomForest predictions.
Replaces LSTM for faster training without TensorFlow dependency.
"""
import logging
import pickle
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
from config.settings import IS_FROZEN
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from .base_model import BaseModel
logger = logging.getLogger(__name__)
[docs]
class StackingModel(BaseModel):
"""
Stacking Ensemble Model using XGBoost, LightGBM, RandomForest as base learners
and LogisticRegression as meta-learner.
Benefits over LSTM:
- 10-20x faster training (no deep learning)
- No TensorFlow/GPU dependency
- Often better for tabular data
"""
def __init__(self, timeframe: str, model_dir: Path):
super().__init__(timeframe, model_dir)
self.model_path = model_dir / f"stacking_{timeframe}.pkl"
@property
def name(self) -> str:
return "stacking"
def _build_model(self, n_classes: int):
"""Build the stacking classifier"""
# Base estimators (lighter versions for stacking)
# NOTE: CatBoost removed - not compatible with sklearn 1.8 StackingClassifier
# (CatBoostClassifier lacks __sklearn_tags__ method)
from sklearn.ensemble import RandomForestClassifier
base_estimators = [
('xgb', XGBClassifier(
n_estimators=50,
max_depth=4,
learning_rate=0.1,
subsample=0.8, # Regularization: use 80% of samples
colsample_bytree=0.8, # Regularization: use 80% of features
min_child_weight=5, # Regularization: min samples per leaf
objective='multi:softprob' if n_classes > 2 else 'binary:logistic',
num_class=n_classes if n_classes > 2 else None,
verbosity=0,
n_jobs=1 if IS_FROZEN else -1
)),
('lgb', LGBMClassifier(
n_estimators=50,
max_depth=4,
learning_rate=0.1,
subsample=0.8, # Regularization: use 80% of samples
colsample_bytree=0.8, # Regularization: use 80% of features
min_child_samples=20, # Regularization: min samples per leaf
objective='multiclass' if n_classes > 2 else 'binary',
num_class=n_classes if n_classes > 2 else None,
verbose=-1,
n_jobs=1 if IS_FROZEN else -1
)),
('rf', RandomForestClassifier(
n_estimators=50,
max_depth=6,
max_samples=0.8, # Regularization: use 80% of samples per tree
min_samples_leaf=5, # Regularization: min samples per leaf
n_jobs=1 if IS_FROZEN else -1,
random_state=42
))
]
# Meta-learner (multi_class removed in sklearn 1.8+, auto-determined)
meta_learner = LogisticRegression(
max_iter=1000,
solver='lbfgs',
n_jobs=1 if IS_FROZEN else -1
)
return StackingClassifier(
estimators=base_estimators,
final_estimator=meta_learner,
cv=3, # 3-fold CV for base predictions
stack_method='predict_proba',
n_jobs=1 if IS_FROZEN else -1,
passthrough=False # Only use base predictions, not original features
)
[docs]
def train(self, X: np.ndarray, y: np.ndarray,
X_val: np.ndarray = None, y_val: np.ndarray = None,
class_weights: np.ndarray = None) -> float:
"""Train the stacking model"""
try:
n_classes = len(np.unique(y))
logger.info(f"[Stacking {self.timeframe}] Training with {len(X)} samples, {n_classes} classes")
# Build and train
self.model = self._build_model(n_classes)
# Use sample weights if provided (convert array to sample_weight format)
sample_weight = class_weights if class_weights is not None else None
self.model.fit(X, y, sample_weight=sample_weight)
self.is_trained = True
# Evaluate
if X_val is not None and len(X_val) > 0:
y_pred = self.model.predict(X_val)
self.accuracy = balanced_accuracy_score(y_val, y_pred)
else:
# Use training accuracy as fallback
y_pred = self.model.predict(X)
self.accuracy = balanced_accuracy_score(y, y_pred)
logger.info(f"[Stacking {self.timeframe}] ✓ Balanced Accuracy: {self.accuracy:.2%}")
self.save()
return self.accuracy
except Exception as e:
logger.error(f"Stacking training failed: {e}", exc_info=True)
return 0.0
[docs]
def predict(self, X: np.ndarray) -> Tuple[str, float]:
"""Make prediction"""
if not self.is_trained and not self.load():
return "NEUTRAL", 0.0
try:
# Handle Feature Mismatch (Train vs Predict)
# 1. New System
expected_features = self.feature_names
# 2. Legacy Fallback
if not expected_features and hasattr(self.model, 'feature_names_in_'):
expected_features = self.model.feature_names_in_
# Use the determined expected features to align input
if hasattr(X, 'columns') and expected_features:
# Check if we have all required features
missing = [f for f in expected_features if f not in X.columns]
if missing:
logger.warning(f"StackingModel mismatch: Missing {len(missing)} features ({missing[:3]}...). Returning NEUTRAL.")
return "NEUTRAL", 0.0
# Check for "Unnamed" features vs Named features mismatch (Legacy)
if len(expected_features) > 0 and str(expected_features[0]).startswith("Column_") and not str(X.columns[0]).startswith("Column_"):
logger.warning(f"StackingModel schema mismatch: Model expects raw features (Column_X) but got named features. Returning NEUTRAL. (Retrain required)")
return "NEUTRAL", 0.0
# Select only expected columns in correct order
X = X[expected_features]
# Ensure 2D
if hasattr(X, 'values'):
X = X.values
if X.ndim == 1:
X = X.reshape(1, -1)
# Get probabilities
proba = self.model.predict_proba(X)
# Get prediction for last row
pred_proba = proba[-1] if len(proba) > 1 else proba[0]
pred_class = np.argmax(pred_proba)
confidence = float(pred_proba[pred_class])
signal = self.signal_from_prediction(pred_class)
return signal, confidence
except Exception as e:
logger.error(f"Stacking prediction failed: {e}")
return "NEUTRAL", 0.0
[docs]
def load(self) -> bool:
"""Load model from disk"""
if not self.model_path.exists():
return False
try:
with open(self.model_path, 'rb') as f:
data = pickle.load(f)
self.model = data['model']
self.accuracy = data.get('accuracy', 0.0)
self.feature_names = data.get('feature_names', [])
self.is_trained = data.get('is_trained', True)
logger.info(f"Stacking model loaded from {self.model_path} ({len(self.feature_names)} features)")
return True
except Exception as e:
logger.error(f"Failed to load Stacking model: {e}")
return False