【图卷积网络】GCN基础原理简单python实现

基础原理讲解

应用路径

卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是边,能有效提取拓扑结构中的有效信息,实现节点分类,边预测等。

基础原理

其核心公式是:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H l W l ) H^{(l+1)}=\sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{l}W^l) H(l+1)=σ(D~1/2A~D~1/2HlWl)
其中:

  • σ \sigma σ 是非线性激活函数
  • D ~ \tilde{D} D~是度矩阵, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=jA~ij
  • A ~ \tilde{A} A~是加了自环的邻接矩阵,通常表示为 A + I A+I A+I A A A是原始邻接矩阵, I I I是单位矩阵
  • H l H^l Hl是第 l l l层的节点特征矩阵, H l + 1 H^{l+1} Hl+1是第 l + 1 l+1 l+1层的节点特征矩阵
  • W l W^l Wl是第 l l l层的学习权重矩阵

步骤讲解:
1、邻接矩阵归一化: 将邻接矩阵归一化,使得邻居节点特征对中心节点特征的贡献相等。
2、特征聚合: 通过邻接矩阵与节点特征矩阵相乘,实现邻居特征聚合。
3、线性变换: 通过可学习的权重矩阵对聚合后的特征进行线性变换。

加自环的邻接矩阵

A ~ = A + λ I \tilde{A} = A+\lambda I A~=A+λI
邻接矩阵加上一个单位矩阵, λ \lambda λ是一个可以训练的参数,但也可直接取1。加自环 是为了增强节点自我特征表示,这样在进行图卷积操作时,节点不仅会聚合来自邻居节点的特征,还会聚合自己的特征。

图卷积操作

图像卷积和图卷积
图片的卷积是一个一个卷积核,在图片上滑动着做卷积。图的卷积就是自己加邻居一起做加和。
即:
A ~ X \tilde{A}X A~X

度矩阵求解

D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=jA~ij
度矩阵的求解

标准化

在进行加和时,节点的度不同,有存在较高度值的节点和较低度值的节点,这可能导致梯度爆炸梯度消失的问题。
根据度矩阵,求逆,然后 D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~1A~D~1X,就进行了标准化,前一个 D ~ − 1 \tilde{D}^{-1} D~1是对行进行标准化,后一个 D ~ − 1 \tilde{D}^{-1} D~1是对列进行标准化。能够实现给与低度节点更大的权重,从而降低高节点的影响。
在上式推导中, D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~1A~D~1X 做了两次标准化,所以修改上式为 D ~ − 1 / 2 A ~ D ~ − 1 / 2 X \tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}X D~1/2A~D~1/2X

简单python实现

基于cora数据集实现节点分类

  • cora数据集处理
# cora数据集测试
raw_data = pd.read_csv('./data/data/cora/cora.content', sep='\t', header=None)
print("content shape: ", raw_data.shape)

raw_data_cites = pd.read_csv('./data/data/cora/cora.cites', sep='\t', header=None)
print("cites shape: ", raw_data_cites.shape)

features = raw_data.iloc[:,1:-1]
print("features shape: ", features.shape)

# one-hot encoding
labels = pd.get_dummies(raw_data[1434])
print("\n----head(3) one-hot label----")
print(labels.head(3))
l_ = np.array([0,1,2,3,4,5,6])
lab = []
for i in range(labels.shape[0]):
    lab.append(l_[labels.loc[i,:].values.astype(bool)][0])
#构建邻接矩阵
num_nodes = raw_data.shape[0]

# 将节点重新编号为[0, 2707]
new_id = list(raw_data.index)
id = list(raw_data[0])
c = zip(id, new_id)
map = dict(c)

# 根据节点个数定义矩阵维度
matrix = np.zeros((num_nodes,num_nodes))

# 根据边构建矩阵
for i ,j in zip(raw_data_cites[0],raw_data_cites[1]):
    x = map[i] ; y = map[j]
    matrix[x][y] = matrix[y][x] = 1   # 无向图:有引用关系的样本点之间取1

