Utilites to train and utilize SASRec

class sasrec.util.SASRecDataSet(filename=None, col_sep='\t')[source]

A class for creating SASRec specific dataset used during train, validation and testing.

Parameters
  • filename (str) – Data Filename.

  • col_sep (str) – column separator in the data file.

usernum

Total number of users.

Type

int

itemnum

Total number of items.

Type

int

User

All the users (keys) with items as values.

Type

dict

Items

Set of all the items.

Type

set

user_train

Subset of User that are used for training.

Type

dict

user_valid

Subset of User that are used for validation.

Type

dict

user_test

Subset of User that are used for testing.

Type

dict

filename

Data Filename. Defaults to None.

Type

str

col_sep

Column separator in the data file. Defaults to ‘/t’.

Type

str

Examples

>>> data = SASRecDataSet('filename','/t')
sasrec.util.filter_k_core(data, core_num=0, col_user='userID', col_item='itemID')[source]

Filter rating dataframe for minimum number of users and items by # repeatedly applying min_rating_filter until the condition is satisfied.

Parameters
  • data (pd.DataFrame) – DataFrame to filter.

  • core_num (int, optional) – Minimun number for user and item to appear on data. Defaults to 0.

  • col_user (str, optional) – User column name. Defaults to “userID”.

  • col_item (str, optional) – Item column name. Defaults to “itemID”.

Returns

pd.DataFrame – Filtered dataframe

sasrec.util.load_model(path, exp_name='sas_experiment')[source]

Load SASRec model

Parameters
  • path (str) – Path where the model is saved.

  • exp_name (str, optional) – Experiment name (folder name). Defaults to ‘sas_experiment’.

Returns

model.SASREC – loaded SASRec model

sasrec.util.min_rating_filter_pandas(data, min_rating=1, filter_by='user', col_user='userID', col_item='itemID')[source]

Filter rating DataFrame for each user with minimum rating.

Filter rating data frame with minimum number of ratings for user/item is usually useful to generate a new data frame with warm user/item. The warmth is defined by min_rating argument. For example, a user is called warm if he has rated at least 4 items.

Parameters
  • data (pd.DataFrame) – DataFrame of user-item tuples. Columns of user and item should be present in the DataFrame while other columns like rating, timestamp, etc. can be optional.

  • min_rating (int) – Minimum number of ratings for user or item.

  • filter_by (str) – Either “user” or “item”, depending on which of the two is to filter with min_rating.

  • col_user (str) – Column name of user ID.

  • col_item (str) – Column name of item ID.

Returns

pandas.DataFrame – DataFrame with at least columns of user and item that has been filtered by the given specifications.