Datagnosis Tutorial 01 - simple tabular example#

*If you prefer, this tutorial is also available on Open in Colab

In this tutorial we will see how to use “hardness characterization method” plugins to calculate the hardness scores for the data points in a dataset. We will also plot these values and extract some data points based on these scores. For this tutorial we will be using the iris dataset from scikit learn. For a more realistic dataset checkout tutorials 2 and 3!

OK, Lets start!

First we import our logger from datagnosis and set the logging level at “INFO”. If something goes wrong and you want to see more detailed logs, you can change the logging level to “DEBUG” or, conversely, if you don’t want to see any logs you can remove them with log.remove().

[1]:
import sys
import datagnosis.logger as log
log.add(sink=sys.stderr, level="INFO")

Load the dataset

[2]:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True, as_frame=True)
df = X.copy(deep=True)
df['target'] = y
display(df)
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2

150 rows × 5 columns

Do some pre-processing on the data if you like, such as scaling.

The next key step is to then pass the data to the DataHandler object provided by Datagnosis. This is done by passing the features and the labels separately. The features can be a pandas.DataFrame, numpy.ndarray or torch.Tensor. The labels can be pandas.series, numpy.ndarray or torch.Tensor.

[3]:

from datagnosis.plugins.core.datahandler import DataHandler from datagnosis.plugins.core.models.simple_mlp import SimpleMLP from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split import torch import torch.nn as nn std_scaler = StandardScaler() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0) X_train = std_scaler.fit_transform(X_train) X_test = std_scaler.transform(X_test) datahander = DataHandler(X_train, y_train, batch_size=32)

Now we define some values which we will pass to the plugin, such as the model that we want to use to classify the data.

[4]:

# creating our model object, which we both want to use downstream, but also we will use to judge the hardness of the data points model = SimpleMLP() # creating our optimizer and loss function objects learning_rate = 0.01 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

Import the Plugins object from Datagnosis. Then by calling list() on the we can see all the available plugins that we can use.

[5]:
# datagnosis absolute
from datagnosis.plugins import Plugins

plugins = Plugins().list()
print(plugins)

['large_loss', 'conf_agree', 'confident_learning', 'grand', 'data_maps', 'allsh', 'el2n', 'data_iq', 'forgetting', 'prototypicality', 'vog', 'aum']

Now we can call get() to load up a specific plugin from the list.

[6]:
hcm = Plugins().get(
    "data_iq",
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    lr=learning_rate,
    epochs=10,
    num_classes=3,
    logging_interval=1,
)

Next we need to fit() the plugin

[7]:

hcm.fit( datahandler=datahander, use_caches_if_exist=True, )
[2023-08-29T15:24:24.230806+0100][20574][INFO] Fitting data_iq
[2023-08-29T15:24:25.672759+0100][20574][INFO] Epoch 1/10: Loss=0.6950
[2023-08-29T15:24:25.683451+0100][20574][INFO] Epoch 2/10: Loss=0.2694
[2023-08-29T15:24:25.693171+0100][20574][INFO] Epoch 3/10: Loss=0.1799
[2023-08-29T15:24:25.702520+0100][20574][INFO] Epoch 4/10: Loss=0.0955
[2023-08-29T15:24:25.711981+0100][20574][INFO] Epoch 5/10: Loss=0.0779
[2023-08-29T15:24:25.721269+0100][20574][INFO] Epoch 6/10: Loss=0.1023
[2023-08-29T15:24:25.730489+0100][20574][INFO] Epoch 7/10: Loss=0.0663
[2023-08-29T15:24:25.740372+0100][20574][INFO] Epoch 8/10: Loss=0.0481
[2023-08-29T15:24:25.750483+0100][20574][INFO] Epoch 9/10: Loss=0.0544
[2023-08-29T15:24:25.760021+0100][20574][INFO] Epoch 10/10: Loss=0.0513
[7]:
<datagnosis.plugins.generic.plugin_data_iq.DataIQPlugin at 0x7fbb47227510>

Now the plugin has been fit we can access scores. First, lets get a description of the scores then print them.

[8]:
print(hcm.score_description())
print(hcm.scores)

Compute scores returns two scores for this data_iq plugin. The first is the Aleatoric
Uncertainty and the second is the Confidence. Aleatoric uncertainty permits a principled characterization
and then subsequent stratification of data examples into three distinct subgroups (Easy, Ambiguous, Hard).
Confidence is a measure of the model's confidence in its prediction. High Confidence predictions
define the category `Easy`. Low Confidence scores define `Hard`. High Aleatoric Uncertainty scores define ambiguous.

