## kl divergence & cross entropy

### 定义及计算

- entropy：
$$
H(p)=-\sum_{x\in \mathcal X}p(x)\log p(x)=\sum_{x\in\mathcal X}p(x)\log\frac1{p(x)}
$$

- cross entropy（交叉熵）：
$$
H(p,q)=-\sum_{x\in\mathcal X}p(x)\log q(x)=\sum_{x\in\mathcal X}p(x)\log\frac1{q(x)}
$$

- kl divergence（相对熵）：
$$
\begin{align}
D_\text{KL}(p\parallel q) &=-\sum_{x\in\mathcal X}p(x)\log\frac{q(x)}{p(x)}\\
&=-\sum_{x\in \mathcal X}\left(p(x)\log q(x)-p(x)\log p(x)\right)\\
&=-\sum_{x\mathcal X}p(x)\log q(x)-\left(-\sum_{x\mathcal X}p(x)\log p(x)\right)\\
&=H(p,q)-H(p)
\end{align}
$$

### 性质

- $H(p)\geq 0$
- $H(p,q)\geq 0$
    - 非对称
    - $H(p,p)=H(p)$
- $D_{\text{KL}}(p\parallel q)\geq 0$（gibbs inequality）
    - $H(p,q)\geq H(p)$
    - $D_{\text{KL}} = 0 $ 当 $p == q$ 时；
    - 非对称，$D_{\text{KL}}(p\parallel q) \neq D_{\text{KL}}(q\parallel p)$
        

## pytorch api

- $p$: true(target) distribution，label值；
    - $p(y|x)$：one hot，一种特殊的概率分布
- $q$: predict(input) distribution，模型的预测值（分布）；
    - $q(y|x)$：dense prob distribution
- $H(p,q)=-\sum p(x)\log q(x)$ （退化成 log loss）
    - $H(q,p)=-\sum q(x)\log p(x)$

### wiki

In [49]:
import torch
from torch import nn
import torch.nn.functional as F

In [50]:
kl_loss = nn.KLDivLoss(reduction="batchmean")

In [52]:
input = torch.log(torch.FloatTensor([[1/3, 1/3, 1/3]]))
print(input.shape, input)
target = torch.FloatTensor([[9/25, 12/25, 4/25]])
print(target.shape)

torch.Size([1, 3]) tensor([[-1.0986, -1.0986, -1.0986]])
torch.Size([1, 3])


In [53]:
kl_loss(input, target)

tensor(0.0853)

###  kl loss vs. ce loss

- kl loss:
    - `input.shape == target.shape`

In [54]:
kl_loss = nn.KLDivLoss(reduction="batchmean")
ce_loss = nn.CrossEntropyLoss()

In [55]:
input = torch.randn(3, 5, requires_grad=True)
input

tensor([[ 0.3286,  0.4247,  1.5101, -0.4628, -0.6365],
        [-0.1079, -0.6032, -1.1911, -0.1564,  0.9218],
        [ 1.5526, -0.2029,  0.7564, -2.0878, -1.1310]], requires_grad=True)

In [57]:
target = torch.empty(3, dtype=torch.long).random_(5)
target

tensor([0, 2, 0])

In [58]:
ce_value = ce_loss(input, target)
ce_value

tensor(1.7296, grad_fn=<NllLossBackward0>)

In [60]:
input_log_softmax = F.log_softmax(input, dim=1)
input_log_softmax

tensor([[-1.8236, -1.7275, -0.6421, -2.6150, -2.7887],
        [-1.7406, -2.2359, -2.8238, -1.7891, -0.7109],
        [-0.5414, -2.2969, -1.3376, -4.1818, -3.2250]],
       grad_fn=<LogSoftmaxBackward0>)

In [66]:
target_shaped = torch.FloatTensor([[1, 0, 0, 0, 0], 
                                   [0, 0, 1, 0, 0], 
                                   [1, 0, 0, 0, 0]])

In [62]:
kl_loss(input_log_softmax, target_shaped)

tensor(1.7296, grad_fn=<DivBackward0>)

In [63]:
ce_loss(input, target_shaped)

tensor(1.7296, grad_fn=<DivBackward1>)

- $H(p)==0$

### $H(p)\neq 0$

In [69]:
kl_loss = torch.nn.KLDivLoss(reduction='none')
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')

In [71]:
# q(x)
input = torch.Tensor([[-0.1, 0.2, -0.4, 0.3]])
# p(x)
target = torch.Tensor([[-0.7, 0.1, -0.1, 0.1]])

In [73]:
# kl(p, q)
kl_output = kl_loss(F.log_softmax(input, dim=1), torch.softmax(target, dim=1))
print(kl_output)
print(kl_output.mean())
print(kl_output.sum())

tensor([[-0.0635,  0.0116,  0.1097, -0.0190]])
tensor(0.0097)
tensor(0.0389)


In [74]:
# H(p, q)
ce_loss(input, torch.softmax(target, dim=1))

tensor([1.3832])

In [76]:
# H(p) == H(p, p)
ce_loss(target, torch.softmax(target, dim=1))

tensor([1.3443])

In [78]:
# H(p, q) - H(p) == kl(p, q)
1.3832 - 1.3443

0.038899999999999935