Skip to content

Easy Data Wrapper Tutorial

The data construction covered in the Data Management tutorial might be too complicated for users without prior experience in PyTorch. This tutorial offers a helper class to wrap the dataset, all the user needs to know is

(1) loading data-frames to Python, Pandas provides one-line solution to loading various types of data files including CSV, TSV, Stata, and Excel.

(2) basic usage of pandas.

We aim to make this tutorial as self-contained as possible, so you don't need to be worried if you haven't went through the Data Management tutorial. But we invite you to go through that tutorial to obtain a more in-depth understanding of data management in this project.

Author: Tianyu Du

Date: May. 20, 2022

Update: Jul. 9, 2022

__author__ = 'Tianyu Du'

Let's import a few necessary packages.

import pandas as pd
import torch
from torch_choice.utils.easy_data_wrapper import EasyDatasetWrapper

References and Background for Stata Users

This tutorial aim to show how to manage choice datasets using the torch-choice package, we will follow the Stata documentation here to offer a seamless experience for the user to transfer prior knowledge in other packages to our package.

From Stata Documentation: Choice models (CM) are models for data with outcomes that are choices. The choices are selected by a decision maker, such as a person or a business (i.e., the user), from a set of possible alternatives (i.e., the items). For instance, we could model choices made by consumers who select a breakfast cereal from several different brands. Or we could model choices made by businesses who chose whether to buy TV, radio, Internet, or newspaper advertising.

Models for choice data come in two varieties—models for discrete choices and models for rank-ordered alternatives. When each individual selects a single alternative, say, he or she purchases one box of cereal, the data are discrete choice data. When each individual ranks the choices, say, he or she orders cereals from most favorite to least favorite, the data are rank-ordered data. Stata has commands for fitting both discrete choice models and rank-ordered models.

Our torch-choice package handles the discrete choice models in the Stata document above.

Motivations

In the following parts, we demonstrate how to convert a long-format data (e.g., the one used in Stata) to the ChoiceDataset data format expected by our package.

But first, Why do we want another ChoiceDataset object instead of just one long-format data-frame? In earlier versions of Stata, we can only have one single data-frame loaded in memory, this would introduce memory error especially when teh dataset is large. For example, you have a dataset of a million decisions recorded, each consists of four items, and each item has a persistent built quality that stay the same in all observations. The Stata format would make a million copy of these variables, which is very inefficient.

We would need to collect a couple of data-frames as the essential pieces to build our ChoiceDataset. Don't worry, as soon as you have the data-frames ready, the EasyDataWrapper helper class would take care of the rest.

We call a single statistical observation a "purchase record" and use this terminology throughout the tutorial.

df = pd.read_stata('https://www.stata-press.com/data/r17/carchoice.dta')

We load the artificial dataset from the Stata website. Here we borrow the description of dataset reported from the describe command in Stata.

Contains data from https://www.stata-press.com/data/r17/carchoice.dta
 Observations:         3,160                  Car choice data
    Variables:             6                  30 Jul 2020 14:58
---------------------------------------------------------------------------------------------------------------------------------------------------
Variable      Storage   Display    Value
    name         type    format    label      Variable label
---------------------------------------------------------------------------------------------------------------------------------------------------
consumerid      int     %8.0g                 ID of individual consumer
car             byte    %9.0g      nation     Nationality of car
purchase        byte    %10.0g                Indicator of car purchased
gender          byte    %9.0g      gender     Gender: 0 = Female, 1 = Male
income          float   %9.0g                 Income (in $1,000)
dealers         byte    %9.0g                 No. of dealerships in community
---------------------------------------------------------------------------------------------------------------------------------------------------
Sorted by: consumerid  car

In this dataset, the first four rows with consumerid == 1 corresponds to the first purchasing record, it means the consumer with ID 1 was making the decision among four types of cars (i.e., items) and chose American car (since the purchase == 1 in that row of American car).

Even though there were four types of cars, not all of them were available all the time. For example, for the purchase record by consumer with ID 4, only American, Japanese, and European cars were available (note that there is no row in the dataset with consumerid == 4 and car == 'Korean', this indicates unavailability of a certain item.)