# 查看邻接矩阵的元素
print(matrix.shape)
  • GCN网络实现
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, adj):
        rowsum = torch.sum(adj,dim=1)
        d_inv_sqrt = torch.pow(rowsum,-0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] =0.0
        d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
        adj_normalized = torch.mm(torch.mm(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)
        
        out = torch.mm(adj_normalized,x)
        out = self.linear(out)
        return out
class GCN(nn.Module):
    def __init__(self, n_features, n_hidden, n_classes):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(n_features, n_hidden)
        self.gcn2 = GCNLayer(n_hidden, n_classes)

    def forward(self, x, adj):
        x = self.gcn1(x, adj)
        x = F.relu(x)
        x = self.gcn2(x, adj)
        return x#F.log_softmax(x, dim=1)
# 示例数据(实际数据应根据具体情况加载)

features = torch.tensor(features.values, dtype=torch.float32)
adj = torch.tensor(matrix, dtype=torch.float32)
labels = torch.tensor(lab, dtype=torch.long)
# features = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.float32)
# adj = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
# labels = torch.tensor([0, 1, 0], dtype=torch.long)

# 模型参数
n_features = features.shape[1]
n_hidden = 16
n_classes = len(torch.unique(labels))

# 创建模型
model = GCN(n_features, n_hidden, n_classes)
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# 训练模型
n_epochs = 200
for epoch in range(n_epochs):
    model.train()
    features, labels = features.cuda(), labels.cuda()
    adj = adj.cuda()
    optimizer.zero_grad()
    output = model(features, adj)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 20 == 0:
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
print("Training complete.")

参考

cora数据集及简介
图卷积网络详细介绍
GCN讲解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/773631.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言 指针和数组——指针的算术运算

目录 指针的算术运算 指针加上一个整数 指针减去一个整数 指针相减 指针的关系比较运算 小结 指针的算术运算 指针加上一个整数 指针减去一个整数 指针相减 指针的关系比较运算 小结  指针变量 – 指针类型的变量,保存地址型数据  指针变量与其他类型…

关于SQL NOT IN判断失效的情况记录

1.准备测试数据 CREATE TABLE tmp_1 (val integer);CREATE TABLE tmp_2 (val integer, val2 integer);INSERT INTO tmp_1 (val) VALUES (1); INSERT INTO tmp_1 (val) VALUES (2); INSERT INTO tmp_2 (val) VALUES (1); INSERT INTO tmp_2 (val, val2) VALUES (NULL,0);2.测…

swiftui中设置建议最多5个tabItem项,多个tabItem项会被自动折叠起来

在swiftui中设置底部的菜单栏的时候,最多建议设置5个,如果超过了,会被自动折叠到More中,点击More就会出现类似list的样式显示,不是很友好。 最多按照5个默认设置的话,就会正常全部显示出来: 测…

(七)glDrawArry绘制

几何数据&#xff1a;vao和vbo 材质程序&#xff1a;vs和fs(顶点着色器和片元着色器) 接下来只需要告诉GPU&#xff0c;使用几何数据和材质程序来进行绘制。 #include <glad/glad.h>//glad必须在glfw头文件之前包含 #include <GLFW/glfw3.h> #include <iostrea…

红海云签约海新域集团,产业服务运营领军企业加速人力资源数字化转型

北京海新域城市更新集团有限公司&#xff08;以下简称“海新域集团”&#xff09;是北京市海淀国有资产投资集团有限公司一级监管企业&#xff0c;致力于成为国内领先的产业服务运营商。集团积极探索城市和产业升级新模式&#xff0c;通过对老旧、低效等空间载体重新定位规划、…

2024年新考纲下的PMP考试有多难?全面解读!

一、PMP考试是什么&#xff0c;PMP证书有什么用&#xff1f; PMP&#xff08;Project Management Professional&#xff09;是指项目管理专业人士。PMP考试由美国PMI发起&#xff0c;旨在严格评估项目管理人员的知识和技能&#xff0c;以确定其是否具备高品质的资格认证。 PM…

c进阶篇(四):内存函数

内存函数以字节为单位更改 1.memcpy memcpy 是 C/C 中的一个标准库函数&#xff0c;用于内存拷贝操作。它的原型通常定义在 <cstring> 头文件中&#xff0c;其作用是将一块内存中的数据复制到另一块内存中。 函数原型&#xff1a;void *memcpy(void *dest, const void…

手机如何充当电脑摄像头,新手使用教程分享(新)

手机如何充当电脑摄像头&#xff1f;随着科技的发展&#xff0c;智能手机已经成为我们日常生活中不可或缺的一部分。手机的摄像头除了拍摄记录美好瞬间之外&#xff0c;其实还有个妙用&#xff0c;那就是充当电脑的摄像头。手机摄像头充当电脑摄像头使用的话&#xff0c;我们就…

上位机图像处理和嵌入式模块部署(mcu 项目1:固件编写)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 说完了上位机的开发&#xff0c;接下来就是固件的开发。前面我们说过&#xff0c;目前使用的开发板是极海apm32f103的开发板。它自身包含了iap示例…

Linux - Shell 以及 权限问题

目录 Shell的运行原理 Linux权限问题 Linux权限的概念 如何实现用户账号之间的切换 如何仅提升当前指令的权限 如何将普通用户添加到信任列表 Linux权限管理 文件访问者的分类&#xff08;人&#xff09; 文件类型和访问权限&#xff08;事物属性&#xff09; 文件权限值的表…

【linux高级IO(一)】理解五种IO模型

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:Linux从入门到精通⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学更多操作系统知识   &#x1f51d;&#x1f51d; Linux高级IO 1. 前言2. 重谈对…

Linux之文本三剑客

Linux之三剑客 Linux的三个命令,主要是用来处理文本,grep,sed,awk,处理日志的时候使用的非常多 1 grep 对文本的内容进行查找 1) 基础用法 语法 grep 选项 内容|正则表达式 文件选项: -i 不区分大小写 -v 排除,反选 -n 显示行号 -c 统计个数查看文件里包含有的内容 [roo…

【项目管理】项目风险管理(Word原件)

风险和机会管理就是在一个项目开发过程中对风险进行识别、跟踪、控制的手段。风险和机会管理提供了对可能出现的风险进行持续评估&#xff0c;确定重要的风险机会以及实施处理的策略的一种规范化的环境。包括识别、分析、制定处理和减缓行动、跟踪 。合理的风险和机会管理应尽力…

【TB作品】体重监控系统,ATMEGA16单片机,Proteus仿真

机电荷2018级课程设计题目及要求 题1:电子称重器设计 功能要求: 1)开机显示时间(小时、分)、时分可修改; 2)用滑动变阻器模拟称重传感器(测量范围0- 200g),数码管显示当前重量值,当重量值高于高 值时,红灯长亮; 3)当重量值低于低值时,黄灯长亮; 4)当重量值在正常值时,绿灯亮; 5…

