Source code for cornac.eval_methods.ratio_split

# Copyright 2018 The Cornac Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from math import ceil

from .base_method import BaseMethod
from ..utils import get_rng
from ..utils.common import safe_indexing

[docs]class RatioSplit(BaseMethod): """Splitting data into training, validation, and test sets based on provided sizes. Data is always shuffled before split. Parameters ---------- data: array-like, required Raw preference data in the triplet format [(user_id, item_id, rating_value)]. test_size: float, optional, default: 0.2 The proportion of the test set, \ if > 1 then it is treated as the size of the test set. val_size: float, optional, default: 0.0 The proportion of the validation set, \ if > 1 then it is treated as the size of the validation set. rating_threshold: float, optional, default: 1.0 Threshold used to binarize rating values into positive or negative feedback for model evaluation using ranking metrics (rating metrics are not affected). seed: int, optional, default: None Random seed for reproducibility. exclude_unknowns: bool, optional, default: True If `True`, unknown users and items will be ignored during model evaluation. verbose: bool, optional, default: False Output running log. """ def __init__(self, data, test_size=0.2, val_size=0.0, rating_threshold=1.0, seed=None, exclude_unknowns=True, verbose=False, **kwargs): super().__init__(data=data, rating_threshold=rating_threshold, seed=seed, exclude_unknowns=exclude_unknowns, verbose=verbose, **kwargs) self.train_size, self.val_size, self.test_size = self.validate_size(val_size, test_size, len(self._data)) self._split() @staticmethod def validate_size(val_size, test_size, num_ratings): if val_size is None: val_size = 0.0 elif val_size < 0: raise ValueError('val_size={} should be greater than zero'.format(val_size)) elif val_size >= num_ratings: raise ValueError( 'val_size={} should be less than the number of ratings {}'.format(val_size, num_ratings)) if test_size is None: test_size = 0.0 elif test_size < 0: raise ValueError('test_size={} should be greater than zero'.format(test_size)) elif test_size >= num_ratings: raise ValueError( 'test_size={} should be less than the number of ratings {}'.format(test_size, num_ratings)) if val_size < 1: val_size = ceil(val_size * num_ratings) if test_size < 1: test_size = ceil(test_size * num_ratings) if val_size + test_size >= num_ratings: raise ValueError( 'The sum of val_size and test_size ({}) should be smaller than the number of ratings {}'.format( val_size + test_size, num_ratings)) train_size = num_ratings - (val_size + test_size) return int(train_size), int(val_size), int(test_size) def _split(self): data_idx = self.rng.permutation(len(self._data)) train_idx = data_idx[:self.train_size] test_idx = data_idx[-self.test_size:] val_idx = data_idx[self.train_size:-self.test_size] train_data = safe_indexing(self._data, train_idx) test_data = safe_indexing(self._data, test_idx) val_data = safe_indexing(self._data, val_idx) if len(val_idx) > 0 else None, test_data=test_data, val_data=val_data)