-
Notifications
You must be signed in to change notification settings - Fork 849
Description
Is your feature request related to a problem? Please describe.
I'm always frustrated when I try to use the SAR model with userCol
and itemCol
as string types. Currently, the SAR model only accepts these columns as integer types, which requires additional data preprocessing steps to convert string IDs to integers. This limitation can be cumbersome and time-consuming, especially when dealing with large datasets where user and item IDs are naturally represented as strings.
Describe the solution you'd like
I would like the SAR model to support userCol
and itemCol
as string types. This would allow for more flexibility and ease of use, as many real-world datasets use string identifiers for users and items. By supporting string types, the SAR model would eliminate the need for additional preprocessing steps, making it more user-friendly and efficient.
Additional context
Add any other context or screenshots about the feature request here.
Example Code
Here is an example of how the feature could be used if implemented:
import requests
import zipfile
import io
import pandas as pd
from pyspark.sql.types import DoubleType, StringType, LongType
from synapse.ml.recommendation import SAR
url = "http://files.grouplens.org/datasets/movielens/ml-25m.zip"
response = requests.get(url)
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
with z.open('ml-25m/ratings.csv') as csvfile:
pdf_ratings = pd.read_csv(csvfile)
# 明示的評価を暗黙的評価にするために全部に1.0を代入する
pdf_ratings["rating"] = 1.0
# pandas DataFrameをSpark DataFrameに変換
spark_df_ratings = spark.createDataFrame(pdf_ratings)
# 各列のデータ型を表示して確認
print("Before casting:")
spark_df_ratings.printSchema()
# データ型を明示的に変換
spark_df_ratings = spark_df_ratings.withColumn("userId", spark_df_ratings["userId"].cast(StringType()))
spark_df_ratings = spark_df_ratings.withColumn("movieId", spark_df_ratings["movieId"].cast(StringType()))
spark_df_ratings = spark_df_ratings.withColumn("rating", spark_df_ratings["rating"].cast(DoubleType()))
spark_df_ratings = spark_df_ratings.withColumn("timestamp", spark_df_ratings["timestamp"].cast(LongType()))
# 各列のデータ型を再度表示して確認
print("After casting:")
spark_df_ratings.printSchema()
# SARモデルの設定
sar = SAR(
userCol="userId",
itemCol="movieId",
ratingCol="rating",
timeCol="timestamp",
implicitPrefs=True,
activityTimeFormat="epoch"
)
# モデルのトレーニング
model = sar.fit(spark_df_ratings)
Activity
Support for userCol and itemCol as string types in SAR model