MXNET源码中NDArray数据的获取和打印
生活随笔
收集整理的這篇文章主要介紹了
MXNET源码中NDArray数据的获取和打印
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
雖然本人也很想寫一個系列的分析文章,奈何水平不足,零零碎碎學到一點就寫一點吧
本人是想學習MXNET的源碼,首先想要添加一些打印,debug一下,第一個問題是如何在C++源碼中打印出NDArray結構的值,
今天嘗試如下,可以打印出來,
文件 incubator-mxnet/src/c_api/c_api.cc 中,函數MXNDArraySlice修改如下:
int MXNDArraySlice(NDArrayHandle handle,mx_uint slice_begin,mx_uint slice_end,NDArrayHandle *out) {NDArray *ptr = new NDArray();API_BEGIN();std::cout << "slice_begin:" << slice_begin << std::endl;std::cout << "slice_end:" << slice_end << std::endl;*ptr = static_cast<NDArray*>(handle)->SliceWithRecord(slice_begin, slice_end);*out = ptr;float *p = (float *)ptr->data().dptr_;std::cout << "p[0] = " << p[0] << std::endl;std::cout << "p[1] = " << p[1] << std::endl;API_END_HANDLE_ERROR(delete ptr);
}
Python測試代碼如下
from mxnet import autograd, nd
import mxnet
print(mxnet.__version__)x = nd.arange(2, 7).reshape((5, 1))
print(x[2:4].asnumpy())
打印結果為:
#python3 mxnet_test.py
1.5.0
slice_begin:2
slice_end:4
p[0] = 4
p[1] = 5
[[4.][5.]]
Great,可以驗證出來實際的數值就是在NDArray的data()函數的dptr_指針中,
?
____________________________________________
但是在操作時有時會無法得到預期的結果,如同文件中函數MXNDArrayGetGrad,如果按照上面的代碼進行打印的話,會發現打印出的值全為0,這時需要在代碼中添加一行WaitToRead,如下可正常打印
int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) {API_BEGIN();NDArray *arr = static_cast<NDArray*>(handle);NDArray ret = arr->grad();if (ret.is_none()) {*out = NULL;} else {std::cout << "ret.shape().ndim() = " << ret.shape().ndim() << std::endl;std::cout << "ret.shape()[0] = " << ret.shape()[0] << std::endl;std::cout << "ret.shape()[1] = " << ret.shape()[1] << std::endl;*out = new NDArray(ret);ret.WaitToRead();float *p_float = (float *)(ret.data().dptr_);for (int i = 0; i < ret.shape()[0] * ret.shape()[1]; i++){std::cout << "p_float[" << i << "] = " << p_float[i] << std::endl;}}API_END();
}
Python 測試代碼為:
from mxnet import autograd, nd
import mxnet
print(mxnet.__version__)x = nd.arange(2, 7).reshape((5, 1))x.attach_grad()with autograd.record():y = 2 * nd.dot(x.T, x)y.backward()# assert (x.grad - 4 * x).norm().asscalar() == 0
print(x.grad)
輸出為:
# python3 autograd_test.py
1.5.0
ret.shape().ndim() = 2
ret.shape()[0] = 5
ret.shape()[1] = 1
p_float[0] = 8
p_float[1] = 12
p_float[2] = 16
p_float[3] = 20
p_float[4] = 24[[ 8.][12.][16.][20.][24.]]
<NDArray 5x1 @cpu(0)>
?
總結
以上是生活随笔為你收集整理的MXNET源码中NDArray数据的获取和打印的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Makefile中的几个调试方法
- 下一篇: MXNET源码中TShape值的获取和打