df.head(30)
consumerid car purchase gender income dealers
0 1 American 1 Male 46.699997 9
1 1 Japanese 0 Male 46.699997 11
2 1 European 0 Male 46.699997 5
3 1 Korean 0 Male 46.699997 1
4 2 American 1 Male 26.100000 10
5 2 Japanese 0 Male 26.100000 7
6 2 European 0 Male 26.100000 2
7 2 Korean 0 Male 26.100000 1
8 3 American 0 Male 32.700001 8
9 3 Japanese 1 Male 32.700001 6
10 3 European 0 Male 32.700001 2
11 4 American 1 Female 49.199997 5
12 4 Japanese 0 Female 49.199997 4
13 4 European 0 Female 49.199997 3
14 5 American 0 Male 24.299999 8
15 5 Japanese 0 Male 24.299999 3
16 5 European 1 Male 24.299999 3
17 6 American 1 Female 39.000000 10
18 6 Japanese 0 Female 39.000000 6
19 6 European 0 Female 39.000000 1
20 7 American 0 Male 33.000000 10
21 7 Japanese 0 Male 33.000000 6
22 7 European 1 Male 33.000000 4
23 7 Korean 0 Male 33.000000 1
24 8 American 1 Male 20.299999 6
25 8 Japanese 0 Male 20.299999 5
26 8 European 0 Male 20.299999 3
27 9 American 0 Male 38.000000 9
28 9 Japanese 1 Male 38.000000 9
29 9 European 0 Male 38.000000 2

Components of the Consumer Choice Modelling Problem

We begin with essential component of the consumer choice modelling problem. Walking through these components should help you understand what kind of data our models are working on.

Purchasing Record

Each row (record) of the dataset is called a purchasing record, which includes who bought what at when and where. Let \(B\) denote the number of purchasing records in the dataset (i.e., number of rows of the dataset). Each row \(b \in \{1,2,\dots, B\}\) corresponds to a purchase record (i.e., who bought what at where and when).

Items and Categories

To begin with, there are \(I\) items indexed by \(i \in \{1,2,\dots,I\}\) under our consideration.

Further, the researcher can optionally partition the set items into \(C\) categories indexed by \(c \in \{1,2,\dots,C\}\). Let \(I_c\) denote the collection of items in category \(c\), it is easy to verify that

\[ \bigcup_{c \in \{1, 2, \dots, C\}} I_c = \{1, 2, \dots I\} \]

If the researcher does not wish to model different categories differently, the researcher can simply put all items in one single category: \(I_1 = \{1, 2, \dots I\}\), so that all items belong to the same category.

Note: since we will be using PyTorch to train our model, we represent their identities with integer values instead of the raw human-readable names of items (e.g., Dell 24 inch LCD monitor). Raw item names can be encoded easily with sklearn.preprocessing.OrdinalEncoder.

Users

Each purchaing reocrd is naturally associated with an user indexed by \(u \in \{1,2,\dots,U\}\) (who) as well.

Sessions

Our data structure encompasses where and when using a notion called session indexed by \(s \in \{1,2,\dots, S\}\). For example, when the data came from a single store over the period of a year. In this case, the notion of where does not matter that much, and session \(s\) is simply the date of purchase.

Another example is that we have the purchase record from different stores, the session \(s\) can be defined as a pair of (date, store) instead.

If the researcher does not wish to handle records from different sessions differently, the researcher can assign the same session ID to all rows of the dataset.

To summarize, each purchasing record \(b\) in the dataset is characterized by a user-session-item tuple \((u, s, i)\).

When there are multiple items bought by the same user in the same session, there will be multiple rows in the dataset with the same \((u, s)\) corresponding to the same receipt.

Format the Dataset a Little Bit

The wrapper we built requires several data frames, providing the correct information is all we need to do in this tutorial, the data wrapper will handle the construction of ChoiceDataset for you.

