249 lines
8.0 KiB
Python
249 lines
8.0 KiB
Python
# src/paper/broker.py
|
|
"""
|
|
Paper Broker (Spot, Long-only) con posiciones POR ESTRATEGIA.
|
|
|
|
- Mantiene cash, equity, pnl
|
|
- Posiciones indexadas por position_id (ej: "BTC/USDT::MA_Crossover")
|
|
- Cada Position guarda el symbol REAL ("BTC/USDT") para MTM correcto
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Optional, Dict, Any, List
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
@dataclass
|
|
class PaperTrade:
|
|
symbol: str # ticker real, e.g. "BTC/USDT"
|
|
position_id: str # e.g. "BTC/USDT::MA_Crossover"
|
|
side: str # "BUY" or "SELL"
|
|
qty: float
|
|
price: float # executed price (after slippage)
|
|
fee: float
|
|
notional: float
|
|
realized_pnl: float
|
|
timestamp: str # ISO8601
|
|
meta: Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class Position:
|
|
symbol: str # ticker real
|
|
position_id: str # clave lógica
|
|
qty: float = 0.0
|
|
avg_entry: float = 0.0
|
|
|
|
def is_open(self) -> bool:
|
|
return self.qty > 0.0
|
|
|
|
|
|
class PaperBroker:
|
|
def __init__(
|
|
self,
|
|
initial_cash: float,
|
|
commission_rate: float = 0.001,
|
|
slippage_rate: float = 0.0005,
|
|
):
|
|
if initial_cash <= 0:
|
|
raise ValueError("initial_cash must be > 0")
|
|
|
|
self.initial_cash = float(initial_cash)
|
|
self.cash = float(initial_cash)
|
|
|
|
self.commission_rate = float(commission_rate)
|
|
self.slippage_rate = float(slippage_rate)
|
|
|
|
# position_id -> Position
|
|
self.positions: Dict[str, Position] = {}
|
|
|
|
# symbol -> last price
|
|
self.last_price: Dict[str, float] = {}
|
|
|
|
self.realized_pnl: float = 0.0
|
|
self.trades: List[PaperTrade] = []
|
|
|
|
# -----------------------------
|
|
# Pricing / MTM
|
|
# -----------------------------
|
|
def update_price(self, symbol: str, price: float) -> None:
|
|
if price <= 0:
|
|
raise ValueError("price must be > 0")
|
|
self.last_price[symbol] = float(price)
|
|
|
|
def get_position(self, position_id: str, symbol: Optional[str] = None) -> Position:
|
|
"""
|
|
Devuelve una Position por position_id.
|
|
Si no existe, crea una nueva (requiere symbol para inicializar correctamente).
|
|
"""
|
|
if position_id not in self.positions:
|
|
if symbol is None:
|
|
raise ValueError("symbol is required to create a new position")
|
|
self.positions[position_id] = Position(symbol=symbol, position_id=position_id)
|
|
return self.positions[position_id]
|
|
|
|
def get_unrealized_pnl(self, position_id: str) -> float:
|
|
pos = self.positions.get(position_id)
|
|
if not pos or not pos.is_open():
|
|
return 0.0
|
|
lp = self.last_price.get(pos.symbol)
|
|
if lp is None:
|
|
return 0.0
|
|
return (lp - pos.avg_entry) * pos.qty
|
|
|
|
def get_equity(self) -> float:
|
|
"""
|
|
Equity = cash + sum(pos.qty * last_price[pos.symbol])
|
|
"""
|
|
equity = self.cash
|
|
for _, pos in self.positions.items():
|
|
if pos.is_open():
|
|
lp = self.last_price.get(pos.symbol)
|
|
if lp is not None:
|
|
equity += pos.qty * lp
|
|
else:
|
|
equity += pos.qty * pos.avg_entry
|
|
return float(equity)
|
|
|
|
# -----------------------------
|
|
# Execution
|
|
# -----------------------------
|
|
def _now_iso(self) -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
def _apply_slippage(self, side: str, price: float) -> float:
|
|
s = side.upper()
|
|
if s == "BUY":
|
|
return price * (1.0 + self.slippage_rate)
|
|
elif s == "SELL":
|
|
return price * (1.0 - self.slippage_rate)
|
|
raise ValueError(f"Invalid side: {side}")
|
|
|
|
def _fee(self, notional: float) -> float:
|
|
return abs(notional) * self.commission_rate
|
|
|
|
def place_market_order(
|
|
self,
|
|
*,
|
|
position_id: str,
|
|
symbol: str,
|
|
side: str,
|
|
qty: float,
|
|
price: float,
|
|
meta: Optional[Dict[str, Any]] = None,
|
|
allow_partial: bool = False,
|
|
) -> PaperTrade:
|
|
side = side.upper().strip()
|
|
if qty <= 0:
|
|
raise ValueError("qty must be > 0")
|
|
if price <= 0:
|
|
raise ValueError("price must be > 0")
|
|
|
|
meta = meta or {}
|
|
|
|
exec_price = self._apply_slippage(side, float(price))
|
|
notional = exec_price * float(qty)
|
|
fee = self._fee(notional)
|
|
|
|
pos = self.get_position(position_id, symbol=symbol)
|
|
realized = 0.0
|
|
|
|
if side == "BUY":
|
|
total_cost = notional + fee
|
|
if total_cost > self.cash:
|
|
if not allow_partial:
|
|
raise ValueError(
|
|
f"Not enough cash for BUY. Need {total_cost:.2f}, have {self.cash:.2f}"
|
|
)
|
|
# qty máxima por cash (aprox)
|
|
max_qty = max((self.cash / (exec_price * (1.0 + self.commission_rate))), 0.0)
|
|
if max_qty <= 0:
|
|
raise ValueError("Not enough cash to buy even a minimal quantity.")
|
|
qty = float(max_qty)
|
|
notional = exec_price * qty
|
|
fee = self._fee(notional)
|
|
|
|
new_qty = pos.qty + qty
|
|
if pos.qty == 0:
|
|
new_avg = exec_price
|
|
else:
|
|
new_avg = (pos.avg_entry * pos.qty + exec_price * qty) / new_qty
|
|
|
|
pos.qty = new_qty
|
|
pos.avg_entry = new_avg
|
|
|
|
self.cash -= (notional + fee)
|
|
|
|
elif side == "SELL":
|
|
if qty > pos.qty:
|
|
raise ValueError(f"Cannot SELL more than position qty. Have {pos.qty}, want {qty}")
|
|
|
|
realized = (exec_price - pos.avg_entry) * qty
|
|
self.realized_pnl += realized
|
|
|
|
pos.qty -= qty
|
|
if pos.qty <= 0:
|
|
pos.qty = 0.0
|
|
pos.avg_entry = 0.0
|
|
|
|
self.cash += (notional - fee)
|
|
|
|
else:
|
|
raise ValueError(f"Invalid side: {side}")
|
|
|
|
trade = PaperTrade(
|
|
symbol=symbol,
|
|
position_id=position_id,
|
|
side=side,
|
|
qty=float(qty),
|
|
price=float(exec_price),
|
|
fee=float(fee),
|
|
notional=float(notional),
|
|
realized_pnl=float(realized),
|
|
timestamp=self._now_iso(),
|
|
meta=meta,
|
|
)
|
|
|
|
self.trades.append(trade)
|
|
self.update_price(symbol, price)
|
|
|
|
return trade
|
|
|
|
# -----------------------------
|
|
# Serialization
|
|
# -----------------------------
|
|
def snapshot(self) -> Dict[str, Any]:
|
|
return {
|
|
"initial_cash": self.initial_cash,
|
|
"cash": self.cash,
|
|
"commission_rate": self.commission_rate,
|
|
"slippage_rate": self.slippage_rate,
|
|
"realized_pnl": self.realized_pnl,
|
|
"equity": self.get_equity(),
|
|
"positions": {pid: asdict(pos) for pid, pos in self.positions.items()},
|
|
"last_price": dict(self.last_price),
|
|
"trades_count": len(self.trades),
|
|
"updated_at": self._now_iso(),
|
|
}
|
|
|
|
def restore(self, state: Dict[str, Any]) -> None:
|
|
self.initial_cash = float(state.get("initial_cash", self.initial_cash))
|
|
self.cash = float(state.get("cash", self.cash))
|
|
self.commission_rate = float(state.get("commission_rate", self.commission_rate))
|
|
self.slippage_rate = float(state.get("slippage_rate", self.slippage_rate))
|
|
self.realized_pnl = float(state.get("realized_pnl", self.realized_pnl))
|
|
|
|
self.last_price = {k: float(v) for k, v in (state.get("last_price") or {}).items()}
|
|
|
|
self.positions = {}
|
|
for pid, p in (state.get("positions") or {}).items():
|
|
self.positions[pid] = Position(
|
|
symbol=p.get("symbol", ""), # ticker real
|
|
position_id=p.get("position_id", pid),
|
|
qty=float(p.get("qty", 0.0)),
|
|
avg_entry=float(p.get("avg_entry", 0.0)),
|
|
)
|
|
|
|
self.trades = []
|