RuntimeError: Assertion cur_target 0 cur_target n_classes failed
生活随笔
收集整理的這篇文章主要介紹了
RuntimeError: Assertion cur_target 0 cur_target n_classes failed
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
問題描述
使用pytorch的函數 torch.nn.CrossEntropyLoss()計算Loss時報錯:
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed
- 1
報錯原因
直觀上看,函數要求目標分類數大于等于0并且小于等于輸入的類別。所以一般而言,都是網絡中輸出的種類數和標簽中設置的種類數量不同造成的。
解決方案
針對于不同原因,主要從兩方面考慮解決。
方向一:模型輸出與分類數不一致
- 看一下模型的輸出尺寸與分類數差異是否明顯,核查代碼是否存在錯誤。
- 如果沒有錯誤,只是映射維度不對,可以考慮在模型的最后一層加一層FC層,將輸出尺寸映射到分類大小。
方向二:標簽的設置不是從0開始
- 如果模型的輸出尺寸與分類數大小相同,看一下標簽的設定是否是從0開始的。
- 如果標簽是從1開始設置的,重新設置標簽。這里存在的坑是:在使用CrossEntropyLoss()這個函數進行驗證時,標簽必須從0開始設置,否則便會報錯。
</div><link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-b6c3c6d139.css" rel="stylesheet"><div class="more-toolbox"><div class="left-toolbox"><ul class="toolbox-list"><li class="tool-item tool-active is-like "><a href="javascript:;"><svg class="icon" aria-hidden="true"><use xlink:href="#csdnc-thumbsup"></use></svg><span class="name">點贊</span><span class="count">2</span></a></li><li class="tool-item tool-active is-collection "><a href="javascript:;" data-report-click="{"mod":"popu_824"}"><svg class="icon" aria-hidden="true"><use xlink:href="#icon-csdnc-Collection-G"></use></svg><span class="name">收藏</span></a></li><li class="tool-item tool-active is-share"><a href="javascript:;" data-report-click="{"mod":"1582594662_002"}"><svg class="icon" aria-hidden="true"><use xlink:href="#icon-csdnc-fenxiang"></use></svg>分享</a></li><!--打賞開始--><!--打賞結束--><li class="tool-item tool-more"><a><svg t="1575545411852" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="5717" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><defs><style type="text/css"></style></defs><path d="M179.176 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5718"></path><path d="M509.684 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5719"></path><path d="M846.175 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5720"></path></svg></a><ul class="more-box"><li class="item"><a class="article-report">文章舉報</a></li></ul></li></ul></div></div><div class="person-messagebox"><div class="left-message"><a href="https://blog.csdn.net/m0_37369043"><img src="https://profile.csdnimg.cn/9/9/0/3_m0_37369043" class="avatar_pic" username="m0_37369043"><img src="https://g.csdnimg.cn/static/user-reg-year/2x/3.png" class="user-years"></a></div><div class="middle-message"><div class="title"><span class="tit"><a href="https://blog.csdn.net/m0_37369043" data-report-click="{"mod":"popu_379"}" target="_blank">唐申庚</a></span></div><div class="text"><span>發布了10 篇原創文章</span> · <span>獲贊 17</span> · <span>訪問量 1萬+</span></div></div><div class="right-message"><a href="https://im.csdn.net/im/main.html?userName=m0_37369043" target="_blank" class="btn btn-sm btn-red-hollow bt-button personal-letter">私信</a><a class="btn btn-sm bt-button personal-watch" data-report-click="{"mod":"popu_379"}">關注</a></div></div></div>
</article>
總結
以上是生活随笔為你收集整理的RuntimeError: Assertion cur_target 0 cur_target n_classes failed的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python pandas 如何找到Na
- 下一篇: np.percentile()函数超详解