`
`Guy Blanc 1 Steffen Rendle 2
`
`Abstract
`Softmax is the most commonly used output func-
`tion for multiclass problems and is widely used in
`areas such as vision, natural language processing,
`and recommendation. A softmax model has lin-
`ear costs in the number of classes which makes it
`too expensive for many real-world problems. A
`common approach to speed up training involves
`sampling only some of the classes at each train-
`ing step. It is known that this method is biased
`and that the bias increases the more the sampling
`distribution deviates from the output distribution.
`Nevertheless, almost all recent work uses simple
`sampling distributions that require a large sample
`size to mitigate the bias. In this work, we propose
`a new class of kernel based sampling methods and
`develop an efficient sampling algorithm. Kernel
`based sampling adapts to the model as it is trained,
`thus resulting in low bias. It can also be easily ap-
`plied to many models because it relies only on the
`model’s last hidden layer. We empirically study
`the trade-off of bias, sampling distribution and
`sample size and show that kernel based sampling
`results in low bias with few samples.
`
`1. Introduction
`Classification problems with a large number of classes are
`common in many language tasks (Mikolov et al., 2013; Ben-
`gio & S´en´ecal, 2008) and recommender systems (Covington
`et al., 2016). A standard and effective approach to these
`classification tasks is to use some model, such as a neural
`network, to compute a logit for each class, and assume that
`the class probabilities are a softmax of the logits. Comput-
`ing class probabilities with softmax involves a normalization
`step where a partition function over the logits of all classes
`is computed. For learning the model parameters, an opti-
`mization algorithm, e.g., stochastic gradient descent, needs
`to compute the gradients with respect to the loss. When the
`number of classes, n, is large, computing the probability of
`
`1Work done during internship at Google, Mountain View,
`USA 2Google, Mountain View, USA. Correspondence to:
`Guy Blanc <guy.blanc@gmail.com>, Steffen Rendle <sren-
`dle@google.com>.
`
`each class is often too slow, as the time for each training
`step grows linearly with n. Sampled softmax, which creates
`a sample of m < n classes in every update step, is com-
`monly used when the number of classes becomes too large.
`It is well known that sampled softmax is biased (Bengio &
`S´en´ecal, 2008), i.e., it does not converge to the same loss as
`a full softmax – no matter how many update steps are taken.
`The only way to eliminate the bias is to sample from the
`softmax distribution which is not efficient. For any other
`sampling distribution, there are two directions to mitigate
`the bias: (i) choose a sampling distribution that is closer
`to softmax, or (ii) increase the sample size, m – which is
`trivial but costly. Early work (Bengio & S´en´ecal, 2008) has
`shown that a good sampling distribution should be adaptive
`and should depend on the model’s output.
`While the importance of the sampling distribution is known,
`surprisingly, almost all recent applications use simple sam-
`pling distributions such as uniform or global popularity,
`which require large sample sizes to achieve an acceptable
`bias. One reason for this trend could be that the models
`have tended to get more complex, e.g. stacked LSTMs, very
`deep networks, convolutional NN, etc. which makes it hard
`to design an efficient sampling distribution that adapts to
`the model.
`In this work, we propose a new class of sampling distribu-
`tions that approximate softmax but are efficient to compute.
`The proposed sampling distributions are defined over the
`model’s output, making them adaptive to the input, the
`model’s structure, and the current model parameters. The
`main idea is to sample proportionally to a non-negative
`kernel. We show that kernels allow us to compute the parti-
`tion function efficiently in the kernel space. This result can
`be used in a divide and conquer algorithm that samples in
`O(D log n) time, where D is the dimension of the kernel
`space. We suggest the quadratic kernel as an approximation
`for (absolute) softmax. See Section 3.3 for details. Kernel
`based sampling is generic and can be applied directly to
`any model where the final layer is a dot product between a
`hidden layer and class embeddings.
`We study the bias of uniform, quadratic kernel and softmax
`sampling empirically and show that the quadratic kernel
`needs one to two orders of magnitude less samples than
`uniform to reach the same quality as full softmax. A second
`
`arXiv:1712.00527v2 [cs.LG] 1 Aug 2018
`
`1
`
`Petitioner, EX1021
`IPR2024-01234
`Hugging Face, Inc., v. FriendliAI Inc.
`
`
`
`Adaptive Sampled Softmax with Kernel Based Sampling
`
`observation is that once the bias is eliminated, more samples
`usually do not increase the convergence speed.
`
`Let y ∈ [0, 1]n with(cid:80)n
`
`2. Modelling Large Multiclass Problems
`In this section, we first formalize the multiclass softmax and
`then recap its sampling version.
`i=1 yi = 1 be a distribution over n
`classes for an input x ∈ X . The goal of supervised learning
`is to find a function that explains a set of observed pairs
`(x, y) of input x and label y. Let o : X × Θ → Rn be such
`a function that maps an input x to a raw score for each class.
`The model function o is parameterized by model parameters
`θ ∈ Θ. To shorten notation, we drop the arguments x and θ
`from o and all derived functions, whenever the dependency
`is clear.
`
`2.1. Full Softmax
`
`A softmax model links the model outputs o to a class proba-
`i=1 pi = 1 by applying
`
`an exponential function
`
`bility distribution p ∈ [0, 1]n with(cid:80)n
`(cid:80)n
`
`pi :=
`
`exp(oi)
`j=1 exp(oj)
`
`(1)
`
`The denominator of pi is also known as the partition func-
`tion and takes at least O(n) time to compute. For softmax,
`the output o is often referred to as the logits. The loss L
`of a parameter setting θ is measured by the cross entropy
`between y and p
`
`L(y, p) := − n(cid:88)
`
`n(cid:88)
`
`exp(oi) − n(cid:88)
`
`yi oi
`
`yi log pi = log
`
`i=1
`
`i=1
`
`i=1
`
`This full softmax loss depends on all classes. Thus, learning
`a full softmax is expensive when the number of classes, n,
`is large.
`
`2.2. Sampled Softmax
`
`Sampled softmax aims to approximate a full softmax during
`model training (Bengio & S´en´ecal, 2008; 2003). Rather
`than computing the loss over all classes, only the positive
`class and a sample of m negative classes are considered.
`Each negative class is sampled with probability qi with
`replacement. For the rest of the paper, we assume w.l.o.g.
`that there is one positive class per training example, i.e.,
`y ∈ {0, 1}n. The vector s ∈ {1, . . . , n}m+1 represents a
`sample of classes and stores the index of the positive and
`the index of the m sampled negative classes. For instance,
`s = (2, 6, 7, 6, 3), represents a sample of size m = 4 with
`the positive class at index 2 and four negative classes, where
`the class at index 6 was sampled twice, and the classes at
`index 7 and 3 once each.
`
`Just as o, y, and p with cardinality n refer to important char-
`acteristics of all the classes, o(cid:48), y(cid:48), and p(cid:48) with cardinality
`m + 1 reflect similar values for a sample s. First, each index
`i ∈ {1, . . . , m + 1} of the sample s is assigned an adjusted
`logit o(cid:48)
`i.
`
`(cid:26) osi − ln(m qsi )
`
`o(cid:48)
`i :=
`
`if ysi = 0
`else
`
`(2)
`
`osi − ln(1) = osi
`The adjusted logit corrects the true logit osi by the expected
`number of occurences of a class si in the sample s. This
`correction ensures that in the limit of m → ∞, sampled
`softmax is unbiased (Bengio & S´en´ecal, 2008).
`Second, p(cid:48) is the softmax probability distribution computed
`over adjusted logits o(cid:48), and y(cid:48) is a projection of the original
`labels y to the sample s.
`exp(o(cid:48)
`i)
`j=1 exp(o(cid:48)
`j)
`The loss of a sample s is the cross entropy L(y(cid:48), p(cid:48)) between
`predicted probabilities p(cid:48) and labels y(cid:48). In contrast to full
`softmax, the loss of sampled softmax depends only on (at
`most) m + 1 different classes.
`
`(cid:80)m+1
`
`p(cid:48)
`i :=
`
`y(cid:48)
`i := ysi
`
`,
`
`(3)
`
`2.3. Importance of the Sampling Distribution
`
`Sampled softmax can be viewed as an algorithm that gener-
`ates an estimator for the full softmax gradient with respect
`to the logits. The full softmax gradient with respect to a
`logit oi is
`
`∂L(p, y)
`∂oi
`whereas the sampled softmax gradient with respect to an
`original logit oi reads
`∂L(p(cid:48), y(cid:48))
`∂oi
`
`= pi − yi
`
`m+1(cid:88)
`m+1(cid:88)
`
`j=1
`
`j=1
`
`=
`
`=
`
`j − y(cid:48)j)
`I(sj = i)(p(cid:48)
`
`
`
`j − yi
`I(sj = i)p(cid:48)
`
`(4)
`
`(5)
`
`Ideally, we would like to pick a sampling distribution such
`that sampled softmax converges to the same value as full
`softmax. At the very least, we would like to guarantee
`convergence with infinitely small step size and infinitely
`many steps. That is guaranteed if the sampled softmax
`estimator is unbiased:
`
`(cid:21)
`
`(cid:20) ∂L(p(cid:48), y(cid:48))
`m+1(cid:88)
`
`∂oi
`
`I(sj = i)p(cid:48)
`
`j
`
`j=1
`
`?=
`
`∂L(p, y)
`∂oi
`
` ?= pi
`
`(6)
`
`(7)
`
`E
`
`⇔E
`
`2
`
`
`
`Adaptive Sampled Softmax with Kernel Based Sampling
`layer of a deep neural network and W ∈ Rn×d the last ma-
`trix of weights, such that o = W T h. The cost of computing
`the full softmax on a dot product model is O(nd).
`
`3.1. Kernel Based Distributions
`
`Bengio & S´en´ecal (2008) have shown that sampling propor-
`tional to the softmax probability, qi = pi ∝ exp(oi), is an
`unbiased estimator. In fact, qi = pi ∝ exp(oi) is the only
`unbiased estimator.
`Theorem 2.1. The gradient of sample softmax is an unbi-
`ased estimator of the full softmax gradient iff qi = pi ∝
`exp(oi).
`
`We include a detailed proof in the appendix.
`
`2.4. Properties of a Good Sampling Distribution
`
`The last sections argued that sampled softmax is biased and
`the only way to mitigate the bias are (1) choose a sampling
`distribution qi closer to softmax pi or (2) increase the sam-
`ple size, m. The closer the sampling distribution qi reflects
`pi, the smaller the sampling size that is needed for low bias.
`Finally, we highlight three properties of the softmax distri-
`bution, pi ∝ oi(x, θ), that a good sampling distribution, q,
`should meet as well.
`
`1. Example dependent: Every input, x, has an individ-
`ual sampling distribution, because the output, o(x),
`depends on the input, x.
`
`2. Model structure dependent: The sampling depends
`on the functional structure of o. For instance, if o
`is an LSTM, the sampling distribution should not be
`represented by simple bigrams.
`
`3. Model parameter dependent: The sampling distribu-
`tion changes while the model is learned, because o
`depends on the model parameters.
`
`Common sampling schemes such as uniform or popularity
`based sampling are neither example nor model dependent.
`In the following section, we introduce a sampling algorithm
`that meets these criteria and is efficient.
`
`3. Kernel Based Sampling
`Sampling directly from qi ∝ exp(oi) requires computing
`the partition function and is as expensive as computing the
`full softmax. The motivation for sampling is to avoid that
`inefficiency, so sampling from qi ∝ exp(oi) is not a good
`option. In this section, we propose efficient sampling distri-
`butions that depend on the example x, the model structure
`o and the model parameter θ as highlighted in Section 2.4.
`So far, we have ignored how the logits o are computed. In
`the following, we assume that oi is a dot product between
`a context or query embedding, h ∈ Rd, and a class embed-
`ding, wi ∈ Rd. This type of model is extremely common
`with many examples such as deep neural networks and fac-
`torization models. For example, h could be the last hidden
`
`We consider sampling distributions that are proportional to
`some function K : Rd × Rd → R+. We assume that K is a
`kernel function for a D dimensional space, i.e., there exists a
`mapping φ : Rd → RD such that K(a, b) = (cid:104)φ(a), φ(b)(cid:105).
`Thus, the sampling distribution can be written as:
`
`K(h, wi)
`
`(8)
`
`(cid:43)
`(cid:125)
`
`n(cid:88)
`(cid:124)
`
`φ(wj)
`
`(cid:123)(cid:122)
`
`j=1
`=:z∈RD
`
`(cid:80)n
`
`qi =
`
`K(h, wi)
`j=1 K(h, wj)
`
`=
`
`(cid:42)
`
`φ(h),
`
`The last step shows the key property that we gain from a
`kernel: the summation over all classes can be isolated from
`the query h – i.e., the partition function becomes a simple
`dot product between a query vector and a summary vector z.
`This summary vector is independent of the query and can
`be precomputed.
`
`3.2. Sampling with Divide and Conquer
`
`z(C) :=(cid:80)
`
`The kernel gives the ability to compute the probability of
`one class efficiently. Next, we discuss how this property can
`be used for efficient sampling from all classes. Instead of
`sampling a class directly from all the possible classes, we
`sample a subset of classes recursively until the subset has
`only one class (see Figure 1(a)). To formalize this algorithm,
`we introduce C ⊆ {1, . . . , n} as a set of classes and define
`j∈C φ(wj). Let C(cid:48) ∪ C(cid:48)(cid:48) = C be a partition of
`C into two disjoint sets C(cid:48) and C(cid:48)(cid:48) = C \ C(cid:48). We define
`the probability, qC(cid:48)|C, of sampling the set C(cid:48) from C, as the
`(cid:88)
`sum of the probabilities of its elements:
`(cid:80)
`(cid:104)φ(h),(cid:80)
`(cid:104)φ(h),(cid:80)
`
`(9)
`
`qC(cid:48)|C :=
`
`j∈C(cid:48)
`
`=
`
`K(h, wj)
`l∈C K(h, wl)
`j∈C(cid:48) φ(wj)(cid:105)
`(cid:104)φ(h), z(C(cid:48))(cid:105)
`(cid:104)φ(h), z(C)(cid:105)
`l∈C φ(wl)(cid:105) =
`If we know z for C and C(cid:48), we can sample from this distribu-
`tion in O(D) time. This scheme can be applied recursively
`to the sampled subset until the subset contains exactly one
`class. With n classes and two sets of equal size at each step,
`this takes log2 n steps and in total the time for sampling a
`class proportional to q is O(D log2 n).
`
`3.2.1. ANALYSIS
`Correctness The correctness of the divide and conquer
`algorithm, i.e., that it samples proportional to the kernel
`distribution (eq. 8), is easy to show. Assume the algorithm
`
`3
`
`
`
`Adaptive Sampled Softmax with Kernel Based Sampling
`
`(a) sampling a class
`
`(b) updating statistics
`
`(c) large branching factor for leaves
`
`Figure 1. Divide and conquer algorithm for sampling from a kernel distribution q. Figure 1(a) shows how to sample subsets starting from
`all classes {1, . . . , n} until a single item i is reached. After the class embedding, wi, of class i changes from wold
`to wnew
`, all statistics,
`i
`i
`) − φ(wold
`z, on the sampling path of i are updated by ∆φ(wi) := φ(wnew
`i ) (Figure 1(b)). To minimize storage costs for statistics z, it is
`i
`beneficial to use a higher branching factor of O( D
`d ) for the leaves (Figure 1(c)).
`
`Increasing the branching factor seems very costly for the
`final step because the algorithm has to sample from a set of
`O( D
`d ) many classes. However, for most kernels, K(a, b)
`can be computed efficiently in O(d) time, e.g., for kernels
`of the form K(a, b) = f ((cid:104)a, b(cid:105)). Thus, performing the
`
`last step in the original space takes O(d Dd ) = O(D) time
`even with a naive implementation. The proposed modifi-
`cation decreases the height of the tree from O(log2 n) to
`O(log2
`D ), and adds a final step to sampling with time
`O(D). The total sampling time is thus O(D(1 + log2
`nd
`D )
`which is still O(D log2 n).
`
`nd
`
`Multiple Partial Samples Usually, we want to sample
`several negatives from q. Instead of applying the divide and
`conquer algorithm m times, a single run could return all the
`leaf nodes. This would require an additional correction in
`sampled softmax to accept a weight on each sample. Then,
`instead of qi being the probability of sampling a particular
`class, it is the probability of sampling a class multiplied by
`the weight given to that class when it is sampled. The draw-
`back of this approach is that the samples are not independent
`and likely more total samples would be needed. We do not
`further investigate this approach, but in some applications,
`faster sampling might justify the cost of requiring a few
`more samples.
`
`Dd
`
`3.3. Quadratic Kernel
`
`One obvious choice for a kernel is a quadratic function
`K(h, wi) = α(cid:104)h, wi(cid:105)2 + 1. This function is conveniently
`always positive. Its feature representation is
`
`
`
`φ(a) =(cid:2)√α vec(a ⊗ a), 1(cid:3)
`
`(10)
`
`with D = O(d2), allowing for O(d2 log n) sampling. It
`is also a reasonably good approximation of exp near the
`
`samples class i and the intermediate sets were C1, C2, . . . ,
`Clog n−1. The probability for sampling class i with the
`divide and conquer algorithm is equal to qi:
`qC1|{1,...,n} qC2|C1 . . . q{i}|Clog n−1
`(((((((
`(((((((
`(cid:104)φ(h), z(C1)(cid:105)
`(cid:104)φ(h), z(C2)(cid:105)
`(((((((
`(cid:104)φ(h), z(C1)(cid:105) . . .
`(cid:104)φ(h), z({1, . . . , n})(cid:105)
`K(h, wi)
`(cid:104)φ(h), z(cid:105) = qi
`
`=
`
`=
`
`(cid:104)φ(h), φ(wi)(cid:105)
`(((((((
`(cid:104)φ(h), zlog n−1(cid:105)
`
`Runtime The divide and conquer algorithm assumes that
`z(C) is known for every set that is involved in sampling.
`As sampling is independent of the particular choice of the
`splits, we can choose any arbitrary (binary and balanced)
`split and keep it fixed. In total, there are n many sets that
`are arranged in a tree like structure and each class appears in
`exactly log2 n many sets. This allows to precompute z(C)
`for any of the n sets. If we update an embedding, wi during
`training, we can also update all sets in which i appears in, in
`time O(D log n) by updating z(C) for every node along the
`path from the root to that embedding. Figure 1(b) illustrates
`the update process.
`
`3.2.2. PRACTICAL CONSIDERATIONS
`Less Memory The structure described so far has O(n)
`nodes in total, each of which must store O(D) information
`for z. This means O(nD) space is required to store it. Here
`we will describe how to reduce that to O(nd) space while
`maintaining fast sampling and updating.
`Instead of splitting sets until they reach the trivial size 1,
`we suggest to stop splitting as soon as the size of a set is
`
`O( Dd ). This leads to the tree having a total of O( ndD ) sets,
`
`and requires O(nd) memory. Figure 1(c) sketches the sam-
`pling process with a larger branching factor for the leaves.
`
`{1,...,n}
`
`q{1,...,n/2}|{1..n}
`
`q{n/2+1,...,n}|{1..n}
`
`{1,...,n/2}
`
`{n/2+1,...,n}
`
`{i,i+1}
`
`q{i}|{i,i+1}
`
`i
`
`q{i+1}|{i,i+1}
`
`i+1
`
`{1,...,n}
`z({1,..,n}) ← z({1,..,n}) +∆(cid:7600)(wi)
`
`{1,...,n/2}
`z({1,..,n/2}) ← z({1,..,n/2}) +∆(cid:7600)(wi)
`
`{i,i+1}
`z({i,i+1}) ← z({i,i+1}) +∆(cid:7600)(wi)
`
`i
`
`{1,...,n}
`
`{1,...,n/2}
`
`{n/2+1,...,n}
`
`{k,...,k+D/d}
`
`k
`
`k+1
`
`...
`
`i
`
`...
`
`k+D/d
`
`4
`
`
`
`Adaptive Sampled Softmax with Kernel Based Sampling
`
`(11)
`
`pi =
`
`origin, where many logits tend to be. However, a quadratic
`function is a poor approximation for negative logits and
`would oversample classes with negative logits. To align the
`sampling distribution q better with the prediction distribu-
`tion p, we suggest a modification of the softmax probability,
`p, in eq. (1) to an absolute softmax
`(cid:80)n
`exp(|oi|)
`j=1 exp(|oj|)
`This modified prediction distribution does not negatively
`impact the expressiveness because softmax is shift invariant,
`i.e., qi ∝ exp(oi) ∝ exp(oi) exp(c) = exp(oi + c) for
`any constant c ∈ R. In particular, any softmax solution has
`a corresponding absolute softmax solution by shifting the
`logits, o, of the softmax solution by any c large enough
`to make all the logits nonnegative. We investigated also
`empirically the quality of softmax and absolute softmax as
`prediction distribution when learning without sampling, i.e.,
`full softmax, and both performed very similarly1 on the
`datasets of Section 4.1.1. Finally, analogous to Section 2.3,
`for absolute softmax as the prediction distribution, the only
`unbiased sampling distribution is absolute softmax. This
`follows directly from Theorem 6.1 in the appendix, because
`the analysis was shown for pi ∝ (oi) and any output oi,
`so it also holds for the modified output |oi|. Therefore, we
`suggest to use an absolute softmax as prediction distribution
`when sampling from a symmetric kernel like the quadratic
`kernel and a standard softmax in other cases.
`Another way to look at absolute softmax is to add an ad-
`ditional layer to o that performs |o| and then passing the
`result to a standard softmax.
`
`4. Experiments
`In this section, we empirically investigate the trade-off be-
`tween bias, sampling distribution, and number of samples.
`
`4.1. Experimental Setup
`
`4.1.1. DATASETS AND MODELS
`
`We study sampled softmax on a natural language processing
`(NLP) problem and a recommender system dataset.
`
`Penn Tree Bank For the NLP problem, we learn a lan-
`guage model on the Penn Tree Bank dataset (Marcus et al.,
`1999), a dataset with approximately 1 million training words
`and a vocabulary of size 10,000. We use the well-studied
`”medium regularized LSTM” implementation2 of Zaremba
`et al. (2014). We made one minor modification, and changed
`
`1Similar empirical findings were obtained by Br´ebisson & Vin-
`cent (2015) on various tasks.
`2https://www.tensorflow.org/tutorials/
`recurrent
`
`the units per layer from 650 to 200. Doing so ensures that
`the expressiveness of the model is small enough that we do
`not need to worry about early-stopping, and dropout on its
`own is a sufficient regularizer. We report the perplexity loss
`as in (Zaremba et al., 2014).
`
`YouTube
`In this recommendation dataset, we predict
`which video a user will watch next based upon various
`user features and the three previously watched videos. We
`train a deep neural network where the user features and
`previous videos are the input and the output is the watch
`probability over all videos. To study the effect on sampling,
`we created two versions of the dataset: YouTube10k, and
`YouTube100k with 10,000, and 100,000 videos (=classes)
`respectively. The 10k dataset has about 113 million training
`examples, and the 100k dataset about 187 million examples.
`For recommender systems, a common evaluation protocol is
`to rank videos by their scores and then use some ranking met-
`ric (e.g. mean average precision) to measure the quality of
`a model. Here, we only wish to measure how well sampled
`softmax approximates a full softmax. Thus, we measure
`the cross-entropy loss of full softmax. In our YouTube ex-
`periments, the cross-entropy loss was also highly correlated
`with ranking metrics such as mean average precision.
`
`4.1.2. SAMPLING DISTRIBUTIONS
`
`We test the performance of three sampling distributions:
`1. Uniform distribution, qi ∝ 1, where every class is
`sampled with the same probability. This provides a
`convenient baseline.
`2. Softmax distribution, qi ∝ exp(oi), which is the ideal
`sampling distribution as shown in Theorem 6.1, but is
`very expensive to sample from.
`3. Quadratic distribution, qi ∝ 100(oi)2 + 1, as proposed
`in Section 3.3
`
`4.2. Results and Analysis
`
`4.2.1. BIAS OF SAMPLING
`
`First, we study the bias of sampled softmax empirically.
`According to Section 2.3, any sampled softmax is biased
`unless softmax is chosen as the sampling distribution, and
`this bias decreases as the sample size, m, increases. We
`visualize the bias by learning models with different sampling
`strategies until convergence and reporting the final accuracy.
`Very biased models perform poorly even when they are run
`until convergence.
`The results are shown in Figure 2. As expected from theo-
`rem 6.1, the quality of softmax sampling, i.e., q ∝ exp(o),
`is independent of the number of samples m. This verifies
`
`5
`
`
`
`Adaptive Sampled Softmax with Kernel Based Sampling
`
`Figure 2. Final model quality when training a sampled softmax with different sampling distributions (uniform, quadratic, softmax) and
`number of samples, m. The quadratic distribution needs one to two orders of magnitude less samples than uniform sampling to learn a low
`bias model. Penn Tree Bank includes additional results for a unigram and a bigram sampler which are common sampling distributions
`in NLP sequence tasks. The results for Penn Tree Bank also include a quartic sampler which is a 4-th degree polynomial kernel with
`qi ∝ o4
`i + 1.
`
`Figure 3. Convergence speed for a varying sample size m ∈ {10, 20, 40, . . .}. Once enough samples are taken to remove the bias, adding
`more samples does not increase convergence speed considerably. Additional results for YouTube10k and YouTube100k as well as other
`samplers for Penn Tree Bank show a similar behavior an can be found in Figures 5, 6.
`
`that a ”good” sampling distribution does not need many sam-
`ples. On the other hand, uniform and quadratic sampling are
`both biased and their final quality improves with increasing
`sample size, m. Again, it is important to note that training
`for more epochs does not solve this issue because the loss
`that sampled softmax optimized is biased when sampling
`uniformly or according to a quadratic kernel for any fixed
`size m. On all datasets, quadratic has a much lower bias
`than uniform sampling and approaches the loss of softmax
`with 10s to 100s of samples.
`
`4.2.2. CONVERGENCE SPEED
`
`Second, we study the speed of convergence by measur-
`ing the progress of the loss against the number of training
`epochs. Every update step consists of reading a batch of
`training examples, sampling m negative classes per example
`
`and performing the update with sampled softmax. We plot
`loss against epochs instead of wall runtime to eliminate any
`implementation specific artifacts. Please note that the larger
`the sample size m, the more computationally expensive an
`epoch.
`
`Sample Size First, we study how the sample size, m, influ-
`ences convergence speed. Figure 3 shows the convergence
`for the three sampling strategies. As already discussed, we
`see that the number of samples has a large effect on the
`accuracy of the model for the uniform and quadratic sam-
`pler. Interestingly, once enough samples are taken to remove
`the bias, adding more samples appears to have a small and
`mostly unobservable effect on convergence speed. We previ-
`ously discussed how bias of the sampled softmax estimator
`affects the final optimum achieved. The variance of this
`
`Penn Tree Bank: Bias
`
`Uniform
`Unigram
`Bigram
`Quadratic
`Quartic
`Softmax
`
`l
`
`l
`
`l
`
`l
`
`400500
`
`300
`
`200
`
`Perplexity
`
`l
`
`l
`
`l
`
`l
`
`l
`
`l
`
`l
`
`10
`
`20
`
`50 100
`
`500
`
`2000
`
`Number of Samples (m)
`
`YouTube 10k: Bias
`
`l
`
`Uniform
`Quadratic
`Softmax
`
`l
`
`l
`
`l
`
`l
`
`l
`
`l
`
`l
`
`10
`
`20
`
`50
`
`100
`
`200
`
`500
`
`Number of Samples (m)
`
`7.0
`
`6.9
`
`6.8
`
`6.7
`
`6.6
`
`6.5
`
`6.4
`
`Cross−Entropy
`
`YouTube 100k: Bias
`
`l
`
`Uniform
`Quadratic
`Softmax
`
`l
`
`l
`
`l
`
`l
`
`l
`
`l
`
`l
`
`8.8
`
`8.6
`
`8.4
`
`8.2
`
`8.0
`
`7.8
`
`Cross−Entropy
`
`10
`
`20
`
`50
`
`100
`
`200
`
`500
`
`Number of Samples (m)
`
`Penn Tree Bank: Uniform Sampler
`
`l
`
`10
`20
`40
`80
`160
`
`320
`640
`1280
`2560
`5120
`
`0
`
`10
`
`20
`
`30
`
`40
`
`Training Epoch
`
`400500
`
`300
`
`200
`
`100
`
`Perplexity
`
`Penn Tree Bank: Quadratic Sampler
`
`l
`
`10
`20
`40
`80
`160
`
`320
`640
`1280
`2560
`5120
`
`l
`
`l
`
`l
`
`l
`
`400500
`
`300
`
`200
`
`100
`
`Perplexity
`
`0
`
`10
`
`20
`
`30
`
`40
`
`Training Epoch
`
`Penn Tree Bank: Softmax Sampler
`
`l
`
`10
`20
`40
`80
`160
`
`320
`640
`1280
`2560
`5120
`
`l
`
`10
`
`l
`
`l
`
`20
`
`30
`
`40
`
`Training Epoch
`
`l
`
`0
`
`400500
`
`300
`
`200
`
`100
`
`Perplexity
`
`6
`
`
`
`Adaptive Sampled Softmax with Kernel Based Sampling
`
`Figure 4. Convergence speed of different sampling distributions for a fixed sampling size. The convergence speed of all distributions is
`similar only the bias is different. Figure 7 shows more comparisons.
`
`estimator affects how many steps we need to converge. The
`variance has two sources: (i) The gradient computed on
`a batch is a noisy (but unbiased) estimator of the gradient
`on the entire training set and (ii) the gradient given a set
`of sampled classes is an estimator of the gradient on that
`batch. While taking a larger sample size can reduce the
`variance from source (ii), if the variance from source (i) is
`the dominate source, doing so will not appreciably increase
`convergence speed. For our data sets, we found that once
`we take a reasonable number of samples (only 10s), adding
`more does not noticeably increase convergence speed. This
`is likely because the variance from source (i) dominates that
`from source (ii). For instance on Penn Tree Bank, quadratic
`sampling with m ∈ {160, 320, . . .} samples does not show
`any difference in convergence speed.
`To summarize, the sample size m influences the bias but the
`influence on the convergence speed is small and often not
`noticeable.
`
`Sampling Distribution Finally, we fix the number of sam-
`ples m and vary the sampling distribution. Figure 4 shows
`that all three sampling distributions have a comparable speed
`of convergence, however, uniform converges to a much
`worse loss due to its high bias. Quadratic and softmax con-
`verge similarly although quadratic has a slightly worse loss
`throughout the whole training process due to its bias.
`
`5. Related Work
`In this section, we summarize the main approaches for train-
`ing classification models over many classes. All of them
`make some approximation of the full softmax to lower the
`computational complexity.
`
`5.1. Sampled Softmax
`
`Other works on sampled softmax have noted that a good
`sampling distribution can boost performance and attempted
`to come up with such distributions. Bengio & S´en´ecal
`(2008) propose an adaptive sampler for language models.
`They argue that the sampling distribution should track the
`model distribution as closely as possible. They propose to
`learn a mixture of unigrams, bigrams, trigrams, etc. that is
`adapted while training. While the work of Bengio & S´en´ecal
`(2008) needs a second model to track the trained model, our
`work uses the trained model directly for sampling. This
`makes our approach much easier to apply. Secondly, kernel
`based sampling is more appealing for sophisticated model
`structures where it is hard to come up with a simple model
`that can track the trained model well. Labeau & Allauzen
`(2017) study sampling distributions for noise contrastive es-
`timation (NCE) (Gutmann & Hyvrinen, 2010). Their exper-
`iments highlight the issues of simple sampling distributions
`such as uniform, or unigram. Another idea to improve the
`sampling distribution is the Two-Pass Approximate Adap-
`tive Sampling for Softmax (TAPAS). In that work, Bai et al.
`(2017) propose taking one large sample of classes, which
`might be in the order of 100,000 (20% of all classes in their
`case) and computing the logits from that sample. Then, a
`smaller number of classes, e.g., 1,000, is chosen from those
`100,000 classes based on the computed logits. This second
`sample of 1,000 classes is used for the sampled softmax. By
`using a distributed implementation and GPUs, it is possible
`to compute the logits of the larger sample quickly. While
`the TAPAS sampler is adaptive and depends on the current
`model’s output as in our work, it is computationally much
`more expensive. Bakhtiary et al. (2015) also explore se-
`lectively computing logits using hashing to obtain faster
`training steps for large batch sizes.
`
`Penn Tree Bank: m=80 Samples
`
`Uniform
`Unigram
`Bigram
`
`l
`
`Quadratic
`Quartic
`Softmax
`
`l
`
`l
`
`l
`
`l
`
`0
`
`10
`
`20
`
`30
`
`40
`
`Training Epoch
`
`2000
`
`1000
`
`500
`
`200
`
`100
`
`Perplexity
`
`YouTube 10k: m=80 Samples
`
`l
`
`Uniform
`Quadratic
`Softmax
`
`l
`
`l
`
`l
`
`l
`
`0
`
`10
`
`20
`
`30
`
`40
`
`50
`
`60
`
`Training Epoch
`
`7.0
`
`6.9
`
`6.8
`
`6.7
`
`6.6
`
`6.5
`
`6.4
`
`Cross−Entropy
`
`YouTube 100k: m=160 Samples
`
`l
`
`Uniform
`Quadratic
`Softmax
`
`l
`
`l
`
`l
`
`l
`
`0
`
`10
`
`20
`
`30
`
`40
`
`50
`
`60
`
`Training Epoch
`
`8.8
`
`8.6
`
`8.4
`
`8.2
`
`8.0
`
`7.8
`
`Cross−Entropy
`
`7
`
`
`
`Adaptive Sampled Softmax with Kernel Based Sampling
`
`5.2. Hierarchical Softmax and Its Variations
`
`Hierachical Softmax (HSM) is an approximation of a full
`softmax introduced in (Goodman, 2001) that is quickly
`computable.
`It involves grouping the classes into clus-
`ters, where each cluster is a latent variable.3
`If cj is
`the jth cluster and class i is in cj, then we factor pi as
`p(yi|x) = p(cj|x) p(yi|cj). If we set the number of classes
`√
`in each cluster to be O(
`n) and the cluster probabilities can
`√
`be computed in time O(d), then this version of hierarchical
`softmax can be done in O(d
`n).
`Morin & Bengio (2005) extend this structure to a tree. In-
`stead of having one layer of clusters they use a binary tree
`where each internal node is a cluster and the leaf nodes are
`the classes. The probability of a class is then the product of
`the conditional probability of each node along the path from
`the root to that class. Such a structure allows for O(d log n)
`training time.
`While hierarchical softmax can be much faster than a full
`softmax, it often performs worse at convergence. For in-
`stance, Chen et al. (2015) found full softmax to achieve a
`perplexity more than 10% better than hierarchical softmax.
`They also note that while hierarchical softmax can speed up
`training, it slows down inference if the goal is to compute
`the class or classes with the highest logits. In particular,
`both a full softmax and sampled softmax can treat inference
`as a maximum inn