Note: The dataset in this tutorial is a bit over-simplified, we only have one purchase record for each user in each session, so the consumerid column identifies all of the user, the session, and the purchase record (because we have different dealers for the same type of car, we define each purchase record of it's session instead of assigning all purchase records to the same session). That is, we have a single user makes a single choice in each single session.

The main dataset should contain the following columns:

  1. purchase_record_column: a column identifies purchase record (also called case in Stata syntax). this tutorial, the consumerid column is the identifier. For example, the first 4 rows of the dataset (see above) has consumerid == 1, this means we should look at the first 4 rows together and they constitute the first purchase record.
  2. item_name_column: a column identifies names of items, which is car in the dataset above. This column provides information above the availability as well. As mentioned above, there is no column with car == Korean in the fourth purchasing record (consumerid == 4), so we know that Korean car was not available that time.
  3. choice_column: a column identifies the choice made by the consumer in each purchase record, which is the purchase column in our example. Exactly one row per purchase record (i.e., rows with the same values in purchase_record_column) should have 1, while the values are zeros for all other rows.
  4. user_index_column: a optional column identifies the user making the choice, which is also consumerid in our case.
  5. session_index_column: a optional column identifies the session of the choice, which is also consumerid in our case.

As you might have noticed, the consumerid column in the data-frame identifies multiple pieces of information: purchase_record, user_index, and session_index. This is not a mistake, you can use the same column in df to supply multiple pieces of information.

df.gender.value_counts(dropna=False)
Male      2283
Female     854
NaN         23
Name: gender, dtype: int64

The only modification required is to convert gender (with values of Male, Female or NaN) to integers because PyTorch does not handle strings. For simplicity, we will assume all NaN gender to be Female (you should not do this in a real application!) and re-define the gender variable as \(\mathbb{I}\{\texttt{gender} == \texttt{Male}\}\).

# we change gender to binary 0/1 because pytorch doesn't handle strings.
df['gender'] = (df['gender'] == 'Male').astype(int)

Now the gender column contains only binary integers.

df.gender.value_counts(dropna=False)
1    2283
0     877
Name: gender, dtype: int64

The data-frame looks like the following right now:

df.head()
consumerid car purchase gender income dealers
0 1 American 1 1 46.699997 9
1 1 Japanese 0 1 46.699997 11
2 1 European 0 1 46.699997 5
3 1 Korean 0 1 46.699997 1
4 2 American 1 1 26.100000 10

Adding the Observables

The next step is to identify observables going into the model.

Specifically, we would want to add: 1. gender and income as user-specific observables 2. and dealers as (session, item)-specific observable. Such observables are called price observables in our setting, why? because price is the most typical (session, item)-specific observable.

Method 1: Adding Observables by Extracting Columns of the Dataset

As you can see, gender, income and dealers are already encompassed in df, the first way to add observables is simply mentioning these columns while initializing the EasyDatasetWrapper object.

You can supply a list of names of columns to each of {user, item, session, price}_observable_columns keyword argument. For example, we use user_observable_columns=['gender', 'income'] to inform the EasyDatasetWrapper that we wish to derive user-specific observables from the gender and income columns of df.

Also, we inform the EasyDatasetWrapper that we want to derive (session, item)-specific (i.e., price observable) by specifying price_observable_columns=['dealers'].

Since our package leverages GPU-acceleration, it is necessary to supply the device on which the dataset should reside. The EasyDatasetWrapper also takes a device keyword, which can be either 'cpu' or an appropriate CUDA device.

if torch.cuda.is_available():
    device = 'cuda'  # use GPU if available
else:
    device = 'cpu'  # use CPU otherwise
data_1 = EasyDatasetWrapper(main_data=df,
                            # TODO: better naming convention? Need to discuss.
                            # after discussion, we add it to the default value
                            # in the data wrapper class.
                            # these are just names.
                            purchase_record_column='consumerid',
                            choice_column='purchase',
                            item_name_column='car',
                            user_index_column='consumerid',
                            session_index_column='consumerid',
                            # it can be derived from columns of the dataframe or supplied as
                            user_observable_columns=['gender', 'income'],
                            price_observable_columns=['dealers'],
                            device=device)
Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}
Finished Creating Choice Dataset.

The dataset has a summary() method, which can be used to print out the summary of the dataset.

data_1.summary()
* purchase record index range: [1 2 3] ... [883 884 885]
* Space of 4 items:
                   0         1         2       3
item name  American  European  Japanese  Korean
* Number of purchase records/cases: 885.
* Preview of main data frame:
      consumerid       car  purchase  gender     income  dealers
0              1  American         1       1  46.699997        9
1              1  Japanese         0       1  46.699997       11
2              1  European         0       1  46.699997        5
3              1    Korean         0       1  46.699997        1
4              2  American         1       1  26.100000       10
...          ...       ...       ...     ...        ...      ...
3155         884  Japanese         1       1  20.900000       10
3156         884  European         0       1  20.900000        4
3157         885  American         1       1  30.600000       10
3158         885  Japanese         0       1  30.600000        5
3159         885  European         0       1  30.600000        4

[3160 rows x 6 columns]
* Preview of ChoiceDataset:
ChoiceDataset(label=[], item_index=[885], provided_num_items=[], user_index=[885], session_index=[885], item_availability=[885, 4], user_gender=[885, 1], user_income=[885, 1], itemsession_dealers=[885, 4, 1], device=cpu)

