Source code for cornac.data.reader

# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import ast
import itertools
from collections import Counter


def ui_parser(tokens, line_idx, id_inline=False, **kwargs):
    if id_inline:
        return [(str(line_idx + 1), iid, 1.0) for iid in tokens]
    else:
        return [(tokens[0], iid, 1.0) for iid in tokens[1:]]


def uir_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], float(tokens[2]))]


def review_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], tokens[2])]


def uirt_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], float(tokens[2]), int(tokens[3]))]


def tup_parser(tokens, **kwargs):
    return [
        (
            tokens[0],
            tokens[1],
            [tuple(tup.split(kwargs.get("tup_sep"))) for tup in tokens[2:]],
        )
    ]


def ubi_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], tokens[2])]


def ubit_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], tokens[2], int(tokens[3]))]


def ubitjson_parser(tokens, **kwargs):
    return [
        (tokens[0], tokens[1], tokens[2], int(tokens[3]), ast.literal_eval(tokens[4]))
    ]


def sit_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], int(tokens[2]))]


def sitjson_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], int(tokens[2]), ast.literal_eval(tokens[3]))]


def usit_parser(tokens, **kwargs):
    return [(tokens[0], tokens[1], tokens[2], int(tokens[3]))]


def usitjson_parser(tokens, **kwargs):
    return [
        (tokens[0], tokens[1], tokens[2], int(tokens[3]), ast.literal_eval(tokens[4]))
    ]


PARSERS = {
    "UI": ui_parser,
    "UIR": uir_parser,
    "UIRT": uirt_parser,
    "UITup": tup_parser,
    "UIReview": review_parser,
    "UBI": ubi_parser,
    "UBIT": ubit_parser,
    "UBITJson": ubitjson_parser,
    "SIT": sit_parser,
    "SITJson": sitjson_parser,
    "USIT": usit_parser,
    "USITJson": usitjson_parser,
}


