python3 Softmax函数
阅读原文时间:2023年07月10日阅读:2

Softmax的作用简单的说就计算一组数值中每个值的占比

import torch
import torch.nn.functional as F

# 原始数据tensor
y = torch.rand(size=[2, 3, 4])
print(y, '\n')


tensor([[[0.6898, 0.0193, 0.0913, 0.9597],
         [0.2965, 0.6402, 0.3175, 0.2141],
         [0.6842, 0.6477, 0.1265, 0.2181]],

        [[0.7287, 0.9654, 0.8608, 0.1618],
         [0.4583, 0.4862, 0.3352, 0.1108],
         [0.1539, 0.0863, 0.1511, 0.6078]]])


dim0 = F.softmax(y, dim=0)
print('dim=0 softmax:\n', dim0)
print('dim=0, tensor:')
for i in range(2):
    print(y[i, :, :].reshape(-1))
# dim = 0指第一个维度,例子中第一个维度的size是2


dim=0 softmax:
 tensor([[[0.4903, 0.2797, 0.3166, 0.6895],
         [0.4596, 0.5384, 0.4956, 0.5258],
         [0.6296, 0.6368, 0.4938, 0.4038]],

        [[0.5097, 0.7203, 0.6834, 0.3105],
         [0.5404, 0.4616, 0.5044, 0.4742],
         [0.3704, 0.3632, 0.5062, 0.5962]]])
dim=0, tensor:
tensor([0.6898, 0.0193, 0.0913, 0.9597, 0.2965, 0.6402, 0.3175, 0.2141, 0.6842,
        0.6477, 0.1265, 0.2181])
tensor([0.7287, 0.9654, 0.8608, 0.1618, 0.4583, 0.4862, 0.3352, 0.1108, 0.1539,
        0.0863, 0.1511, 0.6078])


dim1 = F.softmax(y, dim=1)
print('dim=1 softmax:\n', dim1)
print('dim=1, tensor:')
for i in range(3):
    print(y[:, i, :].reshape(-1))
# dim = 1指第二个维度,例子中第一个维度的size是3


dim=1 softmax:
 tensor([[[0.3746, 0.2112, 0.3040, 0.5126],
         [0.2528, 0.3929, 0.3811, 0.2432],
         [0.3725, 0.3959, 0.3149, 0.2442]],

        [[0.4299, 0.4915, 0.4801, 0.2847],
         [0.3281, 0.3044, 0.2838, 0.2705],
         [0.2420, 0.2041, 0.2361, 0.4447]]])
dim=1, tensor:
tensor([0.6898, 0.0193, 0.0913, 0.9597, 0.7287, 0.9654, 0.8608, 0.1618])
tensor([0.2965, 0.6402, 0.3175, 0.2141, 0.4583, 0.4862, 0.3352, 0.1108])
tensor([0.6842, 0.6477, 0.1265, 0.2181, 0.1539, 0.0863, 0.1511, 0.6078])


dim2 = F.softmax(y, dim=2)
print('dim=2 softmax:\n', dim2)
print('dim=2, tensor:')
for i in range(4):
    print(y[:, :, i].reshape(-1))
# dim = 2指第三个维度,例子中第一个维度的size是4


dim=2 softmax:
 tensor([[[0.2967, 0.1517, 0.1631, 0.3886],
         [0.2298, 0.3240, 0.2346, 0.2116],
         [0.3161, 0.3047, 0.1809, 0.1983]],

        [[0.2515, 0.3187, 0.2870, 0.1427],
         [0.2763, 0.2841, 0.2443, 0.1952],
         [0.2219, 0.2074, 0.2213, 0.3494]]])
dim=2, tensor:
tensor([0.6898, 0.2965, 0.6842, 0.7287, 0.4583, 0.1539])
tensor([0.0193, 0.6402, 0.6477, 0.9654, 0.4862, 0.0863])
tensor([0.0913, 0.3175, 0.1265, 0.8608, 0.3352, 0.1511])
tensor([0.9597, 0.2141, 0.2181, 0.1618, 0.1108, 0.6078])

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器

你可能感兴趣的文章