You can access the ChoiceDataset object constructed by calling the data.choice_dataset object.

data_1.choice_dataset
ChoiceDataset(label=[], item_index=[885], provided_num_items=[], user_index=[885], session_index=[885], item_availability=[885, 4], user_gender=[885, 1], user_income=[885, 1], itemsession_dealers=[885, 4, 1], device=cpu)

Method 2: Adding Observables as Data Frames

We can also construct data frames and use data frames to supply different observables. This is useful when you have a large dataset, for example, if there are many purchase records for the same user (to be concrete, say \(U\) users and \(N\) purchase records for each user, resulting \(U \times N\) total purchase records). Using a single data-frame requires a lot of memory: you need to store \(U \times N\) entires of user genders in total. However, user genders should be persistent across all purchasing records, if we use a separate data-frame mapping user index to gender of the user, we only need to store \(U\) entries (i.e., one for each user) of gender information.

Similarly, the long-format data requires storing each piece of item-specific information for number of purchase records times, which leads to inefficient usage of disk/memory space.

How Do Observable Data-frame Look Like?

Our package natively support the following four types of observables:

  1. User Observables: user-specific observables (e.g., gender and income) should (1) have length equal to the number of unique users in the dataset (885 here); (2) contains a column named as user_index_column (user_index_column is a variable, the actual column name should be the value of variable user_index_column! E.g., here the user observable data-frame should have a column named 'consumerid'); (3) the user observable can have any number of other columns beside the user_index_column column, each of them corresponding to a user-specific observable. For example, a data-frame containing \(X\) user-specific observables has shape (num_users, X + 1).

  2. Item Observables item-specific observables (not shown in this tutorial) should be (1) have length equal to the number of unique items in the dataset (4 here); (2) contain a column named as item_index_column (item_index_column is a variable, the actual column name should be the value of variable item_index_column! E.g., here the item observable data-frame should have a column named 'car'); (3) the item observable can have any number of other columns beside the item_index_column column, each of them corresponding to a item-specific observable.

  3. Session Observable session-specific observables (not shown in this tutorial) should be (1) have length equal to the number of unique sessions in the dataset; (2) contain a column named as session_index_column (session_index_column is a variable, the actual column name should be the value of variable session_index_column! E.g., here the session observable data-frame should have a column named 'consumerid'); (3) the session observable can have any number of other columns beside the session_index_column column, each of them corresponding to a session-specific observable.

  4. Price Observables (session, item)-specific observables (e.g., dealers) should be (1) contains a column named as session_index_column (e.g., consumerid in our example) and a column named as item_name_column (e.g., car in our example), (2) the price observable can have any number of other columns beside the session_index_column and item_name_column columns, each of them corresponding to a (session, item)-specific observable. For example, a data-frame containing \(X\) (session, item)-specific observables has shape (num_sessions, num_items, X + 2).

We encourage the reader to review the Data Management Tutorial for more details on types of observables.

Suggested Procedure of Storing and Loading Data

  1. Suppose SESSION_INDEX column in df_main is the index of the session, ALTERNATIVES column is the index of the car.
  2. For user-specific observables, you should have a CSV on disk with columns {consumerid, var_1, var_2, ...}.
  3. You load the user-specific dataset as user_obs = pd.read_csv(..., index='consumerid').

Let's first construct the data frame for user genders first.

gender = df.groupby('consumerid')['gender'].first().reset_index()

The user-observable data-frame contains a column of user IDs (the consumerid column), this column should have exactly the same name as the column containing user indices. Otherwise, the wrapper won't know which column corresponds to user IDs and which column corresponds to variables.

gender.head()
consumerid gender
0 1 1
1 2 1
2 3 1
3 4 0
4 5 1

Then, let's build the data-frame for user-specific income variables.

income = df.groupby('consumerid')['income'].first().reset_index()
income.head()
consumerid income
0 1 46.699997
1 2 26.100000
2 3 32.700001
3 4 49.199997
4 5 24.299999

Please note that we can have multiple observables contained in the same data-frame as well.

gender_and_income = df.groupby('consumerid')[['gender', 'income']].first().reset_index()
gender_and_income
consumerid gender income
0 1 1 46.699997
1 2 1 26.100000
2 3 1 32.700001
3 4 0 49.199997
4 5 1 24.299999
... ... ... ...
880 881 1 45.700001
881 882 1 69.800003
882 883 0 45.599998
883 884 1 20.900000
884 885 1 30.600000

