From 769e5b3cb3108e11a939ca7f697ace61e0d7090b Mon Sep 17 00:00:00 2001 From: Cat Tom Date: Mon, 5 May 2025 02:29:05 +0800 Subject: [PATCH] Data Weaving & Preparation --- data_weaving_preparation.py | 26 ++ movielens_processor.py | 753 ++++++++++++++++++++++++++++++++++++ 2 files changed, 779 insertions(+) create mode 100644 data_weaving_preparation.py create mode 100644 movielens_processor.py diff --git a/data_weaving_preparation.py b/data_weaving_preparation.py new file mode 100644 index 0000000..528fb6c --- /dev/null +++ b/data_weaving_preparation.py @@ -0,0 +1,26 @@ +# 导入处理器 +from movielens_processor import MovieLensProcessor + +# 初始化处理器并指定数据路径 +processor = MovieLensProcessor(data_path='./dataset') + +# 1. 加载数据 +processor.load_data(verbose=True) + +# 2. 清洗数据 +processor.clean_data(verbose=True) + +# 3. 分析数据质量 +quality_metrics = processor.analyze_data_quality(plot=True) + +# 4. 创建评分矩阵 (稀疏表示) +rating_sparse = processor.create_rating_matrix(sparse=True) + +# 5. 划分训练测试集 (基于时间) +train_ratings, test_ratings = processor.split_train_test( + test_ratio=0.2, + method='time' +) + +# 6. 保存处理后的数据 +processor.save_processed_data(output_dir='./processed_data') diff --git a/movielens_processor.py b/movielens_processor.py new file mode 100644 index 0000000..0694cdf --- /dev/null +++ b/movielens_processor.py @@ -0,0 +1,753 @@ +# -*- coding: utf-8 -*- +""" +MovieLens数据集编织与清洗工具 +---------------------------------------- +作者: 首席数据科学家 +版本: 1.0.0 +""" + +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.sparse import csr_matrix +from datetime import datetime +import re +import warnings +from tqdm import tqdm + +# 设置pandas显示选项 +pd.set_option('display.max_columns', 50) +pd.set_option('display.width', 1000) +warnings.filterwarnings('ignore') + +class MovieLensProcessor: + """MovieLens数据集处理器""" + + def __init__(self, data_path='./'): + """ + 初始化处理器 + + 参数: + data_path (str): 数据文件所在目录 + """ + self.data_path = data_path + self.ratings = None + self.users = None + self.movies = None + self.rating_matrix = None + self.rating_sparse = None + + # 数据质量指标 + self.quality_metrics = {} + + def load_data(self, verbose=True): + """ + 加载MovieLens数据集 + + 参数: + verbose (bool): 是否显示加载进度 + + 返回: + self: 处理器实例 + """ + if verbose: + print("开始加载MovieLens数据集...") + + # 定义最优内存数据类型 + ratings_dtypes = { + 'user_id': np.int32, + 'movie_id': np.int32, + 'rating': np.float32, + 'timestamp': np.int64 + } + + users_dtypes = { + 'user_id': np.int32, + 'gender': 'category', + 'age': np.int8, + 'occupation': np.int8, + 'zip_code': 'category' + } + + # 使用分块读取大文件 + # 评分数据 + if verbose: + print("加载评分数据 (ratings.dat)...") + + try: + ratings_chunks = pd.read_csv( + os.path.join(self.data_path, 'ratings.dat'), + sep='::', + names=['user_id', 'movie_id', 'rating', 'timestamp'], + engine='python', + dtype=ratings_dtypes, + chunksize=250000 # 分块加载,每块250K行 + ) + + # 合并数据块 + ratings_list = [] + for chunk in tqdm(ratings_chunks, desc="读取评分数据") if verbose else ratings_chunks: + ratings_list.append(chunk) + + self.ratings = pd.concat(ratings_list) + + # 转换时间戳为日期时间 + self.ratings['datetime'] = pd.to_datetime(self.ratings['timestamp'], unit='s') + + # 提取时间特征 + self.ratings['year'] = self.ratings['datetime'].dt.year + self.ratings['month'] = self.ratings['datetime'].dt.month + self.ratings['dayofweek'] = self.ratings['datetime'].dt.dayofweek + self.ratings['hour'] = self.ratings['datetime'].dt.hour + + if verbose: + print(f"评分数据加载完成: {len(self.ratings)} 条记录") + + except Exception as e: + print(f"加载评分数据时出错: {e}") + raise + + # 用户数据 + if verbose: + print("加载用户数据 (users.dat)...") + + try: + self.users = pd.read_csv( + os.path.join(self.data_path, 'users.dat'), + sep='::', + names=['user_id', 'gender', 'age', 'occupation', 'zip_code'], + engine='python', + dtype=users_dtypes + ) + + if verbose: + print(f"用户数据加载完成: {len(self.users)} 条记录") + + except Exception as e: + print(f"加载用户数据时出错: {e}") + raise + + # 电影数据 + if verbose: + print("加载电影数据 (movies.dat)...") + + try: + self.movies = pd.read_csv( + os.path.join(self.data_path, 'movies.dat'), + sep='::', + names=['movie_id', 'title', 'genres'], + engine='python', + encoding='latin-1' # 处理特殊字符 + ) + + if verbose: + print(f"电影数据加载完成: {len(self.movies)} 条记录") + + except Exception as e: + print(f"加载电影数据时出错: {e}") + raise + + return self + + def clean_data(self, verbose=True): + """ + 清洗数据集 + + 参数: + verbose (bool): 是否显示处理进度 + + 返回: + self: 处理器实例 + """ + if any(df is None for df in [self.ratings, self.users, self.movies]): + raise ValueError("请先加载数据 (使用 load_data 方法)") + + if verbose: + print("开始数据清洗流程...") + + # 1. 记录原始数据量 + original_counts = { + 'ratings': len(self.ratings), + 'users': len(self.users), + 'movies': len(self.movies) + } + + # 2. 检测并处理重复记录 + if verbose: + print("检测重复记录...") + + ratings_duplicates = self.ratings.duplicated(subset=['user_id', 'movie_id']).sum() + if ratings_duplicates > 0: + if verbose: + print(f"发现 {ratings_duplicates} 条重复评分记录,保留最新评分...") + + # 对于重复评分,保留最新的一条 + self.ratings = self.ratings.sort_values('timestamp').drop_duplicates( + subset=['user_id', 'movie_id'], + keep='last' + ) + + # 3. 检测并处理异常值 + if verbose: + print("检测评分异常值...") + + # 检查评分范围 (应在1-5之间) + invalid_ratings = ((self.ratings['rating'] < 1) | (self.ratings['rating'] > 5)).sum() + if invalid_ratings > 0: + if verbose: + print(f"发现 {invalid_ratings} 条评分超出有效范围 (1-5)...") + + # 将异常评分限制在有效范围内 + self.ratings['rating'] = self.ratings['rating'].clip(1, 5) + + # 4. 提取电影信息 + if verbose: + print("处理电影标题和年份信息...") + + # 从电影标题中提取年份 + year_pattern = re.compile(r'\((\d{4})\)$') + + # 创建新列存储提取的信息 + self.movies['year'] = np.nan + self.movies['clean_title'] = self.movies['title'] + + for idx, title in enumerate(self.movies['title']): + match = year_pattern.search(title) + if match: + year = int(match.group(1)) + self.movies.at[idx, 'year'] = year + self.movies.at[idx, 'clean_title'] = title[:match.start()].strip() + + # 将年份转换为整数类型 + self.movies['year'] = self.movies['year'].astype('Int64') # 允许NaN值的整数类型 + + # 5. 处理流派数据 + if verbose: + print("处理电影流派信息...") + + # 创建one-hot编码的流派特征 + all_genres = set() + for genres in self.movies['genres']: + all_genres.update(genres.split('|')) + + all_genres = sorted(list(all_genres)) + + # 为每个流派创建独立列 + for genre in all_genres: + self.movies[f'genre_{genre}'] = self.movies['genres'].apply( + lambda x: 1 if genre in x.split('|') else 0 + ) + + # 计算每部电影的流派数量 + self.movies['genre_count'] = self.movies['genres'].apply(lambda x: len(x.split('|'))) + + # 6. 处理用户特征 + if verbose: + print("处理用户人口统计特征...") + + # 性别编码 + self.users['gender_code'] = self.users['gender'].map({'M': 1, 'F': 0}) + + # 对年龄段进行标准化处理 + age_mapping = { + 1: 'Under 18', + 18: '18-24', + 25: '25-34', + 35: '35-44', + 45: '45-49', + 50: '50-55', + 56: '56+' + } + + self.users['age_group'] = self.users['age'].map(age_mapping) + + # 创建职业名称映射 + occupation_mapping = { + 0: 'other', + 1: 'academic/educator', + 2: 'artist', + 3: 'clerical/admin', + 4: 'college/grad student', + 5: 'customer service', + 6: 'doctor/health care', + 7: 'executive/managerial', + 8: 'farmer', + 9: 'homemaker', + 10: 'K-12 student', + 11: 'lawyer', + 12: 'programmer', + 13: 'retired', + 14: 'sales/marketing', + 15: 'scientist', + 16: 'self-employed', + 17: 'technician/engineer', + 18: 'tradesman/craftsman', + 19: 'unemployed', + 20: 'writer' + } + + self.users['occupation_name'] = self.users['occupation'].map(occupation_mapping) + + # 7. 确保引用完整性 + if verbose: + print("验证数据引用完整性...") + + # 检查评分表中的用户和电影是否都存在于相应的数据集 + ratings_users = set(self.ratings['user_id']) + ratings_movies = set(self.ratings['movie_id']) + + valid_users = set(self.users['user_id']) + valid_movies = set(self.movies['movie_id']) + + invalid_users = ratings_users - valid_users + invalid_movies = ratings_movies - valid_movies + + if invalid_users: + if verbose: + print(f"评分数据中发现 {len(invalid_users)} 个无效用户ID") + + # 删除包含无效用户ID的评分 + self.ratings = self.ratings[~self.ratings['user_id'].isin(invalid_users)] + + if invalid_movies: + if verbose: + print(f"评分数据中发现 {len(invalid_movies)} 个无效电影ID") + + # 删除包含无效电影ID的评分 + self.ratings = self.ratings[~self.ratings['movie_id'].isin(invalid_movies)] + + # 8. 计算用户和电影统计特征 + if verbose: + print("计算用户和电影统计特征...") + + # 用户统计特征 + user_stats = self.ratings.groupby('user_id').agg( + rating_count=('rating', 'count'), + avg_rating=('rating', 'mean'), + rating_std=('rating', 'std'), + min_rating=('rating', 'min'), + max_rating=('rating', 'max'), + latest_rating=('timestamp', 'max') + ).reset_index() + + # 处理可能的NaN值 + user_stats['rating_std'] = user_stats['rating_std'].fillna(0) + + # 合并回用户数据集 + self.users = pd.merge(self.users, user_stats, on='user_id', how='left') + + # 电影统计特征 + movie_stats = self.ratings.groupby('movie_id').agg( + rating_count=('rating', 'count'), + avg_rating=('rating', 'mean'), + rating_std=('rating', 'std'), + min_rating=('rating', 'min'), + max_rating=('rating', 'max'), + latest_rating=('timestamp', 'max') + ).reset_index() + + # 处理可能的NaN值 + movie_stats['rating_std'] = movie_stats['rating_std'].fillna(0) + + # 合并回电影数据集 + self.movies = pd.merge(self.movies, movie_stats, on='movie_id', how='left') + + # 9. 记录数据清洗结果 + cleaned_counts = { + 'ratings': len(self.ratings), + 'users': len(self.users), + 'movies': len(self.movies) + } + + self.quality_metrics = { + 'original_counts': original_counts, + 'cleaned_counts': cleaned_counts, + 'removed_ratings': original_counts['ratings'] - cleaned_counts['ratings'], + 'duplicate_ratings': ratings_duplicates, + 'invalid_ratings': invalid_ratings, + 'invalid_users': len(invalid_users), + 'invalid_movies': len(invalid_movies) + } + + if verbose: + print("数据清洗完成!") + print(f"原始评分数: {original_counts['ratings']}, 清洗后: {cleaned_counts['ratings']}") + print(f"移除的评分: {self.quality_metrics['removed_ratings']}") + + return self + + def create_rating_matrix(self, sparse=True): + """ + 创建用户-电影评分矩阵 + + 参数: + sparse (bool): 是否创建稀疏矩阵 + + 返回: + pd.DataFrame or scipy.sparse.csr_matrix: 评分矩阵 + """ + if self.ratings is None: + raise ValueError("请先加载和清洗数据") + + # 使用pivot创建评分矩阵 + rating_matrix = self.ratings.pivot( + index='user_id', + columns='movie_id', + values='rating' + ) + + # 保存原始评分矩阵 + self.rating_matrix = rating_matrix + + # 计算稀疏度 + sparsity = 1.0 - rating_matrix.count().sum() / (rating_matrix.shape[0] * rating_matrix.shape[1]) + print(f"评分矩阵稀疏度: {sparsity:.4f}") + + # 如果需要稀疏表示 + if sparse: + # 创建COO稀疏矩阵并转换为CSR格式 + sparse_data = self.ratings['rating'].values + row_indices = self.ratings['user_id'].astype('category').cat.codes.values + col_indices = self.ratings['movie_id'].astype('category').cat.codes.values + + # 保存用户ID和电影ID的映射关系 + self.user_id_map = dict(zip( + np.unique(self.ratings['user_id']), + np.unique(row_indices) + )) + self.movie_id_map = dict(zip( + np.unique(self.ratings['movie_id']), + np.unique(col_indices) + )) + + # 逆映射 + self.user_id_reverse_map = {v: k for k, v in self.user_id_map.items()} + self.movie_id_reverse_map = {v: k for k, v in self.movie_id_map.items()} + + # 创建稀疏矩阵 + self.rating_sparse = csr_matrix( + (sparse_data, (row_indices, col_indices)), + shape=(len(self.user_id_map), len(self.movie_id_map)) + ) + + print(f"稀疏矩阵大小: {self.rating_sparse.shape}, 非零元素: {self.rating_sparse.nnz}") + return self.rating_sparse + + return self.rating_matrix + + def analyze_data_quality(self, plot=True): + """ + 分析数据质量并生成报告 + + 参数: + plot (bool): 是否生成可视化图表 + + 返回: + dict: 数据质量指标 + """ + if any(df is None for df in [self.ratings, self.users, self.movies]): + raise ValueError("请先加载和清洗数据") + + print("===== 数据质量分析报告 =====") + + # 1. 基础统计 + print("\n1. 基础统计:") + print(f"用户数: {len(self.users)}") + print(f"电影数: {len(self.movies)}") + print(f"评分数: {len(self.ratings)}") + + # 评分稀疏度 + total_possible = len(self.users) * len(self.movies) + actual_ratings = len(self.ratings) + sparsity = 1 - (actual_ratings / total_possible) + print(f"评分稀疏度: {sparsity:.6f} ({sparsity:.2%})") + + # 2. 评分分布 + print("\n2. 评分分布:") + rating_counts = self.ratings['rating'].value_counts().sort_index() + for rating, count in rating_counts.items(): + print(f" {rating}星: {count} ({count/len(self.ratings):.2%})") + + # 3. 用户行为分析 + print("\n3. 用户行为分析:") + user_rating_counts = self.ratings.groupby('user_id')['rating'].count() + print(f" 平均每用户评分数: {user_rating_counts.mean():.2f}") + print(f" 中位数每用户评分数: {user_rating_counts.median():.2f}") + print(f" 最小每用户评分数: {user_rating_counts.min()}") + print(f" 最大每用户评分数: {user_rating_counts.max()}") + + # 4. 电影流行度分析 + print("\n4. 电影流行度分析:") + movie_rating_counts = self.ratings.groupby('movie_id')['rating'].count() + print(f" 平均每电影评分数: {movie_rating_counts.mean():.2f}") + print(f" 中位数每电影评分数: {movie_rating_counts.median():.2f}") + print(f" 最小每电影评分数: {movie_rating_counts.min()}") + print(f" 最大每电影评分数: {movie_rating_counts.max()}") + + # 计算低评分电影比例 (长尾分布) + low_rated_movies = (movie_rating_counts < 20).sum() + print(f" 评分少于20次的电影: {low_rated_movies} ({low_rated_movies/len(movie_rating_counts):.2%})") + + # 5. 时间趋势分析 + print("\n5. 时间趋势分析:") + min_date = self.ratings['datetime'].min() + max_date = self.ratings['datetime'].max() + print(f" 数据时间范围: {min_date.date()} 至 {max_date.date()}") + + # 按月统计评分数量 + monthly_ratings = self.ratings.set_index('datetime').resample('M')['rating'].count() + print(f" 月均评分数: {monthly_ratings.mean():.2f}") + + # 6. 用户人口统计 + print("\n6. 用户人口统计:") + gender_dist = self.users['gender'].value_counts() + print(f" 性别分布: 男 {gender_dist.get('M', 0)} ({gender_dist.get('M', 0)/len(self.users):.2%}), " + f"女 {gender_dist.get('F', 0)} ({gender_dist.get('F', 0)/len(self.users):.2%})") + + age_dist = self.users['age_group'].value_counts().sort_index() + print(" 年龄分布:") + for age_group, count in age_dist.items(): + print(f" {age_group}: {count} ({count/len(self.users):.2%})") + + # 7. 电影类型分析 + print("\n7. 电影类型分析:") + genre_columns = [col for col in self.movies.columns if col.startswith('genre_') and col != 'genre_count'] + genre_counts = self.movies[genre_columns].sum().sort_values(ascending=False) + + print(" 流派分布 (前10):") + for genre, count in genre_counts.iloc[:10].items(): + genre_name = genre.replace('genre_', '') + print(f" {genre_name}: {count} ({count/len(self.movies):.2%})") + + # 生成可视化图表 + if plot: + print("\n生成数据质量可视化图表...") + + # 创建4x2网格布局 + fig, axs = plt.subplots(4, 2, figsize=(18, 20)) + fig.suptitle('MovieLens数据集质量分析', fontsize=16) + + # 1. 评分分布直方图 + sns.histplot(self.ratings['rating'], bins=9, kde=True, ax=axs[0, 0]) + axs[0, 0].set_title('评分分布') + axs[0, 0].set_xlabel('评分值') + axs[0, 0].set_ylabel('频次') + + # 2. 用户评分数量分布 + sns.histplot(user_rating_counts, bins=50, kde=True, ax=axs[0, 1]) + axs[0, 1].set_title('用户评分数量分布') + axs[0, 1].set_xlabel('每用户评分数') + axs[0, 1].set_ylabel('用户数') + axs[0, 1].set_yscale('log') + + # 3. 电影评分数量分布 + sns.histplot(movie_rating_counts, bins=50, kde=True, ax=axs[1, 0]) + axs[1, 0].set_title('电影评分数量分布') + axs[1, 0].set_xlabel('每电影评分数') + axs[1, 0].set_ylabel('电影数') + axs[1, 0].set_yscale('log') + + # 4. 月度评分数量趋势 + monthly_ratings.plot(ax=axs[1, 1]) + axs[1, 1].set_title('月度评分数量趋势') + axs[1, 1].set_xlabel('日期') + axs[1, 1].set_ylabel('评分数') + + # 5. 用户年龄分布 + sns.countplot(y='age_group', data=self.users, ax=axs[2, 0], order=age_dist.index) + axs[2, 0].set_title('用户年龄分布') + axs[2, 0].set_xlabel('用户数') + axs[2, 0].set_ylabel('年龄组') + + # 6. 用户职业分布 + occupation_counts = self.users['occupation_name'].value_counts().sort_values(ascending=False).head(15) + sns.barplot(y=occupation_counts.index, x=occupation_counts.values, ax=axs[2, 1]) + axs[2, 1].set_title('用户职业分布 (前15)') + axs[2, 1].set_xlabel('用户数') + axs[2, 1].set_ylabel('职业') + + # 7. 电影流派分布 + top_genres = genre_counts.head(15).sort_values() + sns.barplot(y=top_genres.index.str.replace('genre_', ''), x=top_genres.values, ax=axs[3, 0]) + axs[3, 0].set_title('电影流派分布 (前15)') + axs[3, 0].set_xlabel('电影数') + axs[3, 0].set_ylabel('流派') + + # 8. 电影发行年份分布 + decade_bins = pd.cut(self.movies['year'].dropna(), + bins=[1900, 1950, 1960, 1970, 1980, 1990, 2000], + labels=['1950前', '1950s', '1960s', '1970s', '1980s', '1990s']) + decade_counts = decade_bins.value_counts().sort_index() + sns.barplot(x=decade_counts.index, y=decade_counts.values, ax=axs[3, 1]) + axs[3, 1].set_title('电影发行年代分布') + axs[3, 1].set_xlabel('年代') + axs[3, 1].set_ylabel('电影数') + + plt.tight_layout(rect=[0, 0, 1, 0.97]) + plt.savefig('movielens_data_quality.png', dpi=300, bbox_inches='tight') + plt.show() + + # 构建并返回质量指标 + quality_metrics = { + 'users_count': len(self.users), + 'movies_count': len(self.movies), + 'ratings_count': len(self.ratings), + 'sparsity': sparsity, + 'avg_rating': self.ratings['rating'].mean(), + 'median_rating': self.ratings['rating'].median(), + 'ratings_per_user': { + 'mean': user_rating_counts.mean(), + 'median': user_rating_counts.median(), + 'min': user_rating_counts.min(), + 'max': user_rating_counts.max() + }, + 'ratings_per_movie': { + 'mean': movie_rating_counts.mean(), + 'median': movie_rating_counts.median(), + 'min': movie_rating_counts.min(), + 'max': movie_rating_counts.max() + }, + 'rating_distribution': rating_counts.to_dict(), + 'low_rated_movies_ratio': low_rated_movies/len(movie_rating_counts), + 'date_range': { + 'start': min_date.strftime('%Y-%m-%d'), + 'end': max_date.strftime('%Y-%m-%d'), + 'days': (max_date - min_date).days + }, + 'gender_ratio': { + 'male': gender_dist.get('M', 0) / len(self.users), + 'female': gender_dist.get('F', 0) / len(self.users), + } + } + + # 更新对象的质量指标 + self.quality_metrics.update(quality_metrics) + + return self.quality_metrics + + def split_train_test(self, test_ratio=0.2, method='time', seed=42): + """ + 将数据集拆分为训练集和测试集 + + 参数: + test_ratio (float): 测试集比例 (0-1之间) + method (str): 拆分方法 ('random', 'time', 'user') + seed (int): 随机种子 + + 返回: + tuple: (train_ratings, test_ratings) + """ + if self.ratings is None: + raise ValueError("请先加载数据") + + np.random.seed(seed) + + if method == 'time': + # 基于时间的拆分 + time_threshold = self.ratings['timestamp'].quantile(1 - test_ratio) + train_ratings = self.ratings[self.ratings['timestamp'] < time_threshold] + test_ratings = self.ratings[self.ratings['timestamp'] >= time_threshold] + + elif method == 'user': + # 对每个用户的评分进行拆分 + train_ratings = pd.DataFrame() + test_ratings = pd.DataFrame() + + for user_id, user_ratings in self.ratings.groupby('user_id'): + # 对每个用户,确保至少有10个评分用于训练 + n_user_ratings = len(user_ratings) + n_test = int(test_ratio * n_user_ratings) + + # 确保测试集不为空且训练集至少有10个样本 + n_test = min(max(1, n_test), n_user_ratings - 10) + + # 随机选择测试评分 + test_indices = np.random.choice(user_ratings.index, size=n_test, replace=False) + + user_test = user_ratings.loc[test_indices] + user_train = user_ratings.drop(test_indices) + + train_ratings = pd.concat([train_ratings, user_train]) + test_ratings = pd.concat([test_ratings, user_test]) + + else: # 默认随机拆分 + # 随机拆分 + shuffled_indices = np.random.permutation(len(self.ratings)) + test_size = int(test_ratio * len(self.ratings)) + + test_indices = shuffled_indices[:test_size] + train_indices = shuffled_indices[test_size:] + + train_ratings = self.ratings.iloc[train_indices] + test_ratings = self.ratings.iloc[test_indices] + + print(f"训练集: {len(train_ratings)} 评分 ({len(train_ratings)/len(self.ratings):.2%})") + print(f"测试集: {len(test_ratings)} 评分 ({len(test_ratings)/len(self.ratings):.2%})") + + # 验证拆分的有效性 + train_users = set(train_ratings['user_id']) + train_movies = set(train_ratings['movie_id']) + + test_users = set(test_ratings['user_id']) + test_movies = set(test_ratings['movie_id']) + + missing_users = test_users - train_users + missing_movies = test_movies - train_movies + + if missing_users: + print(f"警告: 测试集中有 {len(missing_users)} 个用户在训练集中不存在") + + if missing_movies: + print(f"警告: 测试集中有 {len(missing_movies)} 部电影在训练集中不存在") + + return train_ratings, test_ratings + + def save_processed_data(self, output_dir='./processed'): + """ + 保存处理后的数据 + + 参数: + output_dir (str): 输出目录 + + 返回: + bool: 是否保存成功 + """ + if any(df is None for df in [self.ratings, self.users, self.movies]): + raise ValueError("请先加载和清洗数据") + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + try: + # 保存DataFrame到CSV + self.ratings.to_csv(os.path.join(output_dir, 'ratings_processed.csv'), index=False) + self.users.to_csv(os.path.join(output_dir, 'users_processed.csv'), index=False) + self.movies.to_csv(os.path.join(output_dir, 'movies_processed.csv'), index=False) + + # 保存评分矩阵 + if self.rating_matrix is not None: + self.rating_matrix.to_pickle(os.path.join(output_dir, 'rating_matrix.pkl')) + + # 保存稀疏矩阵 (如果已创建) + if self.rating_sparse is not None: + import scipy.sparse as sp + sp.save_npz(os.path.join(output_dir, 'rating_sparse.npz'), self.rating_sparse) + + # 保存ID映射 + np.save(os.path.join(output_dir, 'user_id_map.npy'), self.user_id_map) + np.save(os.path.join(output_dir, 'movie_id_map.npy'), self.movie_id_map) + + # 保存质量报告 + if self.quality_metrics: + import json + with open(os.path.join(output_dir, 'quality_metrics.json'), 'w') as f: + json.dump(self.quality_metrics, f, indent=4, default=str) + + print(f"处理后的数据已保存至 {output_dir}") + return True + + except Exception as e: + print(f"保存数据时出错: {e}") + return False