[docs] class Reader: """Reader class for reading data with different types of format. Parameters ---------- user_set: set, default = None Set of users to be retained when reading data. If `None`, all users will be included. item_set: set, default = None Set of items to be retained when reading data. If `None`, all items will be included. min_user_freq: int, default = 1 The minimum frequency of a user to be retained. If `min_user_freq = 1`, all users will be included. min_item_freq: int, default = 1 The minimum frequency of an item to be retained. If `min_item_freq = 1`, all items will be included. num_top_freq_user: int, default = 0 The number of top popular users to be retained. If `num_top_freq_user = 0`, all users will be included. num_top_freq_item: int, default = 0 The number of top popular items to be retained. If `num_top_freq_item = 0`, all items will be included. min_basket_size: int, default = 1 The minimum number of items of a basket to be retained. If `min_basket_size = 1`, all items will be included. max_basket_size: int, default = -1 The maximum number of items of a basket to be retained. If `min_basket_size = -1`, all items will be included. min_basket_sequence: int, default = 1 The minimum number of baskets of a user to be retained. If `min_basket_sequence = 1`, all baskets will be included. min_sequence_size: int, default = 1 The minimum number of items of a sequence to be retained. If `min_sequence_size = 1`, all sequences will be included. max_sequence_size: int, default = -1 The maximum number of items of a sequence to be retained. If `min_sequence_size = -1`, all sequences will be included. bin_threshold: float, default = None The rating threshold to binarize rating values (turn explicit feedback to implicit feedback). For example, if `bin_threshold = 3.0`, all rating values >= 3.0 will be set to 1.0, and the rest (< 3.0) will be discarded. encoding: str, default = `utf-8` Encoding used to decode the file. errors: int, default = None Optional string that specifies how encoding errors are to be handled. Pass 'strict' to raise a ValueError exception if there is an encoding error (None has the same effect), or pass 'ignore' to ignore errors. """ def __init__( self, user_set=None, item_set=None, min_user_freq=1, min_item_freq=1, num_top_freq_user=0, num_top_freq_item=0, min_basket_size=1, max_basket_size=-1, min_basket_sequence=1, min_sequence_size=1, max_sequence_size=-1, bin_threshold=None, encoding="utf-8", errors=None, ): self.user_set = ( user_set if (user_set is None or isinstance(user_set, set)) else set(user_set) ) self.item_set = ( item_set if (item_set is None or isinstance(item_set, set)) else set(item_set) ) self.min_uf = min_user_freq self.min_if = min_item_freq self.num_top_freq_user = num_top_freq_user self.num_top_freq_item = num_top_freq_item self.min_basket_size = min_basket_size self.max_basket_size = max_basket_size self.min_basket_sequence = min_basket_sequence self.min_sequence_size = min_sequence_size self.max_sequence_size = max_sequence_size self.bin_threshold = bin_threshold self.encoding = encoding self.errors = errors def _filter(self, tuples, fmt="UIR"): i_pos = fmt.find("I") u_pos = fmt.find("U") r_pos = fmt.find("R") if self.bin_threshold is not None and r_pos >= 0: def binarize(t): t = list(t) t[r_pos] = 1.0 return tuple(t) tuples = [binarize(t) for t in tuples if t[r_pos] >= self.bin_threshold] if self.num_top_freq_user > 0: user_freq = Counter(t[u_pos] for t in tuples) top_freq_users = set( k for (k, _) in user_freq.most_common(self.num_top_freq_user) ) tuples = [t for t in tuples if t[u_pos] in top_freq_users] if self.num_top_freq_item > 0: item_freq = Counter(t[i_pos] for t in tuples) top_freq_items = set( k for (k, _) in item_freq.most_common(self.num_top_freq_item) ) tuples = [t for t in tuples if t[i_pos] in top_freq_items] if self.user_set is not None: tuples = [t for t in tuples if t[u_pos] in self.user_set] if self.item_set is not None: tuples = [t for t in tuples if t[i_pos] in self.item_set] if self.min_uf > 1: user_freq = Counter(t[u_pos] for t in tuples) tuples = [t for t in tuples if user_freq[t[u_pos]] >= self.min_uf] if self.min_if > 1: item_freq = Counter(t[i_pos] for t in tuples) tuples = [t for t in tuples if item_freq[t[i_pos]] >= self.min_if] return tuples def _filter_basket(self, tuples, fmt="UBI"): u_pos = fmt.find("U") b_pos = fmt.find("B") if self.min_basket_size > 1: sizes = Counter(t[b_pos] for t in tuples) tuples = [t for t in tuples if sizes[t[b_pos]] >= self.min_basket_size] if self.max_basket_size > 1: sizes = Counter(t[b_pos] for t in tuples) tuples = [t for t in tuples if sizes[t[b_pos]] <= self.max_basket_size] if self.min_basket_sequence > 1: basket_sequence = Counter( u for (u, _) in set((t[u_pos], t[b_pos]) for t in tuples) ) tuples = [ t for t in tuples if basket_sequence[t[u_pos]] >= self.min_basket_sequence ] return tuples def _filter_sequence(self, tuples, fmt="SIT"): s_pos = fmt.find("S") if self.min_sequence_size > 1: sizes = Counter(t[s_pos] for t in tuples) tuples = [t for t in tuples if sizes[t[s_pos]] >= self.min_sequence_size] if self.max_sequence_size > 1: sizes = Counter(t[s_pos] for t in tuples) tuples = [t for t in tuples if sizes[t[s_pos]] <= self.max_sequence_size] return tuples
[docs] def read( self, fpath, fmt="UIR", sep="\t", skip_lines=0, id_inline=False, parser=None, **kwargs ): """Read data and parse line by line based on provided `fmt` or `parser`. Parameters ---------- fpath: str Path to the data file. fmt: str, default: 'UIR' Line format to be parsed ('UI', 'UIR', 'UIRT', 'UITup', 'UIReview', 'UBI', 'UBIT', or 'UBITJson') sep: str, default: '\t' The delimiter string. skip_lines: int, default: 0 Number of first lines to skip id_inline: bool, default: False If `True`, user ids corresponding to the line numbers of the file, where all the ids in each line are item ids. parser: function, default: None Function takes a list of `str` tokenized by `sep` and returns a list of tuples which will be joined to the final results. If `None`, parser will be determined based on `fmt`. Returns ------- tuples: list Data in the form of list of tuples. What inside each tuple depends on `parser` or `fmt`. """ parser = PARSERS.get(fmt, None) if parser is None else parser if parser is None: raise ValueError( "Invalid line format: {}\n" "Supported formats: {}".format(fmt, PARSERS.keys()) ) with open(fpath, encoding=self.encoding, errors=self.errors) as f: tuples = [ tup for idx, line in enumerate(itertools.islice(f, skip_lines, None)) for tup in parser( line.strip().split(sep), line_idx=idx, id_inline=id_inline, **kwargs ) ] tuples = self._filter(tuples=tuples, fmt=fmt) if fmt in {"UBI", "UBIT", "UBITJson"}: tuples = self._filter_basket(tuples=tuples, fmt=fmt) elif fmt in {"SIT", "SITJson", "USIT", "USITJson"}: tuples = self._filter_sequence(tuples=tuples, fmt=fmt) return tuples
[docs] def read_text(fpath, sep=None, encoding="utf-8", errors=None): """Read text file and return two lists of text documents and corresponding ids. If `sep` is None, only return one list containing elements are lines of text in the original file. Parameters ---------- fpath: str Path to the data file sep: str, default = None The delimiter string used to split `id` and `text`. Each line is assumed containing an `id` followed by corresponding `text` document. If `None`, each line will be a `str` in returned list. encoding: str, default = `utf-8` Encoding used to decode the file. errors: int, default = None Optional string that specifies how encoding errors are to be handled. Pass 'strict' to raise a ValueError exception if there is an encoding error (None has the same effect), or pass 'ignore' to ignore errors. Returns ------- texts, ids (optional): list, list Return list of text strings with corresponding indices (if `sep` is not None). """ with open(fpath, encoding=encoding, errors=errors) as f: if sep is None: return [line.strip() for line in f] else: texts, ids = [], [] for line in f: tokens = line.strip().split(sep) ids.append(tokens[0]) texts.append(sep.join(tokens[1:])) return texts, ids