diff options
Diffstat (limited to 'models/sparse_embedding.py')
| -rw-r--r-- | models/sparse_embedding.py | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py new file mode 100644 index 0000000..c701524 --- /dev/null +++ b/models/sparse_embedding.py @@ -0,0 +1,132 @@ +from typing import Union + +import torch +from torch import nn +import torch.distributed as dist +from torch.optim.optimizer import Optimizer, ParamsT + +from models.common import trunc_normal_init_ + + +class CastedSparseEmbedding(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype): + super().__init__() + self.cast_to = cast_to + + # Real Weights + # Truncated LeCun normal init + self.weights = nn.Buffer( + trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True + ) + + # Local weights and IDs + # Local embeddings, with gradient, not persistent + self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) + # Local embedding IDs, not persistent + self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if not self.training: + # Test mode, no gradient + return self.weights[inputs].to(self.cast_to) + + # Training mode, fill puzzle embedding from weights + with torch.no_grad(): + self.local_weights.copy_(self.weights[inputs]) + self.local_ids.copy_(inputs) + + return self.local_weights.to(self.cast_to) + + +class CastedSparseEmbeddingSignSGD_Distributed(Optimizer): + def __init__( + self, + params: ParamsT, + + world_size: int, + lr: Union[float, torch.Tensor] = 1e-3, + weight_decay: float = 1e-2, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + weight_decay=weight_decay, + world_size=world_size + ) + super().__init__(params, defaults) + + @torch.no_grad + def step(self, closure=None): # type: ignore + for group in self.param_groups: + # Find the sparse embedding weights + local_weights_grad = None + local_ids = None + weights = None + + assert len(group["params"]) == 3 + for p in group["params"]: + if p.requires_grad: + local_weights_grad = p.grad + elif p.ndim == 1: + local_ids = p + elif p.ndim == 2: + weights = p + else: + assert False + + assert local_weights_grad is not None + assert local_ids is not None + assert weights is not None + + # Apply SignSGD + # Adam ≈ SignSGD if gradient is very sparse + _sparse_emb_signsgd_dist( + local_weights_grad, + local_ids, + weights, + + lr=group["lr"], + weight_decay=group["weight_decay"], + world_size=group["world_size"] + ) + + +def _sparse_emb_signsgd_dist( + local_weights_grad: torch.Tensor, + local_ids: torch.Tensor, + weights: torch.Tensor, + + lr: float, + weight_decay: float, + world_size: int +) -> None: + N, D = local_weights_grad.shape + + # All-gather + all_weights_grad = local_weights_grad + all_ids = local_ids + + if world_size > 1: + all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) + all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) + + dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) + dist.all_gather_into_tensor(all_ids, local_ids) + + # Unique + grad_ids, inv = all_ids.unique(return_inverse=True) + + grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device) + grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad) + + # SignSGD with decoupled weight decay + p = weights[grad_ids] + + p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr) + + # Write updated slices back + weights[grad_ids] = p |