885 rows × 3 columns

The price observable data-frame contains two columns identifying session (i.e., the consumerid column) and item (i.e., the car column). The session index column should have exactly the same name as the session index column in df and the column indexing columns should have exactly the same name as the item-name-column in df.

dealers = df[['consumerid', 'car', 'dealers']]
dealers.head()
consumerid car dealers
0 1 American 9
1 1 Japanese 11
2 1 European 5
3 1 Korean 1
4 2 American 10

Build Datasets using EasyDatasetWrapper with Observables as Data-Frames

We can observables as data-frames using {user, item, session, price}_observable_data keyword arguments.

data_2 = EasyDatasetWrapper(main_data=df,
                            purchase_record_column='consumerid',
                            choice_column='purchase',
                            item_name_column='car',
                            user_index_column='consumerid',
                            session_index_column='consumerid',
                            # above are the same as before, but we update the following.
                            user_observable_data={'gender': gender, 'income': income},
                            price_observable_data={'dealers': dealers},
                            device=device)
Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}
Finished Creating Choice Dataset.
# Use summary to see what's inside the data wrapper.
data_2.summary()
* purchase record index range: [1 2 3] ... [883 884 885]
* Space of 4 items:
                   0         1         2       3
item name  American  European  Japanese  Korean
* Number of purchase records/cases: 885.
* Preview of main data frame:
      consumerid       car  purchase  gender     income  dealers
0              1  American         1       1  46.699997        9
1              1  Japanese         0       1  46.699997       11
2              1  European         0       1  46.699997        5
3              1    Korean         0       1  46.699997        1
4              2  American         1       1  26.100000       10
...          ...       ...       ...     ...        ...      ...
3155         884  Japanese         1       1  20.900000       10
3156         884  European         0       1  20.900000        4
3157         885  American         1       1  30.600000       10
3158         885  Japanese         0       1  30.600000        5
3159         885  European         0       1  30.600000        4

[3160 rows x 6 columns]
* Preview of ChoiceDataset:
ChoiceDataset(label=[], item_index=[885], provided_num_items=[], user_index=[885], session_index=[885], item_availability=[885, 4], user_gender=[885, 1], user_income=[885, 1], itemsession_dealers=[885, 4, 1], device=cpu)

Alternatively, we can supply user income and gender as a single dataframe, instead of user_gender and user_income tensors, now the constructed ChoiceDataset contains a single user_gender_and_income tensor with shape (885, 2) encompassing both income and gender of users.

data_3 = EasyDatasetWrapper(main_data=df,
                            purchase_record_column='consumerid',
                            choice_column='purchase',
                            item_name_column='car',
                            user_index_column='consumerid',
                            session_index_column='consumerid',
                            # above are the same as before, but we update the following.
                            user_observable_data={'gender_and_income': gender_and_income},
                            price_observable_data={'dealers': dealers},
                            device=device)
Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}
Finished Creating Choice Dataset.
data_3.summary()
* purchase record index range: [1 2 3] ... [883 884 885]
* Space of 4 items:
                   0         1         2       3
item name  American  European  Japanese  Korean
* Number of purchase records/cases: 885.
* Preview of main data frame:
      consumerid       car  purchase  gender     income  dealers
0              1  American         1       1  46.699997        9
1              1  Japanese         0       1  46.699997       11
2              1  European         0       1  46.699997        5
3              1    Korean         0       1  46.699997        1
4              2  American         1       1  26.100000       10
...          ...       ...       ...     ...        ...      ...
3155         884  Japanese         1       1  20.900000       10
3156         884  European         0       1  20.900000        4
3157         885  American         1       1  30.600000       10
3158         885  Japanese         0       1  30.600000        5
3159         885  European         0       1  30.600000        4

[3160 rows x 6 columns]
* Preview of ChoiceDataset:
ChoiceDataset(label=[], item_index=[885], provided_num_items=[], user_index=[885], session_index=[885], item_availability=[885, 4], user_gender_and_income=[885, 2], itemsession_dealers=[885, 4, 1], device=cpu)

Method 3: Mixing Method 1 and Method 2

The EasyDataWrapper also support supplying observables as a mixture of above methods. The following example supplies gender user observable as a data-frame but income and dealers as column names.

data_4 = EasyDatasetWrapper(main_data=df,
                            purchase_record_column='consumerid',
                            choice_column='purchase',
                            item_name_column='car',
                            user_index_column='consumerid',
                            session_index_column='consumerid',
                            # above are the same as before, but we update the following.
                            user_observable_data={'gender': gender},
                            user_observable_columns=['income'],
                            price_observable_columns=['dealers'],
                            device=device)
