summaryrefslogtreecommitdiff
path: root/models/sparse_embedding.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/sparse_embedding.py')
-rw-r--r--models/sparse_embedding.py132
1 files changed, 132 insertions, 0 deletions
diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py
new file mode 100644
index 0000000..c701524
--- /dev/null
+++ b/models/sparse_embedding.py
@@ -0,0 +1,132 @@
+from typing import Union
+
+import torch
+from torch import nn
+import torch.distributed as dist
+from torch.optim.optimizer import Optimizer, ParamsT
+
+from models.common import trunc_normal_init_
+
+
+class CastedSparseEmbedding(nn.Module):
+ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
+ super().__init__()
+ self.cast_to = cast_to
+
+ # Real Weights
+ # Truncated LeCun normal init
+ self.weights = nn.Buffer(
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
+ )
+
+ # Local weights and IDs
+ # Local embeddings, with gradient, not persistent
+ self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
+ # Local embedding IDs, not persistent
+ self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ if not self.training:
+ # Test mode, no gradient
+ return self.weights[inputs].to(self.cast_to)
+
+ # Training mode, fill puzzle embedding from weights
+ with torch.no_grad():
+ self.local_weights.copy_(self.weights[inputs])
+ self.local_ids.copy_(inputs)
+
+ return self.local_weights.to(self.cast_to)
+
+
+class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
+ def __init__(
+ self,
+ params: ParamsT,
+
+ world_size: int,
+ lr: Union[float, torch.Tensor] = 1e-3,
+ weight_decay: float = 1e-2,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= weight_decay:
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+
+ defaults = dict(
+ lr=lr,
+ weight_decay=weight_decay,
+ world_size=world_size
+ )
+ super().__init__(params, defaults)
+
+ @torch.no_grad
+ def step(self, closure=None): # type: ignore
+ for group in self.param_groups:
+ # Find the sparse embedding weights
+ local_weights_grad = None
+ local_ids = None
+ weights = None
+
+ assert len(group["params"]) == 3
+ for p in group["params"]:
+ if p.requires_grad:
+ local_weights_grad = p.grad
+ elif p.ndim == 1:
+ local_ids = p
+ elif p.ndim == 2:
+ weights = p
+ else:
+ assert False
+
+ assert local_weights_grad is not None
+ assert local_ids is not None
+ assert weights is not None
+
+ # Apply SignSGD
+ # Adam ≈ SignSGD if gradient is very sparse
+ _sparse_emb_signsgd_dist(
+ local_weights_grad,
+ local_ids,
+ weights,
+
+ lr=group["lr"],
+ weight_decay=group["weight_decay"],
+ world_size=group["world_size"]
+ )
+
+
+def _sparse_emb_signsgd_dist(
+ local_weights_grad: torch.Tensor,
+ local_ids: torch.Tensor,
+ weights: torch.Tensor,
+
+ lr: float,
+ weight_decay: float,
+ world_size: int
+) -> None:
+ N, D = local_weights_grad.shape
+
+ # All-gather
+ all_weights_grad = local_weights_grad
+ all_ids = local_ids
+
+ if world_size > 1:
+ all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
+ all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
+
+ dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
+ dist.all_gather_into_tensor(all_ids, local_ids)
+
+ # Unique
+ grad_ids, inv = all_ids.unique(return_inverse=True)
+
+ grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
+ grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
+
+ # SignSGD with decoupled weight decay
+ p = weights[grad_ids]
+
+ p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
+
+ # Write updated slices back
+ weights[grad_ids] = p