All Collections
OpenAI API
Prompt engineering
Fine-tuning a Classifier to Improve Truthfulness
Fine-tuning a Classifier to Improve Truthfulness
Michael Schade avatar
Written by Michael Schade
Updated over a week ago

Motivation

It is often relatively easy to achieve a high quality output with instruct models some percentage of the time, which is enough to impress when seeing cherry picked examples, but this may not be reliable enough to deploy to production.

Examples

This often happens in a class of problems we will call generative transformation of the input, where the task is to transform the information presented in one format as input into another format. These tasks include:

  • Creating an engaging product description based on a structured input, such as product name, color, size, category

  • Summarizing a number of customer reviews into a descriptive, neutral tone summary, an advert or a catchy tagline

  • Rewriting a piece of content into a particular brand style and format, while focusing on the topic of interest. (e.g. based on a press release about a new smartphone, write an advert aimed at amateur photographers.

  • Answering a question based on the context provided.

For the purpose of this guide we will focus on generating an ad based on the context provided. We would like the ad to be truthful, supported by the context and engaging.

Idea

Better prompts can improve that performance on this task to give good results about 50% to 65% of the time, but often not beyond. However, when looking at ~5 different generations with a slightly higher temperature, a human expert usually finds one generation which is of high publishable quality.

We will fine-tune a classifier to perform the same discriminative function as a human expert - to effectively select the best out of the number of generated samples. Best may range from domain to a domain, but usually truthfulness is the main constraint on being able to "productionize" a prototype. For example, a slightly less engaging advertisement every now and then is tolerable, but an untruthful ad, not supported by the inputs, is not.

Additionally, we could create other discriminators focused on validating how engaging an advertisement is to readers.

Approach

The approach can be broken down into the following steps:

  1. Create a prompt for generating plausible completions, some of which will be high quality. Alternatively fine-tune a model on the desired generative task. We will call this model the generator.

  2. Fine-tune an ada binary classifier to rate each completion for truthfulness based on a few hundred to a thousand expert labelled examples, predicting “ yes” or “ no”. Alternatively, use a generic pre-built truthfulness and entailment model we trained. We will call this model the discriminator.

  3. Generate a number of different completions (10-50), which is most easily achieved by increasing the temperature of the generator.

  4. Rank each of those completion by the predicted logprob of the “ yes” label of the discriminator from step 2, and pick only the completion with a high enough confidence for truthfulness required by your application. If no completion achieves that threshold, you could try generating more samples with higher temperature, or you could return a special output, saying that none of the generated samples were truthful enough.

How to Fine-tune a discriminator?

Pleas read the Case study: Is the model making untrue statements? for more background on this topic.

Format the input

{“prompt”:”Context:<elaborate dry context>\nAd:<generated ad>\nSupported:”, “completion”:” yes”}

{“prompt”:”Context:<elaborate dry context>\nAd:<generated ad>\nSupported:”, “completion”:” no”}

{“prompt”:”Context:<elaborate dry context>\nAd:<generated ad>\nSupported:”, “completion”:” yes”}

Set parameters

We recommend using ada, since it is the fastest engine and capable of creating good prediction in a classification task after fine-tuning. To get a better performance on classification with fine tuning, compared to few shot learning we normally need at least 100 examples per class. With any doubling in the number of examples, the performance tends to linearly increase.

The higher the accuracy of the discriminator, the easier it will be to find a sample for which the model is confident enough.

How to use a discriminator to get confidence?

Log probability of the first generated completion token can be used to determine confidence. To get the log probability, you can add logprobs=2 and logit_bias={‘645’: 100, ‘3763’: 100}argument to the completion request, where 645 and 3763 are token IDs of ` no` and ` yes` respectively. See more details in the last section of the classification guide. The higher the log probability for the ` yes` token, the more confident the prediction is that the output is supported.

How to determine a log probability threshold?

To determine a threshold above which the ad is likely to be supported more than 98% of the time we can:

  1. Use the discriminator to predict the probability of ` yes` on a held out dataset

  2. Convert the log probability measure into percentiles. This can be achieved by sorting the predictions, and assigning each log probability a percentile rank, which represents a percentage of predictions which have a lower log probability.

  3. For each percentile compute a precision, which is the share of actual truthful ads found above that threshold.

  4. Then you can find a percentile at which the precision is just above 98%. The log probability threshold needed to obtain a precision of at least 98% is then the log probability at this percentile on the held out dataset.

Here is a graph of precision and truthfulness percentiles on the held out dataset. On this dataset, the 98% precision is achieved at the 0.58 percentile, which corresponds to the log probability of -0.000685.

As can be seen in the graph, if we accept any sample with log probability for ` yes` above this threshold, we would expect such samples to be supported ~98% of the time. This means that if we were to use a single sample, we would achieve sufficient precision only 56% of the time.

Increasing truthfulness by generating more samples

By generating several samples, and then picking the one with the highest log probability, we can increase the probability that the selected ad is indeed truthful. By generating 3 samples we can achieve the sufficient truthfulness threshold of 98.3%. This can be increased to 99.6% by generating 10 samples, and 99.8% by generating 17 samples. We can observe diminishing returns with this methodology as the number of samples significantly increases.

How many samples do you need for training a good discriminator?

Area under the precision recall curve (auprc) is commonly used in machine learning to evaluate discriminator performance. We compare the performance of the fine-tuning discriminator as we increase the number of training examples to the zero-shot davinci-instruct-beta baseline, which achieved 0.8 auprc.

As we can see the fine tuning performance on the task increases as we increase the number of training examples for fine-tuning. Usually around a hundred examples per class are needed to achieve better performance than with prompt design, which can also be observed on this dataset. Then the performance keeps increasing roughly linearly as we double the size of the fine-tuning training set.

We wanted to test how well a model trained on general inference tasks performs on this problem. It seems to perform roughly as well as a fine-tuned model with 20 examples, or a zero-shot davinci-instruct-beta well designed prompt.

Possible extensions to multiple discriminators

We could train multiple discriminators, and combine their outputs in interesting ways.

For example we could train a discriminator to predict how engaging the generated ads are, and then select the most engaging ad out of the ads which are truthful enough.

Conclusion

We were able to increase the % of publishable content to over 99.5% by sampling 10 generations using the technique described.

Did this answer your question?