Source code for neer_match_utilities.split

import pandas as pd

[docs] class SplitError(Exception): """Custom exception for errors in data splitting.""" pass
[docs] def split_test_train(left: pd.DataFrame, right: pd.DataFrame, matches: pd.DataFrame, test_ratio: float = 0.3, validation_ratio: float = 0.1): """ Splits datasets into training, validation, and testing subsets. This function ensures that only observations from `left` and `right` that are referenced in the `matches` DataFrame are included in the split process. Parameters ---------- left : pd.DataFrame The left dataset to split. right : pd.DataFrame The right dataset to split. matches : pd.DataFrame A DataFrame containing matching pairs between the left and right datasets. It must include columns 'left' and 'right', referencing indices in `left` and `right`. test_ratio : float, optional The proportion of the data to be used for testing (default is 0.3). validation_ratio : float, optional The proportion of the data to be used for validation (default is 0.1). Returns ------- tuple A tuple containing: - left_train : pd.DataFrame - right_train : pd.DataFrame - matches_train : pd.DataFrame - left_validation : pd.DataFrame - right_validation : pd.DataFrame - matches_validation : pd.DataFrame - left_test : pd.DataFrame - right_test : pd.DataFrame - matches_test : pd.DataFrame Raises ------ SplitError If the total counts of split subsets do not match the original dataset size. """ # Filter `left` and `right` to only include rows referenced in `matches` left = left[left.index.isin(matches['left'])] right = right[right.index.isin(matches['right'])] # Get absolute counts for each subset matches_test_count = int(round(test_ratio * len(matches))) matches_validation_count = int(round(validation_ratio * len(matches))) matches_train_count = len(matches) - matches_test_count - matches_validation_count # Validate the requested counts if matches_test_count + matches_validation_count + matches_train_count != len(matches): raise SplitError("Split does not work correctly: The sum of observations in the three datasets is not equal to the number of observations in matches") def subsets(matches_count: int, matches_exclude_index: list): """Generates random subsets for matches DataFrame and corresponding rows in `left` and `right`.""" m = matches[~matches.index.isin(matches_exclude_index)].sample(n=matches_count).copy() l = left[left.index.isin(m['left'])].copy() r = right[right.index.isin(m['right'])].copy() return m, l, r matches_test, left_test, right_test = subsets( matches_test_count, matches_exclude_index=[] ) matches_exclude_index = matches_test.index.tolist() matches_validation, left_validation, right_validation = subsets( matches_validation_count, matches_exclude_index=matches_exclude_index ) matches_exclude_index += matches_validation.index.tolist() matches_train, left_train, right_train = subsets( matches_train_count, matches_exclude_index=matches_exclude_index ) def process_dataset(df): """ Processes a single dataset by adding an 'index_original' column, resetting the index, and creating a mapping dictionary. """ df['index_original'] = df.index.copy() df = df.reset_index(drop=True) mapping_dict = dict(zip(df['index_original'], df.index)) return df, mapping_dict # Process datasets individually left_train, left_train_dict = process_dataset(left_train) right_train, right_train_dict = process_dataset(right_train) left_test, left_test_dict = process_dataset(left_test) right_test, right_test_dict = process_dataset(right_test) left_validation, left_validate_dict = process_dataset(left_validation) right_validation, right_validation_dict = process_dataset(right_validation) # Update matches dataframes def update_matches(matches_df, left_dict, right_dict): matches_df['left'] = matches_df['left'].map(left_dict) matches_df['right'] = matches_df['right'].map(right_dict) return matches_df.sort_values(by='left', ascending=True).reset_index(drop=True) matches_train = update_matches( matches_train, left_train_dict, right_train_dict ) matches_test = update_matches( matches_test, left_test_dict, right_test_dict ) matches_validation = update_matches( matches_validation, left_validate_dict, right_validation_dict ) return left_train, right_train, matches_train, left_validation, right_validation, matches_validation, left_test, right_test, matches_test