summaryrefslogtreecommitdiff
path: root/research/flossing/engelken_python_flossing.py
blob: b83be2ff19708a62ef4717bbb40f1eecd4b4189f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""PyTorch port of Rainer Engelken's GradientFlossing_ExampleCode.jl.

This script is intentionally a vanilla-RNN delayed-XOR reproduction scaffold,
not an HRM/TRM experiment. It keeps the algorithmic choices that matter for
checking gradient flossing itself:

  - flossing is a separate phase, not mixed with task loss;
  - flossing optimizes only input/recurrent/bias parameters, not readout;
  - flossing uses mean((lambda_i - target)^2);
  - flossing batch size is one, as in the Julia code;
  - QR is differentiable and participates in the backward pass;
  - Lyapunov accumulation discards the initial transient.

Reference:
  research/flossing/external/GradientFlossing/GradientFlossing_ExampleCode.jl
"""

from __future__ import annotations

import argparse
import json
import math
import time
from dataclasses import asdict, dataclass
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F


def recabs(bits: np.ndarray) -> np.ndarray:
    out = bits[0]
    for bit in bits[1:]:
        out = np.abs(out - bit)
    return out


def generate_input_output(
    batch_size: int,
    steps: int,
    seed: int,
    input_dim: int,
    input_scale: float,
    delay: int,
    task: int,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Port of generateInputOutput in GradientFlossing_XOR.jl.

    Returns tensors shaped as Julia uses conceptually:
      s:      [T, S, B]
      target: [T, 1, B]
    """
    rng = np.random.default_rng(seed)
    raw_steps = steps + delay
    s = rng.integers(0, 2, size=(raw_steps, input_dim, batch_size), dtype=np.int64).astype(np.float32)
    target = np.zeros((raw_steps, 1, batch_size), dtype=np.float32)

    if task == 0:
        for i in range(delay, raw_steps):
            target[i, :, :] = s[i - delay, :1, :]
    else:
        h = abs(task)
        span = delay * (2**h)
        for i in range(span, raw_steps):
            indices = [i - delay * j for j in range(1, 2**h + 1)]
            values = np.flip(s[indices, 0, :], axis=0)
            target[i, 0, :] = recabs(values)

    sx = torch.from_numpy(s[delay:] * np.float32(input_scale)).to(device)
    ty = torch.from_numpy(target[delay:]).to(device)
    return sx, ty


@dataclass
class Config:
    hidden_size: int = 80
    train_epochs: int = 3001
    inter_period: int = 100
    inter_epochs: int = 500
    pre_epochs: int = 500
    max_inter_episodes: int = 2
    batch_size: int = 16
    input_dim: int = 1
    train_steps: int = 300
    lyap_steps: int = 55
    floss_input_steps: int = 300
    seed_ic: int = 1
    seed_input: int = 1
    seed_net: int = 1
    seed_ons: int = 1
    lr: float = 1e-3
    beta1: float = 0.9
    beta2: float = 0.999
    init_type: int = 1
    recurrent_gain: float = 1.0
    recurrent_mean_gain: float = 0.0
    input_scale: float = 1.0
    delay: int = 10
    ws_std: float = 1.0
    ws_mean: float = 0.0
    wr_std: float = 1.0
    wr_mean: float = 0.0
    b_std: float = 0.1
    b_mean: float = 0.0
    n_lyap: int = 75
    task: int = -1
    lyap_target: float = 0.0
    eval_every: int = 100
    eval_batches: int = 4
    log_every_floss: int = 50
    device: str = "cpu"
    out: str = "research/flossing/engelken_python/run.json"