Exploting an API endpoiint using documentation

HTTP request methods https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods 第一步:burp抓包刷新页面 httphistory中只能看到两个记录,可以看下Response,是HTML页面,说明这里有HTML页面 ,但是没有发现特定的API接口。 第二步:用户登录 转到用户登录的功能点处…

Android --- Service

出自于此&#xff0c;写得很清楚。关于Android Service真正的完全详解&#xff0c;你需要知道的一切_android service-CSDN博客 出自【zejian的博客】 什么是Service? Service(服务)是一个一种可以在后台执行长时间运行操作而没有用户界面的应用组件。 服务可由其他应用组件…

【Python】已解决:ValueError: Worksheet named ‘Sheet’ not found

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决&#xff1a;ValueError: Worksheet named ‘Sheet’ not found 一、分析问题背景 在Python编程中&#xff0c;处理Excel文件是一个常见的任务。通常&#xff0c;我们会使用…

DFS之搜索顺序——AcWing 1116. 马走日

DFS之搜索顺序 定义 DFS之搜索顺序是指在执行深度优先搜索时&#xff0c;遍历图或树中节点的策略。具体而言&#xff0c;DFS会沿着一条路径深入到底&#xff0c;当无法继续深入时回溯&#xff0c;然后选择另一条未探索的路径继续深入。搜索顺序直接影响到搜索效率和剪枝的可能…

线性代数|机器学习-P21概率定义和Markov不等式

文章目录 1. 样本期望和方差1.1 样本期望 E ( X ) \mathrm{E}(X) E(X)1.2 样本期望 D ( X ) \mathrm{D}(X) D(X) 2. Markov 不等式&Chebyshev不等式2.1 Markov不等式公式 概述2.2 Markov不等式公式 证明&#xff1a;2.3 Markov不等式公式 举例&#xff1a;2.4 Chebyshev不…

HarmonyOS - 通过.p7b文件获取fingerprint

1、查询工程所对应的 .p7b 文件 通常新工程运行按照需要通过 DevEco Studio 的 Project Structure 勾选 Automatically generate signature 自动生成签名文件&#xff0c;自动生成的 .p7b 文件通常默认在系统用户目录下. 如&#xff1a;C:/Users/zhangsan/.ohos/config/default…