Creating choice dataset from stata format data-frames...
Note: choice sets of different sizes found in different purchase records: {'size 4': 'occurrence 505', 'size 3': 'occurrence 380'}
Finished Creating Choice Dataset.
data_4.summary()
* purchase record index range: [1 2 3] ... [883 884 885]
* Space of 4 items:
                   0         1         2       3
item name  American  European  Japanese  Korean
* Number of purchase records/cases: 885.
* Preview of main data frame:
      consumerid       car  purchase  gender     income  dealers
0              1  American         1       1  46.699997        9
1              1  Japanese         0       1  46.699997       11
2              1  European         0       1  46.699997        5
3              1    Korean         0       1  46.699997        1
4              2  American         1       1  26.100000       10
...          ...       ...       ...     ...        ...      ...
3155         884  Japanese         1       1  20.900000       10
3156         884  European         0       1  20.900000        4
3157         885  American         1       1  30.600000       10
3158         885  Japanese         0       1  30.600000        5
3159         885  European         0       1  30.600000        4

[3160 rows x 6 columns]
* Preview of ChoiceDataset:
ChoiceDataset(label=[], item_index=[885], provided_num_items=[], user_index=[885], session_index=[885], item_availability=[885, 4], user_gender=[885, 1], user_income=[885, 1], itemsession_dealers=[885, 4, 1], device=cpu)

Sanity Checks

Lastly, let's check choice datasets constructed via different methods are actually the same.

The == method of choice datasets will compare the non-NAN entries of all tensors in datasets.

print(data_1.choice_dataset == data_2.choice_dataset)
print(data_1.choice_dataset == data_4.choice_dataset)
True
True

For data_3, we have income and gender combined:

data_3.choice_dataset.user_gender_and_income == torch.cat([data_1.choice_dataset.user_gender, data_1.choice_dataset.user_income], dim=1)
tensor([[True, True],
        [True, True],
        [True, True],
        ...,
        [True, True],
        [True, True],
        [True, True]])

Now let's compare what's inside the data structure and our raw data.

bought_raw = df[df['purchase'] == 1]['car'].values
bought_data = list()
encoder = {0: 'American', 1: 'European', 2: 'Japanese', 3: 'Korean'}
for b in data_1.choice_dataset.item_index:
    bought_data.append(encoder[float(b)])
all(bought_raw == bought_data)
True

Then, let's compare the income and gender variable contained in the dataset.

X = df.groupby('consumerid')['income'].first().values
Y = data_1.choice_dataset.user_income.cpu().numpy().squeeze()
all(X == Y)
True
X = df.groupby('consumerid')['gender'].first().values
Y = data_1.choice_dataset.user_gender.cpu().numpy().squeeze()
all(X == Y)
True

Lastly, let's compare the price_dealer variable. Since there are NAN-values in it for unavailable cars, we can't not use all(X == Y) to compare them. We will first fill NANs values with -1 and then compare resulted data-frames.

# rearrange columns to align it with the internal encoding scheme of the data wrapper.
X = df.pivot('consumerid', 'car', 'dealers')[['American', 'European', 'Japanese', 'Korean']]
X
car American European Japanese Korean
consumerid
1 9.0 5.0 11.0 1.0
2 10.0 2.0 7.0 1.0
3 8.0 2.0 6.0 NaN
4 5.0 3.0 4.0 NaN
5 8.0 3.0 3.0 NaN
... ... ... ... ...
881 8.0 2.0 10.0 NaN
882 8.0 6.0 8.0 1.0
883 9.0 5.0 8.0 1.0
884 12.0 4.0 10.0 NaN
885 10.0 4.0 5.0 NaN

885 rows × 4 columns

Y = data_1.choice_dataset.itemsession_dealers.squeeze(dim=-1)
Y
tensor([[ 9.,  5., 11.,  1.],
        [10.,  2.,  7.,  1.],
        [ 8.,  2.,  6., nan],
        ...,
        [ 9.,  5.,  8.,  1.],
        [12.,  4., 10., nan],
        [10.,  4.,  5., nan]])
print(X.fillna(-1).values == torch.nan_to_num(Y, -1).cpu().numpy())
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]
 ...
 [ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]

This concludes our tutorial on building the dataset, if you wish more in-depth understanding of the data structure, please refer to the Data Management Tutorial.