BioLLM: CellType Annotation

1. Prediction

scGPT:

from biollm.tasks.cell_annotation import CellAnnotation

config_file = './config/anno/scgpt.toml'
obj = CellAnnotation(config_file)
obj.run()

Geneformer:

from biollm.tasks.cell_annotation import CellAnnotation

config_file = './config/anno/geneformer.toml'
obj = CellAnnotation(config_file)
obj.run()

scFoundation:

from biollm.tasks.cell_annotation import CellAnnotation

config_file = './config/anno/scfoundation.toml'
obj = CellAnnotation(config_file)
obj.run()

scBERT:

from biollm.tasks.cell_annotation import CellAnnotation

config_file = './config/anno/scbert.toml'
obj = CellAnnotation(config_file)
obj.run()

CellPLM:

from biollm.tasks.cell_annotation import CellAnnotation

config_file = './config/anno/cellplm_ms.toml'
obj = CellAnnotation(config_file)
obj.run()

Note: The config directory can be found in the biollm/config/anno. Users can modify the corresponding parameters based on the path of their own input and output.

2. Evaluation

scGPT:

import scanpy as sc
import pickle
from sklearn.metrics import accuracy_score, f1_score


path = f'./output/scgpt/'  # the outputdir in the config file.
predict_label = pickle.load(open(path + 'predict_list.pk', 'rb'))
adata = sc.read_h5ad(
    f'./zheng68k.h5ad')
labels = adata.obs['celltype'].values
acc = accuracy_score(labels, predict_label)
macro_f1 = f1_score(labels, predict_label, average='macro')
res = {'acc': acc, 'macro_f1': macro_f1}
print(acc, macro_f1)

Geneformer:

import scanpy as sc
import pickle
from sklearn.metrics import accuracy_score, f1_score


path = f'./output/geneformer/'  # the outputdir in the config file.
predict_label = pickle.load(open(path + 'predict_list.pk', 'rb'))
adata = sc.read_h5ad(
    f'./zheng68k.h5ad')
labels = adata.obs['celltype'].values
acc = accuracy_score(labels, predict_label)
macro_f1 = f1_score(labels, predict_label, average='macro')
res = {'acc': acc, 'macro_f1': macro_f1}
print(acc, macro_f1)

scFoundation:

import scanpy as sc
import pickle
from sklearn.metrics import accuracy_score, f1_score


path = f'./output/scfoundation/'  # the outputdir in the config file.
predict_label = pickle.load(open(path + 'predict_list.pk', 'rb'))
adata = sc.read_h5ad(
    f'./zheng68k.h5ad')
labels = adata.obs['celltype'].values
acc = accuracy_score(labels, predict_label)
macro_f1 = f1_score(labels, predict_label, average='macro')
res = {'acc': acc, 'macro_f1': macro_f1}
print(acc, macro_f1)

scBERT:

import scanpy as sc
import pickle
from sklearn.metrics import accuracy_score, f1_score


path = f'./output/scbert/'  # the outputdir in the config file.
predict_label = pickle.load(open(path + 'predict_list.pk', 'rb'))
adata = sc.read_h5ad(
    f'./zheng68k.h5ad')
labels = adata.obs['celltype'].values
acc = accuracy_score(labels, predict_label)
macro_f1 = f1_score(labels, predict_label, average='macro')
res = {'acc': acc, 'macro_f1': macro_f1}
print(acc, macro_f1)

3. Visualization

import pandas as pd
from typing import Optional
from plottable import ColumnDefinition, Table
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import matplotlib
from plottable.plots import bar


_METRIC_TYPE = "Metric Type"

def plot_results_table(df, show: bool = True, save_path: Optional[str] = None) -> Table:
    """Plot the benchmarking results as bar charts for Accuracy and Macro F1 using different colormaps.

    Parameters
    ----------
    show
        Whether to show the plot.
    save_path
        The path to save the plot to. If `None`, the plot is not saved.
    """
    # Delete the 'Metric Type' row as it does not need to be displayed in the final table
    plot_df = df.drop(_METRIC_TYPE, axis=0)
    num_embeds = plot_df.shape[0]

    # Add “Dataset” as a Column
    plot_df["Dataset"] = plot_df.index

    # Define all columns as bar charts
    column_definitions = [
        ColumnDefinition("Dataset", width=1.5, textprops={"ha": "left", "weight": "bold"}),
    ]

    # Extract columns for “Accuracy” and “Macro F1”
    accuracy_cols = df.columns[df.loc[_METRIC_TYPE] == "Accuracy"]
    macro_f1_cols = df.columns[df.loc[_METRIC_TYPE] == "Macro F1"]

    colors = plt.get_cmap('PRGn')(np.linspace(0.25, 1, 256))
    new_colors = colors[::-1]
    new_cmap1 = LinearSegmentedColormap.from_list('modified_magma', new_colors, N=256)

    colors = plt.get_cmap('YlGnBu')(np.linspace(0, 1, 256))
    new_colors = colors[::-1]
    new_cmap2 = LinearSegmentedColormap.from_list('modified_magma', new_colors, N=256)

    # Define a bar chart for the “Accuracy” column
    column_definitions += [
        ColumnDefinition(
            col,
            width=1,
            title=col.split('.')[0],
            plot_fn=bar,
            plot_kw={
                "cmap": new_cmap1,
                "plot_bg_bar": False,
                "annotate": True,
                "height": 0.9,
                "formatter": "{:.2f}",
            },
            group=df.loc[_METRIC_TYPE, col],
        )
        for col in accuracy_cols
    ]

    # Define a bar chart for the “Macro F1” column
    column_definitions += [
        ColumnDefinition(
            col,
            width=1,
            title=col.split('.')[0],
            plot_fn=bar,
            plot_kw={
                "cmap": new_cmap2,
                "plot_bg_bar": False,
                "annotate": True,
                "height": 0.9,
                "formatter": "{:.2f}",
            },
            group=df.loc[_METRIC_TYPE, col],
        )
        for col in macro_f1_cols
    ]

    plt.rcParams['pdf.fonttype'] = 42  # Set PDF font type
    with matplotlib.rc_context({"svg.fonttype": "none"}):
        fig, ax = plt.subplots(figsize=(len(df.columns) * 1.3, 3 + 0.35 * num_embeds))
        ax.patch.set_facecolor("white")
        tab = Table(
            plot_df,
            cell_kw={
                "linewidth": 0,
                "edgecolor": "k",
            },
            column_definitions=column_definitions,
            ax=ax,
            row_dividers=True,
            footer_divider=True,
            textprops={"fontsize": 10, "ha": "center"},
            row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 5))},
            col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
            column_border_kw={"linewidth": 1, "linestyle": "-"},
            index_col="Dataset",
        ).autoset_fontcolors(colnames=plot_df.columns)
    
    if show:
        plt.show()
    if save_path is not None:
        fig.savefig(save_path, facecolor=ax.get_facecolor(), dpi=300)

    return tab

df = pd.read_csv('./annotation_performance.csv') # Regarding the model performance (Accuracy and Macro F1) of four models and three other annotation tools (scANVI, celltypist and singleR) on different datasets
df = df.set_index("dataset")
plot_results_table(df)

Sample Data

img_5.png

Figure

img_3.png