Skip to content

Commit 75b9b3a

Browse files
committed
add movie recommandation example.
1 parent 345cf7d commit 75b9b3a

File tree

5 files changed

+215
-2
lines changed

5 files changed

+215
-2
lines changed

Dockerfile

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM tensorflow/tensorflow:0.12.0
1+
FROM tensorflow/tensorflow:1.0.0
22

33
ENV LANG C.UTF-8
44
RUN apt-get update && apt-get install -y bc
@@ -9,7 +9,7 @@ RUN rm -rf /notebooks/*
99

1010
COPY caicloud.tensorflow /caicloud.tensorflow
1111
COPY Deep_Learning_with_TensorFlow/datasets /notebooks/Deep_Learning_with_TensorFlow/datasets
12-
COPY Deep_Learning_with_TensorFlow/0.12.0 /notebooks/Deep_Learning_with_TensorFlow/0.12.0
12+
COPY Deep_Learning_with_TensorFlow/1.0.0 /notebooks/Deep_Learning_with_TensorFlow/1.0.0
1313
COPY run_tf.sh /run_tf.sh
1414

1515
CMD ["/run_tf.sh"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# 使用TensorFlow解决推荐问题
2+
3+
## 数据集
4+
这里使用[电影评级数据集](http://grouplens.org/datasets/movielens/)来模拟推荐问题。该数据集中数据格式如下:
5+
```
6+
1::1193::5::978300760
7+
1::661::3::978302109
8+
1::914::3::978301968
9+
1::3408::4::978300275
10+
1::2355::5::978824291
11+
```
12+
每一行包含了一个用户对一个电影的评分。比如第一行表示用户1对电影1193评分为5。数据中最后一列为时间戳,在本样例中我们并没有使用时间戳信息。这里我们的目标是对于给定的(用户,电影)对,预测给定用户对给定电影的评分。
13+
14+
运行一下命令可以下载数据:
15+
```
16+
./download_data.sh
17+
```
18+
19+
20+
## 任务训练
21+
通过以下脚本可以在本地训练:
22+
```
23+
./train_model.sh
24+
```
25+
26+
运行改脚本可以得到类似下面的结果:
27+
```
28+
Training begins @ 2017-05-18 00:24:33.373159
29+
Eval RMSE at round 0 is: 2.81291127205
30+
Eval RMSE at round 2000 is: 0.945966959
31+
Eval RMSE at round 4000 is: 0.933194696903
32+
Eval RMSE at round 6000 is: 0.927836835384
33+
Eval RMSE at round 8000 is: 0.923974812031
34+
Eval RMSE at round 10000 is: 0.92291110754
35+
Eval RMSE at round 12000 is: 0.919465661049
36+
Eval RMSE at round 14000 is: 0.918680250645
37+
Eval RMSE at round 16000 is: 0.917023718357
38+
Eval RMSE at round 18000 is: 0.915674805641
39+
Eval RMSE at round 20000 is: 0.91452050209
40+
Eval RMSE at round 22000 is: 0.915164649487
41+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env bash
2+
3+
DATA_DIR=/tmp/movielens
4+
SIZE=1m
5+
mkdir -p ${DATA_DIR}
6+
wget http://files.grouplens.org/datasets/movielens/ml-${SIZE}.zip -O ${DATA_DIR}/ml-${SIZE}.zip
7+
unzip ${DATA_DIR}/ml-${SIZE}.zip -d ${DATA_DIR}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# coding=utf-8
2+
3+
import time
4+
5+
import numpy as np
6+
import tensorflow as tf
7+
import pandas as pd
8+
9+
from caicloud.clever.tensorflow import dist_base
10+
from caicloud.clever.tensorflow import model_exporter
11+
12+
tf.app.flags.DEFINE_string("export_dir",
13+
"/tmp/saved_model/movie",
14+
"model export directory path.")
15+
16+
tf.app.flags.DEFINE_string("batch_size", 128, "training batch size.")
17+
tf.app.flags.DEFINE_string("embedding_dim", 50, "embedding dimension.")
18+
19+
FLAGS = tf.app.flags.FLAGS
20+
USER_NUM = 6040
21+
ITEM_NUM = 3952
22+
23+
def get_data():
24+
col_names = ["user", "item", "rate", "st"]
25+
df = pd.read_csv("/tmp/movielens/ml-1m/ratings.dat", sep="::", header=None, names=col_names, engine='python')
26+
27+
df["user"] -= 1
28+
df["item"] -= 1
29+
for col in ("user", "item"):
30+
df[col] = df[col].astype(np.int32)
31+
df["rate"] = df["rate"].astype(np.float32)
32+
33+
rows = len(df)
34+
print "Total number of instances: ", rows
35+
df = df.iloc[np.random.permutation(rows)].reset_index(drop=True)
36+
split_index = int(rows * 0.9)
37+
return df[0:split_index], df[split_index:]
38+
39+
class ShuffleIterator(object):
40+
def __init__(self, inputs, batch_size=10):
41+
self.inputs = inputs
42+
self.batch_size = batch_size
43+
self.num_cols = len(self.inputs)
44+
self.len = len(self.inputs[0])
45+
self.inputs = np.transpose(np.vstack([np.array(self.inputs[i]) for i in range(self.num_cols)]))
46+
47+
def __len__(self):
48+
return self.len
49+
50+
def __iter__(self):
51+
return self
52+
53+
def __next__(self):
54+
return self.next()
55+
56+
def next(self):
57+
ids = np.random.randint(0, self.len, (self.batch_size,))
58+
out = self.inputs[ids, :]
59+
return [out[:, i] for i in range(self.num_cols)]
60+
61+
_train, _test = get_data()
62+
_iter_train = ShuffleIterator([_train["user"], _train["item"], _train["rate"]], batch_size=FLAGS.batch_size)
63+
_train_op = None
64+
_infer = None
65+
_global_step = None
66+
_user_batch = None
67+
_item_batch = None
68+
_rate_batch = None
69+
_cost = None
70+
_rmse = None
71+
_local_step = 0
72+
73+
def inference(user_batch, item_batch, dim):
74+
w_user = tf.get_variable("embd_user", shape=[USER_NUM, dim],
75+
initializer=tf.truncated_normal_initializer(stddev=0.02))
76+
w_item = tf.get_variable("embd_item", shape=[ITEM_NUM, dim],
77+
initializer=tf.truncated_normal_initializer(stddev=0.02))
78+
79+
input1 = tf.nn.embedding_lookup(w_user, user_batch)
80+
input2 = tf.nn.embedding_lookup(w_item, item_batch)
81+
input = tf.concat([input1, input2], 1)
82+
83+
w = tf.get_variable("w", shape=[2*dim, 1], initializer=tf.truncated_normal_initializer(stddev=0.02))
84+
b = tf.get_variable("b", shape=[1], initializer=tf.constant_initializer(1))
85+
infer = tf.transpose(tf.matmul(input, w) + b, name="infer")
86+
return infer
87+
88+
def model_fn(sync, num_replicas):
89+
global _train_op, _infer, _user_batch, _item_batch, _rate_batch, _rmse, _cost, _global_step
90+
91+
_user_batch = tf.placeholder(tf.int32, shape=[None], name="user")
92+
_item_batch = tf.placeholder(tf.int32, shape=[None], name="item")
93+
_rate_batch = tf.placeholder(tf.float32, shape=[None], name="rate")
94+
95+
_infer = inference(_user_batch, _item_batch, FLAGS.embedding_dim)
96+
_global_step = tf.contrib.framework.get_or_create_global_step()
97+
98+
_cost = tf.square(_infer - _rate_batch)
99+
optimizer = tf.train.AdamOptimizer(0.001)
100+
_train_op = optimizer.minimize(_cost, global_step=_global_step)
101+
102+
_rmse = tf.sqrt(tf.reduce_mean(_cost))
103+
104+
def rmse_evalute_fn(session):
105+
return session.run(_rmse, feed_dict={
106+
_user_batch: _test["user"], _item_batch: _test["item"], _rate_batch: _test["rate"]})
107+
108+
# 定义模型导出配置
109+
model_export_spec = model_exporter.ModelExportSpec(
110+
export_dir=FLAGS.export_dir,
111+
input_tensors={"user": _user_batch, "item": _item_batch},
112+
output_tensors={"infer": _infer})
113+
114+
# 定义模型评测(准确率)的计算方法
115+
model_metric_ops = {
116+
"rmse": rmse_evalute_fn
117+
}
118+
119+
return dist_base.ModelFnHandler(
120+
global_step=_global_step,
121+
optimizer=optimizer,
122+
model_metric_ops=model_metric_ops,
123+
model_export_spec=model_export_spec,
124+
summary_op=None)
125+
126+
def train_fn(session, num_global_step):
127+
global _train_op, _infer, _user_batch, _item_batch, _rate_batch, _rmse, _local_step, _cost
128+
129+
users, items, rates = next(_iter_train)
130+
session.run(_train_op, feed_dict={_user_batch: users, _item_batch: items, _rate_batch: rates})
131+
132+
if _local_step % 2000 == 0:
133+
rmse, infer, cost = session.run([_rmse, _infer, _cost], feed_dict={_user_batch: _test["user"], _item_batch: _test["item"], _rate_batch: _test["rate"]})
134+
print("Eval RMSE at round {} is: {}".format(num_global_step, rmse))
135+
136+
_local_step += 1
137+
return False
138+
139+
if __name__ == '__main__':
140+
distTfRunner = dist_base.DistTensorflowRunner(model_fn = model_fn, gen_init_fn=None)
141+
distTfRunner.run(train_fn)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
#
3+
# Copyright 2017 Caicloud authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ==============================================================================
17+
18+
rm -rf /tmp/caicloud-dist-tf
19+
rm -rf /tmp/saved_model/movie
20+
21+
export TF_MAX_STEPS=30000
22+
export TF_SAVE_CHECKPOINTS_SECS=60
23+
export TF_SAVE_SUMMARIES_STEPS=1000
24+
python train.py

0 commit comments

Comments
 (0)