知识表示学习 TransE 代码逻辑梳理 超详细解析
知識(shí)表示學(xué)習(xí)
網(wǎng)絡(luò)上已經(jīng)存在了大量知識(shí)庫(kù)(KBs),比如OpenCyc,WordNet,Freebase,Dbpedia等等。
這些知識(shí)庫(kù)是為了各種各樣的目的建立的,因此很難用到其他系統(tǒng)上面。為了發(fā)揮知識(shí)庫(kù)的圖(graph)性,也為了得到統(tǒng)計(jì)學(xué)習(xí)(包括機(jī)器學(xué)習(xí)和深度學(xué)習(xí))的優(yōu)勢(shì),我們需要將知識(shí)庫(kù)嵌入(embedding)到一個(gè)低維空間里(比如10、20、50維)。我們都知道,獲得了向量后,就可以運(yùn)用各種數(shù)學(xué)工具進(jìn)行分析。它為許多知識(shí)獲取任務(wù)和下游應(yīng)用鋪平了道路。
總的來(lái)說(shuō),廢話這么多,所謂知識(shí)表示學(xué)習(xí),就是將知識(shí)庫(kù)給映射成向量,同時(shí)滿足屬于同一個(gè)三元組的(h,t,l)滿足h+l≈t,而不是一個(gè)三元組的不滿足這個(gè)條件。
TransE思路
TranE是一篇Bordes等人2013年發(fā)表在NIPS上的文章提出的算法。它的提出,是為了解決多關(guān)系數(shù)據(jù)(multi-relational data)的處理問(wèn)題。我們現(xiàn)在有很多很多的知識(shí)庫(kù)數(shù)據(jù)knowledge bases (KBs),比如Freebase、 Google Knowledge Graph 、 GeneOntology等等。
TransE的直觀含義,就是TransE基于實(shí)體和關(guān)系的分布式向量表示,將每個(gè)三元組實(shí)例(head,relation,tail)中的關(guān)系relation看做從實(shí)體head到實(shí)體tail的翻譯(其實(shí)我一直很納悶為什么叫做translating,其實(shí)就是向量相加),通過(guò)不斷調(diào)整h、r和t(head、relation和tail的向量),使(h + r) 盡可能與 t 相等,即 h + r = t。
損失函數(shù)是
TransE代碼邏輯梳理
首先注明,該代碼不是出自我手,但由于最近需要使用并修改TransE,故從github上找到一個(gè)還不錯(cuò)的TransE實(shí)現(xiàn),對(duì)其進(jìn)行閱讀,并梳理其邏輯,為后續(xù)工作做好鋪墊。貼上其github鏈接,感謝前人辛苦付出。https://github.com/wuxiyu/transE/blob/master/tranE.py
下面對(duì)其代碼進(jìn)行分析。
首先,這里將整個(gè)代碼封裝成了一個(gè)類,該類的構(gòu)造方法(由于平常用的語(yǔ)言是java,python只當(dāng)做工具語(yǔ)言,沒(méi)有系統(tǒng)學(xué)過(guò)語(yǔ)法,所以用詞可能不當(dāng),見(jiàn)諒)中需要的參數(shù)如下所示:
首先,我們將目光放到main方法,從main方法開(kāi)始整個(gè)TransE的旅程。
dirEntity = "C:\\data\\entity2id.txt"entityIdNum, entityList = openDetailsAndId(dirEntity)dirRelation = "C:\\data\\relation2id.txt"relationIdNum, relationList = openDetailsAndId(dirRelation)dirTrain = "C:\\data\\train.txt"tripleNum, tripleList = openTrain(dirTrain)print("打開(kāi)TransE")transE = TransE(entityList,relationList,tripleList, margin=1, dim = 100)print("TranE初始化")transE.initialize()transE.transE(15000)transE.writeRelationVector("c:\\relationVector.txt")transE.writeEntilyVector("c:\\entityVector.txt")首先是通過(guò)三個(gè)Open方法分別獲取實(shí)體數(shù)量和實(shí)體列表、關(guān)系總數(shù)量和關(guān)系列表、三元組總數(shù)量和三元組列表。獲取需要的數(shù)據(jù)。
例如,其中entityList是一個(gè)list,其樣式就為[05451384,04958634,00620424,....];
而relationList樣式為["_member_of_domain_topic","_member_meronym"...];
而tripleList例如[(03964744,04371774,_hyponym), (....)....],其中全是三元組,都是(h,t,l)的格式。
至于那些Num們,都只是用于計(jì)數(shù)?并沒(méi)發(fā)現(xiàn)用在哪里,也不用管
然后就是實(shí)例化TransE這個(gè)類了,將實(shí)體列表,關(guān)系列表,和三元組列表放進(jìn)去,設(shè)置間距γ為1(這個(gè)是超參數(shù),可以調(diào)),然后對(duì)于輸出向量,其維度設(shè)為100(這個(gè)也可以自己指定)。
之后調(diào)用transE的initialize()方法,進(jìn)行初始化。這里初始化具體做了什么呢?答曰初始化向量,構(gòu)建字典集合,分別來(lái)裝實(shí)體向量們和關(guān)系向量們。那么問(wèn)題就來(lái)了,這個(gè)向量如何生成呢,之前我們手里只有05451384這串?dāng)?shù)字來(lái)代表實(shí)體,但是,并沒(méi)有向量啊。這里采用的方式就是···隨機(jī)生成,對(duì)于個(gè)100維的向量,隨機(jī)生成它,方式為每一個(gè)數(shù)字都是在-6/(dim**0.5), 6/(dim**0.5)之間隨機(jī)生成,然后構(gòu)成一個(gè)100個(gè)元素的列表,即代表這個(gè)實(shí)體的向量,同時(shí),將這個(gè)實(shí)體和其對(duì)應(yīng)的隨機(jī)生成的向量放入新創(chuàng)建的字典entityVectorList中去,同理對(duì)于關(guān)系也是如此操作。當(dāng)然,在向量生成之后對(duì)其做一個(gè)歸一化,保證它是單位向量,做法就是每個(gè)元素除以元素總和的平方和的開(kāi)平方,具體見(jiàn)norm方法,這個(gè)很簡(jiǎn)單。
entityVectorList = {}relationVectorList = {}for entity in self.entityList:n = 0entityVector = []while n < self.dim:ram = init(self.dim)entityVector.append(ram) #注意到這里的ram和entity是毫無(wú)關(guān)系的,是一個(gè)隨機(jī)的值,所以這里append之后,就是一個(gè)dim個(gè)元素的列表n += 1entityVector = norm(entityVector)#歸一化entityVectorList[entity] = entityVector至此,我們便為每個(gè)關(guān)系和實(shí)體生成了一個(gè)向量,向量是一個(gè)100維的列表。
然后我們將entityList和relationList賦值成這兩個(gè)字典,也就是我們最初的entityList是列表,而經(jīng)過(guò)初始化之后卻變成了字典,字典的樣式為{實(shí)體名:對(duì)應(yīng)向量,…}
之后,下一步就是進(jìn)行訓(xùn)練了。調(diào)用transE的transE()方法,其中輸入的15000意為迭代的次數(shù)。
for cycleIndex in range(cI):#迭代cI次Sbatch = self.getSample(150) #隨機(jī)獲取150個(gè)三元組Tbatch = []#元組對(duì)(原三元組,打碎的三元組)的列表 :[((h,r,t),(h',r,t'))]for sbatch in Sbatch:#遍歷獲取到的元組,并獲取它們的打碎三元組,從而獲得<=150個(gè)元組對(duì)(防止重復(fù))tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch)) #將sbatch傳入,獲取打碎的三元組,然后構(gòu)成一個(gè)元組對(duì)if(tripletWithCorruptedTriplet not in Tbatch):Tbatch.append(tripletWithCorruptedTriplet)self.update(Tbatch)#對(duì)整個(gè)集合進(jìn)行更新if cycleIndex % 100 == 0:print("第%d次循環(huán)"%cycleIndex)print(self.loss)self.writeRelationVector("c:\\relationVector.txt")self.writeEntilyVector("c:\\entityVector.txt")self.loss = 0這里cI參數(shù)就是迭代次數(shù)。
首先是調(diào)用getSample()方法,該方法作用為在tripleList中隨機(jī)選取size個(gè)三元組并返回。所以這里的Sbatch就是隨機(jī)獲取的150個(gè)三元組。
然后Tbatch是一個(gè)新創(chuàng)建的列表,用于存儲(chǔ)元組(元組是tuple,是python中一種數(shù)據(jù)結(jié)構(gòu),而三元組是知識(shí)圖譜的一種結(jié)構(gòu),不要搞亂了),其中的樣式為[((h,r,t),(h',r,t'))]。
下面就是對(duì)Sbatch進(jìn)行遍歷,遍歷每一個(gè)三元組,調(diào)用getCorruptedTriplet()方法來(lái)獲取某個(gè)三元組的打碎的三元組,也就是在上面算法中提到的,對(duì)一個(gè)三元組,我們假定它是h+l=t的,此時(shí)我們創(chuàng)建一個(gè)范例,一個(gè)絕對(duì)不滿足假設(shè)的,如何創(chuàng)建呢,任意用別的h或t來(lái)替換掉我們這里的h或t,從而得到一個(gè)錯(cuò)誤的三元組,即打碎的三元組(我也不知道為啥叫打碎,不過(guò)挺有意思哈哈哈)。將打碎的三元組和正確的三元組放在一起組成一個(gè)新的元組,然后將其放入Tbatch列表中,當(dāng)然這里有個(gè)去重的判斷,很簡(jiǎn)單,就不說(shuō)了哈。
下面的操作就是最重要的了,進(jìn)行更新。
首先,要明確,這里的更新,只是針對(duì)我們隨機(jī)選出來(lái)的150個(gè)三元組進(jìn)行更新。然后,更新什么呢?當(dāng)然是更新它們的向量,所以假設(shè)我們的h,t都互不相同,那么這里最多也就更新了300個(gè)實(shí)體的向量,(關(guān)系因?yàn)閿?shù)量肯定沒(méi)那么多,就不舉例了)。然后更新的方式是什么,那就是通過(guò)梯度下降法來(lái)求得損失函數(shù)的最小值,從而獲得一個(gè)最優(yōu)的向量們。
好,下面我們來(lái)看這個(gè)更新操作,這里是調(diào)用update()方法,將剛才的Tbatch傳入。
首先在該方法的開(kāi)始,進(jìn)行了兩次拷貝,將實(shí)體列表(其實(shí)是實(shí)體-向量字典)和關(guān)系列表分別進(jìn)行拷貝,目的是為了之后更新,不相互影響。然后關(guān)于deepcopy和copy的區(qū)別大家可以去查一下,簡(jiǎn)單來(lái)說(shuō)就是前者copy的更徹底,列表或字典中的每個(gè)元素都單獨(dú)拷貝了一份。
然后便是遍歷這里的Tbatch,對(duì)每個(gè)元組進(jìn)行操作。
首先是前面是一長(zhǎng)串的賦值操作,選其中一個(gè)來(lái)說(shuō)明。
首先我們知道tripletWithCorruptedTriplet的格式是這樣的[((h,r,t),(h',r,t'))],那[0][0]就是獲取其中的h實(shí)體,然后根據(jù)h實(shí)體在entityList字典中獲取其對(duì)應(yīng)的向量。如此便是,其余也皆是同理。
然后根據(jù)L1參數(shù)是否為true來(lái)使用矩陣1范數(shù)或矩陣2范數(shù),因?yàn)椴煌稊?shù)它的梯度是不一樣的。
我們接下來(lái)矩陣2范數(shù)即L1==false來(lái)進(jìn)行說(shuō)明。此時(shí)進(jìn)行計(jì)算Loss損失函數(shù)的值,根據(jù)公式γ+d(h+l,t)?d(h′+l,t′)\gamma+d(h+l,t)-d(h'+l,t')γ+d(h+l,t)?d(h′+l,t′)來(lái)計(jì)算,當(dāng)然這里的d(h+l,t)d(h+l,t)d(h+l,t)要進(jìn)行展開(kāi),就是普通的距離公式,展開(kāi)之后的Loss函數(shù)為γ+(h+l?t)2?(h′+l?t′)2\gamma+(h+l-t)^{2}-(h'+l-t')^{2}γ+(h+l?t)2?(h′+l?t′)2,等一下,是不是主要到這里和之前說(shuō)的有些不同,對(duì)的,這里沒(méi)有求和符號(hào),因?yàn)檫@里相當(dāng)于是把總的Loss給分開(kāi)算的,所以沒(méi)有求和符號(hào)了。累加起來(lái)便有。
然后當(dāng)這個(gè)損失函數(shù)的值>0時(shí),才進(jìn)行更新,否則不進(jìn)行更新。這里解釋一下為什么這么操作。如此操作的原因在于我們喜歡正確的三元組的向量們滿足h+l≈t,而打碎的三元組不滿足,則正確三元組距離應(yīng)該接近于0,而錯(cuò)誤的應(yīng)為一個(gè)不小的正值(因?yàn)槭蔷仃?范數(shù)),然后此時(shí)必然有損失函數(shù)值e<0的情況。當(dāng)然,你也會(huì)說(shuō)那假如兩個(gè)值都不小,剛好前者小于后者呢,這種情況少,且沒(méi)必要要求這么高,畢竟可以近似,同時(shí)這是算法層級(jí)的問(wèn)題,這里不再討論。
當(dāng)e>0時(shí),我們進(jìn)行更新,這里更新的操作就是一個(gè)很簡(jiǎn)單的梯度下降方法。下面來(lái)介紹一下。首先損失函數(shù)Loss是γ+(h+l?t)2?(h′+l?t′)2\gamma+(h+l-t)^{2}-(h'+l-t')^{2}γ+(h+l?t)2?(h′+l?t′)2,我們對(duì)其h進(jìn)行求導(dǎo)得其梯度,則其結(jié)果是??h=2(h+l?t)\frac{\partial }{\partial h} = 2(h+l-t)?h??=2(h+l?t),則h更新為h?=h?u???h=h?u?2?(h+l?t)=h+u?2?(t?h?l)h^{*}=h-u*\frac{\partial }{\partial h}=h-u*2*(h+l-t)=h+u*2*(t-h-l)h?=h?u??h??=h?u?2?(h+l?t)=h+u?2?(t?h?l),這里的u是梯度下降的步長(zhǎng),也就是上面提到的學(xué)習(xí)率,同理,t的更新也是一樣,t?=t?u?2?(t?h?l)t^{*}=t-u*2*(t-h-l)t?=t?u?2?(t?h?l),然后同理l也是一樣l?=l+u?2?(t?h?l)?u?2?(t′?h′?l)l^{*}=l+u*2*(t-h-l)-u*2*(t'-h'-l)l?=l+u?2?(t?h?l)?u?2?(t′?h′?l)。
如此,進(jìn)行更新,然后進(jìn)行歸一化,最終更新總的entityList和relationList。
至此,更新過(guò)程結(jié)束,至于后面的向量寫(xiě)入文件這里就不贅述了。
完整代碼
這里代碼我都加上了較為詳細(xì)的注釋,可以結(jié)合上面的代碼梳理進(jìn)行理解。
from random import uniform, sample from numpy import * from copy import deepcopyclass TransE:def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):''':param entityList: 實(shí)體列表,讀取文本文件,實(shí)體+id:param relationList: 關(guān)系列表,讀取文本文件,關(guān)系+id:param tripleList: 三元組列表,讀取文本文件,實(shí)體+實(shí)體+關(guān)系:param margin: gamma,目標(biāo)函數(shù)的常數(shù):param learingRate: 學(xué)習(xí)率:param dim: 向量維度,也就是h,t,l向量的維度是1*dim:param L1: 距離公式'''self.margin = marginself.learingRate = learingRateself.dim = dim#向量維度self.entityList = entityList#一開(kāi)始,entityList是entity的list;初始化后,變?yōu)樽值?#xff0c;key是entity,values是其向量(使用narray)。self.relationList = relationList#理由同上self.tripleList = tripleList#理由同上self.loss = 0self.L1 = L1def initialize(self):'''初始化向量'''entityVectorList = {}relationVectorList = {}for entity in self.entityList:n = 0entityVector = []while n < self.dim:ram = init(self.dim)#初始化的范圍entityVector.append(ram) #注意到這里的ram和entity是毫無(wú)關(guān)系的,是一個(gè)隨機(jī)的值,所以這里append之后,就是一個(gè)dim個(gè)元素的列表n += 1entityVector = norm(entityVector)#歸一化entityVectorList[entity] = entityVectorprint("entityVector初始化完成,數(shù)量是%d"%len(entityVectorList))for relation in self. relationList:n = 0relationVector = []while n < self.dim:ram = init(self.dim)#初始化的范圍relationVector.append(ram)n += 1relationVector = norm(relationVector)#歸一化relationVectorList[relation] = relationVectorprint("relationVectorList初始化完成,數(shù)量是%d"%len(relationVectorList))self.entityList = entityVectorListself.relationList = relationVectorListdef transE(self, cI = 20):print("訓(xùn)練開(kāi)始")for cycleIndex in range(cI):#迭代cI次Sbatch = self.getSample(150) #隨機(jī)獲取150個(gè)三元組Tbatch = []#元組對(duì)(原三元組,打碎的三元組)的列表 :{((h,r,t),(h',r,t'))}for sbatch in Sbatch:#遍歷獲取到的元組,并獲取它們的打碎三元組,從而獲得<=150個(gè)元組對(duì)(防止重復(fù))tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch)) #將sbatch傳入,獲取打碎的三元組,然后構(gòu)成一個(gè)元組對(duì)if(tripletWithCorruptedTriplet not in Tbatch):Tbatch.append(tripletWithCorruptedTriplet)self.update(Tbatch)#對(duì)整個(gè)集合進(jìn)行更新if cycleIndex % 100 == 0:print("第%d次循環(huán)"%cycleIndex)print(self.loss)self.writeRelationVector("c:\\relationVector.txt")self.writeEntilyVector("c:\\entityVector.txt")self.loss = 0def getSample(self, size):'''隨機(jī)選取部分三元關(guān)系 sbatch:param size::return:'''return sample(self.tripleList, size) #從tripleList中隨機(jī)獲取size個(gè)元素def getCorruptedTriplet(self, triplet):'''training triplets with either the head or tail replaced by a random entity (but not both at the same time)隨機(jī)替換三元組的實(shí)體,h和t中任意一個(gè)被替換,但不同時(shí)替換。也就是構(gòu)建損壞的三元組集合:param triplet::return corruptedTriplet:'''i = uniform(-1, 1)if i < 0:#小于0,打壞三元組的第一項(xiàng)while True:entityTemp = sample(self.entityList.keys(), 1)[0]if entityTemp != triplet[0]:breakcorruptedTriplet = (entityTemp, triplet[1], triplet[2])else:#大于等于0,打壞三元組的第二項(xiàng)while True:entityTemp = sample(self.entityList.keys(), 1)[0]if entityTemp != triplet[1]:breakcorruptedTriplet = (triplet[0], entityTemp, triplet[2])return corruptedTripletdef update(self, Tbatch):'''進(jìn)行更新,更新的過(guò)程就是一個(gè)梯度下降:param Tbatch::return:'''copyEntityList = deepcopy(self.entityList) #copy和deepcopy的區(qū)別在于,copy只拷貝整體,若局部改變,則拷貝整體的局部也改變,而deepcopy則全部拷貝過(guò)去copyRelationList = deepcopy(self.relationList)for tripletWithCorruptedTriplet in Tbatch:#遍歷整個(gè)元組,最多迭代150次# 這里的索引很好理解((h,t,l)(h',t',l)) 但是copyEntityList[h]# 懂了,這里EntityList是類似于字典的,有id與向量這兩個(gè)東西,所以是輸入id,獲取向量headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元組和打碎的三元組的元組tupletailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]#下面的也是一模一樣,感覺(jué)只是為了備份一份,進(jìn)行比較headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元組和打碎的三元組的元組tupletailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]if self.L1:#這L1啥意思···哦是L1范數(shù)distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)else:#否則L2范數(shù)distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)eg = self.margin + distTriplet - distCorruptedTriplet #損失函數(shù) 就跟論文上公式是一樣的if eg > 0: #[function]+ 是一個(gè)取正值的函數(shù) 似乎是只有大于0時(shí)才進(jìn)行更新,想一下,也確實(shí),因?yàn)榍耙粋€(gè)距離應(yīng)該為0,后一個(gè)不為0,然后,0-正<0則不用改,正-正>則需要改self.loss += egif self.L1:#這個(gè)學(xué)習(xí)率有點(diǎn)懵tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)tempPositiveL1 = []tempNegtativeL1 = []for i in range(self.dim):#不知道有沒(méi)有pythonic的寫(xiě)法(比如列表推倒或者numpy的函數(shù))?if tempPositive[i] >= 0:tempPositiveL1.append(1)else:tempPositiveL1.append(-1)if tempNegtative[i] >= 0:tempNegtativeL1.append(1)else:tempNegtativeL1.append(-1)tempPositive = array(tempPositiveL1) tempNegtative = array(tempNegtativeL1)else:#這里學(xué)習(xí)率就是y?對(duì),應(yīng)該這里的學(xué)習(xí)率就是梯度下降中的步長(zhǎng)#然后括號(hào)里是t-h-ltempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)#進(jìn)行更新headEntityVector = headEntityVector + tempPositive #h* = h + 增量tailEntityVector = tailEntityVector - tempPositive #t* = t - 增量relationVector = relationVector + tempPositive - tempNegtative #l* = l +y*2(t-h-l) -y*2(t'-h'-l)headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative #同理tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative #同理#只歸一化這幾個(gè)剛更新的向量,而不是按原論文那些一口氣全更新了copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)self.entityList = copyEntityList #進(jìn)行更新self.relationList = copyRelationListdef writeEntilyVector(self, dir):print("寫(xiě)入實(shí)體")entityVectorFile = open(dir, 'w')for entity in self.entityList.keys():entityVectorFile.write(entity+"\t")entityVectorFile.write(str(self.entityList[entity].tolist()))entityVectorFile.write("\n")entityVectorFile.close()def writeRelationVector(self, dir):print("寫(xiě)入關(guān)系")relationVectorFile = open(dir, 'w')for relation in self.relationList.keys():relationVectorFile.write(relation + "\t")relationVectorFile.write(str(self.relationList[relation].tolist()))relationVectorFile.write("\n")relationVectorFile.close()def init(dim):'''向量初始化,隨機(jī)生成值:param dim: 維度:return:'''return uniform(-6/(dim**0.5), 6/(dim**0.5)) #uniform(a, b)#隨機(jī)生成a,b之間的數(shù),左閉右開(kāi)def distanceL1(h, t ,r):s = h + r - tsum = fabs(s).sum()return sumdef distanceL2(h, t, r):'''這里是對(duì)向量進(jìn)行操作的,所以有個(gè)sum:param h: 這里的都是向量:param t::param r::return:'''s = h + r - tsum = (s*s).sum()return sumdef norm(list):'''歸一化:param 向量:return: 向量/向量的能量'''var = linalg.norm(list)i = 0while i < len(list):list[i] = list[i]/vari += 1return array(list)def openDetailsAndId(dir,sp="\t"):idNum = 0list = []with open(dir) as file:lines = file.readlines()for line in lines:DetailsAndId = line.strip().split(sp)list.append(DetailsAndId[0])idNum += 1return idNum, listdef openTrain(dir,sp="\t"):num = 0list = []with open(dir) as file:lines = file.readlines()for line in lines:triple = line.strip().split(sp)if(len(triple)<3):continuelist.append(tuple(triple))num += 1return num, listif __name__ == '__main__':dirEntity = "C:\\data\\entity2id.txt"entityIdNum, entityList = openDetailsAndId(dirEntity)dirRelation = "C:\\data\\relation2id.txt"relationIdNum, relationList = openDetailsAndId(dirRelation)dirTrain = "C:\\data\\train.txt"tripleNum, tripleList = openTrain(dirTrain)print("打開(kāi)TransE")transE = TransE(entityList,relationList,tripleList, margin=1, dim = 100)print("TranE初始化")transE.initialize()transE.transE(15000)transE.writeRelationVector("c:\\relationVector.txt")transE.writeEntilyVector("c:\\entityVector.txt")參考資料
https://blog.csdn.net/u011274209/article/details/50991385
https://blog.csdn.net/jiayalu/article/details/100543909
https://github.com/wuxiyu/transE/blob/master/tranE.py
總結(jié)
以上是生活随笔為你收集整理的知识表示学习 TransE 代码逻辑梳理 超详细解析的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 【自然语言处理】【知识图谱】知识图谱表示
- 下一篇: python第一周练习 货币转换