summaryrefslogtreecommitdiff
path: root/experiments/boundary_ablation.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/boundary_ablation.py')
-rw-r--r--experiments/boundary_ablation.py590
1 files changed, 590 insertions, 0 deletions
diff --git a/experiments/boundary_ablation.py b/experiments/boundary_ablation.py
new file mode 100644
index 0000000..64d08c9
--- /dev/null
+++ b/experiments/boundary_ablation.py
@@ -0,0 +1,590 @@
+"""
+Phase 3: Boundary-condition ablation on credit bridge.
+
+Test different terminal conditioning codes:
+ s1 = e_T (current default, softmax error)
+ s2 = delta_L (grad of CE w.r.t. h_L, output-layer-local)
+ s3 = concat(e_T, proj(h_L)) -- h_L projected to smaller dim
+ s4 = concat(delta_L, proj(h_L))
+
+Also ablate:
+ - terminal gradient matching weight: w_term in {0, 0.25, 1.0, 4.0}
+ - warmup ratio: r_warm in {0, 0.05, 0.2, 0.5}
+
+Run on best regimes from Phase 1/2.
+"""
+import os
+import sys
+import json
+import argparse
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+# =============================================================================
+# Reuse teacher and student from synth ladder
+# =============================================================================
+class TeacherNet:
+ def __init__(self, d_hidden, num_blocks, num_classes, alpha, seed=0):
+ rng = np.random.RandomState(seed)
+ self.d_hidden = d_hidden
+ self.num_blocks = num_blocks
+ self.num_classes = num_classes
+ self.alpha = alpha
+ self.Ws = []
+ for l in range(num_blocks):
+ W = rng.randn(d_hidden, d_hidden).astype(np.float32)
+ W = W / (np.linalg.norm(W, ord=2) + 1e-8) * 0.3
+ self.Ws.append(torch.from_numpy(W))
+ U = rng.randn(num_classes, d_hidden).astype(np.float32)
+ U = U / (np.linalg.norm(U, ord=2) + 1e-8)
+ self.U = torch.from_numpy(U)
+
+ def to(self, device):
+ self.Ws = [W.to(device) for W in self.Ws]
+ self.U = self.U.to(device)
+ return self
+
+ def phi(self, z):
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, h0):
+ h = h0
+ hiddens = [h]
+ for l in range(self.num_blocks):
+ f = F.linear(self.phi(h), self.Ws[l])
+ h = h + f
+ hiddens.append(h)
+ logits = F.linear(h, self.U)
+ return logits, hiddens
+
+
+def generate_dataset(teacher, num_samples, d_hidden, device, seed=0):
+ torch.manual_seed(seed)
+ X = torch.randn(num_samples, d_hidden, device=device)
+ with torch.no_grad():
+ logits, _ = teacher.forward(X)
+ Y = logits.argmax(dim=-1)
+ return X, Y
+
+
+class StudentBlock(nn.Module):
+ def __init__(self, d_hidden, alpha):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w = nn.Linear(d_hidden, d_hidden, bias=False)
+ self.alpha = alpha
+ nn.init.normal_(self.w.weight, std=0.01)
+
+ def phi(self, z):
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, h):
+ return self.w(self.phi(self.ln(h)))
+
+
+class StudentNet(nn.Module):
+ def __init__(self, d_hidden, num_classes, num_blocks, alpha):
+ super().__init__()
+ self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)])
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def forward(self, x, return_hidden=False):
+ h = x
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ f = block(h)
+ h = h + f
+ if return_hidden:
+ hiddens.append(h)
+ logits = self.out_head(h)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h, start_layer):
+ for i in range(start_layer, self.num_blocks):
+ h = h + self.blocks[i](h)
+ return self.out_head(h)
+
+
+# =============================================================================
+# Extended ValueNet that supports different s_dim
+# =============================================================================
+class ValueNetFlex(nn.Module):
+ """Value net with flexible s_dim."""
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, 1))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp).squeeze(-1)
+
+
+# =============================================================================
+# Terminal conditioning code computation
+# =============================================================================
+def compute_s(s_type, model, hiddens, logits, y, device, hL_proj=None):
+ """
+ Compute terminal conditioning code s based on s_type.
+
+ Args:
+ s_type: 'eT', 'deltaL', 'eT_hL', 'deltaL_hL'
+ model: student net
+ hiddens: list of hidden states
+ logits: model logits
+ y: true labels
+ device: torch device
+ hL_proj: fixed random projection matrix for h_L (d_hidden x proj_dim)
+
+ Returns:
+ s: (batch, s_dim)
+ """
+ batch = logits.shape[0]
+ hL_det = hiddens[-1].detach()
+
+ if s_type == 'eT':
+ e_T = logits.softmax(dim=-1).detach()
+ e_T[torch.arange(batch), y] -= 1
+ return e_T
+
+ elif s_type == 'deltaL':
+ # grad of CE w.r.t. h_L (output-layer-local)
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_local = model.out_head(hL_req)
+ loss_local = F.cross_entropy(logits_local, y, reduction='sum')
+ delta_L = torch.autograd.grad(loss_local, hL_req, create_graph=False)[0].detach()
+ return delta_L
+
+ elif s_type == 'eT_hL':
+ e_T = logits.softmax(dim=-1).detach()
+ e_T[torch.arange(batch), y] -= 1
+ hL_proj_emb = hL_det @ hL_proj # (batch, proj_dim)
+ return torch.cat([e_T, hL_proj_emb], dim=-1)
+
+ elif s_type == 'deltaL_hL':
+ hL_req = hL_det.clone().requires_grad_(True)
+ logits_local = model.out_head(hL_req)
+ loss_local = F.cross_entropy(logits_local, y, reduction='sum')
+ delta_L = torch.autograd.grad(loss_local, hL_req, create_graph=False)[0].detach()
+ hL_proj_emb = hL_det @ hL_proj
+ return torch.cat([delta_L, hL_proj_emb], dim=-1)
+
+ else:
+ raise ValueError(f"Unknown s_type: {s_type}")
+
+
+def get_s_dim(s_type, num_classes, d_hidden, proj_dim=32):
+ if s_type == 'eT':
+ return num_classes
+ elif s_type == 'deltaL':
+ return d_hidden
+ elif s_type == 'eT_hL':
+ return num_classes + proj_dim
+ elif s_type == 'deltaL_hL':
+ return d_hidden + proj_dim
+ else:
+ raise ValueError(f"Unknown s_type: {s_type}")
+
+
+# =============================================================================
+# Credit bridge training with configurable boundary conditions
+# =============================================================================
+def train_credit_bridge_ablation(model, train_loader, test_loader, device, args,
+ s_type='eT', term_grad_weight=1.0, warmup_ratio=0.2,
+ hL_proj=None):
+ d = model.d_hidden
+ L = model.num_blocks
+ C = args.num_classes
+ warmup_epochs = max(1, int(args.epochs * warmup_ratio))
+
+ s_dim = get_s_dim(s_type, C, d, proj_dim=32)
+ value_net = ValueNetFlex(d_hidden=d, s_dim=s_dim, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd)
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ lam = args.lam
+ K_samples = args.K
+ sigma_bridge = args.sigma_bridge
+ ema_momentum = args.ema_momentum
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [],
+ 'value_loss': [], 'term_loss': [], 'bridge_loss': [], 'tgrad_loss': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ if warmup_epochs == 0:
+ credit_blend = 1.0
+ elif epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ # Compute s with the specified type
+ s = compute_s(s_type, model, hiddens, logits, y, device, hL_proj)
+ hL_det = hiddens[-1].detach()
+
+ # Also need e_T for DFA fallback
+ with torch.no_grad():
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+
+ # Train value net
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(hL_req2)
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K_samples):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples))
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+ total_vloss += value_loss.item() * batch
+
+ # Compute credits
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)
+ credits.append(a)
+
+ # Update output head
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ test_acc = 0
+ model.eval()
+ with torch.no_grad():
+ tc, tt = 0, 0
+ for x, y in test_loader:
+ x, y = x.to(device), y.to(device)
+ logits = model(x)
+ tc += (logits.argmax(1) == y).sum().item()
+ tt += x.size(0)
+ test_acc = tc / tt
+ log['test_acc'].append(test_acc)
+ log['value_loss'].append(total_vloss / total)
+
+ return log, value_net
+
+
+def compute_diagnostics(model, value_net, test_loader, device, args,
+ s_type='eT', hL_proj=None):
+ model.eval()
+ value_net.eval()
+ d = model.d_hidden
+ L = model.num_blocks
+ C = args.num_classes
+
+ for x, y in test_loader:
+ x, y = x.to(device), y.to(device)
+ break
+
+ batch = x.size(0)
+
+ # BP gradients
+ h = x.detach().requires_grad_(True)
+ hiddens_bp = [h]
+ for block in model.blocks:
+ f = block(hiddens_bp[-1])
+ hiddens_bp.append(hiddens_bp[-1] + f)
+ logits_bp = model.out_head(hiddens_bp[-1])
+ loss_bp = F.cross_entropy(logits_bp, y)
+ grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False)
+ bp_grads = {l: grads[l].detach().clone() for l in range(L + 1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+
+ s = compute_s(s_type, model, hiddens, logits, y, device, hL_proj)
+
+ results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.01': []}}
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(bp_cos)
+
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(curr)
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)
+ results['perturbation_rho'].append(rho)
+
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=0.01)
+ results['nudging']['0.01'].append(nud)
+
+ return results
+
+
+def run_ablation(args, device):
+ d = args.d_hidden
+ C = args.num_classes
+ alpha = args.alpha
+ L = args.L
+
+ teacher = TeacherNet(d, L, C, alpha, seed=0).to(device)
+ X_train, Y_train = generate_dataset(teacher, args.n_train, d, device, seed=args.seed)
+ X_test, Y_test = generate_dataset(teacher, args.n_test, d, device, seed=args.seed + 10000)
+ train_ds = TensorDataset(X_train, Y_train)
+ test_ds = TensorDataset(X_test, Y_test)
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
+ test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
+
+ # h_L projection matrix (fixed random)
+ proj_dim = 32
+ hL_proj = torch.randn(d, proj_dim, device=device) / np.sqrt(d)
+
+ results = {}
+
+ for s_type in args.s_types:
+ for tgw in args.term_grad_weights:
+ for wr in args.warmup_ratios:
+ key = f"s_{s_type}_tgw{tgw}_wr{wr}"
+ print(f"\n === {key} ===")
+ t0 = time.time()
+
+ torch.manual_seed(args.seed)
+ model = StudentNet(d, C, L, alpha).to(device)
+
+ log, vnet = train_credit_bridge_ablation(
+ model, train_loader, test_loader, device, args,
+ s_type=s_type, term_grad_weight=tgw, warmup_ratio=wr,
+ hL_proj=hL_proj
+ )
+
+ diag = compute_diagnostics(model, vnet, test_loader, device, args,
+ s_type=s_type, hL_proj=hL_proj)
+
+ mean_gamma = np.mean(diag['bp_cosine'])
+ mean_rho = np.mean(diag['perturbation_rho'])
+ mean_nudge = np.mean(diag['nudging']['0.01'])
+ test_acc = log['test_acc'][-1]
+
+ results[key] = {
+ 'test_acc': test_acc,
+ 'mean_bp_cosine': float(mean_gamma),
+ 'mean_rho': float(mean_rho),
+ 'mean_nudge': float(mean_nudge),
+ 'bp_cosine_per_layer': [float(x) for x in diag['bp_cosine']],
+ 'rho_per_layer': [float(x) for x in diag['perturbation_rho']],
+ 'final_value_loss': log['value_loss'][-1],
+ 's_type': s_type,
+ 'term_grad_weight': tgw,
+ 'warmup_ratio': wr,
+ }
+
+ elapsed = time.time() - t0
+ print(f" Done in {elapsed:.0f}s: acc={test_acc:.4f} Gamma={mean_gamma:.4f} "
+ f"rho={mean_rho:.4f} nudge={mean_nudge:.6f}")
+
+ return results
+
+
+def serialize(obj):
+ if isinstance(obj, dict):
+ return {str(k): serialize(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [serialize(v) for v in obj]
+ elif isinstance(obj, (np.floating, np.integer)):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, torch.Tensor):
+ return obj.cpu().numpy().tolist()
+ return obj
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Boundary Condition Ablation')
+ parser.add_argument('--alpha', type=float, default=1.0)
+ parser.add_argument('--L', type=int, default=4)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--d_hidden', type=int, default=128)
+ parser.add_argument('--num_classes', type=int, default=10)
+ parser.add_argument('--n_train', type=int, default=10000)
+ parser.add_argument('--n_test', type=int, default=2000)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--epochs', type=int, default=80)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--s_types', type=str, nargs='+',
+ default=['eT', 'deltaL', 'eT_hL', 'deltaL_hL'])
+ parser.add_argument('--term_grad_weights', type=float, nargs='+',
+ default=[0.0, 0.25, 1.0, 4.0])
+ parser.add_argument('--warmup_ratios', type=float, nargs='+',
+ default=[0.0, 0.05, 0.2, 0.5])
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/boundary_ablation')
+ args = parser.parse_args()
+
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Device: {device}")
+ print(f"alpha={args.alpha}, L={args.L}, seed={args.seed}")
+ print(f"s_types: {args.s_types}")
+ print(f"term_grad_weights: {args.term_grad_weights}")
+ print(f"warmup_ratios: {args.warmup_ratios}")
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ results = run_ablation(args, device)
+
+ out_path = os.path.join(args.output_dir, f'ablation_a{args.alpha}_L{args.L}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(serialize(results), f, indent=2)
+
+ # Print summary
+ print("\n" + "=" * 100)
+ print("BOUNDARY CONDITION ABLATION SUMMARY")
+ print("=" * 100)
+ print(f"{'Config':<40} {'Acc':>8} {'Gamma':>8} {'rho':>8} {'nudge':>10}")
+ print("-" * 100)
+ for key in sorted(results.keys()):
+ r = results[key]
+ print(f"{key:<40} {r['test_acc']:>8.4f} {r['mean_bp_cosine']:>8.4f} "
+ f"{r['mean_rho']:>8.4f} {r['mean_nudge']:>10.6f}")
+
+
+if __name__ == '__main__':
+ main()