class VanillaRNN:
    def __init__(self, cfg: Config, device: torch.device):
        self.cfg = cfg
        self.device = device
        torch.manual_seed(cfg.seed_net)
        n = cfg.hidden_size

        self.ws = torch.nn.Parameter(
            cfg.ws_std * torch.randn(n, cfg.input_dim, device=device) + cfg.ws_mean
        )
        self.wr = torch.nn.Parameter(
            cfg.wr_std * torch.randn(1, n, device=device) + cfg.wr_mean
        )
        self.b = torch.nn.Parameter(
            cfg.b_mean + cfg.b_std * torch.rand(n, device=device)
        )
        self.offset = torch.nn.Parameter(torch.tensor([0.001], device=device))

        if cfg.init_type == 1:
            j = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n)
            if cfg.recurrent_mean_gain != 0:
                j = j + cfg.recurrent_mean_gain / n
        elif cfg.init_type == 3:
            j0 = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n)
            q, _ = torch.linalg.qr(j0)
            j = cfg.recurrent_gain * q
        else:
            raise ValueError(f"unsupported init_type={cfg.init_type}")
        self.j = torch.nn.Parameter(j)

        self.x_init = torch.nn.Parameter(self._relax_initial_condition(cfg.seed_ic, 100))

    def _relax_initial_condition(self, seed: int, steps: int) -> torch.Tensor:
        gen = torch.Generator(device=self.device).manual_seed(seed)
        x = torch.randn(self.cfg.hidden_size, self.cfg.batch_size, generator=gen, device=self.device)
        with torch.no_grad():
            for _ in range(steps):
                x = self.j.detach() @ torch.tanh(x) + self.b.detach().unsqueeze(1)
        return x

    def task_parameters(self) -> list[torch.nn.Parameter]:
        return [self.ws, self.j, self.wr, self.b, self.offset, self.x_init]

    def floss_parameters(self) -> list[torch.nn.Parameter]:
        return [self.ws, self.j, self.b]

    def task_loss_and_accuracy(self, s: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, float]:
        cfg = self.cfg
        first_loss_idx = 2 * cfg.delay if cfg.task > 0 else cfg.delay * (2 ** abs(cfg.task))
        # Julia is 1-indexed and starts loss at ti >= firstlossidx.
        first_loss_idx = max(first_loss_idx - 1, 0)
        x = self.x_init
        loss = torch.zeros((), device=self.device)
        correct = 0.0
        count = 0
        for ti in range(cfg.train_steps):
            x = self.j @ torch.tanh(x) + self.b.unsqueeze(1) + (self.ws @ s[ti]) / cfg.input_dim
            if ti >= first_loss_idx:
                logits = self.wr @ x + self.offset.unsqueeze(1)
                loss = loss + F.binary_cross_entropy_with_logits(logits, target[ti], reduction="sum")
                pred = (torch.sigmoid(logits) >= 0.5).to(target.dtype)
                correct += (pred == target[ti]).float().sum().item()
                count += target[ti].numel()
        denom = cfg.batch_size * max(cfg.train_steps - first_loss_idx, 1)
        return loss / denom, correct / max(count, 1)

    def lyapunov_spectrum(self, s: torch.Tensor) -> torch.Tensor:
        cfg = self.cfg
        if cfg.n_lyap <= 0:
            return torch.empty(0, device=self.device)

        x = self.x_init[:, 0].clone()
        gen = torch.Generator(device=self.device).manual_seed(cfg.seed_ons)
        q0 = torch.randn(cfg.hidden_size, cfg.n_lyap, generator=gen, device=self.device)
        q, _ = torch.linalg.qr(q0)
        accum = torch.zeros(cfg.n_lyap, device=self.device)
        tsim = 0
        transient = 10
        transient_ons = math.ceil(cfg.lyap_steps / 10)
        diag_idx = torch.arange(cfg.n_lyap, device=self.device)

        for n in range(1, cfg.lyap_steps + 1):
            x = self.j @ torch.tanh(x) + self.b + (self.ws @ s[n - 1, :, 0]) / cfg.input_dim
            phi_prime = 1.0 / torch.cosh(x).square()
            tangent = phi_prime.unsqueeze(1) * self.j
            q = tangent @ q
            q, r = torch.linalg.qr(q)
            if n > transient + transient_ons:
                accum = accum + r[diag_idx, diag_idx].abs().clamp_min(1e-30).log()
                tsim += 1

        if tsim == 0:
            raise ValueError("lyap_steps too short after transient discard")
        return accum / tsim


def evaluate(model: VanillaRNN, cfg: Config, epoch: int) -> dict[str, float]:
    losses = []
    accs = []
    with torch.no_grad():
        for i in range(cfg.eval_batches):
            seed = 10_000_000 + epoch * 1000 + i
            s, target = generate_input_output(
                cfg.batch_size, cfg.train_steps, seed, cfg.input_dim,
                cfg.input_scale, cfg.delay, cfg.task, model.device,
            )
            loss, acc = model.task_loss_and_accuracy(s, target)
            losses.append(float(loss.item()))
            accs.append(acc)
    return {"loss": float(np.mean(losses)), "accuracy": float(np.mean(accs))}


