As mentioned last time, I am going to look at different sampling techniques to get the next token on each iteration. I plan to look at temperature (not strictly a sampling technique), top-k and top-p. There are also other decoding techniques like beam search that I will probably not tackle at this time.
Let’s have a look at how each of the three work. I was thinking about including examples, numeric or visual, but have decided not to do so. If you have been following along, I think the following brief explanations should be adequate.
Temperature
Temperature scales the probability distribution by scaling the logits. A low temperature, less than 1, concentrates the probability mass on the most likely next words. It does so by making higher log probabilities higher, lower ones lower. A temperature greater than 1 does the opposite. The logits are made less extreme, shrinking toward zero. I.E. it flattens the probablity distribution.
Implementation is basically simple division of the logits by the temperature before generating the probability distribution. As you have likely guessed, our current code has a temperature of 1.
Top-k Sampling
In this case, we select the \(k\) tokens with the highest probabilities. We re-normalize the probability distribution over those \(k\) tokens. Then select the next token by sampling from those \(k\) options. Fortunately for us, PyTorch provides a .topk()
method on tensors that will do the heavy lifting for us.
The smaller the value of \(k\), the more the sampling focuses on the most likely or coherent tokens. The larger the value, the greater the potential for creativity. And, of course, for incoherent output. Again, you have likely realized that our current \(k\) value is the size of the vocabulary (14,224).
Top-p Sampling (aka Nucleus Sampling)
In this case, we filter out tokens that have a cumulative probability less than some specified threshold. For example, for a threshold of \(.75\), we would select all the highest probablitity tokens whose sum equals or exceeds the threshold. The remainder would be excluded from the sample. The selected tokens/proababilities are referred to as a nucleus.
This approach provides for more diversity while not considering low probability tokens.
But, we are, I believe, going to have to code it ourselves. I won’t bother discussing that here.
Greedy Sampling
Not discussed and not used. We used random sampling from the get go. But, if we only ever selected the highest probability token, we would have been using greedy sampling. Needless to say, that would likely not generate particularly interesting outputs.
What Next
I plan to write, iteratively, a function to generate model output for a given prompt. The function’s parameters will allow the selection of one or more of the three above output moderating techniques. Temperature will always be applied first. Followed by top-k. And, if all three requested, top-p will be applied last.
A number of articles I read suggested temperature and top-p where enough to control the creativity and/or coherence of model outputs. But, I am probably going to look at using all three in various combinations.
And I think I may need to get a little creative in order to be able to compare the outputs using different combinations of those techniques.
Refactor Output Generation
For now, I am going too write, iteratively, a function that only scales the logits based on the passed parameters.
Okay, let’s start on that new function, apply_sampling()
. To begin with, I will pass the logits and the temperature value (defaults to 1). I will process the logits values accordingly and return to the caller.
I am doing it this way because I want to be able to use various combinations on the same prompt to be able to compare the results as I develop the various bits of sampling code. Will probably be rather messy code, but…
Allow Setting Temperature
My approach and code have turned out even messier than I expected. But, I am getting a textual visualization of how temperature may be affecting the output. Which is something I wanted to investigate as I went along.
Because I plan to use a match
statement in various places to control program flow, I am already adding parameters for all three sampling techniques to the function definition.
The code in the final match
block will, as written, not work for the top-k and top-p cases. We will need to also have a list of the modified vocabulary indices. So I am trying to get things set up for the future code additions.
... ...
def apply_sampling(logits, temp=1, topk=0, topp=1):
# don't want to alter the logits tensor passed as parameter
nw_logits = logits.clone()
if temp != 1:
nw_logits = nw_logits / temp
if topk > 0:
...
if topp < 1:
...
return nw_logits
... ...
tst_temp = True # run code to use temp, top-k and top-p sampling
... ...
# including the whole block for context
if eval_model:
do_dev = True
cp_pth = cfg.sv_dir/f"lstm_{cfg.start_ep}.pt"
chkpt = utl.ld_chkpt(cp_pth, lstm, optr, rtn_chk=False)
if False:
# let's check the corpus and vocab variables
for obj in [corpus, corpus.vocab]:
print(f"{obj.__class__.__name__}")
for i in inspect.getmembers(obj):
# to remove private and protected functions
if not i[0].startswith('_'):
# To remove other methods that don't start with a underscore
if not inspect.ismethod(i[1]):
print(f"\t{i[0]} {len(i[1]) if obj==corpus.vocab or i[0]=="vocab" else ""}")
lstm.eval()
p_txt = "The prince left".lower().split(' ')
p_txt = "The prince was".lower().split(' ')
max_wds = len(p_txt) + (12 if do_dev else 50)
print(f"initial input: {p_txt}")
# batch of 1
hh, hc = lstm.init_hdn(1)
if not tst_temp:
while len(p_txt) < max_wds:
inp = torch.tensor([[corpus.vocab.word2idx[w] for w in p_txt]])
inps = inp.to(cfg.device)
outp, (hh, hc) = lstm(inps, (hh, hc))
logits = outp[0][-1]
p = nn.functional.softmax(logits, dim=0).detach().cpu().numpy()
nxt_tk_idx = cfg.rng.choice(len(logits), p=p)
p_txt.append(corpus.vocab.idx2word[nxt_tk_idx])
fn_txt = " ".join(p_txt)
print(f"generated text: {tidy_output(fn_txt)}")
else:
use_what = "temperature"
t_temps = [0.1, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
t_topk = [25, 50, 100, 500, 1000, 2500, 5000, 7500, 0]
t_topp = [.15, .25, .5, .75, .9, 1]
s_vals = []
match use_what:
case "temperature":
s_vals.extend(t_temps)
case "top-k":
s_vals.extend(t_topk)
case "top-p":
s_vals.extend(t_topp)
outs = [p_txt[:] for _ in range(len(s_vals))]
# print(outs, len(outs), len(outs[0]))
with torch.no_grad():
while len(outs[0]) < max_wds:
for ip, s_val in enumerate(s_vals):
# print(f"{ip} ({s_val}, {len(outs[0])}): {outs[ip]}")
inp = torch.tensor([[corpus.vocab.word2idx[w] for w in outs[ip]]])
inps = inp.to(cfg.device)
outp, (hh, hc) = lstm(inps, (hh, hc))
logits = outp[0][-1]
# print(f"logits: {logits.shape}")
match use_what:
case "temperature":
nw_logits = apply_sampling(logits, temp=s_val)
case "top-k":
nw_logits = apply_sampling(logits, topk=s_val)
case "top-p":
nw_logits = apply_sampling(logits, topp=s_val)
case _:
nw_logits = logits
if len(outs[0]) == len(p_txt):
print(f"nw_logits: {len(nw_logits)}")
p = nn.functional.softmax(nw_logits, dim=0).detach().cpu().numpy()
nxt_tk_idx = cfg.rng.choice(len(nw_logits), p=p)
outs[ip].append(corpus.vocab.idx2word[nxt_tk_idx])
# print(outs, len(outs), len(outs[0]))
for io, outp in enumerate(outs):
o_txt = tidy_output(" ".join(outp))
print(f"{use_what} {s_vals[io]} -> {o_txt}")
And here’s the output for the two prompts shown above.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj8> python nlp.py -rn rk1 -bs 32 -se 50
... ...
loading runs\rk1_sv\lstm_50.pt
initial input: ['the', 'prince', 'left']
nw_logits: 14224
temperature 0.1 -> the prince left him. the priest was a tall, thinnish man of a
temperature 0.25 -> the prince left the servant, and the coachman were a little, and stooping
temperature 0.5 -> the prince left the servant, and heard the colonel and the old man at
temperature 0.75 -> the prince left: the carriage, were a little back in his chest,
temperature 1.0 -> the prince left. red with whom the old man had finished, looked with
temperature 1.25 -> the prince left one lady smiled. on here, running up to his children
temperature 1.5 -> the prince left away from. “oblonsky had generally seen your dog with his children
temperature 1.75 -> the prince left looking from side of what admitting you you mean. her doors
(mclp-3.12) PS F:\learn\mcl_pytorch\proj8> python nlp.py -rn rk1 -bs 32 -se 50
... ...
loading runs\rk1_sv\lstm_50.pt
initial input: ['the', 'prince', 'was']
nw_logits: 14224
temperature 0.1 -> the prince was in the same way. he was not a quarter of the
temperature 0.25 -> the prince was in a continual state of discomfort, and he was not merely
temperature 0.5 -> the prince was in moscow. the prince was not able to think of the
temperature 0.75 -> the prince was in lying. he went up to his brother and the porter
temperature 1.0 -> the prince was brought up. but together she was ashamed, and he could
temperature 1.25 -> the prince was more into his high brother, and nikolay felt that still,
temperature 1.5 -> the prince was highly, while he was surprised at something we know he was
temperature 1.75 -> the prince was seen with everything high stricken and friendly his fascinated by children,
Can definitely see the incoherence increasing with the temperature value. Though there is clearly a lack of sufficient training obvious in all the outputs.
Allow Setting Top-k
Okay let’s add top-k and do a wee test similar to the above. We need to return the modified vocabulary index tensor since torch.topk
reduces it to match the topk
value.
... ...
def apply_sampling(logits, temp=1, topk=0, topp=1):
# don't want to alter the logits tensor passed as parameter
nw_logits = logits.clone()
l_ndx = []
if temp != 1:
nw_logits = nw_logits / temp
if topk > 0:
# this will reduce the size of the logits tensor
nw_logits, l_ndx = torch.topk(logits, topk)
if topp < 1:
...
return nw_logits, l_ndx
... ...
use_what = "top-k"
... ...
l_lens = []
... ...
match use_what:
case "temperature":
nw_logits, t_ndx = apply_sampling(logits, temp=s_val)
case "top-k":
nw_logits, t_ndx = apply_sampling(logits, topk=s_val)
l_lens.append(len(nw_logits))
case "top-p":
nw_logits, t_ndx = apply_sampling(logits, topp=s_val)
case _:
nw_logits, t_ndx = logits, []
... ...
match use_what:
case "top-k" | "top-p":
new_ndx = t_ndx[nxt_tk_idx]
case _:
new_ndx = nxt_tk_idx
outs[ip].append(corpus.vocab.idx2word[new_ndx])
... ...
if use_what == "temperature":
print(f"{use_what} {s_vals[io]} -> {o_txt}")
else:
print(f"{use_what} {s_vals[io]} ({l_lens[io]}) -> {o_txt}")
And the test output.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj8> python nlp.py -rn rk1 -bs 32 -se 50
... ...
initial input: ['the', 'prince', 'was']
nw_logits: 25
top-k 25 (25) -> the prince was so well. alexey alexandrovitch saw what was happening, and he
top-k 50 (50) -> the prince was standing here, and a footman with his lips held, levin
top-k 100 (100) -> the prince was all at the mention of twenty; that alexey alexandrovitch was more
top-k 500 (500) -> the prince was dissatisfied as though the word were. “yes, but what’
top-k 1000 (1000) -> the prince was told by what was the chief next thing. the boy was
top-k 2500 (2500) -> the prince was rapidly having to say. it was still for the days that
top-k 5000 (5000) -> the prince was involved in that the wiles of whose early years the old painter
top-k 7500 (7500) -> the prince was shown that he would working all here. instead of his own
top-k 0 (14224) -> the prince was through this. how was it he was not on the railway
Allow Setting Top-p
Okay, let’s move on to top-p, aka nucleus, sampling. This took me quite some time to sort out.
The code does the following:
- sorts the logits (we will be doing a cummulative sum), fortunately for us
torch.sort
returns the sorted logits and the sorted indices for the original logits tensor - because we are after the sum of probabilities we need to generate those probabilities
- we use
torch.cumsum
to get the cummulative sums of the probabilities - we next determine which cummulative sums are greater than the specified threshold
- and extend that to not include the first sum greater than the threshold, we want probabilities that sum to at least the threshold, not below, we set that item’s value to 0 (i.e. False) so that it will not be removed
- we get the appropriate indices from the sorted indices returned, this gives us a tensor of binary values; True values indicate which indices are to be removed
- we use that last tensor to set the appropriate logits to \(-\infty\), so that when the logits are converted to probabilities the removed indicies with have a probability of \(0\)
... ...
def apply_sampling(logits, temp=1, topk=0, topp=1):
# don't want to alter the logits tensor passed as parameter
nw_logits = logits.clone()
l_ndx = []
if temp != 1:
nw_logits = nw_logits / temp
if topk > 0:
nw_logits, l_ndx = torch.topk(logits, topk)
if topp < 1:
sort_logs, sort_ndxs = torch.sort(nw_logits, descending=True)
cum_probs = torch.cumsum(nn.functional.softmax(sort_logs, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sort_ndxs_remove = cum_probs > topp
# keep the first token above the threshold
sort_ndxs_remove[..., 1:] = sort_ndxs_remove[..., :-1].clone()
sort_ndxs_remove[..., 0] = 0
sort_ndxs_remove = sort_ndxs[sort_ndxs_remove]
nw_logits[sort_ndxs_remove] = -float('Inf')
l_ndx = sort_ndxs
return nw_logits, l_ndx
... ...
use_what = "top-p"
... ...
case "top-p":
nw_logits, t_ndx = apply_sampling(logits, topp=s_val)
l_lens.append(nw_logits.gt(-float('Inf')).sum())
case _:
nw_logits, t_ndx = logits, []
if len(outs[0]) == len(p_txt):
print(f"nw_logits: {nw_logits.gt(-float('Inf')).sum()}, t_ndx: {len(t_ndx)}")
p = nn.functional.softmax(nw_logits, dim=0).detach().cpu().numpy()
nxt_tk_idx = cfg.rng.choice(len(nw_logits), p=p)
match use_what:
case "top-k":
if s_val == 0:
new_ndx = nxt_tk_idx
else:
new_ndx = t_ndx[nxt_tk_idx]
case "top-p":
new_ndx = nxt_tk_idx
case _:
new_ndx = nxt_tk_idx
And a sample test output.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj8> python nlp.py -rn rk1 -bs 32 -se 50
... ...
initial input: ['the', 'prince', 'was']
nw_logits: 5, t_ndx: 14224
top-p 0.15 (5) -> the prince was in the same state of mind, and so he was conscious
top-p 0.25 (9) -> the prince was standing with a significant smile, and was pulling his head off
top-p 0.5 (40) -> the prince was speaking of the professor, who had gone to his office,
top-p 0.75 (123) -> the prince was crying. “i’m so glad!” said kitty.
top-p 0.9 (289) -> the prince was, when the peasants’” said the colonel. “you were
top-p 1 (14224) -> the prince was explaining to him with something, and incapable of more aware of
Again, one can see that the incoherence increases with increasing top-p threshold values.
Done
I think that’s it for this one. Lot’s of things to think about. A fair bit of code, with a few interesting methodologies. Way too much terminal output. And, just the right amount of fun for a single blog post.
Next time I will look at combining the various sampling techniques rather than just using one of them.
Until then, may your endeavours bring you much fun.
Afterthought
The night after I figured the draft of this post was done, I woke up in the middle of the night. Got thinking about things, which eventually ended up with me thinking about how my code would work if I used top-k followed by top-p. After all my top-k code reduced the size of the logits array, and returned a new set of indices as well. Well, for all but one case.
So, I have decided to code the top-k block in the apply function differently. In a way that should not affect the top-p code.
We use torch.topk
to determine the value of the last logit it would keep. Then we set all logits with a value less than that to \(-\infty\). We could end up with more values than that specified by the topk
parameter. But I expect that there should not be many more, nor would that be a significant problem.
And we no longer need to return the index tensor. So here’s the revised function and related changes.
def apply_sampling(logits, temp=1, topk=0, topp=1):
nw_logits = logits.clone()
if temp != 1:
nw_logits = nw_logits / temp
if topk > 0:
ndx_remove = nw_logits < torch.topk(nw_logits, topk)[0][..., -1, None]
nw_logits[ndx_remove] = -float('Inf') # assign a very low value, so probability will be 0
if topp < 1:
sort_logs, sort_ndxs = torch.sort(nw_logits, descending=True)
cum_probs = torch.cumsum(nn.functional.softmax(sort_logs, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sort_ndxs_remove = cum_probs > topp
# keep the first token above the threshold
sort_ndxs_remove[..., 1:] = sort_ndxs_remove[..., :-1].clone()
sort_ndxs_remove[..., 0] = 0
sort_ndxs_remove = sort_ndxs[sort_ndxs_remove]
nw_logits[sort_ndxs_remove] = -float('Inf')
return nw_logits
... ...
use_what = "top-k"
... ...
match use_what:
case "temperature":
nw_logits, t_ndx = apply_sampling(logits, temp=s_val)
case "top-k":
nw_logits, t_ndx = apply_sampling(logits, topk=s_val)
l_lens.append(nw_logits.gt(-float('Inf')).sum())
case "top-p":
nw_logits, t_ndx = apply_sampling(logits, topp=s_val)
l_lens.append(nw_logits.gt(-float('Inf')).sum())
case _:
nw_logits, t_ndx = logits, []
if len(outs[0]) == len(p_txt):
print(f"nw_logits: {nw_logits.gt(-float('Inf')).sum()}, t_ndx: {len(t_ndx)}")
p = nn.functional.softmax(nw_logits, dim=0).detach().cpu().numpy()
nxt_tk_idx = cfg.rng.choice(len(nw_logits), p=p)
And a test, to see if that works more or less as before.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj8> python nlp.py -rn rk1 -bs 32 -se 50
... ...
initial input: ['the', 'prince', 'was']
nw_logits: 25, t_ndx: 0
top-k 25 (25) -> the prince was talking of what was brought to a wife’s quarrel.
top-k 50 (50) -> the prince was with him. “how can she fail to be dining today?
top-k 100 (100) -> the prince was weak, but a new woman. when she had finished the
top-k 500 (500) -> the prince was dead. it was very queer. levin was a good natured
top-k 1000 (1000) -> the prince was broken up. if levin saw too, he never were asleep
top-k 2500 (2500) -> the prince was at home, and now was doubly pleased, and all the
top-k 5000 (5000) -> the prince was with a look of energy, so not to speak of anything
top-k 7500 (7500) -> the prince was asleep. the leaders parted in terror, and was pleased by
top-k 0 (14224) -> the prince was sunday, and the old party, who always could to love
And, that appears to work as desired. We will see, in the next post, if that prevents any issues with the scenario I mentioned at the start of this section.
2nd Afterthought
Following night another concern. Was it actually ok to not return the index tensor from the function?
So, I decided to do a bit of a test.
... ...
tst_sampling = True
use_what = "top-k"
... ...
if tst_sampling:
mx_ndx = torch.argmax(logits)
st, nd = mx_ndx - 3, mx_ndx + 4
print(f"logits[{st}:{nd}]: {logits[st:nd]} ({len(logits)})")
nw_logits, t_ndx = apply_sampling(logits, temp=t_temps[0])
print(f" tempr({t_temps[0]}): {nw_logits[st:nd]} (ratio: {logits[3] / nw_logits[3]}) ({len(nw_logits)})")
nw_logits, t_ndx = apply_sampling(logits, topk=t_topk[0])
print(f" top-k({t_topk[0]}): {nw_logits[st:nd]} -> {nw_logits.gt(-float('Inf')).sum()} ({len(nw_logits)})")
print(f"logits[{st}:{nd}]: {logits[st:nd]} ({len(logits)})")
nw_logits, t_ndx = apply_sampling(logits, topp=t_topp[0])
print(f" top-p({t_topp[0]}): {nw_logits[st:nd]} -> {nw_logits.gt(-float('Inf')).sum()} ({len(nw_logits)})")
exit(0)
Which, after a bit of messing around, produced the following in the terminal window.
(mclp-3.12) PS F:\learn\mcl_pytorch\proj8> python nlp.py -rn rk1 -bs 32 -se 50
... ...
initial input: ['the', 'prince', 'was']
logits[7:14]: tensor([-6.2827, -5.5919, 2.3910, 2.8982, -0.3905, -3.0204, 0.2236],
device='cuda:1') (14224)
tempr(0.1): tensor([-62.8268, -55.9192, 23.9095, 28.9821, -3.9046, -30.2041, 2.2357],
device='cuda:1') (ratio: 0.10000000149011612) (14224)
top-k(25): tensor([ -inf, -inf, 2.3910, 2.8982, -inf, -inf, -inf],
device='cuda:1') -> 25 (14224)
logits[7:14]: tensor([-6.2827, -5.5919, 2.3910, 2.8982, -0.3905, -3.0204, 0.2236],
device='cuda:1') (14224)
top-p(0.15): tensor([ -inf, -inf, 2.3910, 2.8982, -inf, -inf, -inf],
device='cuda:1') -> 5 (14224)
And, that, to me, looks exactly like what we want to happen. Hopefully I will now get a good night’s rest.
Resources
- docs > python > match statment
- docs > torch > torch.cumsum
- docs > torch > torch.sort
- docs > torch > torch.topk
- docs > torch.Tensor > torch.Tensor.sum
- docs > torch > Slicing, Indexing, and Masking
- How to generate text using language models?
- Setting Top-K, Top-P and Temperature in LLMs
- Trying out Sampling Techniques for Language Models