Training a SOTA model for Thermostability Prediction
5 min readI stumbled upon this paper that went viral last month, which designed proteins that were able to remain folded and functional even under extreme heat (150°C). I decided to hack on a project called Chai-Thermo, which is a model trained to predict the thermostability of a protein. Given that protein folding models have an immense implicit base knowledge of the physical and chemical properties of proteins, I wondered if it would be possible to repurpose the a protein folding model to solve a different problem – instead of predicting the protein structure from a sequence, could it predict the thermostability of the protein instead?
Data
For this task, we use the MegaScale dataset, which is approximately a ~260k sample dataset of proteins, their mutations, and the resulting change in thermostability. For example:
Sample 0:
Protein: r10_437_TrROS_Hall.pdb
Mutation: E1Q
ΔΔG: -0.295 kcal/mol
Sequence (47 aa): EPELVFKVRVRTKDGRELEIEVSAEDLEKLLEALPDIEEV...
This sample tells us that the original sequence was modified from E -> Q at position 1. And the change ΔΔG: -0.295 kcal/mol implies that this mutation was destabiliziing. (MegaScale dataset labels negative ΔΔG as destabilizing).

Chai-1
Chai-1 is a protein folding model which takes in a protein sequence and outputs a 3D structure. We want to retain the “brain” of the model, while training it to do something other than 3D structure prediction. Looking into the Chai-1 architecture, we see that Chai-1 has a trunk, which is the model’s “mental map of protein physical/chemical space” and a head, which flattens this embedding into 3D coordinates.
For us to achieve our goal, we can use the trunk of Chai-1 to generate embeddings, then train a model that predicts thermostability using these embeddings as inputs. This is also known as Transfer Learning.
Chai-1’s trunk exposes two types of embeddings: Single and Pair Embeddings. The single embeddings are an [L, 384] vector which is the model’s internal representation of the entire protein. The Pair Embeddings are are an [L, L, 256] vector which captures the pairwise relationship between any two amino acids in the protein. We use both to train the thermostability perdiction model.
Architecture
We first train the simplest possible model, an MLP, on these embeddings. For us to train this MLP, we can take the two embeddings that come out of the box from Chai-1 trunk, and add a few more features. Some of the additional features include the mean of all pairs, mean of distance with closest neighbors, mean of all top-k strongest interactions, and so on.
[Input Features] [MLP Architecture]
local_single ──┐
global_single ──┤ Concat (1577)
pair_global ──┤ │
pair_local_seq ──┼──▶ Linear (1577→512) + GELU
pair_structural ──┤ Linear (512→512) + GELU
mutation_feat ──┘ Linear (512→256) + GELU
Linear (256→1) ──▶ ΔΔG
The model, which is ~5M parameters already does decently out of the box on the test set. It achieves a mean Spearman of 0.7218 and RMSE of 0.7175. This is already ~approximately the same result as the ThermoMPNN model, on the same data splits. My 5M parameter MLP model is also very similar to the ThermoMPNN 4M parameters model.
Can we do better?
Alternative Architectures
I trained 2 other models with alternative architectures. First, I implemented a MPNN (message passing neural network), which is a graph neural network that is used by some other famous bio models like ProteinMPNN. It builds a graph for all the amino acids in the sequence, where the connections of each node is represented by the pair embeddings from Chai-1 trunk. The network learns to predict the final ΔΔG by summing the messages from its neighbors, and optimizing for a network which minimizes that loss.
Lastly, I also trained a transformer. While a regular transformer would need to learn which amino acids are close to each other, we modify this transformer to use the pairwise embeddings from Chai. We also use a “gate”, to make sure the model doesnt get forced to use the pairwise embeddings, and can still learn its own features.
Results
| Model | Architecture Type | Mean Spearman (↑) | RMSE (↓) | Notes |
|---|---|---|---|---|
| ThermoMPNN | ProteinMPNN + LA | 0.725 | 0.708 | Official SOTA baseline from PNAS 2024. |
| Chai-Thermo MLP | Pair-Aware MLP | 0.7218 | 0.7175 | Fast baseline; near-parity with official SOTA using Chai-1 features. |
| Chai-Thermo Transformer | Gated Attention | 0.7456 | 0.7004 | Uses ranking loss & structural bias gating to outperform the paper. |
| Chai-Thermo MPNN | Graph Neural Network | 0.7683 | 0.6982 | Best overall; explicitly antisymmetric and mutation-centered. |
The MPNN model performed the best of the 3, given that all of the models had similar parameter count and same data splits. The MPNN also significantly beats the ThermoMPNN model on the mean Spearman, as well as a slightly lower RMSE.
Conclusion
We can likely replicate this success with other biology prediction tasks – take a big protein folding model, then use transfer learning to repurpose it for a different task. For future work, we can run the same experiment for different base models including AlphaFold itself or other similar models.