def run(cfg: Config) -> dict:
    device = torch.device(cfg.device)
    if cfg.n_lyap > cfg.hidden_size:
        raise ValueError("n_lyap must be <= hidden_size")
    if cfg.lyap_steps <= 10 + math.ceil(cfg.lyap_steps / 10):
        raise ValueError("lyap_steps too short for official transient discard")

    model = VanillaRNN(cfg, device)
    task_optim = torch.optim.AdamW(
        model.task_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0
    )
    floss_optim = torch.optim.AdamW(
        model.floss_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0
    )
    target = torch.full((cfg.n_lyap,), cfg.lyap_target, device=device)
    log: dict = {"config": asdict(cfg), "evals": [], "floss": [], "task": []}

    t0 = time.time()
    inter_count = 0
    for epoch in range(1, cfg.train_epochs + 1):
        should_prefloss = epoch == 1 and cfg.pre_epochs > 0
        should_interfloss = (
            epoch % cfg.inter_period == 0
            and inter_count < cfg.max_inter_episodes
            and cfg.inter_epochs > 0
        )
        if should_prefloss or should_interfloss:
            n_steps = cfg.pre_epochs if should_prefloss else cfg.inter_epochs
            kind = "pre" if should_prefloss else "inter"
            if should_interfloss:
                inter_count += 1
            for floss_step in range(1, n_steps + 1):
                seed = cfg.train_epochs * cfg.seed_input + epoch + floss_step
                s, _ = generate_input_output(
                    1, cfg.floss_input_steps, seed, cfg.input_dim,
                    cfg.input_scale, cfg.delay, cfg.task, device,
                )
                floss_optim.zero_grad(set_to_none=True)
                spectrum = model.lyapunov_spectrum(s)
                floss_loss = (spectrum - target).square().mean()
                floss_loss.backward()
                floss_optim.step()
                if floss_step == 1 or floss_step == n_steps or floss_step % cfg.log_every_floss == 0:
                    rec = {
                        "epoch": epoch,
                        "kind": kind,
                        "floss_step": floss_step,
                        "loss": float(floss_loss.item()),
                        "lambda_mean": float(spectrum.detach().mean().item()),
                        "lambda_1": float(spectrum.detach()[0].item()),
                        "elapsed": time.time() - t0,
                    }
                    log["floss"].append(rec)
                    print(
                        f"F {kind} epoch={epoch} step={floss_step}/{n_steps} "
                        f"loss={rec['loss']:.6f} lambda1={rec['lambda_1']:+.4f}",
                        flush=True,
                    )

        seed = cfg.train_epochs * cfg.seed_input + epoch
        s, target_batch = generate_input_output(
            cfg.batch_size, cfg.train_steps, seed, cfg.input_dim,
            cfg.input_scale, cfg.delay, cfg.task, device,
        )
        task_optim.zero_grad(set_to_none=True)
        task_loss, task_acc = model.task_loss_and_accuracy(s, target_batch)
        task_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.task_parameters(), 0.03)
        task_optim.step()

        if epoch == 1 or epoch == cfg.train_epochs or epoch % cfg.eval_every == 0:
            eval_metrics = evaluate(model, cfg, epoch)
            rec = {
                "epoch": epoch,
                "train_loss": float(task_loss.item()),
                "train_accuracy": float(task_acc),
                "eval_loss": eval_metrics["loss"],
                "eval_accuracy": eval_metrics["accuracy"],
                "elapsed": time.time() - t0,
            }
            log["evals"].append(rec)
            print(
                f"T epoch={epoch}/{cfg.train_epochs} loss={rec['train_loss']:.4f} "
                f"train_acc={rec['train_accuracy']:.3f} eval_acc={rec['eval_accuracy']:.3f}",
                flush=True,
            )

    out_path = Path(cfg.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(log, indent=2))
    return log


def parse_args() -> Config:
    parser = argparse.ArgumentParser()
    defaults = Config()
    for field_name, default in asdict(defaults).items():
        arg = "--" + field_name.replace("_", "-")
        if isinstance(default, bool):
            parser.add_argument(arg, action=argparse.BooleanOptionalAction, default=default)
        else:
            parser.add_argument(arg, type=type(default), default=default)
    ns = parser.parse_args()
    return Config(**vars(ns))


if __name__ == "__main__":
    run(parse_args())