当前位置: 首页> 健康> 母婴 > 咨询公司经营范围大全_供应链管理的主要内容_网站优化员seo招聘_什么叫外链

咨询公司经营范围大全_供应链管理的主要内容_网站优化员seo招聘_什么叫外链

时间:2025/7/8 22:13:50来源:https://blog.csdn.net/sinat_41942180/article/details/143023898 浏览次数:0次
咨询公司经营范围大全_供应链管理的主要内容_网站优化员seo招聘_什么叫外链

任务类型:单任务多类别分类任务  
用途:`Planetoid` 数据集用于处理文献引用图(如 Cora、Citeseer 和 Pubmed),每个节点表示一篇论文,边表示引用关系,任务目标是根据节点特征和图结构,对节点进行多类别分类,如将论文分为不同的主题类别。

是半监督节点分类的经典基准数据集之一,常用于测试图卷积网络(GCN)等图神经网络模型。在此任务中,给定图中部分节点的标签,模型需要基于图结构和节点特征预测未标记节点的类别。

from helpers.dataset_classes.classic_datasets import Planetoid

import torch
import pickle as pkl
import sys
import networkx as nx
import numpy as np
import scipy.sparse as sp
import os.path as ospfrom typing import Optional, Callable, List
from torch_geometric.data import InMemoryDataset, Data
from torch_sparse import coalesce
from torch_geometric.utils.undirected import to_undirected
from torch_geometric.utils import remove_self_loopsclass Planetoid(InMemoryDataset):def __init__(self, root: str, name: str,transform: Optional[Callable] = None,pre_transform: Optional[Callable] = None):self.name = namesuper().__init__(root, transform, pre_transform)self.data, self.slices = torch.load(self.processed_paths[0])data = self.get(0)self.data, self.slices = self.collate([data])@propertydef raw_dir(self) -> str:return osp.join(self.root, self.name, 'raw')@propertydef processed_dir(self) -> str:return osp.join(self.root, self.name, 'processed')@propertydef raw_file_names(self) -> List[str]:names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']return [f'ind.{self.name.lower()}.{name}' for name in names]@propertydef processed_file_names(self) -> str:return 'data.pt'def download(self):passdef process(self):data = full_load_citation(self.name, self.raw_dir)data = data if self.pre_transform is None else self.pre_transform(data)torch.save(self.collate([data]), self.processed_paths[0])def __repr__(self) -> str:return f'{self.name}()'def parse_index_file(filename):"""Code taken from https://github.com/Yujun-Yan/Heterophily_and_oversmoothing/blob/main/process.py#L18""Parse index file."""index = []for line in open(filename):index.append(int(line.strip()))return indexdef full_load_citation(dataset_str, raw_dir):"""Code adapted from https://github.com/Yujun-Yan/Heterophily_and_oversmoothing/blob/main/process.py#L33"""names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']objects = []for i in range(len(names)):path = osp.join(raw_dir, "ind.{}.{}".format(dataset_str, names[i]))with open(path, 'rb') as f:if sys.version_info > (3, 0):objects.append(pkl.load(f, encoding='latin1'))else:objects.append(pkl.load(f))x, y, tx, ty, allx, ally, graph = tuple(objects)test_idx_reorder = parse_index_file(osp.join(raw_dir, "ind.{}.test.index".format(dataset_str)))test_idx_range = np.sort(test_idx_reorder)test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))if len(test_idx_range_full) != len(test_idx_range):# Fix citeseer dataset (there are some isolated nodes in the graph)# Find isolated nodes, add them as zero-vecs into the right position, mark them# Follow H2GCN codetx_extended[test_idx_range - min(test_idx_range), :] = txtx = tx_extendedty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))ty_extended[test_idx_range - min(test_idx_range), :] = tyty = ty_extendednon_valid_samples = set(test_idx_range_full) - set(test_idx_range)else:non_valid_samples = set()features = sp.vstack((allx, tx)).tolil()features[test_idx_reorder, :] = features[test_idx_range, :]adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))labels = np.vstack((ally, ty))labels[test_idx_reorder, :] = labels[test_idx_range, :]non_valid_samples = list(non_valid_samples.union(set(list(np.where(labels.sum(1) == 0)[0]))))labels = np.argmax(labels, axis=-1)features = features.todense()# Prepare in PyTorch Geometric Formatsparse_mx = sp.coo_matrix(adj).astype(np.float32)indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))shape = torch.Size(sparse_mx.shape)edge_index, _ = coalesce(indices, None, shape[0], shape[1])# Remove self-loopsedge_index, _ = remove_self_loops(edge_index)# Make the graph undirectededge_index = to_undirected(edge_index)assert (np.array_equal(np.unique(labels), np.arange(len(np.unique(labels)))))features = torch.FloatTensor(features)labels = torch.LongTensor(labels)non_valid_samples = torch.LongTensor(non_valid_samples)return Data(x=features, edge_index=edge_index, y=labels, num_node_features=features.size(1),non_valid_samples=non_valid_samples)

1. 导入模块

import torch
import pickle as pkl
import sys
import networkx as nx
import numpy as np
import scipy.sparse as sp
import os.path as ospfrom typing import Optional, Callable,
关键字:咨询公司经营范围大全_供应链管理的主要内容_网站优化员seo招聘_什么叫外链

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: