1.准备数据集
训练数据:risk_assessment_host.csv、risk_assessment_guest.csv
2.准备配置文件
upload_host.json
{"file": "/data/aby_lr/data/risk_assessment_host.csv","head": true,"partitions": 20,"extend_sid": true,"meta": {"delimiter": ",","match_id_name": "id"},"namespace": "experiment","name": "risk_assessment_host"
}
upload_guest.json
{
"file": "/data/aby_lr/data/risk_assessment_guest.csv",
"head": true,
"partitions": 20,
"extend_sid": true,
"meta": {
"delimiter": ",",
"label_name": "y",
"match_id_name": "id"
},
"namespace": "experiment",
"name": "risk_assessment_guest"
}
3.数据上传
进入服务器准备:
host服务器IP:10.248.202.131
guest 服务器IP:10.248.202.216
进入文件目录: cd /data/aby_lr/
启动fate: source /data/projects/fate/bin/init_env.sh
双方将数据上传到 /data/aby_lr/data 目录下
host方执行:
pipeline init --ip 10.248.202.131 --port 9380flow data upload -c json/upload_host.jsonpipeline init --ip 10.248.202.216 --port 9380
注:(pipeline ip 切换后才能上传到fateflow,切回去才能执行test)
guest方执行:
flow data upload -c json/upload_guest.json
4.测试脚本train_lr.py
from fate_client.pipeline.components.fate import (SSHELR,Reader,PSI
)
from fate_client.pipeline import FateFlowPipeline
# create pipeline for training
pipeline = FateFlowPipeline().set_parties(guest="9999", host="10000")# 训练数据集读取
reader_0 = Reader("reader_0")
reader_0.guest.task_parameters(namespace="experiment", name="risk_assessment_guest")
reader_0.hosts[0].task_parameters(namespace="experiment", name="risk_assessment_host")
psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"])lr_0 = SSHELR("lr_0",train_data=psi_0.outputs["output_data"],batch_size=10000, #批大小tol=1e-5, #容忍度epochs=1, #迭代次数learning_rate=0.25, #学习率early_stop="diff")pipeline.add_tasks([reader_0, psi_0, lr_0])# compile and train
pipeline.compile()
pipeline.fit()
在host或guest 执行 python test_lr.py
5.查看日志
cat /data/projects/fate/fate_flow/logs/${task_id}/guest/9999/INFO