Post

Logits to Text

How to generate text from a transformer model

Generating text from logits for a Transformer model involves converting the output logits into human-readable text. In this post, we will cover the simplest way to generate text i.e. Greedy Approach but I briefly cover other techniques for the text generation.

If you are looking for more, Huggingface has a great blog post on text generation using different techniques.

Okay let’s go. Here’s a step-by-step outline of the process:

  1. Obtain the Logits: These are the raw scores output by the Transformer model before applying any probability function. Typically, the shape of the logits is [batch_size, sequence_length, vocab_size].

  2. Apply a Softmax Function: Convert logits to probabilities. Softmax is typically used to convert the logits into probabilities for each token in the vocabulary.

1
2
import torch.nn.functional as F
probabilities = F.softmax(logits, dim=-1)
  1. Sample or Argmax: Convert probabilities to actual token indices. There are different strategies to do this:
  • Greedy Decoding (Argmax): Choose the token with the highest probability.
1
predicted_indices = torch.argmax(probabilities, dim=-1)
  • Sampling: Sample from the probability distribution.
1
predicted_indices = torch.multinomial(probabilities, num_samples=1).squeeze()
  1. Convert Token Indices to Text: Use a tokenizer to convert the token indices back to text.
1
2
# Assuming you have a tokenizer that can convert indices back to tokens
generated_text = tokenizer.decode(predicted_indices, skip_special_tokens=True)
  1. Here’s a complete example assuming you have the logits from a Transformer model and a tokenizer:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn.functional as F

# Example logits tensor
logits = torch.randn(1, 10, 50257)  # [batch_size, sequence_length, vocab_size]

# Apply softmax to get probabilities
probabilities = F.softmax(logits, dim=-1)

# Greedy decoding
predicted_indices = torch.argmax(probabilities, dim=-1)

# Assuming you have a tokenizer that can decode the indices to text
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Convert indices to text
generated_text = tokenizer.decode(predicted_indices[0], skip_special_tokens=True)

print(generated_text)

I previously wrote a post on creating GPT2 model from scratch. We can use that model to generate some text.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
gpt2_model = GPT2.from_pretrained()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
text = "Hello, I'm a language model,"
encoded_input = tokenizer(text, return_tensors='pt')
print(encoded_input)

with torch.no_grad():
    my_gpt_logits = gpt2_model(encoded_input['input_ids'])
    print(my_gpt_logits)

probabilities = F.softmax(my_gpt_logits, dim=-1)
predicted_indices = torch.argmax(probabilities, dim=-1)
print("Argmax sampling: ", predicted_indices)
generated_text = tokenizer.decode(predicted_indices[0], skip_special_tokens=True)
print("Generated text: \'", generated_text, "\'")
1
2
3
{'input_ids': tensor([[15496,    11,   314,  1101,   257,  3303,  2746,    11]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
Argmax sampling:  tensor([[  11,  314, 1101,  407, 1310, 5887,   13,  290]])
Generated text: ' , I'm not little expert. and '

That looks coherent enough!

Notes on other Strategies:

  • Temperature Scaling: You might want to scale the logits before applying softmax to control the randomness of the predictions. Lower temperature makes the model more confident (less random), while a higher temperature makes it less confident (more random).
    1
    2
    3
    
    temperature = 1.0
    scaled_logits = logits / temperature
    probabilities = F.softmax(scaled_logits, dim=-1)
    
  • Top-k and Top-p Sampling: These are advanced sampling methods that can help to produce more coherent and diverse text by limiting the sampling pool to the top-k or top-p tokens.

  • Beam Search: This is another decoding method that keeps track of multiple hypotheses (beams) during the generation process and selects the most likely sequence overall.

Sources

This post is licensed under CC BY 4.0 by the author.

Comments powered by Disqus.