对 torch 中 dim 的总结和理解
生活随笔
收集整理的這篇文章主要介紹了
对 torch 中 dim 的总结和理解
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
pytorch 中,使用到 dim 參數(shù)的 api 都是跟集合有關(guān)的,比如 max(), min(), mean(), softmax() 等。當(dāng)指定某個(gè) dim 時(shí),表示使用該維度的所有元素進(jìn)行集合運(yùn)算,一個(gè) tensor 的 shape 為 (3, 4, 5),分別對(duì)應(yīng)的 dim 如下所示
| 0 | 3 |
| 1 | 4 |
| 2 | 5 |
當(dāng)使用 max(dim=1) 時(shí),表示使用第二個(gè)維度中全部四個(gè)元素中的每個(gè)元素參與求最大值計(jì)算,計(jì)算后的 shape 變?yōu)?(3,5),因?yàn)橹粡?四個(gè)中求得最大的那個(gè)作為結(jié)果。如果 shape 的長(zhǎng)度為 3,則 dim 的取值只能在區(qū)間 [-3, 2],否則將報(bào)錯(cuò)。
Example
>>> a = torch.randn(3,4,5) # 求得第二個(gè)維度的最大值 >>> torch.max(a,1) torch.return_types.max( values=tensor([[0.7700, 0.1390, 0.6952, 1.9428, 0.8477],[1.0085, 0.7961, 0.9462, 2.1287, 0.9356],[1.1520, 2.1478, 0.8291, 1.0854, 0.7780]]), indices=tensor([[1, 1, 2, 2, 0],[1, 2, 2, 3, 0],[0, 1, 3, 3, 3]]))# 第二個(gè)維度縮減為只有一個(gè)元素,即 (3,1,5),api 將維度為 1 的去掉了 >>> torch.max(a,1).values.shape torch.Size([3, 5])# 第三個(gè)維度縮減為只有一個(gè)元素,即 (3,4,1),api 將維度為 1 的去掉了 >>> torch.max(a,2).values.shape torch.Size([3, 4])# 超出 dim 范圍,報(bào)錯(cuò) >>> torch.max(a,3).values.shape Traceback (most recent call last):File "<stdin>", line 1, in <module> IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)總結(jié):
1、dim 是一種集合運(yùn)算的參數(shù),表示將某個(gè)維度的所有元素參與集合運(yùn)算
2、dim 的取值和 shape 的長(zhǎng)度密切相關(guān),dim 的取值為 [-len(shape), len(shape)-1]
總結(jié)
以上是生活随笔為你收集整理的对 torch 中 dim 的总结和理解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 揭秘淘宝搜索量快速暴增的秘密
- 下一篇: 一步一步搭建前端监控系统:如何记录用户行