图算法源码学习stellargraph-1.graphsage
2020-04-29 本文已影响0人
logi
为了进一步熟悉基本的图方法,简单阅读github上的一个优秀源码 https://github.com/stellargraph/stellargraph ,做一些简单记录。
简介
此项目基于tf2.0实现了常用的图算法如 GraphSage、GCN、Node2Vec等
项目结构
├── AUTHORS
├── CHANGELOG.md
├── CONTRIBUTING.md
├── CONTRIBUTORS
├── LICENSE
├── MANIFEST.in
├── README.md
├── RELEASE_PROCEDURE.md
├── codecov.yml
├── demos # 例子,建议从这学起
│ ├── README.md
│ ├── basics
│ ├── calibration
│ ├── community_detection
│ ├── connector
│ ├── embeddings
│ ├── ensembles
│ ├── graph-classification
│ ├── interpretability
│ ├── link-prediction
│ ├── node-classification
│ └── use-cases
├── docker # docker环境
│ ├── stellargraph
│ ├── stellargraph-ci-runner
│ ├── stellargraph-neo4j
│ └── stellargraph-treon
├── docker-compose.yml
├── docs # 文档
│ ├── Makefile
│ ├── README.md -> ../README.md
│ ├── api.txt
│ ├── conf.py
│ ├── hinsage.txt
│ ├── images
│ ├── index.txt
│ └── requirements.txt
├── meta.yaml
├── pytest.ini
├── requirements.txt
├── scripts
│ ├── README.md
│ ├── format_notebooks.py
│ ├── test_demos.py
│ └── whitespace.sh
├── setup.py
├── stellar-graph-banner.png
├── stellargraph #核心代码
│ ├── __init__.py
│ ├── calibration.py
│ ├── connector
│ ├── core
│ ├── data
│ ├── datasets #读取并构建图 这个需要看
│ ├── ensemble.py
│ ├── globalvar.py
│ ├── interpretability
│ ├── layer #具体策略实现部分
│ ├── losses.py
│ ├── mapper
│ ├── random.py
│ ├── utils
│ └── version.py
└── tests
├── __init__.py
├── core
├── data
├── datasets
├── interpretability
├── layer
├── mapper
├── reproducibility
├── resources
├── test_calibration.py
├── test_ensemble.py
├── test_losses.py
├── test_random.py
└── test_utils
39 directories, 39 files
1. 数据读取
以directed-graphsage-on-cora-example为例,
dataset = datasets.Cora()
display(HTML(dataset.description))
G, node_subjects = dataset.load(directed=True)
第三行为读取dataframe格式的数据,构建图的入口代码为 datasets/datasets.py
调用顺序为
load 加载->_load_cora_or_citeseer 加载cora数据> cls 构建图(dataset.py 77行,加载graph)
读取的数据格式为:
node_data: 节点特征

edge边信息

输入的格式为:
graph = cls({"paper": features}, {"cites": edgelist})# node的特征和边信息详见(datasets.py 84行)
2. 图构建
以构建有向图为例,77行的cls为StellarDiGraph类
StellarDiGraph(graph.py)-> 返回networkx结构的数据和node节点的subject数据
3. 节点采样构建feature
batch_size = 50 #每次50个node训练
in_samples = [5, 2] # 入度 第一层采样5个 第二层采样2个node
out_samples = [5, 2] # 出度
generator = DirectedGraphSAGENodeGenerator(G, batch_size, in_samples, out_samples) # 实例化生成batch和shuffle的方法
参数说明
G (StellarDiGraph): The machine-learning ready graph.
batch_size (int): Size of batch to return.
in_samples (list): The number of in-node samples per layer (hop) to take.
out_samples (list): The number of out-node samples per layer (hop) to take.
seed (int): [Optional] Random seed for the node sampler.
路径: stellargraph/mapper/sampled_node_generators.py
DirectedGraphSAGENodeGenerator-> DirectedBreadthFirstNeighbours(explorer.py) -DirectedBreadthFirstNeighbours(采样舒适化)-> sample_features 采样->返回feature
DirectedBreadthFirstNeighbours
根据输入的in out samples 进行采样, 并将节点打平, 特征拼接成feature
(这个是在NodeSequence
batch调用的)
4. 训练
/Users/clz/PycharmProjects/stellargraph/stellargraph/mapper/sampled_node_generators.py
103行 flow函数
# batch和shuffle构建
train_gen = generator.flow(train_subjects.index, train_targets, shuffle=True)
# sage构建
graphsage_model = DirectedGraphSAGE(
layer_sizes=[32, 32], generator=generator, bias=False, dropout=0.5,
)
# 网络结构构建
x_inp, x_out = graphsage_model.in_out_tensors()
prediction = layers.Dense(units=train_targets.shape[1], activation="softmax")(x_out)
model = Model(inputs=x_inp, outputs=prediction)
model.compile(
optimizer=optimizers.Adam(lr=0.005),
loss=losses.categorical_crossentropy,
metrics=["acc"],
)
流程
flow -> NodeSequence(提供batch和shuffle方法) ->DirectedGraphSAGE->in_out_tensors()
NodeSequence
提供batch和shuffle方法 返回batch_feats, batch_targets
batch_feats: 被采样node的特征
batch_targets:中心节点的target(label)
graphsage_model.in_out_tensors()
in_tensor是每个采样节点的特征使用keras input进行了转换作为输入
out_tensor是in_tensor调用了apply_layer方法增加了MeanAggregator和drop层
MeanAggregator
graphsage.py 307行