Token Based Indexing#
When indexing hidden states for specific tokens, use .token[<idx>]
or .t[<idx>]
.
As a preliminary example, lets just get a hidden state from the model using .t[<idx>]
.
[1]:
from nnsight import LanguageModel
model = LanguageModel('openai-community/gpt2', device_map='cuda')
[2]:
with model.trace('The Eiffel Tower is in the city of') as tracer:
hidden_states = model.transformer.h[-1].output[0].t[0].save()
output = model.output.save()
print(hidden_states.shape)
print(output.shape)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
torch.Size([1, 768])
CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ -36.2874, -35.0114, -38.0793, ..., -40.5163, -41.3759,
-34.9193],
[ -68.8886, -70.1562, -71.8408, ..., -80.4195, -78.2552,
-71.1206],
[ -82.2950, -81.6519, -83.9941, ..., -94.4878, -94.5194,
-85.6998],
...,
[-113.8675, -111.8628, -113.6634, ..., -116.7652, -114.8267,
-112.3621],
[ -81.8531, -83.3006, -91.8192, ..., -92.9943, -89.8382,
-85.6897],
[-103.9307, -102.5054, -105.1563, ..., -109.3099, -110.4195,
-103.1395]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-0.9420, 1.9023, 0.8722, ..., -1.2703, -0.4792, 1.2469],
[-1.9590, 2.7141, 2.8423, ..., -1.1633, -1.6173, 2.1507],
[-2.6123, 2.0937, 0.9679, ..., -0.9763, -1.2243, 2.0279],
...,
[-2.4282, 2.4462, 2.1550, ..., -0.5916, -1.6641, 2.1119],
[-3.5624, 3.6804, 2.5053, ..., -0.3572, -2.5960, 0.9592],
[-2.6021, 2.8035, 1.7291, ..., -0.8557, -2.1589, 2.6881]],
[[ 0.1103, 0.6967, -1.1409, ..., -0.1243, 1.8249, -0.1592],
[ 0.3364, -2.3421, -3.0033, ..., -0.9075, 3.9665, 0.2082],
[-1.2822, -2.8345, 0.1537, ..., 0.6516, 2.4424, 0.7518],
...,
[-0.1554, -1.0321, -2.5109, ..., -0.9747, 4.8222, -1.8171],
[-1.3993, -2.2428, -0.0644, ..., -0.9444, 3.5096, 0.4326],
[ 0.5759, -0.8102, -1.8774, ..., -1.4308, 3.0181, -2.2393]],
[[-0.0985, -0.0323, 0.7536, ..., -1.1902, -1.6401, 0.6545],
[ 1.1513, -0.7019, 0.2992, ..., -1.8075, -0.1072, 2.0486],
[-0.1089, -1.0244, 0.4639, ..., -1.8416, -0.2348, 1.0322],
...,
[ 0.3554, 0.3485, 0.0083, ..., -3.3077, 0.8817, 1.4423],
[ 0.3027, 0.2488, -0.2483, ..., -2.8617, 0.7589, 0.7380],
[ 0.5146, -0.1207, 0.6076, ..., -2.7679, 1.1288, 1.7932]],
...,
[[ 0.6009, -0.0877, -0.2693, ..., 0.1756, 0.7995, 0.5978],
[ 0.0456, 0.2891, -0.1535, ..., 1.0184, 1.0627, 0.3627],
[-0.0681, -0.5138, 0.5735, ..., 0.7821, 0.8516, 0.4657],
...,
[-0.0383, -0.2532, 0.0525, ..., 1.2245, 0.5464, 0.4056],
[ 0.2111, 0.9947, 0.0403, ..., 1.1817, -0.7079, 0.5290],
[ 0.1136, -0.0611, 0.1199, ..., 1.2025, 0.4589, 0.6644]],
[[ 1.4709, 1.5225, -0.4336, ..., -0.1837, 1.0947, -1.6615],
[ 0.7999, 0.0324, -1.5696, ..., -0.7550, 1.4671, -1.1099],
[ 0.6255, 0.4108, 0.0984, ..., -1.2564, 1.9016, -0.0603],
...,
[ 1.1553, 0.5795, -0.6220, ..., -0.7993, 0.4428, -0.6729],
[ 1.5863, -0.0730, 0.1822, ..., -0.5310, 0.4560, -0.4558],
[ 0.9170, -0.7168, -0.4214, ..., -0.8926, 0.4736, -0.0411]],
[[ 0.6260, 0.2122, 0.2527, ..., -0.6377, 0.2275, 1.5142],
[-0.3332, 1.5151, -0.3315, ..., 1.2160, 0.2653, 2.6735],
[ 0.1930, 0.0467, -0.3682, ..., -0.1827, 0.1576, 0.5612],
...,
[-0.1787, 1.2580, -0.2565, ..., -0.6601, 1.2289, 0.2853],
[ 0.9067, 0.6444, 0.2020, ..., 0.1291, 0.2002, 0.8276],
[-0.5779, 0.4654, -0.8867, ..., 1.4954, 1.3435, -0.6073]]]],
device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[-1.3066e-02, -1.4464e-02, 1.2694e-01, ..., -4.9182e-02,
1.0464e-01, 2.3067e-02],
[ 2.0469e-01, -1.3684e-02, 8.0588e-02, ..., -1.9410e-02,
5.5186e-02, 7.3562e-02],
[ 7.2231e-03, 1.8508e-01, 1.0139e-01, ..., -7.6448e-02,
2.9932e-01, -8.5228e-03],
...,
[-1.6686e-01, 1.9638e-02, 1.2153e-01, ..., 6.1965e-02,
9.3590e-02, -1.0460e-01],
[ 1.5657e-01, 1.5053e-01, 5.7654e-02, ..., -4.2498e-01,
-5.2136e-02, 3.0045e-02],
[-1.0558e-02, -8.6992e-02, -7.6297e-02, ..., -6.3531e-02,
-5.0926e-02, 1.9987e-01]],
[[ 5.9014e-01, 1.0051e-01, -2.0716e-01, ..., -6.9383e-01,
-2.7763e-01, 2.0517e-01],
[ 6.3339e-01, 1.1631e-01, 2.4300e-01, ..., 1.9035e-01,
8.8391e-02, -5.1286e-02],
[ 2.7510e-01, -7.9842e-02, 2.0712e-01, ..., 2.0180e-01,
1.4190e-01, -1.3274e-01],
...,
[ 8.3555e-01, -9.4205e-02, 7.4023e-02, ..., -1.7617e-01,
1.3164e-01, 1.1117e-01],
[ 3.2692e-01, 4.5032e-02, 2.5904e-01, ..., 7.9349e-02,
2.0154e-01, -5.9558e-03],
[ 6.5303e-01, -8.9489e-02, -4.5211e-01, ..., 7.0391e-04,
4.9327e-01, 1.5887e-01]],
[[-2.8404e-02, -1.1449e-01, -2.1676e-02, ..., 3.9217e-03,
7.8844e-02, -3.9936e-03],
[-4.9779e-02, 1.8518e-01, -1.9874e-01, ..., -4.4753e-02,
-9.1100e-02, -2.1138e-02],
[ 4.4283e-02, 5.9255e-02, 5.3522e-02, ..., -1.4617e-02,
-3.3558e-01, 1.7041e-01],
...,
[-5.3011e-01, -6.1409e-04, -6.0240e-01, ..., -1.9026e-01,
-7.1861e-02, 3.2019e-01],
[ 4.0913e-01, -1.1379e-01, -1.6436e-01, ..., 8.2217e-02,
-1.0437e-01, -7.6691e-02],
[ 3.1223e-01, 3.6828e-01, 6.0183e-01, ..., -2.8972e-02,
1.1367e-01, -2.5661e-01]],
...,
[[-1.0771e-01, -2.1316e-01, -2.1841e-02, ..., -2.3210e-01,
2.1270e-02, -6.6547e-02],
[-3.1855e-01, 4.4327e-01, -1.6764e-01, ..., 5.3822e-02,
-7.5202e-02, 1.9941e-01],
[-1.4129e-01, -6.2872e-02, 2.3989e-01, ..., 2.2710e-01,
1.2402e-01, 4.0053e-01],
...,
[-1.0040e-01, -4.9095e-01, 2.2476e-02, ..., 5.5608e-02,
-1.4735e-01, 2.3780e-01],
[-3.0879e-01, 7.1592e-01, 1.2739e-01, ..., 2.9476e-02,
-1.5573e-01, -1.7634e-02],
[-1.6235e-01, -2.5231e-01, -6.0719e-02, ..., -3.7746e-01,
-6.9728e-03, -2.2533e-01]],
[[ 9.3223e-02, -1.0404e-01, -2.1104e-01, ..., 1.8502e-01,
2.2378e-01, -3.1989e-02],
[-4.5714e-01, 6.4180e-02, -1.5538e-01, ..., -2.6814e-01,
2.0829e-01, 8.7156e-03],
[ 2.4634e-03, 1.8372e-01, 7.3725e-03, ..., -4.8131e-01,
1.2558e-01, 6.1276e-02],
...,
[-2.2317e-01, 1.2418e-01, -3.6774e-02, ..., 2.8985e-01,
5.9641e-02, -8.6951e-03],
[-1.8944e-01, -1.7414e-02, -2.9084e-02, ..., -4.5319e-02,
-5.7796e-02, 4.7680e-01],
[ 9.2804e-02, 9.9442e-02, -6.0471e-02, ..., -7.9065e-02,
-1.6836e-01, 7.1764e-02]],
[[-2.4770e-02, -3.7828e-01, 1.1838e-01, ..., 1.1582e-02,
-2.4843e-01, -1.1559e-01],
[ 5.8631e-02, 1.6256e-01, 1.3249e-01, ..., 2.6460e-01,
9.5267e-02, 1.0518e-01],
[ 5.0756e-02, -1.4601e-01, -2.3191e-01, ..., -2.2047e-01,
3.0730e-01, 2.6307e-01],
...,
[ 2.8193e-02, 1.9202e-01, -8.7550e-02, ..., 6.9838e-02,
-4.0262e-02, -5.9197e-03],
[ 2.6708e-01, 1.3450e-01, -8.2224e-02, ..., -6.0210e-04,
-1.4364e-01, 1.5347e-01],
[ 1.5456e-01, -1.1916e-01, 2.8118e-01, ..., 1.1415e-01,
2.5977e-01, 1.8767e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[-0.3465, 1.8232, -1.4522, ..., 1.4427, -0.8329, 1.0962],
[ 0.5499, 2.4760, -0.3377, ..., -0.7511, -2.3585, -0.8001],
[ 1.4355, 3.1091, -0.4843, ..., -0.2537, -1.2017, 0.3102],
...,
[-0.2717, 1.2955, -0.1360, ..., 0.4590, -0.8015, -0.3853],
[-0.8938, 1.0282, 1.2864, ..., 0.0868, -1.8080, -0.7916],
[-0.3702, 0.8338, 0.2453, ..., 0.4522, -1.6840, -0.4447]],
[[-0.8245, -0.3510, -0.5746, ..., -0.2983, 0.9754, -0.5511],
[-0.5012, -0.1667, -1.3464, ..., -0.2514, 0.3567, -0.8778],
[-0.6851, 0.3835, -2.6297, ..., -0.2897, 1.0644, 0.0080],
...,
[-0.1188, 0.6849, -1.4834, ..., -0.3863, 0.2342, -0.6524],
[-0.1495, -0.0715, -2.1277, ..., -0.1427, -1.4979, -0.9024],
[ 0.2886, 1.0222, -1.8740, ..., -0.4773, -0.1798, -0.9583]],
[[ 0.3444, 0.0273, 0.0736, ..., -1.2545, 0.2919, -0.1958],
[ 0.3155, 0.4886, -0.0817, ..., -0.6271, -0.1437, 0.3061],
[ 0.0337, 0.2995, -0.4871, ..., -0.6719, 0.3822, 0.1488],
...,
[-0.1552, 0.0079, -0.0330, ..., -0.7431, 0.4124, 0.2261],
[ 0.1261, -0.2049, -0.1360, ..., -1.0366, -0.2439, 0.4584],
[-0.1678, -0.0727, 0.0116, ..., -0.8169, 0.1020, 0.1669]],
...,
[[-0.3655, -0.2597, -0.7410, ..., -0.8024, 0.4248, -0.4208],
[-0.0382, -0.9516, 2.3876, ..., 1.0405, -1.6202, -0.3938],
[ 0.2023, 0.8736, 1.8099, ..., -0.0363, -0.3845, -0.2317],
...,
[-0.7035, 0.4357, 0.7795, ..., -1.1353, 0.6211, -0.0097],
[ 0.8421, -0.1923, 1.8707, ..., 1.6848, 0.5796, -1.7162],
[ 1.1454, -0.0537, 1.8554, ..., 0.6414, -0.5141, -0.9465]],
[[-1.1789, -2.8546, 0.1095, ..., 1.7660, 1.5671, -1.5985],
[ 0.2043, 0.9542, -0.5255, ..., -0.7796, 0.6046, -0.0179],
[ 0.0663, 0.5924, -0.6268, ..., -0.5561, 0.5178, 0.0364],
...,
[ 0.1530, 0.2313, -0.7300, ..., -0.0290, 0.7031, -0.0167],
[-0.3269, 0.4781, -0.8025, ..., -0.3819, 0.8895, 0.2701],
[-0.0928, 0.2595, -0.7860, ..., -0.7578, 0.9077, 0.4650]],
[[ 1.0171, 1.8497, 0.6378, ..., -0.8035, 0.1293, 0.6040],
[ 0.0156, 2.1213, 2.3286, ..., 1.4849, 0.3781, -2.2845],
[ 0.0148, 2.0272, 2.2091, ..., 1.7553, -0.1371, -2.4600],
...,
[-0.1495, 1.9038, -0.4051, ..., 0.6891, -0.3348, -0.9027],
[ 0.0852, 1.0575, -0.2086, ..., 0.6766, 0.7277, -1.3756],
[ 1.3943, 2.3425, 0.3973, ..., -0.1781, -0.6685, -0.5852]]]],
device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[ 2.8582e-01, -5.7216e-02, 7.6864e-03, ..., 1.7198e-01,
-2.0384e-01, -1.4675e-02],
[ 5.7858e-02, 3.3398e-01, -4.4943e-01, ..., -2.9586e-01,
4.0076e-01, 1.6289e-01],
[-3.2822e-01, -2.3642e-01, 5.2615e-02, ..., -5.6616e-02,
9.3489e-02, 4.1029e-01],
...,
[ 5.1043e-01, 1.1239e-01, 1.1384e-01, ..., 1.5822e-01,
-2.3087e-01, 1.6409e-01],
[-3.5905e-01, -1.7121e-02, -3.3058e-01, ..., 5.2365e-01,
-1.0570e-01, -2.3161e-01],
[ 2.5310e-01, 3.4903e-01, 3.1141e-01, ..., -1.0517e-01,
-2.5150e-01, 3.6372e-01]],
[[ 1.9847e-01, -5.7434e-02, -1.4368e-01, ..., 4.2654e-02,
-4.2432e-01, 3.0031e-02],
[-6.5083e-01, 5.3207e-01, 1.1575e+00, ..., 1.7449e-01,
-1.8299e-01, -5.5981e-01],
[ 5.5086e-01, 1.7719e-01, -4.3369e-01, ..., -1.6054e-01,
-6.7744e-01, 3.3818e-01],
...,
[ 5.4245e-01, 1.8869e-01, -3.8369e-01, ..., 5.7259e-01,
4.3979e-01, -7.9122e-01],
[-1.2102e-01, 6.4750e-01, -6.5491e-01, ..., 4.7697e-01,
-1.5043e-01, 4.5935e-01],
[ 8.3022e-01, 1.8053e-01, 2.6405e-01, ..., -9.6267e-01,
4.0402e-02, 6.3098e-02]],
[[ 4.5447e-02, -1.3822e-01, -5.6856e-02, ..., -5.8098e-01,
7.3407e-02, 3.3327e-02],
[ 4.8486e-01, 3.0004e-01, 2.2786e-01, ..., -6.6910e-01,
-6.2471e-02, -1.1193e-02],
[ 6.1447e-01, 1.9622e-01, 2.6388e-01, ..., -6.4158e-01,
1.2114e-01, -5.8668e-02],
...,
[ 5.2907e-01, 4.2188e-02, 1.2375e-01, ..., -3.8698e-01,
-8.2161e-03, 1.2330e-02],
[ 5.4567e-01, 3.7348e-01, 3.6478e-01, ..., -2.9497e-01,
-1.4520e-01, 1.0892e-01],
[ 7.6291e-01, 3.8150e-01, 1.0935e-01, ..., -6.2610e-02,
3.5521e-01, -2.2801e-01]],
...,
[[ 1.4935e-01, 5.5228e-01, -5.8729e-02, ..., 1.0577e-01,
-1.1574e+00, -1.3604e-02],
[-5.1389e-01, 3.5528e-01, -1.9673e-01, ..., -3.9717e-02,
-7.1098e-01, -7.2981e-02],
[-2.6157e-01, 3.2165e-01, 1.1155e+00, ..., -1.7352e-01,
-6.4596e-01, -6.2517e-02],
...,
[ 2.2258e-01, 3.6877e-01, 1.1427e-01, ..., 3.0327e-02,
-5.1350e-01, 8.5692e-02],
[-4.9558e-02, 4.0748e-01, 1.1954e-01, ..., 3.3549e-01,
-1.6801e-01, -3.9417e-01],
[ 1.0553e-01, -6.5443e-01, 7.7109e-02, ..., 6.8254e-02,
-5.4058e-01, -2.9753e-01]],
[[ 3.3667e-01, -2.2411e-01, -1.9435e-01, ..., 3.1839e-01,
-3.5641e+00, -1.4767e-01],
[ 9.3890e-02, -1.3407e-01, 2.1106e-01, ..., 5.0133e-01,
-3.0602e-02, 6.8217e-02],
[-1.3945e-01, 6.5872e-01, -3.0873e-01, ..., -1.3002e-01,
4.3156e-02, 1.0483e-01],
...,
[ 2.0381e-01, -1.6960e-01, 1.2699e-01, ..., -1.8644e-01,
-5.4254e-02, 2.5791e-02],
[ 3.7103e-01, 9.3224e-02, -5.2816e-02, ..., -2.0285e-01,
3.0910e-01, -4.9643e-02],
[ 2.4020e-01, 6.1041e-01, -1.3586e-01, ..., -1.0973e-01,
-1.0275e-01, -1.1933e-01]],
[[ 1.0385e-01, -6.1643e-02, -8.1606e-02, ..., -1.9553e-01,
1.6988e-01, 7.9386e-02],
[-1.5159e-01, -3.4469e-02, -4.1837e-01, ..., -1.9089e-01,
3.7003e-01, 4.1017e-01],
[-3.2069e-02, 1.7096e-03, 1.6785e-01, ..., -1.9622e-01,
-4.2147e-02, 3.2936e-02],
...,
[ 2.0300e-01, 2.4707e-01, -1.9529e-01, ..., -3.4163e-01,
1.1657e-01, 2.1782e-01],
[ 4.1669e-02, 3.8711e-01, 2.9755e-01, ..., -1.9858e-01,
1.6213e-01, 1.9509e-02],
[-1.7768e-02, 1.2535e-01, 4.7679e-02, ..., -1.9662e-01,
-6.9517e-02, 1.8849e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[-1.1258e-01, -1.1088e+00, 2.8973e-01, ..., -6.3421e-01,
-1.7726e-01, 3.7837e-03],
[ 4.6296e-02, -2.7619e+00, 1.5496e-01, ..., 1.8012e-01,
1.5414e-01, 5.2871e-01],
[ 1.6581e+00, -3.4251e+00, 1.3515e+00, ..., -5.9394e-01,
-1.3542e-01, 8.0459e-01],
...,
[-1.8019e-01, -7.5081e-01, -8.3324e-01, ..., 7.3177e-01,
-3.3206e-01, 1.8634e-01],
[ 5.0296e-01, -2.4352e+00, -4.8550e-01, ..., 3.0387e-01,
-5.1740e-01, 4.8016e-01],
[ 6.5937e-02, -8.9230e-01, 4.7684e-01, ..., 1.5388e+00,
5.9400e-02, -2.0516e-01]],
[[-5.2778e-01, 3.9466e-01, -3.2121e-01, ..., 1.1577e+00,
-5.7213e-01, -4.4179e-01],
[-1.4074e+00, 1.4519e-01, -6.0831e-01, ..., 5.8401e-01,
4.8111e-01, 2.3295e-01],
[ 7.7871e-01, 4.0712e-01, -1.4060e+00, ..., 8.3845e-01,
6.0770e-01, 4.0896e-01],
...,
[-1.6353e+00, -5.3745e-01, 9.7606e-01, ..., -2.2841e-01,
-2.6943e-01, 1.3457e+00],
[-1.6745e+00, -2.2275e+00, -1.1276e+00, ..., 2.0199e-01,
6.4847e-01, -3.5962e-01],
[-2.7640e-01, 1.0089e+00, -2.0779e+00, ..., 1.8831e-01,
9.1014e-01, -2.6479e-01]],
[[ 1.2186e+00, 3.0460e+00, 3.7205e+00, ..., 6.2533e-01,
1.6636e+00, -7.7160e-01],
[-1.6707e+00, 1.7284e+00, -4.4143e+00, ..., -3.4466e+00,
3.8242e+00, 3.0921e-01],
[-2.1330e+00, 8.3251e-01, -3.8796e+00, ..., -3.3404e+00,
4.1093e+00, 7.7233e-01],
...,
[-4.6060e+00, -7.7079e-01, -2.9684e+00, ..., -3.9691e+00,
1.9058e+00, 4.0182e-01],
[-2.7102e+00, -1.9871e+00, -3.9144e+00, ..., -3.3751e+00,
2.4657e+00, 1.6757e+00],
[-4.6400e+00, -1.8988e+00, -4.4826e+00, ..., -3.7084e+00,
2.3719e+00, 7.5402e-01]],
...,
[[ 1.3252e+00, -2.7270e+00, -2.7397e+00, ..., 9.5941e-01,
4.3134e-01, 2.7145e+00],
[-1.3666e+00, 1.3817e+00, 7.7345e-01, ..., 7.0400e-03,
-2.1723e+00, -1.2306e+00],
[-3.3139e+00, 2.7426e+00, 7.5008e-01, ..., -6.9737e-01,
-2.8426e+00, -8.4870e-01],
...,
[-3.3808e+00, 3.6548e+00, 1.0053e+00, ..., -2.0039e-01,
-3.4632e+00, -7.8443e-01],
[-3.2769e+00, 3.0825e+00, 1.8556e+00, ..., -1.5148e+00,
-4.2361e+00, -2.2974e+00],
[-3.5654e+00, 5.0539e+00, 1.9514e+00, ..., -1.8862e+00,
-3.2081e+00, -1.6764e+00]],
[[ 1.7317e+00, 4.5925e-01, 9.2368e-01, ..., 2.3702e-02,
-1.0098e+00, -3.0809e-01],
[ 1.9795e+00, 7.9756e-01, 1.2470e+00, ..., 4.3697e-01,
-1.4722e+00, -1.6632e+00],
[ 2.5675e+00, 6.1664e-01, 1.3403e+00, ..., -1.5458e-01,
-1.6354e+00, -1.2597e+00],
...,
[ 1.8166e+00, 5.7684e-01, 1.0326e+00, ..., 3.2898e-01,
-2.0770e+00, -9.4286e-01],
[ 2.6763e+00, 3.3288e-01, 1.1887e+00, ..., 1.4417e-01,
-1.4898e+00, -8.5120e-01],
[ 2.0895e+00, 4.4764e-01, 8.9569e-01, ..., -3.5181e-01,
-2.2835e+00, -9.7895e-01]],
[[-2.5208e-01, 1.5814e-01, -5.6552e-01, ..., 3.1414e-01,
2.8355e-01, 2.2655e-01],
[ 1.5601e-01, 1.0595e+00, 1.6655e-01, ..., -3.0340e-01,
7.3512e-02, -3.7588e-01],
[ 4.5730e-01, 7.6759e-01, 6.7810e-02, ..., -3.8059e-01,
-4.8543e-02, -1.7695e-01],
...,
[ 3.0563e-01, -4.7971e-01, -7.2877e-01, ..., 2.4584e-01,
2.6441e-01, 7.5348e-01],
[-7.7925e-01, 4.4917e-01, -8.4968e-01, ..., 6.4183e-01,
6.5270e-01, -4.6915e-01],
[ 1.4354e+00, -2.2250e-01, 1.3349e-01, ..., -1.4484e-02,
4.1740e-01, 1.9530e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[ 4.2124e-03, 5.0817e-02, -1.2287e-01, ..., 2.0558e-02,
4.7877e-02, -5.4793e-01],
[ 2.3754e-02, 5.2620e-02, 1.8422e-01, ..., -4.3334e-01,
-2.6523e-01, 7.3234e-01],
[-5.8956e-01, 1.4897e-01, 4.0976e-01, ..., 4.9391e-01,
9.4187e-01, 5.9001e-01],
...,
[ 7.7799e-01, 1.2943e+00, 2.9846e-01, ..., 3.6765e-01,
1.0873e-01, 8.7409e-01],
[ 5.5421e-01, -6.7936e-01, 3.6075e-01, ..., -4.4645e-01,
1.1377e-01, 4.7029e-01],
[-1.4159e-01, 3.3504e-01, -3.2649e-01, ..., 2.2850e-02,
-3.0386e-01, 5.8188e-01]],
[[ 1.3280e-02, -2.8979e-02, 5.5474e-02, ..., -4.4208e-02,
6.2167e-03, 3.4543e-02],
[ 9.9526e-03, 1.0667e-01, -3.8910e-01, ..., 8.4887e-03,
-1.0204e-01, 5.6042e-02],
[-5.3040e-01, -4.0489e-02, -2.2295e-01, ..., 2.6226e-01,
-7.2703e-01, -3.2483e-01],
...,
[ 8.4138e-02, 2.8657e-02, 1.4982e-01, ..., 6.1274e-02,
-1.4219e-01, 3.3446e-01],
[ 6.4882e-02, 1.2559e+00, -6.3594e-01, ..., 2.0117e+00,
4.0344e-01, -2.6313e-01],
[ 3.0646e-01, 5.3148e-02, -6.7682e-03, ..., 2.5434e-01,
-1.0172e-03, 2.6347e-01]],
[[ 1.3993e-02, -8.0731e-01, -3.6560e-02, ..., 7.3044e-02,
8.4237e-03, -2.6865e-02],
[-6.1524e-02, -7.5551e-01, 2.6525e-01, ..., 2.5373e-01,
1.0760e-01, 2.4019e-01],
[-2.8497e-02, -1.3058e+00, -1.9403e-01, ..., -6.1217e-01,
-3.0187e-02, 3.7932e-01],
...,
[-2.6733e-01, -1.9554e+00, -3.6751e-02, ..., 6.9920e-02,
-3.7356e-02, 7.0360e-02],
[-1.6400e-01, -1.9207e+00, 7.5180e-01, ..., 7.4416e-02,
4.1090e-01, -5.3913e-01],
[-4.9150e-01, -7.3422e-01, -6.0788e-01, ..., -5.1220e-01,
5.3674e-02, -9.2746e-02]],
...,
[[ 1.4164e-02, -7.0936e-02, 1.3435e+00, ..., -7.3744e-02,
2.0103e-01, -2.8246e-02],
[ 7.4353e-03, 1.5017e-01, 1.0513e+00, ..., 1.1080e-01,
-1.5743e-01, 6.2338e-02],
[-2.2182e-01, -4.2225e-01, 1.2976e+00, ..., 1.1032e-01,
2.1166e-01, -1.7599e-01],
...,
[-5.7066e-02, 4.6036e-01, 2.5958e+00, ..., -3.7224e-01,
1.4200e-01, -5.8588e-01],
[-1.6861e-01, -1.0492e-01, 1.0817e+00, ..., -2.9654e-01,
4.1565e-01, 1.0558e+00],
[-4.5311e-01, 2.1572e-01, 2.5959e+00, ..., 1.0544e-01,
-5.5539e-01, -8.8997e-02]],
[[-3.1299e-02, -9.0694e-02, -1.5058e-01, ..., 1.1410e-01,
1.2924e-01, 1.9263e-01],
[ 1.3471e-01, 1.8759e-01, -1.1682e-03, ..., 6.0341e-01,
6.3717e-02, -2.9939e-02],
[ 3.2003e-01, -1.8200e-01, 3.5685e-01, ..., -2.3825e-01,
-5.0240e-03, -2.7442e-01],
...,
[-6.7477e-01, 1.9573e-02, -2.9475e-01, ..., -1.5251e-01,
7.6870e-01, 2.2547e-01],
[ 1.0340e-02, -2.6907e-02, -3.2965e-02, ..., 2.9814e-02,
4.7616e-02, -8.8962e-01],
[-4.5525e-01, 1.1967e-01, 5.1916e-01, ..., 1.1242e+00,
-1.9753e-01, -5.9083e-01]],
[[ 1.9867e-02, 7.5011e-03, 2.6127e-02, ..., -3.7370e-03,
2.2394e-01, 1.3843e-02],
[-8.1746e-01, -3.7520e-01, -4.9655e-01, ..., 1.2893e-01,
-1.7921e+00, 3.8990e-01],
[ 1.1536e+00, -2.9730e-01, 1.5741e-01, ..., 2.0536e-02,
-2.0358e+00, 1.8910e-01],
...,
[ 2.2302e-01, -7.9008e-01, -4.9336e-01, ..., 4.5226e-01,
-2.0080e+00, -2.7697e-01],
[ 1.9959e-01, -9.6224e-02, -4.0063e-01, ..., 7.3084e-01,
-2.1308e+00, 1.2104e-01],
[-4.0285e-01, -3.4758e-02, 9.0270e-02, ..., 2.2755e-01,
-1.7261e+00, -1.9400e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[ 3.0508e-02, -2.2260e-01, 1.5264e-01, ..., -8.9290e-01,
7.5289e-01, -1.1935e+00],
[-2.1560e-01, 1.6692e-01, -9.1999e-02, ..., -5.7826e-01,
-1.4156e-01, 3.9989e-01],
[ 7.9431e-01, -8.6379e-01, -2.1202e-01, ..., 1.0457e+00,
-1.7295e+00, 2.5042e+00],
...,
[-2.0043e+00, 2.7582e-01, -7.2417e-01, ..., -5.6687e-01,
-8.2403e-01, 5.5432e-01],
[-2.2761e-01, 1.2259e+00, -2.9970e+00, ..., 6.2500e-01,
2.7706e-01, 2.3849e+00],
[-1.5395e-01, 3.0106e-01, -2.4458e-01, ..., 1.5246e+00,
-2.5438e-01, 1.0539e+00]],
[[ 8.1046e-01, 1.8663e-01, -6.4914e-03, ..., -1.5619e-01,
-1.0988e+00, -1.8967e-01],
[ 1.2836e+00, -1.7571e+00, 5.6523e-01, ..., 1.6203e-01,
5.3114e+00, -4.3371e-01],
[ 1.0058e+00, -1.4609e+00, 7.6332e-01, ..., -2.0171e-01,
6.9416e+00, 1.9571e+00],
...,
[ 6.8441e-01, -1.1089e+00, -1.3350e+00, ..., 2.8493e-01,
6.0420e+00, 1.7264e+00],
[ 3.6634e-01, -1.0104e-01, -2.1439e+00, ..., 2.6209e-01,
6.2713e+00, 2.7654e+00],
[ 9.3607e-01, -1.5236e+00, -1.6343e+00, ..., -1.6642e-01,
6.7665e+00, 2.8936e+00]],
[[ 3.4120e-01, -3.7081e-01, -3.0869e-01, ..., 3.6410e-01,
1.4399e+00, 2.5755e-01],
[-8.2960e-01, -6.1458e+00, -9.4992e-01, ..., -2.7348e+00,
-3.3308e+00, -6.2488e+00],
[-1.6905e+00, -5.8666e+00, -1.3552e+00, ..., -3.8865e+00,
-2.7315e+00, -6.3022e+00],
...,
[-2.7337e+00, -6.8483e+00, -1.6320e+00, ..., -2.6826e+00,
-3.3623e+00, -4.6536e+00],
[-1.3503e+00, -5.0896e+00, -1.5456e+00, ..., -3.7086e+00,
-3.3961e+00, -4.4885e+00],
[-4.4174e+00, -7.9308e+00, -2.5600e+00, ..., -3.4487e+00,
-1.4507e+00, -5.1099e+00]],
...,
[[ 2.2782e-01, 1.7694e+00, 5.3587e-01, ..., 2.5913e-01,
4.6895e-01, -1.6825e+00],
[-1.5099e+00, -4.9805e+00, -4.3378e-01, ..., -8.1371e-01,
-1.6987e+00, 6.6884e+00],
[ 7.1447e-01, -5.3976e+00, 2.6015e-02, ..., -1.4007e+00,
-1.7186e+00, 6.7642e+00],
...,
[-2.4670e+00, -7.2994e+00, 1.0229e+00, ..., -2.6515e+00,
-1.9696e+00, 6.8497e+00],
[-1.2711e+00, -5.8679e+00, -9.0149e-01, ..., -1.3883e+00,
-1.2481e+00, 7.0832e+00],
[-1.3249e+00, -7.4333e+00, 9.6733e-01, ..., -3.6964e+00,
-3.3296e+00, 8.0495e+00]],
[[ 5.4848e-02, -2.5668e-02, 1.4165e-01, ..., -9.0587e-02,
-8.5854e-02, -1.3366e-01],
[ 1.8120e+00, -8.8290e-01, 1.5319e+00, ..., -6.6818e-01,
2.6287e-01, -2.2123e-01],
[ 1.7768e+00, 1.8999e-01, -1.2643e+00, ..., 5.9022e-01,
-4.8334e-01, -6.0814e-01],
...,
[ 8.9763e-02, -9.0987e-01, -7.7684e-01, ..., -4.5400e-02,
5.2621e-01, 3.5608e-01],
[ 3.4736e-01, -6.2006e-01, -4.0945e-01, ..., -8.0235e-01,
7.8035e-01, -3.7293e-01],
[-8.1439e-01, -7.3308e-01, -2.7807e-01, ..., -6.0827e-01,
9.7116e-01, 1.3668e-01]],
[[ 3.9688e-01, -6.6984e-02, 1.8968e+00, ..., -2.4403e-01,
-2.1560e-01, -1.0126e+00],
[ 2.8065e+00, 1.3883e+00, -2.6807e+00, ..., 1.2111e+00,
3.3058e-01, 2.3798e+00],
[ 2.9053e+00, 6.0659e-01, -2.8907e+00, ..., 9.0771e-01,
4.6567e-01, 1.9132e+00],
...,
[ 2.9192e+00, 8.8462e-01, -2.8681e+00, ..., 1.4205e+00,
-2.0973e-01, 2.6587e+00],
[ 2.3439e+00, 1.7387e+00, -1.9872e+00, ..., 2.2289e+00,
-1.4917e+00, 3.7927e+00],
[ 2.2603e+00, 1.6628e+00, -2.2689e+00, ..., 8.3023e-01,
9.3067e-01, 4.7189e+00]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[ 4.1882e-02, 6.4490e-02, -9.3272e-03, ..., 1.6031e-02,
1.0052e-01, 3.5980e-02],
[ 3.2562e-01, -8.8795e-01, -4.9085e-02, ..., 3.6603e-01,
4.3858e-01, -6.1412e-01],
[ 7.5744e-02, -4.6294e-01, 1.1798e+00, ..., 1.4153e-02,
1.2599e-01, -8.4334e-01],
...,
[ 1.8532e-01, 6.7643e-02, -4.8084e-01, ..., 3.4835e-01,
-1.7105e-01, -2.2380e-01],
[-2.3345e-01, -3.6575e-01, 3.9444e-02, ..., -3.5445e-01,
-2.8266e-01, -3.4278e-01],
[-5.4808e-01, -3.0808e-01, -8.5808e-01, ..., 2.2267e-01,
-2.2272e-01, 1.7892e+00]],
[[-3.3985e-02, -1.3486e-03, 8.3480e-02, ..., -4.4081e-02,
-3.3619e-02, -5.0814e-02],
[ 1.6667e-01, 2.3407e-01, 3.0234e-01, ..., 2.8978e-01,
6.3130e-01, 3.1103e-01],
[ 3.4845e-01, -1.3173e+00, 3.9288e-01, ..., 3.7700e-01,
5.9143e-01, -7.0978e-01],
...,
[ 3.0984e-01, 8.3614e-02, 1.8268e-01, ..., -5.9535e-02,
1.1795e-01, 2.4208e-01],
[ 1.1407e-01, -1.4961e+00, -4.7466e-01, ..., 1.1233e+00,
4.0220e-01, 2.3130e-01],
[ 3.3782e-01, -7.3522e-01, -2.8408e-01, ..., 1.5294e-01,
1.7005e-01, -1.5641e-01]],
[[ 3.5987e-02, -1.0232e-01, -4.4135e-02, ..., -2.5785e-02,
8.6479e-02, -1.4895e-01],
[-4.2532e-01, -8.7885e-02, 4.6666e-01, ..., -1.6055e-01,
1.9573e-01, -3.9598e-01],
[-2.4160e-01, -2.7997e-01, 6.6099e-02, ..., -1.2111e-01,
3.1895e-01, -8.4615e-02],
...,
[-2.6233e-01, -4.7053e-01, 6.4113e-01, ..., 3.2162e-01,
-2.1453e-01, -1.4517e-01],
[ 3.8219e-01, 4.5365e-02, 1.6591e-02, ..., 2.6456e-01,
-2.6929e-01, 2.3254e-01],
[-1.4182e-01, 1.8406e-01, -9.6458e-01, ..., -5.1126e-01,
4.8666e-01, -2.5828e-02]],
...,
[[-2.3814e-02, 1.2032e-01, -7.7499e-03, ..., -2.3475e-02,
5.9122e-02, -4.1239e-02],
[ 5.3979e-01, 6.3833e-02, -8.8802e-01, ..., 2.3857e-01,
-3.4962e-01, -1.8191e-01],
[ 3.6208e-01, -8.7884e-02, 2.2842e-01, ..., 2.4545e-02,
-2.7802e-01, 3.5545e-01],
...,
[-2.0331e-02, -1.8373e-01, -2.5179e-01, ..., -4.0274e-02,
7.3222e-02, -3.7319e-04],
[-1.8016e-01, -3.4777e-01, -3.3174e-01, ..., 4.5794e-01,
-1.3380e-01, -1.7557e-01],
[ 1.7746e-01, -3.2240e-01, 2.8709e-01, ..., -1.9587e-01,
3.0646e-01, -3.8693e-01]],
[[-1.5903e-01, -1.2174e-01, -6.9957e-02, ..., -2.3985e-01,
-1.6513e-02, -3.9757e-02],
[-6.4740e-02, -2.3349e-02, 3.1592e-01, ..., 1.0429e+00,
-8.8360e-02, 8.3591e-01],
[-2.1363e+00, -7.1707e-01, -5.2893e-01, ..., 7.2560e-01,
-1.5213e+00, 9.7697e-01],
...,
[ 2.0967e-01, -6.3169e-01, 1.0254e+00, ..., 3.4685e-01,
1.1476e-01, 6.6754e-01],
[ 3.9468e-01, 3.6171e-01, 1.4690e+00, ..., 6.8382e-01,
2.1922e-01, 6.0305e-01],
[ 5.0124e-01, 4.3996e-01, 9.9185e-01, ..., 2.1516e-01,
-1.6881e-01, 3.6948e-01]],
[[ 1.1859e-01, -8.5552e-02, -2.7463e-02, ..., -1.7703e-02,
-9.0103e-02, -9.4804e-02],
[ 1.4253e-01, 4.5364e-01, -3.9422e-01, ..., 2.6342e-01,
6.0791e-01, 2.2922e-02],
[ 2.5674e-01, -4.9790e-02, 2.4432e-01, ..., 6.9650e-01,
3.0487e-01, -2.1513e-02],
...,
[-9.6273e-02, -1.2066e+00, -8.2821e-02, ..., -2.2924e-01,
4.0816e-01, -3.1922e-01],
[-3.7043e-01, 3.5381e-01, 3.8201e-01, ..., 1.6970e-01,
-4.3569e-01, -1.4287e-01],
[ 7.0933e-02, -1.3498e-01, 4.6551e-02, ..., -5.1041e-01,
-7.0377e-02, 1.3301e-02]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[-8.8230e-01, -1.4005e-01, 3.3679e-01, ..., -9.7407e-01,
1.9440e-02, -2.9572e+00],
[ 1.0084e+00, 1.0470e+00, -4.7232e+00, ..., -2.5201e+00,
-1.9014e+00, 8.8033e+00],
[ 1.2540e+00, 4.6455e-01, -3.4971e+00, ..., -1.8074e+00,
-1.3556e-01, 1.0052e+01],
...,
[-1.1896e-01, -4.2050e-01, -2.9757e+00, ..., -2.9992e+00,
-2.6175e+00, 8.8424e+00],
[-3.3484e-01, 4.7369e-01, -4.5834e+00, ..., -1.6147e+00,
-2.4933e+00, 1.0348e+01],
[ 9.3118e-01, -2.6527e+00, -8.8920e-01, ..., -8.0500e-01,
-4.1080e+00, 9.5769e+00]],
[[ 3.7102e-01, -7.1686e-02, 4.7841e-01, ..., -1.3458e-01,
-6.9245e-02, -2.2292e+00],
[-2.2462e+00, 1.5379e+00, 3.5270e+00, ..., -1.5324e+00,
-3.7864e+00, 6.7854e+00],
[-4.6717e-01, 8.8104e-01, 2.0514e+00, ..., 5.7600e-01,
-2.9924e+00, 7.1979e+00],
...,
[-3.3413e+00, -3.0525e-01, 3.8961e+00, ..., -1.8470e-01,
-1.0056e+00, 6.0956e+00],
[-3.8802e+00, -5.7609e-02, 1.5106e+00, ..., -5.5892e-01,
-2.4805e+00, 6.4896e+00],
[-2.2596e+00, 5.5595e-01, 3.5477e+00, ..., -7.9328e-01,
-2.3755e+00, 5.4057e+00]],
[[ 1.1640e-01, -6.5411e-01, -2.1746e-01, ..., 1.4848e-01,
2.7450e-01, -1.7022e-01],
[ 6.3319e-01, 2.4561e+00, -3.0855e-01, ..., 7.4445e-01,
8.7801e-01, -1.2477e+00],
[ 3.7108e-01, 3.1974e+00, 1.9218e-02, ..., -1.9633e+00,
2.6208e-01, -6.9583e-01],
...,
[-8.3210e-01, 3.1895e+00, -1.0288e-01, ..., -6.3625e-01,
7.0729e-01, 5.2103e-01],
[-5.0214e-02, 3.4132e+00, 1.9143e-01, ..., -3.9255e-01,
1.5079e+00, -1.2919e+00],
[ 7.7568e-01, 3.1053e+00, -2.3148e-01, ..., -7.4599e-01,
4.5266e-01, -3.9879e-01]],
...,
[[-3.8907e-01, 2.1348e-02, 5.1327e-03, ..., 1.2555e+00,
6.2963e-02, 1.7806e+00],
[-4.0056e-01, -1.2783e+00, 6.7603e-01, ..., -1.8325e+00,
-7.7475e-01, -1.6207e+00],
[-9.4218e-02, -2.5972e+00, 1.3495e+00, ..., -1.7981e+00,
-2.3495e+00, -1.4134e+00],
...,
[ 6.4192e-01, -7.7718e-01, 5.8877e-01, ..., -2.8845e+00,
-7.1077e-01, -1.4360e+00],
[ 5.6177e-01, 3.9009e-01, 5.3228e-01, ..., -1.8859e+00,
-2.0232e+00, -2.1431e-01],
[ 4.9119e-01, 2.7792e-01, 9.4523e-01, ..., -3.2013e+00,
-7.1947e-01, 6.3064e-01]],
[[-3.3280e-01, -1.1855e-01, 2.2112e-01, ..., 2.5834e-01,
-3.4137e-02, 1.9058e-02],
[-2.2876e-01, -8.1290e-01, -2.3541e-01, ..., 2.0128e-01,
-7.7041e-02, 1.6235e-01],
[ 3.5779e-01, -1.2911e+00, -2.4718e-01, ..., 8.0769e-01,
1.6120e-01, -4.3466e-01],
...,
[-3.7774e-01, -7.7144e-02, 1.0574e-01, ..., 1.1080e+00,
2.6553e-01, 1.8789e+00],
[-1.5887e+00, -1.8580e+00, 9.7814e-01, ..., 6.1661e-01,
2.0212e+00, 6.2472e-01],
[ 8.5004e-01, -1.1226e+00, 8.0483e-01, ..., 1.1392e+00,
1.7727e+00, 5.1945e-01]],
[[ 3.3855e+00, 2.1669e+00, -2.1247e+00, ..., -2.8605e+00,
-3.8921e+00, -1.1778e+00],
[-9.7418e-01, -4.1155e-01, 2.6531e+00, ..., -4.8420e+00,
1.1042e+01, 2.1530e+00],
[ 1.5355e+00, 5.0286e-01, 5.4398e+00, ..., -2.1528e+00,
9.8180e+00, -2.0278e+00],
...,
[-5.7341e+00, -1.2663e+00, 3.5021e+00, ..., -3.7820e+00,
1.3191e+01, 6.6441e+00],
[-2.8189e+00, -2.5276e+00, 3.8678e+00, ..., -4.2633e+00,
9.9248e+00, 3.7757e+00],
[-3.2653e+00, -1.6548e+00, 5.4356e+00, ..., -5.3807e+00,
1.1478e+01, 3.6439e+00]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[-2.8015e-03, -4.2843e-02, 2.1946e-02, ..., 5.9664e-02,
2.8541e-02, 7.5017e-02],
[ 4.7948e-02, -2.7571e-01, 5.6533e-01, ..., -1.1318e+00,
4.9030e-01, 3.7480e-01],
[-1.7491e-01, -1.1429e-01, 3.9419e-01, ..., 3.6174e-01,
6.4979e-01, -2.4254e-02],
...,
[ 5.3288e-02, -1.5184e-01, -6.9785e-04, ..., 1.3779e-01,
-9.9556e-02, 4.6005e-01],
[ 2.4405e-01, -3.2254e-01, -4.3253e-01, ..., -2.5381e-01,
5.0629e-01, -2.1757e-01],
[ 4.3920e-01, -2.6082e-01, 4.6502e-02, ..., -2.6076e-01,
-4.6805e-01, 8.1107e-02]],
[[-6.7511e-02, -1.7221e-02, -1.4186e-01, ..., -4.4599e-02,
4.3422e-02, -1.4441e-02],
[-1.5872e-01, 2.0812e-01, 2.3949e-01, ..., 3.2546e-01,
2.2967e-01, 3.7792e-01],
[ 6.9327e-01, 9.3650e-02, -1.0073e+00, ..., 7.9522e-01,
-2.5250e-01, -2.6677e-01],
...,
[ 1.6569e-01, -5.1011e-01, 6.5104e-02, ..., -3.7376e-02,
1.4629e-01, 3.6823e-01],
[-3.1743e-01, -3.3498e-01, 3.5960e-01, ..., -8.5223e-01,
-3.3373e-01, -3.6818e-01],
[ 2.5580e-01, -4.3256e-01, -4.2976e-01, ..., -7.5222e-01,
-5.0993e-02, 9.3812e-03]],
[[ 7.0874e-02, 8.7366e-02, 8.0504e-02, ..., 1.9967e-02,
-8.7679e-02, 1.5890e-03],
[-8.3180e-01, 2.3641e-01, -4.9481e-01, ..., -1.0852e-01,
6.0099e-01, -4.9572e-02],
[ 3.7039e-01, 2.2074e+00, 7.2645e-01, ..., -1.3161e-01,
1.3500e+00, 8.9254e-01],
...,
[-5.4997e-01, 1.6803e-02, -2.1176e-01, ..., 5.1826e-01,
-8.0877e-01, -5.2248e-01],
[-1.0285e+00, 4.6817e-01, 7.8376e-01, ..., -7.7642e-01,
1.1818e+00, -3.8882e-01],
[ 5.9641e-02, 1.2021e-01, 2.2528e-01, ..., 1.6642e-03,
-2.0485e-01, 7.6625e-01]],
...,
[[-7.1236e-03, 7.8853e-02, -8.0076e-02, ..., 4.3391e-02,
4.2254e-02, -1.3848e-01],
[ 7.1242e-01, 2.3731e-01, 2.5584e-01, ..., -5.2713e-01,
-5.9722e-01, 6.7882e-01],
[ 2.6651e-01, -3.6351e-01, -3.3804e-01, ..., 2.8267e-01,
-6.4981e-01, 2.8016e-01],
...,
[-1.4381e-01, -1.6843e-01, -3.1744e-01, ..., 3.1899e-01,
4.8887e-03, -5.6601e-02],
[ 6.6585e-01, -5.8674e-01, -6.6238e-01, ..., -5.9955e-01,
9.0017e-02, -6.8978e-01],
[ 5.5239e-01, 9.5414e-01, -5.1927e-02, ..., -2.9619e-01,
5.8290e-01, -2.2422e-01]],
[[-1.2629e-01, -4.8935e-02, 1.1181e-01, ..., -6.4385e-02,
4.6905e-02, -1.4047e-02],
[-7.7673e-01, -8.3885e-01, -9.9223e-01, ..., 7.1531e-03,
-6.5990e-01, 5.3350e-02],
[-1.3458e+00, -1.0354e+00, -6.1002e-01, ..., 9.9597e-01,
3.9742e-01, 2.1499e-01],
...,
[-1.5937e-01, 6.1105e-01, 1.3167e+00, ..., 9.3129e-04,
1.4881e-01, 8.5424e-01],
[ 1.0006e+00, 3.9797e-01, 6.0946e-01, ..., 2.0232e+00,
1.9867e-01, 5.5082e-01],
[ 1.2180e-01, -6.8567e-01, -8.9236e-01, ..., 6.9602e-01,
-1.2220e+00, 2.4029e-01]],
[[-3.7457e-03, -9.2607e-03, -2.4621e-02, ..., -2.7746e-02,
5.3264e-03, -1.2048e-02],
[-3.3242e-02, 5.1789e-01, -9.5424e-02, ..., -1.9423e-02,
9.1498e-02, -3.0327e-01],
[-3.9111e-01, 2.6925e-01, -8.8339e-01, ..., -3.5068e-01,
2.4670e-02, -8.1843e-01],
...,
[ 1.2169e-01, -1.4590e-01, -1.0231e-01, ..., -2.1047e-01,
-9.8497e-02, 3.3168e-01],
[-1.3166e-02, -1.5918e-01, 4.4679e-02, ..., -2.8859e-01,
7.7163e-02, -8.9428e-02],
[-1.5484e-01, 4.7036e-01, 4.8949e-01, ..., -5.6709e-01,
-4.7696e-01, 5.4917e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[ 2.4425e-02, -2.9901e-01, 2.2618e-01, ..., 1.7002e+00,
-2.1425e-01, -7.5452e-02],
[-6.7344e-01, 1.9831e+00, 3.7440e-02, ..., -3.6278e+00,
2.9602e-01, -1.7517e+00],
[ 8.0027e-01, 8.3515e-02, -1.5843e+00, ..., -4.8216e+00,
-1.1108e+00, -2.5017e+00],
...,
[ 6.5923e-01, 5.1771e-01, 1.1880e-01, ..., -3.2431e+00,
-2.4071e+00, -9.8829e-01],
[ 9.1825e-01, -4.9183e-01, 9.0978e-01, ..., -3.6721e+00,
-4.4284e-01, -7.8973e-01],
[-1.8351e+00, -9.1667e-01, 4.5614e-01, ..., -4.2885e+00,
-2.8697e-02, -5.2182e-01]],
[[ 1.6053e-01, 9.7512e-01, -1.4212e+00, ..., -1.1826e-01,
2.6721e-01, 9.2271e-01],
[-6.4489e-01, -4.8048e+00, 1.7009e+00, ..., 4.8935e-01,
6.5187e-01, -2.0910e+00],
[-4.3839e+00, -6.9272e+00, 1.9357e+00, ..., -3.3397e+00,
-2.2280e+00, -4.7667e+00],
...,
[ 6.6274e-02, -4.7992e+00, 2.9413e+00, ..., 1.5957e+00,
-9.6139e-04, -3.1260e+00],
[ 4.0182e-02, -4.1015e+00, 2.5708e+00, ..., 9.0249e-02,
1.7288e-01, -2.9824e+00],
[ 1.5543e+00, -6.1758e+00, 7.6686e+00, ..., 9.0697e-01,
-1.1387e+00, -6.2601e+00]],
[[-6.6854e-01, 2.4147e-01, -4.0604e-02, ..., 1.7321e-01,
3.9602e-02, -2.9668e-01],
[ 2.4421e+00, -9.1462e-01, -8.2885e-01, ..., -8.3236e-01,
-1.9725e-01, -2.7875e-01],
[ 1.0662e+00, -6.5159e-01, -8.8043e-02, ..., -1.3173e+00,
1.0929e+00, -3.4334e-02],
...,
[ 1.3428e+00, -1.4381e-01, 1.3997e+00, ..., -9.9791e-01,
-5.2218e-01, -1.3676e+00],
[ 2.1379e+00, -1.6316e+00, -7.0705e-03, ..., 1.3214e+00,
2.4926e+00, -1.0005e+00],
[ 3.3410e+00, -1.9271e+00, -2.6213e-01, ..., -8.0559e-01,
1.2947e+00, -5.7613e-01]],
...,
[[-2.9694e-02, 1.2897e-01, 1.5171e-01, ..., -1.0306e-01,
2.4162e-02, 1.6164e-01],
[ 8.7051e-01, 8.9658e-01, -2.1684e-01, ..., 1.1384e+00,
-3.0451e-01, 7.1239e-01],
[ 2.0829e-01, -5.3203e-01, 1.8881e-01, ..., 1.1098e+00,
-1.0654e+00, 1.7870e+00],
...,
[ 5.8492e-01, 1.2304e+00, 8.5839e-01, ..., 1.8900e+00,
-2.5703e-02, 6.3271e-01],
[ 8.2500e-01, 5.2305e-01, -2.0326e-01, ..., 8.7598e-01,
7.1087e-01, -8.2672e-01],
[-8.7811e-01, 1.5896e+00, -1.4590e+00, ..., 1.2713e+00,
5.9443e-01, 1.0977e+00]],
[[-3.0105e+00, 4.0288e-01, -3.0686e-02, ..., -4.7620e-01,
-3.6181e-01, 1.2402e+00],
[ 4.3091e+00, -3.6355e-02, -1.2733e+00, ..., -1.3761e-01,
-1.0568e+00, 3.7705e-01],
[ 3.5537e+00, -1.2094e+00, 7.3616e-01, ..., -1.2436e+00,
-1.0722e+00, 1.5382e+00],
...,
[ 5.0508e+00, 9.7999e-01, -9.1501e-01, ..., -8.3978e-01,
-7.1137e-01, -9.0957e-01],
[ 5.3393e+00, 1.0491e+00, 1.3785e+00, ..., 2.3665e-01,
7.1385e-01, -8.4908e-01],
[ 4.7278e+00, -7.3271e-01, 2.4610e-01, ..., -8.1648e-01,
8.1328e-01, -4.9938e-01]],
[[-1.0533e-02, -2.4607e-01, 2.7443e-03, ..., -1.7876e-01,
3.2764e-01, 8.2148e-02],
[ 8.6160e-01, 1.4717e-01, 1.1344e+00, ..., -5.2077e-01,
5.9950e-01, -3.6589e-01],
[ 6.8364e-01, -1.4481e+00, -2.4772e-01, ..., 7.1361e-01,
8.1152e-01, -4.3211e-01],
...,
[-6.3802e-01, -7.4226e-01, 2.0134e-02, ..., -2.2923e-02,
-3.0777e-01, -1.3769e+00],
[ 7.0286e-01, -7.6536e-01, -3.0677e-01, ..., 1.0340e+00,
4.4649e-01, -8.3763e-01],
[-9.1978e-01, 3.4823e-01, 7.0052e-01, ..., -7.4284e-01,
9.8658e-01, -5.9086e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[-0.0260, -0.0201, 0.0090, ..., -0.0063, -0.0343, 0.3522],
[ 1.0229, 0.3651, -0.2132, ..., -1.4306, 0.3113, -0.5238],
[ 1.5326, -0.5122, 0.1662, ..., -0.8095, -0.3197, 0.2882],
...,
[ 0.1370, -0.3330, -0.3261, ..., -0.0797, 0.4292, -1.0034],
[-0.1414, 0.3330, 0.3330, ..., 0.0371, 0.1574, -0.5799],
[-0.8609, -2.4307, -0.8873, ..., -0.3572, -0.6777, -0.1879]],
[[-0.0041, -0.0135, 0.0183, ..., -0.0215, 0.0231, 0.0108],
[ 1.3730, 0.6940, 0.7279, ..., -0.9966, 0.9910, -0.2261],
[ 0.3186, 0.4918, 1.7542, ..., -0.2500, -0.0790, -0.1540],
...,
[-0.5563, -0.5748, 0.8149, ..., 0.5584, 1.4183, 0.3196],
[ 0.8510, -0.7996, -0.0429, ..., -0.2517, 0.5159, -0.4891],
[ 0.3454, -1.2026, -0.2286, ..., -0.5790, 0.3819, 0.1525]],
[[-0.0566, 0.0091, -0.0393, ..., -0.0407, 0.0051, -0.0818],
[-0.6938, -1.2322, -1.1038, ..., -0.1297, 1.1394, -0.9904],
[-0.0594, 0.3513, -0.5823, ..., -0.1540, 1.2982, 1.5220],
...,
[-0.5360, 0.0138, -0.0441, ..., -0.5549, -1.2158, -0.0624],
[-0.2674, -0.5419, -0.1500, ..., -0.7724, 0.6361, -0.8990],
[ 0.1295, 0.5581, 0.0097, ..., 0.6635, -1.1220, -0.1898]],
...,
[[-0.3324, -0.1904, -0.0728, ..., -0.4847, 0.2203, 0.1020],
[ 1.2505, -0.1432, -0.2780, ..., 1.0645, -0.9976, 0.3491],
[-0.6492, -2.1424, 1.6660, ..., -0.4581, -1.0233, -1.4598],
...,
[-0.5372, -0.5731, -0.6616, ..., 0.6041, -0.6421, -0.1693],
[ 2.2314, -0.7433, -0.1765, ..., 0.3478, -1.5081, 0.6716],
[ 1.1512, -1.5820, 0.2295, ..., -0.0358, 0.5782, -0.6780]],
[[-0.0805, -0.1338, -0.0433, ..., -0.1887, -0.1398, 0.1256],
[-0.1725, -0.9856, 0.0283, ..., 0.1987, 0.2887, -0.2125],
[-0.2868, -1.2376, -0.8224, ..., 1.6796, 0.9134, 0.3702],
...,
[-0.2671, 0.1030, -0.0178, ..., -0.5901, -0.1262, 0.1294],
[-0.8149, 0.5106, -0.9661, ..., 1.2941, -0.3053, 0.1000],
[-0.2159, 0.3548, -0.1119, ..., 0.1122, 0.0045, 0.7179]],
[[-0.0283, -0.0370, 0.0912, ..., 0.0785, -0.0412, 0.0123],
[-0.3245, 0.4457, -0.2747, ..., -0.8789, -0.0971, 0.6181],
[-1.7313, 0.3132, -0.7923, ..., -0.9594, -1.0080, 1.0855],
...,
[-0.0484, 0.5334, 0.0299, ..., 0.4920, -0.5462, 0.0646],
[-1.3020, 1.6345, 0.7471, ..., -1.0761, -0.1769, 0.0983],
[-0.3621, 1.1414, -0.7577, ..., -0.3690, 1.1279, 0.0493]]]],
device='cuda:0', grad_fn=<PermuteBackward0>)), (tensor([[[[-3.3578e-01, 8.6494e-01, -1.5782e-01, ..., 1.1243e+00,
-1.7831e-01, 1.3638e-01],
[ 6.2278e-01, -2.4989e+00, 1.1164e+00, ..., -2.6288e+00,
-1.5201e+00, 1.1154e+00],
[ 2.5411e+00, -4.5835e+00, 1.0466e+00, ..., -3.0950e+00,
-1.3118e+00, 2.0691e+00],
...,
[ 1.4910e-01, -5.1488e+00, 1.4900e+00, ..., -3.5783e+00,
-1.2938e+00, 1.6264e+00],
[ 7.1818e-01, -4.5352e+00, 3.0051e+00, ..., -3.4607e+00,
5.2433e-01, 3.7942e-01],
[-3.1768e-01, -4.0519e+00, 3.3414e+00, ..., -4.1427e+00,
3.7519e-01, 4.0434e-01]],
[[ 6.0723e-02, 8.6226e-01, -6.4320e-01, ..., -4.2258e-02,
2.9290e-01, 8.5186e-03],
[-8.5910e-01, 2.4301e+00, -2.0816e-03, ..., -8.2199e-01,
-1.9111e-01, -5.7361e-01],
[-1.7353e-01, -1.0406e-02, 2.9922e+00, ..., 1.7248e+00,
-1.3871e+00, -9.8745e-01],
...,
[ 1.4691e+00, -6.5255e-01, 1.0800e+00, ..., 2.4355e+00,
2.8071e-01, -9.0710e-01],
[-4.0101e-01, -1.8545e-01, 1.2093e+00, ..., 1.2111e-01,
2.2229e-01, -8.8183e-01],
[-1.4461e+00, -1.5290e+00, -7.7958e-01, ..., -1.7141e-01,
1.2369e+00, 1.6131e-01]],
[[-3.1530e-01, 1.2470e-01, -9.8533e-01, ..., -3.5174e-01,
-6.0202e-02, -1.3882e-01],
[-8.9693e-01, 2.8848e-01, 4.4991e+00, ..., 5.4666e-01,
-3.4291e-01, 2.7847e-01],
[-5.8518e-01, -4.2858e-01, 3.9293e+00, ..., -4.4146e-02,
-1.5199e-01, -3.5618e-01],
...,
[-8.3947e-01, -3.5598e-01, 2.7114e+00, ..., 4.2037e-01,
6.0501e-01, 1.4794e+00],
[-1.1543e+00, -7.0097e-01, 3.8364e+00, ..., 9.3257e-01,
5.2762e-01, -4.1934e-02],
[-2.1461e-01, 1.0343e-01, 2.4779e+00, ..., -1.0749e+00,
-1.1717e-01, -8.4882e-01]],
...,
[[ 3.8411e-01, 8.0610e-02, -7.7487e-02, ..., -3.1953e-02,
2.1622e-01, 1.0605e-02],
[-6.5354e-01, -3.1612e-02, -1.0638e+00, ..., -5.8875e-01,
1.0573e+00, -1.2601e-01],
[ 2.7933e-01, 1.6202e+00, 5.9959e+00, ..., -2.2502e+00,
-4.2706e+00, -3.7783e-01],
...,
[-1.3431e-01, -2.3244e+00, -1.0504e+00, ..., -1.6168e+00,
-3.2933e-01, 2.5949e-01],
[ 7.4608e-01, -1.4878e+00, -6.3726e-01, ..., -2.2942e+00,
-9.9182e-02, 7.8056e-02],
[-5.6625e-01, -1.2365e+00, -1.7069e+00, ..., -3.6532e+00,
5.9468e-01, 1.1633e+00]],
[[ 2.0019e-01, 6.5883e-02, 3.2484e-01, ..., 4.1552e-01,
1.3110e-02, 2.2381e-01],
[ 1.0125e+00, 2.4286e-01, 9.5984e-01, ..., -4.9362e-01,
-8.9649e-02, 4.7713e-01],
[ 7.9506e-01, 3.0244e-01, 1.0567e+00, ..., -8.4306e-01,
-1.5668e+00, 3.5901e-01],
...,
[ 5.9118e-01, 1.3707e-01, 1.1906e+00, ..., -1.1724e+00,
-1.5317e-01, -3.6017e-01],
[ 1.5434e+00, 6.0157e-01, 9.3965e-01, ..., -9.6186e-01,
-1.3340e+00, 1.4163e+00],
[ 1.2101e+00, -5.5004e-01, 1.9818e+00, ..., -4.1132e-01,
3.9219e-01, 1.3851e+00]],
[[-3.0144e+00, 5.5559e-01, 5.5824e-01, ..., -9.2879e-01,
3.1234e-01, 2.1565e-01],
[ 7.9342e+00, -1.3099e+00, -1.6237e+00, ..., 2.3727e+00,
-1.3910e-01, 2.4147e-01],
[ 8.7548e+00, 3.1521e-01, -3.0962e+00, ..., 2.1272e+00,
2.3255e+00, -9.4848e-01],
...,
[ 1.0275e+01, -4.5047e-01, -2.9359e+00, ..., 1.8721e+00,
-1.5067e+00, -5.9854e-01],
[ 9.8252e+00, -1.5088e+00, -3.1082e+00, ..., 1.3357e+00,
-8.5859e-01, -1.5895e+00],
[ 8.9484e+00, -3.9686e+00, -1.0389e+00, ..., 2.1457e+00,
1.6358e-01, -7.7658e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[ 5.1141e-02, -4.3904e-02, 1.0421e-02, ..., -7.0718e-02,
4.0156e-03, -8.7774e-02],
[-3.6166e-02, -4.8795e-01, 9.8731e-03, ..., 5.3657e-02,
-3.1082e-01, 2.3718e-01],
[ 2.2784e-01, -1.5567e-01, -3.7655e-01, ..., -8.5816e-01,
-2.3128e-01, 5.7785e-01],
...,
[ 9.1042e-01, 1.2031e-01, 1.9260e-01, ..., 4.4779e-01,
9.7324e-01, -1.5580e+00],
[-7.3678e-03, -7.9632e-03, -1.6174e+00, ..., -2.6624e-01,
1.3078e-01, -6.4462e-01],
[-8.8761e-01, 4.9878e-01, -8.4630e-01, ..., 4.2564e-01,
6.2020e-01, -7.2440e-01]],
[[ 6.9888e-02, 1.9044e-02, -2.6442e-02, ..., -2.8091e-02,
8.5949e-03, -8.9132e-03],
[ 3.0326e-01, 9.8972e-01, -3.9176e-01, ..., 7.6245e-02,
-2.2104e-01, 4.6061e-01],
[-8.6927e-01, 6.3877e-02, 1.8170e-01, ..., -1.3091e-01,
-3.3398e-02, 9.1734e-01],
...,
[-4.1310e-01, 2.4178e-01, -1.0092e-02, ..., 6.3579e-01,
-8.6536e-01, 1.3382e+00],
[-1.0261e+00, 2.0221e-01, 5.9535e-01, ..., 5.4258e-02,
-3.7277e-01, 6.7027e-01],
[ 6.3411e-01, 2.9404e-01, -2.1462e-01, ..., 5.9986e-01,
-5.7667e-01, 2.7974e-01]],
[[ 7.8181e-02, 1.2402e-02, 3.8624e-03, ..., 2.5141e-02,
-7.0518e-02, -6.1704e-02],
[-1.8774e-02, 1.2058e-01, -3.6527e-01, ..., -2.2870e-01,
6.9500e-01, 1.5349e+00],
[-1.6330e+00, 1.2746e+00, 9.4037e-01, ..., -9.1634e-01,
-3.8535e-01, -9.6021e-01],
...,
[-1.0440e-01, -6.7721e-01, -9.5557e-01, ..., -6.2216e-01,
1.2931e-01, 5.3809e-01],
[-1.6070e-01, 8.0107e-01, -6.2008e-01, ..., -5.0174e-01,
1.1805e+00, 9.2189e-01],
[-1.6578e-01, -7.5677e-01, -4.0900e-01, ..., -1.2848e+00,
5.3773e-01, -8.9344e-01]],
...,
[[ 4.7203e-03, 2.5398e-02, 1.7264e-02, ..., -7.3335e-02,
-1.9856e-02, 1.3880e-02],
[-1.7961e+00, -1.7059e-01, -2.1769e-01, ..., 5.8683e-01,
-8.8736e-01, -1.7493e+00],
[-5.4797e-01, -1.0059e+00, -6.7425e-01, ..., 1.4185e-01,
1.2526e+00, -4.5236e-01],
...,
[-7.9344e-01, -6.5969e-01, -4.0526e-01, ..., 6.3055e-01,
-5.9347e-01, -1.1570e+00],
[-2.6411e-01, 3.5041e-01, 9.5562e-01, ..., 9.8356e-01,
6.3222e-01, -1.1216e+00],
[ 1.9708e-01, -8.2632e-01, -1.3375e-02, ..., -6.5847e-01,
-8.4580e-01, -1.7935e-03]],
[[ 4.5729e-02, -1.4005e-02, 2.1403e-02, ..., 3.1315e-02,
-8.0189e-03, 2.6656e-03],
[ 7.4114e-01, 1.2927e+00, -8.6742e-02, ..., 4.5960e-01,
-6.8074e-01, 1.6970e-01],
[ 5.9978e-01, 1.2898e+00, 9.5936e-01, ..., -5.8497e-01,
-6.0349e-01, 6.4145e-01],
...,
[ 3.8285e-01, 8.4960e-02, 4.7635e-01, ..., -5.7194e-01,
-3.3707e-01, -1.5005e-01],
[-9.8847e-01, -8.6587e-01, -1.8648e-01, ..., -1.0971e+00,
2.6862e-01, -2.5423e-01],
[-7.6857e-01, -3.6633e-02, -1.9877e+00, ..., -4.6666e-02,
-8.1418e-01, -2.1952e-01]],
[[ 7.2741e-02, -1.9836e-01, -7.2501e-02, ..., -3.1681e-02,
1.9165e-01, -4.7206e-02],
[-6.0899e-01, -1.0427e+00, -2.6415e-02, ..., -6.7561e-01,
-1.8160e-01, 1.0336e+00],
[ 6.4349e-01, 5.6121e-02, -3.9025e-01, ..., -9.6565e-01,
1.0276e+00, -7.7055e-01],
...,
[-1.0704e-01, -6.4438e-01, -1.1632e-01, ..., 4.5279e-01,
-3.9950e-01, -6.8103e-02],
[ 8.0036e-02, -6.1467e-01, 1.7871e-01, ..., -6.1554e-01,
1.4653e-01, -1.2892e+00],
[ 1.3268e+00, -1.8086e-01, 1.1401e+00, ..., -1.1366e+00,
-4.4702e-01, -4.2245e-02]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[ 1.0499e+00, -2.6094e-01, -1.4624e-01, ..., 6.4066e-01,
7.3481e-01, -3.1200e-01],
[-4.3803e+00, -1.7161e+00, 1.2953e+00, ..., 4.6987e-01,
-5.0588e+00, 4.3037e-01],
[-2.9824e+00, -1.4321e+00, 1.6383e+00, ..., -1.0036e+00,
-6.1600e+00, -6.0011e-01],
...,
[-6.8617e+00, -4.6154e+00, 2.7333e-01, ..., -8.7715e-01,
-5.9364e+00, -1.4132e+00],
[-6.7327e+00, -2.1619e+00, 8.2593e-02, ..., 2.3139e-01,
-4.8873e+00, -1.5720e+00],
[-6.2051e+00, -3.0823e+00, 8.0622e-01, ..., 4.3842e-01,
-5.3239e+00, -3.8893e-01]],
[[-1.4052e-01, -6.4056e-02, 1.6708e-01, ..., -5.0385e-02,
-8.8839e-01, -1.9052e-01],
[ 2.1709e-01, 1.1619e+00, 5.7067e-01, ..., -1.3953e+00,
3.2450e-01, 1.2646e+00],
[-1.5323e+00, 1.3859e+00, 1.5808e+00, ..., 8.2068e-01,
7.3862e-01, 1.8539e+00],
...,
[ 6.0382e-01, -1.1719e+00, 2.8846e-01, ..., 5.3128e-01,
-1.8594e+00, 1.5157e-01],
[-9.1678e-01, -4.0683e-01, 6.5201e-01, ..., -3.7854e-01,
3.6177e-01, 7.1991e-01],
[-1.1788e+00, -8.7759e-01, -1.1967e+00, ..., -3.4275e-01,
-4.2240e-01, 4.9525e-01]],
[[ 1.9705e-01, 3.1511e-01, 1.1296e+00, ..., -4.5656e-01,
4.3585e-01, -4.9803e-01],
[-2.4683e-01, -1.8572e+00, -1.1624e+00, ..., -3.6835e-01,
-1.8631e+00, 2.6684e+00],
[-8.0515e-01, 1.7749e-01, -2.0522e+00, ..., -9.9005e-01,
-1.4308e+00, 2.0506e+00],
...,
[ 1.7602e+00, -9.8673e-01, -3.5310e+00, ..., -4.7019e-01,
-3.3827e+00, 2.8605e+00],
[-7.9944e-01, -8.1130e-01, -2.5108e+00, ..., -4.9938e-02,
-1.6870e+00, 2.8972e+00],
[-1.8172e-01, -1.3924e+00, -3.6942e+00, ..., 1.2416e+00,
-1.7045e+00, 3.0463e+00]],
...,
[[-1.5123e-01, 7.1820e-02, -2.2811e-01, ..., 4.1857e-04,
1.5596e-01, 2.9428e-02],
[-2.7472e+00, 8.9720e-01, -8.9011e-01, ..., 2.0976e+00,
4.3475e-01, 3.6547e-01],
[-3.9650e+00, -1.7982e+00, 1.3452e+00, ..., -4.0594e-03,
1.0195e+00, 8.0354e-01],
...,
[-9.3269e-01, -1.4746e-01, -1.1404e+00, ..., 1.7352e+00,
-3.7063e-01, -1.1835e+00],
[-1.5874e+00, -4.5612e-01, 5.5681e-01, ..., 1.8682e+00,
4.1865e-01, -1.6453e+00],
[-3.8839e-01, 1.3997e-01, 6.3983e-01, ..., 4.3045e-01,
-1.2466e+00, -2.4118e-01]],
[[-3.5156e-01, -2.1933e+00, 1.1372e-01, ..., -8.6831e-02,
-4.2204e-02, 9.1871e-01],
[ 1.6741e+00, 2.0308e+00, 2.3405e-02, ..., 6.5852e-01,
-1.0524e+00, 1.5907e-01],
[ 1.9101e-01, 5.3070e+00, -5.7532e-01, ..., -3.4579e+00,
-2.8675e+00, -2.0603e+00],
...,
[-6.3685e-01, 3.5661e+00, -5.8024e-01, ..., 1.1246e-01,
-3.5499e+00, 4.4231e-01],
[ 1.2368e+00, 2.0089e+00, -5.1367e-01, ..., -3.6685e-01,
-1.5970e+00, 6.9527e-01],
[ 1.7322e+00, 6.2705e+00, -2.0639e+00, ..., -1.0150e-01,
-2.8591e+00, -1.2347e+00]],
[[ 3.6793e-01, 7.5982e-02, -1.3786e-01, ..., 6.4646e-01,
1.3555e-01, 2.5748e-01],
[ 1.0353e-01, 3.1763e-02, -7.6792e-01, ..., -4.7279e-01,
6.0120e-01, -5.2789e-01],
[-1.2707e+00, -1.6400e+00, 1.3822e+00, ..., -1.4171e+00,
1.8270e+00, -1.3366e+00],
...,
[-1.9160e+00, -1.9547e-01, 1.0904e+00, ..., -1.0474e+00,
-4.5762e-01, -6.0643e-01],
[-1.0998e+00, -9.2082e-01, -2.4341e-01, ..., -2.0234e-01,
-1.1176e-01, -1.7190e+00],
[-2.0762e+00, 5.8654e-01, 2.3925e-01, ..., -2.1948e+00,
-7.5425e-01, -1.1309e+00]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[-3.6090e-02, 5.8404e-02, -4.7430e-02, ..., -1.6448e-02,
6.5290e-03, 2.0813e-02],
[ 3.3981e-01, -2.0113e-01, -4.2505e-01, ..., 2.1069e-01,
-2.2217e-01, 1.4801e-02],
[ 5.1794e-01, -2.6169e-01, -2.3763e-01, ..., 7.0568e-01,
5.7194e-01, -3.8482e-01],
...,
[ 2.5554e-01, 1.4591e-01, 6.5668e-01, ..., -9.7976e-01,
7.0952e-01, -5.4556e-01],
[ 4.0079e-01, -5.0342e-03, -7.9301e-01, ..., 1.2865e-02,
-1.8381e-01, -3.0065e-02],
[-2.7750e-02, -1.1973e+00, -6.5840e-01, ..., -2.8713e-01,
-1.2795e-01, -5.5134e-01]],
[[ 1.0002e-02, -2.6317e-02, 2.8788e-02, ..., 2.0878e-02,
-5.0190e-02, 1.7778e-02],
[ 7.7976e-01, 3.3234e-01, -6.3060e-01, ..., 9.9832e-02,
-2.1819e-01, 6.5724e-01],
[ 9.6898e-01, -9.7757e-02, -1.4808e-01, ..., -8.0822e-03,
-5.7036e-01, 2.0158e-01],
...,
[ 1.8208e-01, 3.9556e-01, -9.8637e-01, ..., 7.4350e-02,
9.4404e-01, 1.4957e-01],
[-4.6217e-01, 1.1374e+00, -1.9512e+00, ..., -1.6531e-01,
7.4679e-01, -1.0709e-01],
[ 8.1311e-02, 9.7903e-01, -1.7733e+00, ..., -1.1529e-03,
2.6891e-01, -2.6389e-01]],
[[ 2.8050e-02, -3.8394e-02, 4.2667e-02, ..., 1.8323e-02,
-7.2851e-03, 6.6780e-03],
[ 8.1693e-03, -6.7934e-01, -1.6569e+00, ..., 2.1917e+00,
-8.0082e-01, 2.0747e+00],
[-9.8902e-02, -9.6721e-01, -1.9402e+00, ..., 2.0296e-01,
-1.9917e+00, -5.7315e-01],
...,
[-1.3532e+00, -4.1404e-01, -9.7093e-01, ..., -2.2521e-01,
-1.0802e+00, 1.7622e-01],
[-8.9913e-01, 8.1342e-01, -4.2660e-01, ..., -4.5484e-01,
-4.4985e-01, -5.2983e-01],
[ 1.4025e+00, 1.8062e-01, -1.2818e+00, ..., 1.9920e-01,
-5.2821e-01, -2.5628e-01]],
...,
[[-1.8594e-01, 9.0682e-02, 4.9690e-02, ..., 4.3440e-02,
3.8206e-02, -1.2885e-01],
[ 7.9794e-01, 6.0255e-01, -9.1034e-01, ..., 1.2564e+00,
2.1356e+00, 1.4350e+00],
[ 1.6348e+00, 2.4815e+00, -7.9229e-01, ..., 4.7704e-02,
4.2612e-01, 4.7895e-01],
...,
[-8.7949e-02, 5.8050e-01, 8.1942e-01, ..., -6.4081e-01,
-8.1574e-01, 3.7447e-01],
[ 2.8842e-01, 4.7873e-01, 1.0562e-01, ..., -3.5749e-01,
4.3560e-01, -3.4558e-01],
[ 5.2590e-01, 2.0273e-02, -2.1104e-01, ..., 4.7948e-01,
-4.2571e-01, -1.4939e+00]],
[[-5.9095e-01, -5.1091e-03, 4.9711e-02, ..., -6.8233e-03,
1.0263e-02, -1.0583e-03],
[-9.2941e-01, -5.9255e-01, 1.0062e+00, ..., -3.8768e-01,
-5.7188e-01, 8.9663e-01],
[-3.1489e+00, -1.1034e-01, -4.2379e-01, ..., 2.2262e+00,
1.2680e+00, 2.7568e-01],
...,
[-1.5769e+00, 1.9565e-01, -5.7385e-01, ..., 1.3197e+00,
-1.3651e-01, -3.4665e-01],
[-7.1268e-01, 9.8680e-01, -7.7558e-01, ..., 7.7028e-01,
1.1927e-03, -6.3204e-01],
[-1.0811e+00, 1.0461e+00, 1.1175e+00, ..., 6.5761e-01,
-1.2024e+00, -8.7618e-01]],
[[ 4.1802e-03, 8.6289e-02, -4.7156e-02, ..., 6.4784e-02,
3.7624e-02, -3.6941e-02],
[ 3.8048e-01, -1.1929e-01, -7.9124e-01, ..., -4.4751e-01,
-1.1959e+00, -4.3578e-01],
[ 2.4210e-01, -7.7073e-01, -3.8349e-01, ..., -7.1458e-01,
-8.4916e-01, 5.2514e-01],
...,
[ 1.4012e+00, 2.0892e-01, -2.9313e-01, ..., -2.8728e-01,
-7.7548e-01, -3.9739e-01],
[ 5.1188e-01, -8.7199e-01, 8.7336e-01, ..., -4.0334e-01,
-3.4130e+00, -6.1112e-01],
[-6.1907e-01, -8.8935e-01, -1.0951e-01, ..., -7.4700e-01,
-2.2108e+00, -3.7028e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[-2.6786e-02, -2.3484e+00, 1.7383e-01, ..., -2.3387e-01,
-1.8498e-01, 8.0406e-02],
[-6.3816e-01, 4.5202e+00, -4.8769e-01, ..., -6.2377e-01,
-1.1104e+00, 2.6615e-01],
[ 1.6524e-01, 5.3601e+00, -1.5751e+00, ..., 5.1479e-01,
1.4985e-01, 7.5970e-01],
...,
[ 1.1661e+00, 4.1413e+00, 8.5835e-01, ..., -1.3668e+00,
1.1824e+00, -3.7819e-01],
[ 8.8634e-02, 4.8891e+00, -6.3595e-01, ..., -1.6221e-01,
-6.4306e-01, 1.7413e-01],
[-2.2818e-01, 4.9449e+00, -1.4196e-01, ..., 2.3876e-01,
-7.2030e-01, -4.8695e-01]],
[[-8.1155e-01, 2.2757e-01, 4.6853e-01, ..., -5.2139e-01,
1.0647e+00, 1.1201e+00],
[ 7.4454e-01, -6.9863e-02, -1.2254e+00, ..., -2.5475e-01,
1.8946e+00, -6.1198e-01],
[-4.1235e-01, 1.7778e-01, -8.0618e-01, ..., -1.3256e+00,
7.5310e-01, 1.6034e-01],
...,
[-9.3006e-01, 7.3088e-01, 9.9920e-01, ..., -4.3810e-01,
-6.7822e-01, -1.3033e+00],
[-6.1311e-01, -1.2247e-01, -7.6474e-01, ..., 9.3844e-02,
1.2695e+00, -1.8101e+00],
[-3.0981e-01, -3.6292e-01, -7.4754e-01, ..., 4.7145e-01,
1.9980e+00, -1.5107e+00]],
[[-8.6583e-01, 4.7891e-01, 2.1436e-02, ..., 4.8864e-01,
-2.2923e-01, 1.1553e+00],
[ 2.7614e+00, 6.3526e-01, 1.0754e+00, ..., 2.0836e+00,
4.9057e-01, 1.7316e+00],
[ 1.6238e+00, 1.3626e+00, 2.4669e-01, ..., 1.3097e+00,
7.0308e-01, 7.0389e-01],
...,
[ 5.1406e-01, -7.3369e-02, -5.0398e-01, ..., -4.1342e-01,
9.2301e-01, 2.7006e-01],
[ 8.8824e-01, -1.9462e-01, -2.1208e-01, ..., 4.7626e-01,
1.5474e+00, -2.8075e-01],
[ 1.4880e+00, 9.3271e-01, 1.3420e+00, ..., 1.0434e+00,
4.5726e-01, 5.6136e-01]],
...,
[[-3.0037e-01, -1.3118e-01, 1.4130e-01, ..., 1.9293e-01,
1.7425e+00, -2.8649e+00],
[-1.0344e+00, 2.0608e-01, -3.1066e-01, ..., -6.3949e-01,
-4.2152e+00, 4.7219e+00],
[-1.0119e+00, -3.4674e-01, 6.4960e-01, ..., -1.1140e+00,
-5.8159e+00, 5.5243e+00],
...,
[ 9.9599e-01, 7.7899e-01, 6.0444e-01, ..., -2.3092e-01,
-5.5457e+00, 5.7162e+00],
[-1.0503e-01, -3.1278e-02, 2.6437e-01, ..., -7.0391e-01,
-5.0515e+00, 4.5615e+00],
[-2.0179e+00, -6.2885e-01, 4.4828e-01, ..., -8.9845e-01,
-5.6423e+00, 5.0915e+00]],
[[ 1.8530e-01, 3.6999e-01, 2.1407e-01, ..., -2.1567e-01,
3.0667e-02, -1.3970e-01],
[-4.0770e-01, -1.6711e+00, -1.0513e+00, ..., 1.9773e-01,
1.9950e-01, 1.3481e+00],
[-1.2989e+00, -2.9822e-01, -7.9525e-01, ..., -2.1794e-02,
-3.6022e-02, 1.3295e+00],
...,
[ 3.4041e-02, -1.3368e+00, 3.7133e-01, ..., -1.5534e-01,
4.7558e-01, 1.7969e+00],
[-2.4719e-01, -2.1765e+00, 9.6421e-01, ..., -1.7622e-02,
8.2257e-01, 2.6881e+00],
[-1.2455e+00, -1.3100e+00, 1.0811e+00, ..., -5.0272e-01,
5.9302e-01, 2.6775e+00]],
[[ 3.7035e-01, 1.1777e-01, 6.1432e-01, ..., 5.2012e-01,
5.8044e-01, -3.3140e-01],
[ 9.5881e-02, -1.2821e+00, -6.4381e-01, ..., -2.1402e+00,
-2.8265e+00, 7.5561e-01],
[ 5.9308e-01, 1.6217e+00, -1.8073e+00, ..., -1.3962e+00,
-5.3295e+00, 3.7561e-01],
...,
[ 3.4402e-01, 7.1351e-04, -1.0847e+00, ..., -1.5327e+00,
-4.7644e+00, 1.2317e+00],
[ 2.0750e-01, 4.2706e-01, -8.2447e-01, ..., -2.4184e+00,
-4.3488e+00, 5.6796e-01],
[ 2.1829e-01, 1.2217e+00, -7.6039e-01, ..., -1.6768e+00,
-5.0688e+00, 1.1125e+00]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[ 6.5136e-02, -1.0029e-02, -2.2108e-02, ..., 1.1689e-01,
-7.1039e-02, -3.8524e-02],
[ 6.3169e-01, 5.2662e-01, -3.0022e-01, ..., -2.2283e-01,
2.2716e-01, 8.7254e-01],
[-1.1552e+00, -9.5870e-01, 1.1238e+00, ..., 8.6759e-01,
1.5205e+00, 3.0764e-01],
...,
[ 8.4761e-01, -1.3798e+00, -1.2645e+00, ..., 3.3378e-01,
2.3198e-01, 3.6541e-02],
[ 3.6719e-03, 5.9048e-01, 2.4346e+00, ..., 5.0436e-01,
5.6727e-02, -1.4954e-01],
[-1.0332e-01, -1.7860e-01, 5.9478e-01, ..., 1.8129e-01,
1.1238e-01, 5.6843e-01]],
[[ 1.1602e-02, 2.9395e-02, 5.1380e-02, ..., -4.0433e-04,
5.4506e-03, -5.0795e-03],
[-8.1963e-01, 1.2248e+00, 2.0325e+00, ..., -4.0260e-01,
1.3450e+00, -6.8389e-02],
[ 4.5896e-01, -1.3906e+00, 1.5941e+00, ..., -1.3976e+00,
3.3245e-01, 5.1057e-01],
...,
[ 3.7079e-01, -1.0431e+00, 7.7085e-01, ..., -7.9126e-01,
1.4188e+00, -2.7265e-01],
[-1.4974e+00, -1.0762e+00, -1.5428e-01, ..., 5.5193e-01,
6.3376e-01, -5.6096e-01],
[ 1.4492e+00, -3.8667e-01, -3.8344e-01, ..., 2.6724e-01,
5.4262e-01, -4.2422e-01]],
[[ 5.2304e-02, -3.7324e-02, 6.3114e-02, ..., 5.8532e-02,
-6.3338e-02, -5.8268e-02],
[-7.0165e-01, 5.3844e-01, -5.8273e-01, ..., -5.1593e-01,
-6.2940e-01, -1.8096e-01],
[-2.7908e+00, 1.0558e-01, -5.3334e-01, ..., -1.5726e-01,
2.0836e-01, 6.0709e-01],
...,
[-5.1843e-02, 5.4624e-01, 2.0531e-01, ..., 1.7076e+00,
5.7584e-01, 6.2327e-01],
[-3.6931e-02, -7.4176e-01, 3.6640e-01, ..., 7.2651e-01,
7.1034e-01, -5.0419e-01],
[-1.2319e-01, 1.2123e-01, -4.3683e-01, ..., -2.6635e-01,
1.0085e-01, -5.3599e-01]],
...,
[[-9.0498e-02, -4.4621e-02, 4.2299e-02, ..., -9.2541e-02,
3.0748e-02, 1.1883e-02],
[-6.4144e-01, 1.2397e+00, -1.8991e-01, ..., 6.5226e-01,
1.0988e+00, 8.0597e-01],
[ 1.4071e-02, 1.8777e+00, -5.3901e-01, ..., 6.4912e-01,
5.2703e-01, -1.2243e+00],
...,
[ 2.1232e-01, 1.3418e+00, -1.7877e-01, ..., -1.4944e+00,
-6.7317e-02, 8.2103e-01],
[-9.6293e-01, -1.5243e+00, 1.1825e-01, ..., 8.9055e-01,
-2.7738e-01, 1.1285e-01],
[-6.8869e-01, -6.7001e-01, -1.1208e+00, ..., 8.9605e-01,
1.1449e+00, 1.4222e-01]],
[[ 1.4738e-01, -6.2497e-02, 1.3729e-01, ..., 6.0634e-02,
3.2148e-02, -1.2945e-01],
[ 7.4502e-01, 3.8916e-01, 8.7663e-02, ..., 9.7815e-01,
-1.7061e+00, 6.5918e-01],
[ 4.5599e-01, -8.9115e-01, -1.1831e+00, ..., 6.0562e-01,
-1.3010e+00, 5.5591e-01],
...,
[ 5.1827e-02, 9.4449e-01, -1.0665e+00, ..., -5.5041e-01,
1.6840e+00, 3.5078e-01],
[-2.1572e-01, 1.8307e+00, -1.0656e+00, ..., -2.8209e-01,
1.1047e+00, -3.8316e-01],
[ 3.0666e-01, 1.2266e-01, -4.6971e-01, ..., 7.9809e-01,
7.9670e-01, 9.7310e-01]],
[[ 2.0906e-01, -4.6573e-02, -6.0253e-02, ..., 3.5716e-02,
4.1975e-02, 2.2183e-02],
[ 1.1508e+00, -4.2915e-02, 1.7938e+00, ..., -2.8495e+00,
-1.4776e+00, 1.4759e-01],
[ 1.6741e-01, -6.4878e-01, 1.5115e+00, ..., -2.5187e+00,
-2.1148e+00, 4.7344e-01],
...,
[-6.9654e-01, -5.8817e-01, 1.1112e+00, ..., -5.3462e-01,
2.5145e-01, 1.2652e+00],
[ 4.9990e-01, -6.9669e-01, -7.5295e-01, ..., -7.6324e-01,
1.9860e-01, -1.0063e+00],
[ 8.3529e-02, -2.6978e-01, 5.3723e-01, ..., -6.3546e-01,
-1.3855e+00, 1.8679e+00]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[ 2.5910e-02, -2.6886e-01, -4.6291e-01, ..., 2.9839e-01,
3.2603e-01, 3.5748e-01],
[ 3.1461e-01, 1.1503e+00, -1.2219e+00, ..., -5.6552e-01,
5.9567e-02, 8.4989e-02],
[ 6.7203e-01, 2.6852e+00, -1.6149e+00, ..., -4.6337e-01,
-1.2918e-01, -9.1665e-01],
...,
[-2.3362e-02, -1.1149e+00, -1.5101e-01, ..., -8.8944e-01,
-3.1759e-01, -1.8159e+00],
[ 1.3122e+00, -1.0426e+00, -1.0986e+00, ..., 3.0899e-01,
-8.2710e-01, 4.9917e-01],
[ 9.7887e-01, -1.6260e-01, -3.5367e-01, ..., -3.7958e-01,
-5.1196e-01, 9.1400e-02]],
[[-2.8813e-01, 1.6128e-01, 1.1108e-01, ..., 5.3097e-02,
-1.1501e+00, -1.5568e-01],
[ 2.0048e-01, -6.1896e-01, -9.0740e-01, ..., 1.5258e+00,
5.4791e-02, 1.0740e+00],
[ 1.5828e+00, 2.4231e+00, 3.0855e-01, ..., 7.5064e-01,
1.4106e+00, -2.2212e-01],
...,
[-1.0661e+00, 1.2544e+00, 2.0837e-03, ..., -8.5617e-01,
-8.1316e-01, 4.9132e-01],
[-2.6743e-02, 5.4352e-01, -2.2571e+00, ..., -1.3261e+00,
3.2357e-01, 7.3165e-01],
[-7.4855e-01, 1.7948e+00, -1.0619e+00, ..., -1.2624e+00,
2.3455e-01, 1.4526e-01]],
[[-1.2375e+00, -1.2080e-01, 5.5275e-01, ..., -6.6115e-01,
4.6177e-01, -2.7181e-01],
[ 2.1969e+00, -1.5818e-01, -5.5458e-01, ..., 1.1694e-01,
2.1649e-02, 5.6354e-01],
[ 2.0254e+00, 6.7081e-01, -4.4748e-01, ..., 1.4967e+00,
-1.2851e+00, 6.7248e-01],
...,
[ 1.5326e+00, -3.3540e-01, -1.4324e+00, ..., 1.2954e+00,
-1.1284e+00, 4.7197e-01],
[ 8.7941e-01, 2.4923e+00, -5.3935e-01, ..., 8.4928e-01,
1.0220e+00, 3.7455e-01],
[ 5.1570e-01, 1.6997e+00, -5.6758e-02, ..., 1.1494e+00,
-6.0428e-01, -1.0172e+00]],
...,
[[ 7.9977e-01, -9.0907e-01, -3.9427e-01, ..., -1.0356e+00,
-4.3145e-01, 4.9180e-01],
[-1.4421e-01, -8.1030e-01, -1.1014e+00, ..., 1.0065e+00,
-9.9178e-01, -2.7286e+00],
[ 2.2822e+00, -1.0592e+00, -3.2463e+00, ..., 3.9767e-01,
-1.2019e-01, -1.5965e+00],
...,
[ 1.5727e+00, -1.2017e+00, -1.8086e+00, ..., 2.2871e+00,
-9.3960e-01, -5.4853e-01],
[ 1.1026e+00, -1.0645e+00, -1.2469e+00, ..., 1.6859e+00,
1.0864e+00, -1.8350e+00],
[ 9.6080e-01, -1.1839e+00, -5.1685e-01, ..., 1.9371e+00,
5.2251e-01, -6.6630e-01]],
[[-8.9759e-01, 2.5559e+00, 3.0324e-01, ..., 3.4015e-01,
1.9589e+00, -5.4084e-01],
[ 8.5055e-01, -2.9616e+00, 1.0523e+00, ..., 3.1944e-01,
-4.3024e+00, -5.3262e-01],
[-2.1991e-01, -2.6238e+00, -5.0961e-01, ..., -1.0572e-02,
-5.2822e+00, 5.7309e-01],
...,
[ 1.7774e+00, -3.0262e+00, -6.2749e-02, ..., -3.9953e-01,
-3.5927e+00, -1.8122e+00],
[ 3.4358e-01, -2.7537e+00, 5.5012e-01, ..., 9.0664e-01,
-5.0459e+00, -5.0702e-01],
[ 8.0552e-01, -3.1555e+00, -2.6789e-01, ..., -6.1371e-01,
-4.8818e+00, -4.0512e-01]],
[[-2.0290e+00, -3.6629e-01, -1.1161e+00, ..., -4.0021e-01,
6.4893e-02, 2.4993e-01],
[ 1.8800e+00, 6.3801e-01, 1.8007e+00, ..., 1.2802e+00,
-7.8783e-01, 9.2138e-01],
[ 1.2802e+00, 4.2201e+00, 1.5276e+00, ..., 2.1099e-01,
3.6180e-03, -4.3861e-01],
...,
[ 2.9910e+00, 1.1778e+00, 5.4143e-01, ..., -6.6982e-01,
6.7589e-01, -8.9655e-02],
[ 2.0554e+00, 1.6939e+00, 1.7272e+00, ..., 1.1095e-01,
-4.6233e-01, 9.0875e-02],
[ 3.3871e+00, 8.2065e-02, 5.6282e-01, ..., -1.0472e-01,
-6.4104e-03, 8.9617e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>), tensor([[[[-4.9252e-02, -8.7662e-02, 1.9245e-02, ..., 1.1583e-01,
-2.2480e-02, 2.4090e-02],
[ 7.3998e-01, 9.2088e-02, -1.2068e+00, ..., -1.1689e+00,
9.8164e-01, -3.2493e-01],
[ 1.8654e+00, 1.8521e+00, -1.1428e+00, ..., -7.8538e-01,
-1.3162e-01, -9.1184e-01],
...,
[ 5.7013e-01, 3.7659e-02, 3.2336e-01, ..., -6.8518e-01,
7.8144e-01, 3.1578e-01],
[ 3.5704e-01, -1.2859e+00, 8.4548e-01, ..., 7.8927e-01,
-4.1751e-01, -4.3254e-01],
[-7.8941e-02, 4.2914e-01, 2.8057e-01, ..., 1.4775e+00,
-4.4810e-01, -3.5821e-01]],
[[ 1.7039e-02, 1.7426e-02, -3.8615e-02, ..., 3.8622e-02,
1.0783e-02, 3.5067e-02],
[-1.8710e+00, 2.5931e-01, 1.9364e+00, ..., 1.2818e+00,
-2.1004e+00, 8.0968e-01],
[-2.0571e+00, 2.1252e+00, 7.7616e-01, ..., -2.1614e-01,
3.0920e-01, -1.3792e-01],
...,
[ 5.2768e-01, 7.6009e-01, 1.3738e-01, ..., 7.8285e-02,
-1.0129e+00, -1.0930e+00],
[-1.6376e+00, 1.4974e-01, -3.4155e-01, ..., -5.2083e-01,
-3.6068e-01, -6.0488e-01],
[ 5.6162e-01, -1.3109e-01, -1.0858e+00, ..., -5.4658e-01,
3.5207e-01, 2.5588e-01]],
[[ 3.8827e-02, 3.3007e-02, -8.3370e-02, ..., 2.5783e-03,
-1.6268e-02, 9.1886e-04],
[ 6.7396e-01, -7.4608e-01, -1.1437e+00, ..., -4.9611e-01,
1.8052e-02, -9.8371e-01],
[ 1.3927e+00, 2.0229e-01, 1.1575e+00, ..., 9.9906e-01,
6.1328e-01, 1.2476e+00],
...,
[ 4.2289e-01, 2.9166e-01, -1.7195e-01, ..., -9.1719e-01,
-2.4385e-01, -4.0700e-01],
[ 7.7449e-01, 1.9776e+00, 7.9667e-01, ..., -9.7234e-02,
-1.0700e+00, -2.4079e-01],
[ 6.5026e-01, 3.4379e-01, 1.9756e+00, ..., -7.0553e-01,
-8.3279e-01, 5.4796e-01]],
...,
[[-2.6492e-02, 2.4856e-02, -1.0858e-02, ..., -1.9138e-02,
-1.6549e-02, 2.3261e-02],
[ 1.8255e+00, 2.1934e+00, 6.5594e-01, ..., 2.6818e+00,
8.3834e-01, -6.3335e-01],
[-4.0868e+00, 2.9853e+00, 1.5028e+00, ..., 2.8890e-01,
9.7505e-01, -3.2399e+00],
...,
[-7.2338e-01, -1.4377e+00, -5.4115e-01, ..., 6.3100e-01,
1.2142e+00, -9.7212e-01],
[ 1.6900e-02, 2.4184e-02, -7.9526e-01, ..., 1.2693e+00,
7.9554e-01, -9.4520e-01],
[ 4.1660e-01, 6.7636e-01, -7.9182e-01, ..., 4.9160e-01,
-5.0815e-02, -6.8480e-02]],
[[-6.4021e-02, -3.2721e-02, 2.4496e-02, ..., 2.8996e-03,
-5.4889e-02, -1.0845e-01],
[-4.6883e-01, 1.5272e-01, 9.2087e-02, ..., 8.4295e-01,
-2.0841e+00, -3.0567e-02],
[ 3.6565e-01, -1.4821e+00, 9.6754e-01, ..., -9.3504e-02,
-8.7716e-01, -4.0122e-01],
...,
[-4.2841e-01, 1.2406e+00, -4.3889e-01, ..., 4.2291e-01,
-2.8603e-01, -8.7372e-01],
[ 1.4675e-01, 1.5823e+00, -6.8927e-01, ..., 9.1472e-02,
-2.9227e-01, 1.0596e-01],
[-9.5282e-01, 1.2839e+00, 2.9624e-01, ..., -3.9504e-01,
1.3045e+00, -3.3763e-01]],
[[-3.9009e-04, 4.5721e-02, -4.4665e-02, ..., -1.7350e-02,
1.0824e-02, 2.1425e-03],
[-1.1021e+00, -1.0736e+00, -1.3544e-01, ..., -6.8339e-01,
2.5608e-01, -1.5562e+00],
[-1.9336e+00, -2.8961e-01, -1.8405e-01, ..., 1.8993e+00,
4.5613e-02, -1.1907e+00],
...,
[-3.1445e-01, -6.2644e-02, 3.0309e-01, ..., 1.1828e+00,
-1.3786e-02, -7.7986e-01],
[-1.2743e+00, 4.1847e-01, 1.0675e-01, ..., 1.1067e+00,
1.1204e+00, -5.8878e-01],
[ 5.5352e-01, -1.4446e+00, 6.2596e-01, ..., -2.2840e-01,
3.1991e-01, -1.0959e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[-0.5113, 0.4863, -0.8478, ..., -1.0067, -1.3222, 0.2112],
[-0.2776, 0.8382, 1.8349, ..., 2.3225, 0.6648, 0.2540],
[ 1.3453, -0.5017, 1.2809, ..., 4.1645, -1.7425, -0.8501],
...,
[ 1.4667, -0.1365, 2.8200, ..., 2.1293, 0.9827, -1.4491],
[ 0.9350, 1.2324, 1.3160, ..., 2.0101, -0.9095, 0.8805],
[ 2.2349, -0.3502, 2.3561, ..., 4.5559, 0.2889, -0.6012]],
[[ 0.8578, -2.0592, 0.1552, ..., 0.2365, -2.4501, -0.4252],
[-0.5257, 2.9826, -0.0066, ..., 1.6393, -0.1439, -0.3203],
[ 0.1126, 6.1312, -1.8119, ..., -1.3806, 1.5269, 0.0770],
...,
[ 1.8046, 3.1371, -0.0435, ..., -1.1984, 4.0211, -0.3047],
[ 1.9017, 3.2856, 0.5921, ..., -0.0992, 1.5096, -0.1924],
[ 1.6355, 2.5339, -1.0278, ..., 0.2948, 2.4720, -0.0789]],
[[ 1.0126, 0.3729, -0.1703, ..., -0.8086, -1.4068, -0.3650],
[-0.7648, -1.0051, 0.2495, ..., 0.6701, -0.5403, -0.2351],
[-1.9147, -0.8003, 0.3430, ..., 0.9768, -0.6854, 0.0126],
...,
[-1.7370, -0.6033, -1.3323, ..., 1.1623, -1.2578, -1.9024],
[-1.4760, -0.7168, -0.2280, ..., 1.5871, -0.8728, -0.5783],
[-0.7148, -0.8849, -1.7579, ..., 1.1032, -1.6794, -1.3470]],
...,
[[ 0.2134, -0.5814, 0.4874, ..., -0.6810, 1.0950, 0.3049],
[-0.6622, 0.7595, -1.3103, ..., -2.7334, -1.3875, 0.5355],
[-0.9839, 0.5835, -2.7718, ..., -0.7272, -0.6065, 2.6364],
...,
[-0.9617, 1.3015, -1.3307, ..., -2.1932, -0.2873, 0.0981],
[-1.4241, 0.0142, -1.4190, ..., -3.0027, -1.1089, -0.0661],
[-1.7032, 0.8917, -1.3572, ..., -3.3537, 1.2818, 0.3163]],
[[ 0.2561, 0.5523, 0.5580, ..., 0.7246, 0.0652, 0.8349],
[-1.2620, 0.3286, 0.9976, ..., 0.6765, 0.7315, 1.6506],
[ 0.9963, -0.4575, -0.9883, ..., 0.1085, -0.5890, 2.3868],
...,
[-0.1060, 1.5261, 0.9926, ..., -2.7033, -0.0864, -0.1516],
[-3.7611, -0.0674, 1.6440, ..., -0.0794, 0.9523, 1.6051],
[ 0.3993, 0.0166, 0.7194, ..., -0.1722, 0.4801, 0.2899]],
[[-0.7296, 0.3057, -1.5837, ..., -0.3948, 0.2267, -1.3335],
[ 1.0017, -0.9907, -1.0231, ..., 0.4469, 0.4642, -0.7319],
[ 1.4362, -0.7307, -0.9360, ..., -2.7721, -2.5525, -2.4089],
...,
[-1.0814, 0.4487, 0.2135, ..., 2.1611, 3.0587, -1.7137],
[-0.4406, -0.4945, -0.1256, ..., 1.1526, 1.1095, 0.3055],
[-0.2776, 1.5215, -0.5332, ..., 1.1829, 2.3322, -1.0183]]]],
device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[-3.1461e-03, 4.0217e-02, -7.3887e-02, ..., 5.2969e-02,
-2.6784e-02, -8.3422e-02],
[ 6.9965e-01, 8.4726e-01, -1.3827e-02, ..., -1.4325e+00,
-7.6671e-01, 1.1938e-01],
[-6.3505e-01, -1.8155e-02, -1.1006e+00, ..., -9.8568e-01,
-1.9473e+00, 4.5972e-01],
...,
[-1.8922e-02, -1.5553e+00, -9.6255e-01, ..., 3.4609e-01,
7.2390e-01, -7.2770e-01],
[-5.7521e-01, -5.4535e-01, -1.5828e+00, ..., 5.6929e-01,
-6.4923e-01, 6.9475e-01],
[ 1.0637e+00, 4.4950e-01, 5.3918e-02, ..., -4.3740e-01,
-3.2741e-02, -2.3343e-01]],
[[ 4.5011e-02, 1.3549e-02, 2.6805e-02, ..., -3.5686e-02,
-3.5387e-02, -4.4578e-03],
[ 8.8991e-01, 5.9314e-01, -7.2167e-02, ..., -9.4840e-01,
1.0844e+00, 1.2220e+00],
[ 1.2285e+00, 1.8533e+00, 3.6534e-01, ..., -8.7942e-01,
3.1773e+00, -1.4437e-01],
...,
[ 8.9366e-01, 5.2548e-03, -4.4113e-01, ..., 2.4077e-01,
-1.0397e+00, 3.4268e-01],
[-8.1721e-02, -2.7034e-01, -7.1395e-01, ..., -1.3899e-01,
1.1629e+00, -3.0129e-01],
[ 2.9050e-01, 1.2897e+00, -4.2347e-01, ..., -3.6758e-01,
-3.3340e-01, 5.5617e-01]],
[[ 3.8629e-03, 2.0467e-02, -4.0017e-02, ..., -1.7782e-03,
3.0812e-02, 5.9422e-02],
[-1.4746e+00, -7.6598e-01, -4.9963e-01, ..., -1.1237e+00,
4.0888e-01, 3.7366e-02],
[ 3.1216e+00, 3.4719e+00, -1.5450e+00, ..., 3.7155e-01,
1.7554e+00, 1.0825e+00],
...,
[-2.2221e+00, 1.7920e+00, 1.1178e+00, ..., -3.7278e-01,
2.8963e-01, 4.8757e-01],
[ 1.8054e-01, -9.5324e-01, -4.6999e-01, ..., -2.6873e-01,
-3.4931e-01, -2.0291e-01],
[-4.6330e-01, 9.4874e-01, 1.1230e+00, ..., 2.4846e-01,
-5.4433e-01, 1.9094e-01]],
...,
[[-8.9160e-02, 4.4159e-02, -7.4543e-03, ..., -4.1528e-02,
-1.3703e-02, 1.0209e-02],
[ 4.5365e-01, -9.8718e-02, 3.7648e-01, ..., 5.5908e-01,
-1.6181e+00, 2.0091e-01],
[-1.0051e+00, -9.2160e-01, -1.8300e-01, ..., 1.5032e-01,
-3.5601e-01, 1.1393e+00],
...,
[ 3.8664e-01, 8.8380e-01, 6.9114e-01, ..., -7.1387e-01,
-1.8584e+00, -1.8991e+00],
[-2.6471e-01, -1.8073e-01, 1.5080e+00, ..., 1.0224e+00,
-5.7313e-01, -5.2180e-01],
[-1.1091e+00, -1.0401e+00, -1.0134e+00, ..., 6.1705e-01,
-8.1685e-01, -7.5424e-02]],
[[ 7.4462e-02, 1.1415e-02, 6.6753e-02, ..., 2.6585e-02,
5.9258e-02, 2.8430e-02],
[ 8.3152e-01, -2.2886e+00, 2.2605e-01, ..., 3.4404e+00,
-1.3195e+00, -2.8798e+00],
[ 1.4743e-02, 2.1297e-01, -2.4953e+00, ..., -1.3823e+00,
1.4250e+00, -9.5225e-01],
...,
[ 9.1060e-01, 1.6273e+00, 1.8072e+00, ..., -1.6699e+00,
2.7229e+00, -9.9341e-01],
[ 2.0697e+00, -1.0451e+00, -1.6918e+00, ..., -7.6083e-01,
-6.6113e-01, -1.1681e-01],
[ 7.3505e-01, 9.9351e-01, 8.8695e-01, ..., -2.8755e-01,
1.9784e+00, -2.1858e+00]],
[[-1.0551e-01, 3.7056e-02, -5.2910e-02, ..., -8.2814e-02,
8.1320e-02, -1.1580e-02],
[ 7.9286e-01, 1.0883e+00, -2.8319e-01, ..., -1.6656e-01,
8.1304e-01, 2.0001e+00],
[ 7.3818e-01, 2.7335e+00, 2.0108e+00, ..., -8.0340e-01,
-6.2005e-01, 3.4431e-01],
...,
[-2.1225e-01, 1.5986e+00, -4.7704e-01, ..., 1.7525e+00,
2.5042e-01, -7.3810e-01],
[-1.0866e+00, 2.4015e-01, 1.7822e-01, ..., -5.7931e-01,
-3.7555e-01, -4.3796e-01],
[-6.3663e-01, -9.6326e-01, 6.6193e-01, ..., -4.6179e-01,
-1.1061e+00, -3.5319e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>)), (tensor([[[[-1.7171, -0.3168, -0.2978, ..., 0.1637, 0.3394, -0.5041],
[-0.5649, -0.0569, -0.1757, ..., 0.7798, -0.9593, 0.1274],
[ 0.6217, 0.6556, -0.2829, ..., 1.7264, -0.9492, -0.3003],
...,
[ 0.5127, -0.0550, 0.4555, ..., -0.1798, 0.4189, -0.3237],
[ 1.4306, -0.2665, -0.6924, ..., 0.4072, 0.0356, 0.2801],
[ 0.3647, 0.3681, 0.2000, ..., -0.0545, -0.4125, -0.4987]],
[[ 0.1125, -0.0720, 2.2769, ..., 0.2411, 0.0791, -0.1868],
[ 0.3754, 0.3120, -1.5438, ..., 0.7578, 0.4480, -0.2125],
[ 0.6421, 0.4066, -0.9517, ..., 1.9527, -0.8776, -0.8329],
...,
[ 1.4697, 0.6897, -2.0477, ..., -0.3531, -0.8718, 0.1144],
[ 0.4871, -0.1398, -1.8025, ..., 0.0553, 0.1603, -0.1879],
[ 1.1846, 0.8657, -1.2089, ..., -0.3741, -0.3004, 0.4547]],
[[-0.1971, 1.0731, 0.4710, ..., -0.5335, 0.3173, -0.1099],
[-0.6160, -0.6787, -0.1886, ..., 1.2765, -0.2427, 0.0198],
[-0.1753, -1.3161, 0.6825, ..., 2.5880, -0.0398, -1.8585],
...,
[-0.5711, 0.6618, -0.4905, ..., 1.6065, -0.7117, 0.6734],
[-0.4640, 0.6161, 0.2619, ..., 2.0119, 1.0717, -0.6101],
[-0.5517, 0.4712, -1.2532, ..., 1.3872, 0.1236, -0.9117]],
...,
[[ 0.5514, 0.9770, -0.8534, ..., -0.7274, 0.6981, 0.8435],
[-0.4633, -0.9677, 1.0696, ..., -1.3604, -0.4768, 0.4836],
[ 0.3272, 1.9423, -0.0427, ..., -1.0563, -1.5149, 1.0777],
...,
[-0.5303, -0.3247, 0.4965, ..., -1.4606, 0.8117, -0.3339],
[-0.8559, 0.3017, 0.1236, ..., -1.0954, 1.3234, -0.2345],
[-0.2782, 0.0201, -0.3758, ..., -1.2297, 0.2962, -0.0464]],
[[-0.4173, 0.3735, 0.3492, ..., 0.7022, 0.0200, -0.0874],
[-0.6046, 1.8319, -0.4594, ..., -0.1303, -0.2883, -2.4248],
[ 0.8539, -0.2973, 1.2250, ..., 0.9134, -1.1105, 1.3329],
...,
[-0.9636, 0.0477, -0.1330, ..., 0.8147, -0.9786, 1.4489],
[-0.2483, -0.0954, -1.6667, ..., 1.4701, 0.9545, 0.7833],
[-0.7331, 0.7880, -0.9207, ..., 0.3971, -0.3445, 1.1242]],
[[-0.7286, -0.0169, 0.4363, ..., -0.1016, 0.0149, -0.0735],
[-0.3077, -0.2044, 0.9706, ..., 0.5311, -0.4540, 1.4855],
[ 0.7991, -1.3786, -2.1901, ..., -0.3034, -2.3266, -0.9113],
...,
[ 1.4139, -1.7224, 0.0692, ..., 0.2342, -2.1317, 0.4961],
[-1.1990, -0.0822, 1.2112, ..., -0.0050, -2.2928, 0.5032],
[-0.4787, 0.2465, -0.0146, ..., -0.7903, -1.8445, 0.4185]]]],
device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[ 8.6463e-02, -1.3589e-01, -2.0528e-01, ..., -2.7072e-01,
2.5288e-01, -1.4601e-01],
[-1.3737e+00, 1.6262e+00, -3.9644e-01, ..., 2.1092e+00,
-1.2491e+00, 2.2709e+00],
[ 9.3736e-02, 8.0827e-01, 4.1821e-01, ..., 2.4003e+00,
-6.1362e-01, -8.2722e-01],
...,
[ 4.9494e-01, -1.4327e-01, -9.4550e-01, ..., 2.5292e+00,
-3.0418e+00, 1.7612e+00],
[-9.5852e-01, -4.7755e-01, -8.8891e-02, ..., 4.1432e+00,
-4.8426e-01, 9.5120e-01],
[ 1.7663e+00, -5.3790e-01, 2.6296e-01, ..., 1.9204e+00,
-6.1751e-01, 2.3158e+00]],
[[ 8.4597e-02, -3.0458e-02, 4.1542e-02, ..., -3.1874e-02,
-9.8648e-02, 1.6476e-01],
[-1.0960e+00, 1.6559e+00, -2.8393e-01, ..., 1.0959e+00,
7.6596e-01, -1.3732e+00],
[-4.0602e-02, 4.3989e-01, 1.2150e-01, ..., 1.6780e+00,
-9.5549e-01, -8.1111e-01],
...,
[ 1.4156e+00, -3.6635e-01, 3.3233e-01, ..., 1.0604e+00,
-1.4212e-01, -6.3028e-01],
[ 1.3465e+00, 1.7579e+00, 5.1865e-01, ..., 4.6014e-01,
8.0379e-01, 1.0797e+00],
[ 5.4950e-01, 1.1675e+00, 9.0108e-01, ..., -4.9434e-01,
5.4417e-01, 5.9731e-01]],
[[-1.1752e-03, 2.7713e-02, -4.8585e-02, ..., 4.0027e-02,
2.4395e-02, 5.1549e-02],
[ 6.0015e-01, -6.0283e-01, -3.2347e+00, ..., -7.8959e-01,
8.7755e-01, -2.7308e-01],
[ 1.5926e+00, -3.3798e-02, -1.3016e+00, ..., 1.7952e+00,
-9.1188e-01, 5.9074e-01],
...,
[-3.3754e-01, 4.6598e-01, -1.0240e+00, ..., -1.0126e+00,
-1.6857e+00, -1.9817e+00],
[-7.5659e-01, 1.8637e-01, -1.3627e+00, ..., 5.3148e-01,
-1.6594e+00, -1.3408e+00],
[ 4.6923e-01, -3.9860e-01, -1.3944e+00, ..., 1.0094e+00,
1.0449e-01, -8.4683e-01]],
...,
[[ 5.7201e-04, -5.6053e-03, 1.0676e-01, ..., 7.0169e-02,
-1.8888e-02, 4.3763e-02],
[-7.7385e-01, 2.5806e+00, -7.1687e-01, ..., 1.3104e+00,
1.0888e+00, 2.5205e-01],
[ 2.4686e+00, 4.1175e-01, 2.8059e-01, ..., -1.0749e+00,
3.4196e+00, -2.0002e+00],
...,
[-3.1486e-01, -1.0806e+00, 4.2761e-01, ..., -8.3092e-02,
-1.4934e+00, -2.3053e-01],
[-1.4704e-01, 2.0300e+00, 4.0726e-01, ..., -7.3362e-01,
5.5142e-01, -1.8850e+00],
[-2.5338e-01, 1.3151e+00, -4.0857e-01, ..., -4.7356e-01,
-1.0699e+00, -1.0685e+00]],
[[-2.0099e-01, -6.8959e-02, 7.5206e-02, ..., -5.4001e-02,
4.2309e-02, -9.4743e-02],
[ 1.9182e+00, 7.7161e-03, -1.2696e+00, ..., -1.3677e+00,
3.1515e+00, -3.5326e-01],
[ 4.6426e-02, -1.0435e+00, -3.7551e-01, ..., -2.0669e+00,
8.7258e-01, -2.2646e-01],
...,
[-1.0673e+00, -1.1056e+00, -2.9885e-01, ..., -3.9416e-01,
-2.2161e-01, -4.9800e-01],
[ 5.8389e-01, -2.5311e-02, -1.2006e+00, ..., 1.2141e+00,
-1.7424e+00, -9.9682e-02],
[-9.4210e-01, 2.2320e-01, -4.9680e-01, ..., -3.3195e-02,
-2.9205e-01, -5.4909e-02]],
[[ 1.0339e-01, -1.4692e-01, 1.6582e-01, ..., -1.4900e-01,
-1.5749e-02, -1.8594e-01],
[-1.5155e+00, 5.3209e-01, 6.6421e-01, ..., -5.4169e-01,
2.2736e+00, -2.5641e-01],
[-9.2274e-02, -5.9834e-01, -1.2739e+00, ..., -1.5369e+00,
-5.6380e-01, -1.3742e+00],
...,
[-9.9035e-01, -8.1416e-01, -1.5702e-01, ..., 2.5352e-01,
-2.9972e-01, -1.1698e-01],
[-1.3327e+00, -1.1053e+00, -6.8016e-01, ..., -8.5201e-01,
1.1720e+00, -6.5619e-01],
[ 2.2502e-02, -6.0447e-02, 4.3160e-02, ..., 2.4946e-01,
1.6794e-01, -6.6439e-01]]]], device='cuda:0',
grad_fn=<PermuteBackward0>))), hidden_states=None, attentions=None, cross_attentions=None)
Lets see why token based indexing is necessary.
In this example, we call invokes on two inputs of different tokenized length. We incorrectly index into the hidden states using normal python indexing.
[3]:
from rich import print
with model.trace() as tracer:
with tracer.invoke('The') as invoker:
incorrect_a = model.transformer.input[0][0][:,0].save()
with tracer.invoke('The Eiffel Tower is in the city of''The Eiffel Tower is in the city of') as invoker:
incorrect_b = model.transformer.input[0][0][:,0].save()
print(f"Shorter input: {incorrect_a.value}")
print(f"Longer input: {incorrect_b.value}")
Shorter input: tensor([50256], device='cuda:0')
Longer input: tensor([464], device='cuda:0')
Notice how we indexed into the first token for both strings but recieved a different result from each invoke. This is because if there are multiple invocations, padding is performed on the left side so these helper functions index from the back.
Let’s correctly index into the hidden states using token based indexing.
[4]:
with model.trace() as tracer:
with tracer.invoke('The') as invoker:
correct_a = model.transformer.input[0][0].t[0].save()
with tracer.invoke('The Eiffel Tower is in the city of') as invoker:
correct_b = model.transformer.input[0][0].t[0].save()
print(f"Shorter input: {correct_a.value}")
print(f"Longer input: {correct_b.value}")
Shorter input: tensor([464], device='cuda:0')
Longer input: tensor([464], device='cuda:0')
Now we have the correct tokens!