summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/quantizer/kmeans.py
blob: a34c381e6fdc917066697ae62edcaeb87c7257cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""K-means clustering for embeddings."""
from typing import Tuple

from einops import repeat
import torch
from torch import Tensor

from text_recognizer.networks.quantizer.utils import norm, sample_vectors


def kmeans(
    samples: Tensor, num_clusters: int, num_iters: int = 10
) -> Tuple[Tensor, Tensor]:
    """Compute k-means clusters."""
    D = samples.shape[-1]

    means = sample_vectors(samples, num_clusters)

    for _ in range(num_iters):
        dists = samples @ means.t()
        buckets = dists.max(dim=-1).indices
        bins = torch.bincount(buckets, minlength=num_clusters)
        zero_mask = bins == 0
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        new_means = buckets.new_zeros(num_clusters, D).type_as(samples)
        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples)
        new_means /= bins_min_clamped[..., None]
        new_means = norm(new_means)
        means = torch.where(zero_mask[..., None], means, new_means)

    return means, bins