(array([0.99943835, 0.99170929, 1.        , 1.        , 1.        ,
       1.        , 1.        , 0.9221698 , 1.        , 1.        ,
       1.        , 0.99793088, 0.99968648, 0.99883789, 0.99995375,
       0.46798903, 1.        , 0.99981993, 1.        , 1.        ,
       0.99546427, 1.        , 1.        , 0.99993074, 0.99949455,
       1.        , 0.99999726, 0.99968672, 1.        , 0.99958581,
       0.73497045, 0.99876583, 0.88869566, 0.99999678, 0.99999595,
       1.        , 0.99989605, 0.99786913, 0.99623692, 1.        ,
       0.99995482, 0.99875927, 0.99916625, 0.99989748, 0.99858725,
       1.        , 1.        , 0.94479924, 1.        , 0.31462792,
       0.99977821, 0.9994086 , 1.        , 0.99992156, 0.99995446,
       0.99444205, 0.99857104, 0.99801838, 0.9772346 , 1.        ,
       1.        , 0.99871087, 0.99287075, 0.99999702, 1.        ,
       0.99780184, 0.98650694, 1.        , 0.97743428, 0.93475664,
       1.        , 0.99998736, 1.        , 0.99855858, 1.        ,
       0.99977607, 0.8169474 , 0.9999665 , 0.99627388, 0.9994874 ,
       0.99947435, 0.99857438, 0.99835324, 0.99951637, 0.99879587,
       0.90839595, 1.        , 0.99986982, 1.        , 0.99963391,
       0.99742484, 0.44496712, 0.9999603 , 0.99999988, 0.9986279 ,
       1.        , 0.99992907, 0.99993491, 0.99860638, 0.99991441,
       1.        , 0.99986458, 1.        , 0.95097935, 1.        ,
       1.        , 1.        , 0.99627388, 0.83549249, 0.99996197,
       0.99943608, 0.99930859]), array([5.61339112e-04, 8.22197222e-03, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.17726561e-02,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.06483440e-03,
       3.13422136e-04, 1.16076126e-03, 4.62510650e-05, 2.48975298e-01,
       0.00000000e+00, 1.80033208e-04, 0.00000000e+00, 0.00000000e+00,
       4.51516176e-03, 0.00000000e+00, 0.00000000e+00, 6.92558002e-05,
       5.05191911e-04, 0.00000000e+00, 2.74180614e-06, 3.13183867e-04,
       0.00000000e+00, 4.14021121e-04, 1.94788887e-01, 1.23265059e-03,
       9.89156860e-02, 3.21864046e-06, 4.05309942e-06, 0.00000000e+00,
       1.03939695e-04, 2.12632546e-03, 3.74891887e-03, 0.00000000e+00,
       4.51782795e-05, 1.23919087e-03, 8.33054632e-04, 1.02509479e-04,
       1.41075343e-03, 0.00000000e+00, 0.00000000e+00, 5.21536322e-02,
       0.00000000e+00, 2.15637190e-01, 2.21739693e-04, 5.91047535e-04,
       0.00000000e+00, 7.84335597e-05, 4.55358749e-05, 5.52706346e-03,
       1.42691982e-03, 1.97768922e-03, 2.22471347e-02, 0.00000000e+00,
       0.00000000e+00, 1.28746740e-03, 7.07842572e-03, 2.98022336e-06,
       0.00000000e+00, 2.19332779e-03, 1.33109984e-02, 0.00000000e+00,
       2.20565106e-02, 6.09866669e-02, 0.00000000e+00, 1.26360250e-05,
       0.00000000e+00, 1.43934144e-03, 0.00000000e+00, 2.23884504e-04,
       1.49544345e-01, 3.34966883e-05, 3.71224076e-03, 5.12337186e-04,
       5.25377051e-04, 1.42359149e-03, 1.64404532e-03, 4.83398188e-04,
       1.20268310e-03, 8.32127513e-02, 0.00000000e+00, 1.30159598e-04,
       0.00000000e+00, 3.65957705e-04, 2.56852763e-03, 2.46971382e-01,
       3.96951176e-05, 1.19209275e-07, 1.37021627e-03, 0.00000000e+00,
       7.09244963e-05, 6.50840356e-05, 1.39167403e-03, 8.55849439e-05,
       0.00000000e+00, 1.35403414e-04, 0.00000000e+00, 4.66176241e-02,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.71224076e-03,
       1.37444788e-01, 3.80263173e-05, 5.63601539e-04, 6.90935826e-04]))

Printing the scores leaves them difficult to digest, so now we will plot them instead. We can plot 1-dimentional scores in two different ways with plot_type="dist" or plot_type="scatter". Why not have a look at both types and compare?

[9]:

hcm.plot_scores(axis=1, plot_type="dist")
[2023-08-29T15:24:25.784857+0100][20574][INFO] Plotting data_iq scores
../_images/Tutorials_tutorial_01_simple_tabular_example_17_1.png

Finally the extract_datapoints method can be used to select data based on the hcm score. Available methods for extract include "top_n", "threshold" and "index". Give them all a go!

The following cell takes the hardest 10 data points summarises them in a pandas.DataFrame.

[10]:
import pandas as pd
print(f"Data points that are hard to classify have scores that are: {hcm.hard_direction()}")
hardest_10 = hcm.extract_datapoints(method="index", indices=[1,2,6,10])

display(pd.DataFrame(
    data={
        "indices":hardest_10[0][2],
        f"{X.columns[0]}": hardest_10[0][0].transpose(0,1)[0],
        f"{X.columns[1]}": hardest_10[0][0].transpose(0,1)[1],
        f"{X.columns[2]}": hardest_10[0][0].transpose(0,1)[2],
        f"{X.columns[3]}": hardest_10[0][0].transpose(0,1)[3],
        "labels": hardest_10[0][1],
        "scores": hardest_10[1],
    }
))
Data points that are hard to classify have scores that are: low
indices sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) labels scores
0 1 -0.099845 -1.040395 0.113560 -0.029841 1 0.991709
1 2 1.053005 -0.119255 0.950314 1.127075 2 1.000000
2 6 -0.560985 1.492741 -1.281031 -1.315303 0 1.000000
3 10 0.130725 -1.961535 0.671396 0.355798 2 1.000000
[11]:
print(hardest_10)
((tensor([[-0.0998, -1.0404,  0.1136, -0.0298],
        [ 1.0530, -0.1193,  0.9503,  1.1271],
        [-0.5610,  1.4927, -1.2810, -1.3153],
        [ 0.1307, -1.9615,  0.6714,  0.3558]]), tensor([1, 2, 0, 2]), [1, 2, 6, 10]), array([0.99170929, 1.        , 1.        , 1.        ]))