Token Based Indexing#

When indexing hidden states for specific tokens, use .token[<idx>] or .t[<idx>].

As a preliminary example, lets just get a hidden state from the model using .t[<idx>].

[1]:
from nnsight import LanguageModel

model = LanguageModel('openai-community/gpt2', device_map='cuda')
[2]:
with model.trace('The Eiffel Tower is in the city of') as tracer:

    hidden_states = model.transformer.h[-1].output[0].t[0].save()
    output = model.output.save()

print(hidden_states.shape)
print(output.shape)
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
torch.Size([1, 768])
CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ -36.2874,  -35.0114,  -38.0793,  ...,  -40.5163,  -41.3759,
           -34.9193],
         [ -68.8886,  -70.1562,  -71.8408,  ...,  -80.4195,  -78.2552,
           -71.1206],
         [ -82.2950,  -81.6519,  -83.9941,  ...,  -94.4878,  -94.5194,
           -85.6998],
         ...,
         [-113.8675, -111.8628, -113.6634,  ..., -116.7652, -114.8267,
          -112.3621],
         [ -81.8531,  -83.3006,  -91.8192,  ...,  -92.9943,  -89.8382,
           -85.6897],
         [-103.9307, -102.5054, -105.1563,  ..., -109.3099, -110.4195,
          -103.1395]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-0.9420,  1.9023,  0.8722,  ..., -1.2703, -0.4792,  1.2469],
          [-1.9590,  2.7141,  2.8423,  ..., -1.1633, -1.6173,  2.1507],
          [-2.6123,  2.0937,  0.9679,  ..., -0.9763, -1.2243,  2.0279],
          ...,
          [-2.4282,  2.4462,  2.1550,  ..., -0.5916, -1.6641,  2.1119],
          [-3.5624,  3.6804,  2.5053,  ..., -0.3572, -2.5960,  0.9592],
          [-2.6021,  2.8035,  1.7291,  ..., -0.8557, -2.1589,  2.6881]],

         [[ 0.1103,  0.6967, -1.1409,  ..., -0.1243,  1.8249, -0.1592],
          [ 0.3364, -2.3421, -3.0033,  ..., -0.9075,  3.9665,  0.2082],
          [-1.2822, -2.8345,  0.1537,  ...,  0.6516,  2.4424,  0.7518],
          ...,
          [-0.1554, -1.0321, -2.5109,  ..., -0.9747,  4.8222, -1.8171],
          [-1.3993, -2.2428, -0.0644,  ..., -0.9444,  3.5096,  0.4326],
          [ 0.5759, -0.8102, -1.8774,  ..., -1.4308,  3.0181, -2.2393]],

         [[-0.0985, -0.0323,  0.7536,  ..., -1.1902, -1.6401,  0.6545],
          [ 1.1513, -0.7019,  0.2992,  ..., -1.8075, -0.1072,  2.0486],
          [-0.1089, -1.0244,  0.4639,  ..., -1.8416, -0.2348,  1.0322],
          ...,
          [ 0.3554,  0.3485,  0.0083,  ..., -3.3077,  0.8817,  1.4423],
          [ 0.3027,  0.2488, -0.2483,  ..., -2.8617,  0.7589,  0.7380],
          [ 0.5146, -0.1207,  0.6076,  ..., -2.7679,  1.1288,  1.7932]],

         ...,

         [[ 0.6009, -0.0877, -0.2693,  ...,  0.1756,  0.7995,  0.5978],
          [ 0.0456,  0.2891, -0.1535,  ...,  1.0184,  1.0627,  0.3627],
          [-0.0681, -0.5138,  0.5735,  ...,  0.7821,  0.8516,  0.4657],
          ...,
          [-0.0383, -0.2532,  0.0525,  ...,  1.2245,  0.5464,  0.4056],
          [ 0.2111,  0.9947,  0.0403,  ...,  1.1817, -0.7079,  0.5290],
          [ 0.1136, -0.0611,  0.1199,  ...,  1.2025,  0.4589,  0.6644]],

         [[ 1.4709,  1.5225, -0.4336,  ..., -0.1837,  1.0947, -1.6615],
          [ 0.7999,  0.0324, -1.5696,  ..., -0.7550,  1.4671, -1.1099],
          [ 0.6255,  0.4108,  0.0984,  ..., -1.2564,  1.9016, -0.0603],
          ...,
          [ 1.1553,  0.5795, -0.6220,  ..., -0.7993,  0.4428, -0.6729],
          [ 1.5863, -0.0730,  0.1822,  ..., -0.5310,  0.4560, -0.4558],
          [ 0.9170, -0.7168, -0.4214,  ..., -0.8926,  0.4736, -0.0411]],

         [[ 0.6260,  0.2122,  0.2527,  ..., -0.6377,  0.2275,  1.5142],
          [-0.3332,  1.5151, -0.3315,  ...,  1.2160,  0.2653,  2.6735],
          [ 0.1930,  0.0467, -0.3682,  ..., -0.1827,  0.1576,  0.5612],
          ...,
          [-0.1787,  1.2580, -0.2565,  ..., -0.6601,  1.2289,  0.2853],
          [ 0.9067,  0.6444,  0.2020,  ...,  0.1291,  0.2002,  0.8276],
          [-0.5779,  0.4654, -0.8867,  ...,  1.4954,  1.3435, -0.6073]]]],
       device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[-1.3066e-02, -1.4464e-02,  1.2694e-01,  ..., -4.9182e-02,
            1.0464e-01,  2.3067e-02],
          [ 2.0469e-01, -1.3684e-02,  8.0588e-02,  ..., -1.9410e-02,
            5.5186e-02,  7.3562e-02],
          [ 7.2231e-03,  1.8508e-01,  1.0139e-01,  ..., -7.6448e-02,
            2.9932e-01, -8.5228e-03],
          ...,
          [-1.6686e-01,  1.9638e-02,  1.2153e-01,  ...,  6.1965e-02,
            9.3590e-02, -1.0460e-01],
          [ 1.5657e-01,  1.5053e-01,  5.7654e-02,  ..., -4.2498e-01,
           -5.2136e-02,  3.0045e-02],
          [-1.0558e-02, -8.6992e-02, -7.6297e-02,  ..., -6.3531e-02,
           -5.0926e-02,  1.9987e-01]],

         [[ 5.9014e-01,  1.0051e-01, -2.0716e-01,  ..., -6.9383e-01,
           -2.7763e-01,  2.0517e-01],
          [ 6.3339e-01,  1.1631e-01,  2.4300e-01,  ...,  1.9035e-01,
            8.8391e-02, -5.1286e-02],
          [ 2.7510e-01, -7.9842e-02,  2.0712e-01,  ...,  2.0180e-01,
            1.4190e-01, -1.3274e-01],
          ...,
          [ 8.3555e-01, -9.4205e-02,  7.4023e-02,  ..., -1.7617e-01,
            1.3164e-01,  1.1117e-01],
          [ 3.2692e-01,  4.5032e-02,  2.5904e-01,  ...,  7.9349e-02,
            2.0154e-01, -5.9558e-03],
          [ 6.5303e-01, -8.9489e-02, -4.5211e-01,  ...,  7.0391e-04,
            4.9327e-01,  1.5887e-01]],

         [[-2.8404e-02, -1.1449e-01, -2.1676e-02,  ...,  3.9217e-03,
            7.8844e-02, -3.9936e-03],
          [-4.9779e-02,  1.8518e-01, -1.9874e-01,  ..., -4.4753e-02,
           -9.1100e-02, -2.1138e-02],
          [ 4.4283e-02,  5.9255e-02,  5.3522e-02,  ..., -1.4617e-02,
           -3.3558e-01,  1.7041e-01],
          ...,
          [-5.3011e-01, -6.1409e-04, -6.0240e-01,  ..., -1.9026e-01,
           -7.1861e-02,  3.2019e-01],
          [ 4.0913e-01, -1.1379e-01, -1.6436e-01,  ...,  8.2217e-02,
           -1.0437e-01, -7.6691e-02],
          [ 3.1223e-01,  3.6828e-01,  6.0183e-01,  ..., -2.8972e-02,
            1.1367e-01, -2.5661e-01]],

         ...,

         [[-1.0771e-01, -2.1316e-01, -2.1841e-02,  ..., -2.3210e-01,
            2.1270e-02, -6.6547e-02],
          [-3.1855e-01,  4.4327e-01, -1.6764e-01,  ...,  5.3822e-02,
           -7.5202e-02,  1.9941e-01],
          [-1.4129e-01, -6.2872e-02,  2.3989e-01,  ...,  2.2710e-01,
            1.2402e-01,  4.0053e-01],
          ...,
          [-1.0040e-01, -4.9095e-01,  2.2476e-02,  ...,  5.5608e-02,
           -1.4735e-01,  2.3780e-01],
          [-3.0879e-01,  7.1592e-01,  1.2739e-01,  ...,  2.9476e-02,
           -1.5573e-01, -1.7634e-02],
          [-1.6235e-01, -2.5231e-01, -6.0719e-02,  ..., -3.7746e-01,
           -6.9728e-03, -2.2533e-01]],

         [[ 9.3223e-02, -1.0404e-01, -2.1104e-01,  ...,  1.8502e-01,
            2.2378e-01, -3.1989e-02],
          [-4.5714e-01,  6.4180e-02, -1.5538e-01,  ..., -2.6814e-01,
            2.0829e-01,  8.7156e-03],
          [ 2.4634e-03,  1.8372e-01,  7.3725e-03,  ..., -4.8131e-01,
            1.2558e-01,  6.1276e-02],
          ...,
          [-2.2317e-01,  1.2418e-01, -3.6774e-02,  ...,  2.8985e-01,
            5.9641e-02, -8.6951e-03],
          [-1.8944e-01, -1.7414e-02, -2.9084e-02,  ..., -4.5319e-02,
           -5.7796e-02,  4.7680e-01],
          [ 9.2804e-02,  9.9442e-02, -6.0471e-02,  ..., -7.9065e-02,
           -1.6836e-01,  7.1764e-02]],

         [[-2.4770e-02, -3.7828e-01,  1.1838e-01,  ...,  1.1582e-02,
           -2.4843e-01, -1.1559e-01],
          [ 5.8631e-02,  1.6256e-01,  1.3249e-01,  ...,  2.6460e-01,
            9.5267e-02,  1.0518e-01],
          [ 5.0756e-02, -1.4601e-01, -2.3191e-01,  ..., -2.2047e-01,
            3.0730e-01,  2.6307e-01],
          ...,
          [ 2.8193e-02,  1.9202e-01, -8.7550e-02,  ...,  6.9838e-02,
           -4.0262e-02, -5.9197e-03],
          [ 2.6708e-01,  1.3450e-01, -8.2224e-02,  ..., -6.0210e-04,
           -1.4364e-01,  1.5347e-01],
          [ 1.5456e-01, -1.1916e-01,  2.8118e-01,  ...,  1.1415e-01,
            2.5977e-01,  1.8767e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[-0.3465,  1.8232, -1.4522,  ...,  1.4427, -0.8329,  1.0962],
          [ 0.5499,  2.4760, -0.3377,  ..., -0.7511, -2.3585, -0.8001],
          [ 1.4355,  3.1091, -0.4843,  ..., -0.2537, -1.2017,  0.3102],
          ...,
          [-0.2717,  1.2955, -0.1360,  ...,  0.4590, -0.8015, -0.3853],
          [-0.8938,  1.0282,  1.2864,  ...,  0.0868, -1.8080, -0.7916],
          [-0.3702,  0.8338,  0.2453,  ...,  0.4522, -1.6840, -0.4447]],

         [[-0.8245, -0.3510, -0.5746,  ..., -0.2983,  0.9754, -0.5511],
          [-0.5012, -0.1667, -1.3464,  ..., -0.2514,  0.3567, -0.8778],
          [-0.6851,  0.3835, -2.6297,  ..., -0.2897,  1.0644,  0.0080],
          ...,
          [-0.1188,  0.6849, -1.4834,  ..., -0.3863,  0.2342, -0.6524],
          [-0.1495, -0.0715, -2.1277,  ..., -0.1427, -1.4979, -0.9024],
          [ 0.2886,  1.0222, -1.8740,  ..., -0.4773, -0.1798, -0.9583]],

         [[ 0.3444,  0.0273,  0.0736,  ..., -1.2545,  0.2919, -0.1958],
          [ 0.3155,  0.4886, -0.0817,  ..., -0.6271, -0.1437,  0.3061],
          [ 0.0337,  0.2995, -0.4871,  ..., -0.6719,  0.3822,  0.1488],
          ...,
          [-0.1552,  0.0079, -0.0330,  ..., -0.7431,  0.4124,  0.2261],
          [ 0.1261, -0.2049, -0.1360,  ..., -1.0366, -0.2439,  0.4584],
          [-0.1678, -0.0727,  0.0116,  ..., -0.8169,  0.1020,  0.1669]],

         ...,

         [[-0.3655, -0.2597, -0.7410,  ..., -0.8024,  0.4248, -0.4208],
          [-0.0382, -0.9516,  2.3876,  ...,  1.0405, -1.6202, -0.3938],
          [ 0.2023,  0.8736,  1.8099,  ..., -0.0363, -0.3845, -0.2317],
          ...,
          [-0.7035,  0.4357,  0.7795,  ..., -1.1353,  0.6211, -0.0097],
          [ 0.8421, -0.1923,  1.8707,  ...,  1.6848,  0.5796, -1.7162],
          [ 1.1454, -0.0537,  1.8554,  ...,  0.6414, -0.5141, -0.9465]],

         [[-1.1789, -2.8546,  0.1095,  ...,  1.7660,  1.5671, -1.5985],
          [ 0.2043,  0.9542, -0.5255,  ..., -0.7796,  0.6046, -0.0179],
          [ 0.0663,  0.5924, -0.6268,  ..., -0.5561,  0.5178,  0.0364],
          ...,
          [ 0.1530,  0.2313, -0.7300,  ..., -0.0290,  0.7031, -0.0167],
          [-0.3269,  0.4781, -0.8025,  ..., -0.3819,  0.8895,  0.2701],
          [-0.0928,  0.2595, -0.7860,  ..., -0.7578,  0.9077,  0.4650]],

         [[ 1.0171,  1.8497,  0.6378,  ..., -0.8035,  0.1293,  0.6040],
          [ 0.0156,  2.1213,  2.3286,  ...,  1.4849,  0.3781, -2.2845],
          [ 0.0148,  2.0272,  2.2091,  ...,  1.7553, -0.1371, -2.4600],
          ...,
          [-0.1495,  1.9038, -0.4051,  ...,  0.6891, -0.3348, -0.9027],
          [ 0.0852,  1.0575, -0.2086,  ...,  0.6766,  0.7277, -1.3756],
          [ 1.3943,  2.3425,  0.3973,  ..., -0.1781, -0.6685, -0.5852]]]],
       device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[ 2.8582e-01, -5.7216e-02,  7.6864e-03,  ...,  1.7198e-01,
           -2.0384e-01, -1.4675e-02],
          [ 5.7858e-02,  3.3398e-01, -4.4943e-01,  ..., -2.9586e-01,
            4.0076e-01,  1.6289e-01],
          [-3.2822e-01, -2.3642e-01,  5.2615e-02,  ..., -5.6616e-02,
            9.3489e-02,  4.1029e-01],
          ...,
          [ 5.1043e-01,  1.1239e-01,  1.1384e-01,  ...,  1.5822e-01,
           -2.3087e-01,  1.6409e-01],
          [-3.5905e-01, -1.7121e-02, -3.3058e-01,  ...,  5.2365e-01,
           -1.0570e-01, -2.3161e-01],
          [ 2.5310e-01,  3.4903e-01,  3.1141e-01,  ..., -1.0517e-01,
           -2.5150e-01,  3.6372e-01]],

         [[ 1.9847e-01, -5.7434e-02, -1.4368e-01,  ...,  4.2654e-02,
           -4.2432e-01,  3.0031e-02],
          [-6.5083e-01,  5.3207e-01,  1.1575e+00,  ...,  1.7449e-01,
           -1.8299e-01, -5.5981e-01],
          [ 5.5086e-01,  1.7719e-01, -4.3369e-01,  ..., -1.6054e-01,
           -6.7744e-01,  3.3818e-01],
          ...,
          [ 5.4245e-01,  1.8869e-01, -3.8369e-01,  ...,  5.7259e-01,
            4.3979e-01, -7.9122e-01],
          [-1.2102e-01,  6.4750e-01, -6.5491e-01,  ...,  4.7697e-01,
           -1.5043e-01,  4.5935e-01],
          [ 8.3022e-01,  1.8053e-01,  2.6405e-01,  ..., -9.6267e-01,
            4.0402e-02,  6.3098e-02]],

         [[ 4.5447e-02, -1.3822e-01, -5.6856e-02,  ..., -5.8098e-01,
            7.3407e-02,  3.3327e-02],
          [ 4.8486e-01,  3.0004e-01,  2.2786e-01,  ..., -6.6910e-01,
           -6.2471e-02, -1.1193e-02],
          [ 6.1447e-01,  1.9622e-01,  2.6388e-01,  ..., -6.4158e-01,
            1.2114e-01, -5.8668e-02],
          ...,
          [ 5.2907e-01,  4.2188e-02,  1.2375e-01,  ..., -3.8698e-01,
           -8.2161e-03,  1.2330e-02],
          [ 5.4567e-01,  3.7348e-01,  3.6478e-01,  ..., -2.9497e-01,
           -1.4520e-01,  1.0892e-01],
          [ 7.6291e-01,  3.8150e-01,  1.0935e-01,  ..., -6.2610e-02,
            3.5521e-01, -2.2801e-01]],

         ...,

         [[ 1.4935e-01,  5.5228e-01, -5.8729e-02,  ...,  1.0577e-01,
           -1.1574e+00, -1.3604e-02],
          [-5.1389e-01,  3.5528e-01, -1.9673e-01,  ..., -3.9717e-02,
           -7.1098e-01, -7.2981e-02],
          [-2.6157e-01,  3.2165e-01,  1.1155e+00,  ..., -1.7352e-01,
           -6.4596e-01, -6.2517e-02],
          ...,
          [ 2.2258e-01,  3.6877e-01,  1.1427e-01,  ...,  3.0327e-02,
           -5.1350e-01,  8.5692e-02],
          [-4.9558e-02,  4.0748e-01,  1.1954e-01,  ...,  3.3549e-01,
           -1.6801e-01, -3.9417e-01],
          [ 1.0553e-01, -6.5443e-01,  7.7109e-02,  ...,  6.8254e-02,
           -5.4058e-01, -2.9753e-01]],

         [[ 3.3667e-01, -2.2411e-01, -1.9435e-01,  ...,  3.1839e-01,
           -3.5641e+00, -1.4767e-01],
          [ 9.3890e-02, -1.3407e-01,  2.1106e-01,  ...,  5.0133e-01,
           -3.0602e-02,  6.8217e-02],
          [-1.3945e-01,  6.5872e-01, -3.0873e-01,  ..., -1.3002e-01,
            4.3156e-02,  1.0483e-01],
          ...,
          [ 2.0381e-01, -1.6960e-01,  1.2699e-01,  ..., -1.8644e-01,
           -5.4254e-02,  2.5791e-02],
          [ 3.7103e-01,  9.3224e-02, -5.2816e-02,  ..., -2.0285e-01,
            3.0910e-01, -4.9643e-02],
          [ 2.4020e-01,  6.1041e-01, -1.3586e-01,  ..., -1.0973e-01,
           -1.0275e-01, -1.1933e-01]],

         [[ 1.0385e-01, -6.1643e-02, -8.1606e-02,  ..., -1.9553e-01,
            1.6988e-01,  7.9386e-02],
          [-1.5159e-01, -3.4469e-02, -4.1837e-01,  ..., -1.9089e-01,
            3.7003e-01,  4.1017e-01],
          [-3.2069e-02,  1.7096e-03,  1.6785e-01,  ..., -1.9622e-01,
           -4.2147e-02,  3.2936e-02],
          ...,
          [ 2.0300e-01,  2.4707e-01, -1.9529e-01,  ..., -3.4163e-01,
            1.1657e-01,  2.1782e-01],
          [ 4.1669e-02,  3.8711e-01,  2.9755e-01,  ..., -1.9858e-01,
            1.6213e-01,  1.9509e-02],
          [-1.7768e-02,  1.2535e-01,  4.7679e-02,  ..., -1.9662e-01,
           -6.9517e-02,  1.8849e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[-1.1258e-01, -1.1088e+00,  2.8973e-01,  ..., -6.3421e-01,
           -1.7726e-01,  3.7837e-03],
          [ 4.6296e-02, -2.7619e+00,  1.5496e-01,  ...,  1.8012e-01,
            1.5414e-01,  5.2871e-01],
          [ 1.6581e+00, -3.4251e+00,  1.3515e+00,  ..., -5.9394e-01,
           -1.3542e-01,  8.0459e-01],
          ...,
          [-1.8019e-01, -7.5081e-01, -8.3324e-01,  ...,  7.3177e-01,
           -3.3206e-01,  1.8634e-01],
          [ 5.0296e-01, -2.4352e+00, -4.8550e-01,  ...,  3.0387e-01,
           -5.1740e-01,  4.8016e-01],
          [ 6.5937e-02, -8.9230e-01,  4.7684e-01,  ...,  1.5388e+00,
            5.9400e-02, -2.0516e-01]],

         [[-5.2778e-01,  3.9466e-01, -3.2121e-01,  ...,  1.1577e+00,
           -5.7213e-01, -4.4179e-01],
          [-1.4074e+00,  1.4519e-01, -6.0831e-01,  ...,  5.8401e-01,
            4.8111e-01,  2.3295e-01],
          [ 7.7871e-01,  4.0712e-01, -1.4060e+00,  ...,  8.3845e-01,
            6.0770e-01,  4.0896e-01],
          ...,
          [-1.6353e+00, -5.3745e-01,  9.7606e-01,  ..., -2.2841e-01,
           -2.6943e-01,  1.3457e+00],
          [-1.6745e+00, -2.2275e+00, -1.1276e+00,  ...,  2.0199e-01,
            6.4847e-01, -3.5962e-01],
          [-2.7640e-01,  1.0089e+00, -2.0779e+00,  ...,  1.8831e-01,
            9.1014e-01, -2.6479e-01]],

         [[ 1.2186e+00,  3.0460e+00,  3.7205e+00,  ...,  6.2533e-01,
            1.6636e+00, -7.7160e-01],
          [-1.6707e+00,  1.7284e+00, -4.4143e+00,  ..., -3.4466e+00,
            3.8242e+00,  3.0921e-01],
          [-2.1330e+00,  8.3251e-01, -3.8796e+00,  ..., -3.3404e+00,
            4.1093e+00,  7.7233e-01],
          ...,
          [-4.6060e+00, -7.7079e-01, -2.9684e+00,  ..., -3.9691e+00,
            1.9058e+00,  4.0182e-01],
          [-2.7102e+00, -1.9871e+00, -3.9144e+00,  ..., -3.3751e+00,
            2.4657e+00,  1.6757e+00],
          [-4.6400e+00, -1.8988e+00, -4.4826e+00,  ..., -3.7084e+00,
            2.3719e+00,  7.5402e-01]],

         ...,

         [[ 1.3252e+00, -2.7270e+00, -2.7397e+00,  ...,  9.5941e-01,
            4.3134e-01,  2.7145e+00],
          [-1.3666e+00,  1.3817e+00,  7.7345e-01,  ...,  7.0400e-03,
           -2.1723e+00, -1.2306e+00],
          [-3.3139e+00,  2.7426e+00,  7.5008e-01,  ..., -6.9737e-01,
           -2.8426e+00, -8.4870e-01],
          ...,
          [-3.3808e+00,  3.6548e+00,  1.0053e+00,  ..., -2.0039e-01,
           -3.4632e+00, -7.8443e-01],
          [-3.2769e+00,  3.0825e+00,  1.8556e+00,  ..., -1.5148e+00,
           -4.2361e+00, -2.2974e+00],
          [-3.5654e+00,  5.0539e+00,  1.9514e+00,  ..., -1.8862e+00,
           -3.2081e+00, -1.6764e+00]],

         [[ 1.7317e+00,  4.5925e-01,  9.2368e-01,  ...,  2.3702e-02,
           -1.0098e+00, -3.0809e-01],
          [ 1.9795e+00,  7.9756e-01,  1.2470e+00,  ...,  4.3697e-01,
           -1.4722e+00, -1.6632e+00],
          [ 2.5675e+00,  6.1664e-01,  1.3403e+00,  ..., -1.5458e-01,
           -1.6354e+00, -1.2597e+00],
          ...,
          [ 1.8166e+00,  5.7684e-01,  1.0326e+00,  ...,  3.2898e-01,
           -2.0770e+00, -9.4286e-01],
          [ 2.6763e+00,  3.3288e-01,  1.1887e+00,  ...,  1.4417e-01,
           -1.4898e+00, -8.5120e-01],
          [ 2.0895e+00,  4.4764e-01,  8.9569e-01,  ..., -3.5181e-01,
           -2.2835e+00, -9.7895e-01]],

         [[-2.5208e-01,  1.5814e-01, -5.6552e-01,  ...,  3.1414e-01,
            2.8355e-01,  2.2655e-01],
          [ 1.5601e-01,  1.0595e+00,  1.6655e-01,  ..., -3.0340e-01,
            7.3512e-02, -3.7588e-01],
          [ 4.5730e-01,  7.6759e-01,  6.7810e-02,  ..., -3.8059e-01,
           -4.8543e-02, -1.7695e-01],
          ...,
          [ 3.0563e-01, -4.7971e-01, -7.2877e-01,  ...,  2.4584e-01,
            2.6441e-01,  7.5348e-01],
          [-7.7925e-01,  4.4917e-01, -8.4968e-01,  ...,  6.4183e-01,
            6.5270e-01, -4.6915e-01],
          [ 1.4354e+00, -2.2250e-01,  1.3349e-01,  ..., -1.4484e-02,
            4.1740e-01,  1.9530e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[ 4.2124e-03,  5.0817e-02, -1.2287e-01,  ...,  2.0558e-02,
            4.7877e-02, -5.4793e-01],
          [ 2.3754e-02,  5.2620e-02,  1.8422e-01,  ..., -4.3334e-01,
           -2.6523e-01,  7.3234e-01],
          [-5.8956e-01,  1.4897e-01,  4.0976e-01,  ...,  4.9391e-01,
            9.4187e-01,  5.9001e-01],
          ...,
          [ 7.7799e-01,  1.2943e+00,  2.9846e-01,  ...,  3.6765e-01,
            1.0873e-01,  8.7409e-01],
          [ 5.5421e-01, -6.7936e-01,  3.6075e-01,  ..., -4.4645e-01,
            1.1377e-01,  4.7029e-01],
          [-1.4159e-01,  3.3504e-01, -3.2649e-01,  ...,  2.2850e-02,
           -3.0386e-01,  5.8188e-01]],

         [[ 1.3280e-02, -2.8979e-02,  5.5474e-02,  ..., -4.4208e-02,
            6.2167e-03,  3.4543e-02],
          [ 9.9526e-03,  1.0667e-01, -3.8910e-01,  ...,  8.4887e-03,
           -1.0204e-01,  5.6042e-02],
          [-5.3040e-01, -4.0489e-02, -2.2295e-01,  ...,  2.6226e-01,
           -7.2703e-01, -3.2483e-01],
          ...,
          [ 8.4138e-02,  2.8657e-02,  1.4982e-01,  ...,  6.1274e-02,
           -1.4219e-01,  3.3446e-01],
          [ 6.4882e-02,  1.2559e+00, -6.3594e-01,  ...,  2.0117e+00,
            4.0344e-01, -2.6313e-01],
          [ 3.0646e-01,  5.3148e-02, -6.7682e-03,  ...,  2.5434e-01,
           -1.0172e-03,  2.6347e-01]],

         [[ 1.3993e-02, -8.0731e-01, -3.6560e-02,  ...,  7.3044e-02,
            8.4237e-03, -2.6865e-02],
          [-6.1524e-02, -7.5551e-01,  2.6525e-01,  ...,  2.5373e-01,
            1.0760e-01,  2.4019e-01],
          [-2.8497e-02, -1.3058e+00, -1.9403e-01,  ..., -6.1217e-01,
           -3.0187e-02,  3.7932e-01],
          ...,
          [-2.6733e-01, -1.9554e+00, -3.6751e-02,  ...,  6.9920e-02,
           -3.7356e-02,  7.0360e-02],
          [-1.6400e-01, -1.9207e+00,  7.5180e-01,  ...,  7.4416e-02,
            4.1090e-01, -5.3913e-01],
          [-4.9150e-01, -7.3422e-01, -6.0788e-01,  ..., -5.1220e-01,
            5.3674e-02, -9.2746e-02]],

         ...,

         [[ 1.4164e-02, -7.0936e-02,  1.3435e+00,  ..., -7.3744e-02,
            2.0103e-01, -2.8246e-02],
          [ 7.4353e-03,  1.5017e-01,  1.0513e+00,  ...,  1.1080e-01,
           -1.5743e-01,  6.2338e-02],
          [-2.2182e-01, -4.2225e-01,  1.2976e+00,  ...,  1.1032e-01,
            2.1166e-01, -1.7599e-01],
          ...,
          [-5.7066e-02,  4.6036e-01,  2.5958e+00,  ..., -3.7224e-01,
            1.4200e-01, -5.8588e-01],
          [-1.6861e-01, -1.0492e-01,  1.0817e+00,  ..., -2.9654e-01,
            4.1565e-01,  1.0558e+00],
          [-4.5311e-01,  2.1572e-01,  2.5959e+00,  ...,  1.0544e-01,
           -5.5539e-01, -8.8997e-02]],

         [[-3.1299e-02, -9.0694e-02, -1.5058e-01,  ...,  1.1410e-01,
            1.2924e-01,  1.9263e-01],
          [ 1.3471e-01,  1.8759e-01, -1.1682e-03,  ...,  6.0341e-01,
            6.3717e-02, -2.9939e-02],
          [ 3.2003e-01, -1.8200e-01,  3.5685e-01,  ..., -2.3825e-01,
           -5.0240e-03, -2.7442e-01],
          ...,
          [-6.7477e-01,  1.9573e-02, -2.9475e-01,  ..., -1.5251e-01,
            7.6870e-01,  2.2547e-01],
          [ 1.0340e-02, -2.6907e-02, -3.2965e-02,  ...,  2.9814e-02,
            4.7616e-02, -8.8962e-01],
          [-4.5525e-01,  1.1967e-01,  5.1916e-01,  ...,  1.1242e+00,
           -1.9753e-01, -5.9083e-01]],

         [[ 1.9867e-02,  7.5011e-03,  2.6127e-02,  ..., -3.7370e-03,
            2.2394e-01,  1.3843e-02],
          [-8.1746e-01, -3.7520e-01, -4.9655e-01,  ...,  1.2893e-01,
           -1.7921e+00,  3.8990e-01],
          [ 1.1536e+00, -2.9730e-01,  1.5741e-01,  ...,  2.0536e-02,
           -2.0358e+00,  1.8910e-01],
          ...,
          [ 2.2302e-01, -7.9008e-01, -4.9336e-01,  ...,  4.5226e-01,
           -2.0080e+00, -2.7697e-01],
          [ 1.9959e-01, -9.6224e-02, -4.0063e-01,  ...,  7.3084e-01,
           -2.1308e+00,  1.2104e-01],
          [-4.0285e-01, -3.4758e-02,  9.0270e-02,  ...,  2.2755e-01,
           -1.7261e+00, -1.9400e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 3.0508e-02, -2.2260e-01,  1.5264e-01,  ..., -8.9290e-01,
            7.5289e-01, -1.1935e+00],
          [-2.1560e-01,  1.6692e-01, -9.1999e-02,  ..., -5.7826e-01,
           -1.4156e-01,  3.9989e-01],
          [ 7.9431e-01, -8.6379e-01, -2.1202e-01,  ...,  1.0457e+00,
           -1.7295e+00,  2.5042e+00],
          ...,
          [-2.0043e+00,  2.7582e-01, -7.2417e-01,  ..., -5.6687e-01,
           -8.2403e-01,  5.5432e-01],
          [-2.2761e-01,  1.2259e+00, -2.9970e+00,  ...,  6.2500e-01,
            2.7706e-01,  2.3849e+00],
          [-1.5395e-01,  3.0106e-01, -2.4458e-01,  ...,  1.5246e+00,
           -2.5438e-01,  1.0539e+00]],

         [[ 8.1046e-01,  1.8663e-01, -6.4914e-03,  ..., -1.5619e-01,
           -1.0988e+00, -1.8967e-01],
          [ 1.2836e+00, -1.7571e+00,  5.6523e-01,  ...,  1.6203e-01,
            5.3114e+00, -4.3371e-01],
          [ 1.0058e+00, -1.4609e+00,  7.6332e-01,  ..., -2.0171e-01,
            6.9416e+00,  1.9571e+00],
          ...,
          [ 6.8441e-01, -1.1089e+00, -1.3350e+00,  ...,  2.8493e-01,
            6.0420e+00,  1.7264e+00],
          [ 3.6634e-01, -1.0104e-01, -2.1439e+00,  ...,  2.6209e-01,
            6.2713e+00,  2.7654e+00],
          [ 9.3607e-01, -1.5236e+00, -1.6343e+00,  ..., -1.6642e-01,
            6.7665e+00,  2.8936e+00]],

         [[ 3.4120e-01, -3.7081e-01, -3.0869e-01,  ...,  3.6410e-01,
            1.4399e+00,  2.5755e-01],
          [-8.2960e-01, -6.1458e+00, -9.4992e-01,  ..., -2.7348e+00,
           -3.3308e+00, -6.2488e+00],
          [-1.6905e+00, -5.8666e+00, -1.3552e+00,  ..., -3.8865e+00,
           -2.7315e+00, -6.3022e+00],
          ...,
          [-2.7337e+00, -6.8483e+00, -1.6320e+00,  ..., -2.6826e+00,
           -3.3623e+00, -4.6536e+00],
          [-1.3503e+00, -5.0896e+00, -1.5456e+00,  ..., -3.7086e+00,
           -3.3961e+00, -4.4885e+00],
          [-4.4174e+00, -7.9308e+00, -2.5600e+00,  ..., -3.4487e+00,
           -1.4507e+00, -5.1099e+00]],

         ...,

         [[ 2.2782e-01,  1.7694e+00,  5.3587e-01,  ...,  2.5913e-01,
            4.6895e-01, -1.6825e+00],
          [-1.5099e+00, -4.9805e+00, -4.3378e-01,  ..., -8.1371e-01,
           -1.6987e+00,  6.6884e+00],
          [ 7.1447e-01, -5.3976e+00,  2.6015e-02,  ..., -1.4007e+00,
           -1.7186e+00,  6.7642e+00],
          ...,
          [-2.4670e+00, -7.2994e+00,  1.0229e+00,  ..., -2.6515e+00,
           -1.9696e+00,  6.8497e+00],
          [-1.2711e+00, -5.8679e+00, -9.0149e-01,  ..., -1.3883e+00,
           -1.2481e+00,  7.0832e+00],
          [-1.3249e+00, -7.4333e+00,  9.6733e-01,  ..., -3.6964e+00,
           -3.3296e+00,  8.0495e+00]],

         [[ 5.4848e-02, -2.5668e-02,  1.4165e-01,  ..., -9.0587e-02,
           -8.5854e-02, -1.3366e-01],
          [ 1.8120e+00, -8.8290e-01,  1.5319e+00,  ..., -6.6818e-01,
            2.6287e-01, -2.2123e-01],
          [ 1.7768e+00,  1.8999e-01, -1.2643e+00,  ...,  5.9022e-01,
           -4.8334e-01, -6.0814e-01],
          ...,
          [ 8.9763e-02, -9.0987e-01, -7.7684e-01,  ..., -4.5400e-02,
            5.2621e-01,  3.5608e-01],
          [ 3.4736e-01, -6.2006e-01, -4.0945e-01,  ..., -8.0235e-01,
            7.8035e-01, -3.7293e-01],
          [-8.1439e-01, -7.3308e-01, -2.7807e-01,  ..., -6.0827e-01,
            9.7116e-01,  1.3668e-01]],

         [[ 3.9688e-01, -6.6984e-02,  1.8968e+00,  ..., -2.4403e-01,
           -2.1560e-01, -1.0126e+00],
          [ 2.8065e+00,  1.3883e+00, -2.6807e+00,  ...,  1.2111e+00,
            3.3058e-01,  2.3798e+00],
          [ 2.9053e+00,  6.0659e-01, -2.8907e+00,  ...,  9.0771e-01,
            4.6567e-01,  1.9132e+00],
          ...,
          [ 2.9192e+00,  8.8462e-01, -2.8681e+00,  ...,  1.4205e+00,
           -2.0973e-01,  2.6587e+00],
          [ 2.3439e+00,  1.7387e+00, -1.9872e+00,  ...,  2.2289e+00,
           -1.4917e+00,  3.7927e+00],
          [ 2.2603e+00,  1.6628e+00, -2.2689e+00,  ...,  8.3023e-01,
            9.3067e-01,  4.7189e+00]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[ 4.1882e-02,  6.4490e-02, -9.3272e-03,  ...,  1.6031e-02,
            1.0052e-01,  3.5980e-02],
          [ 3.2562e-01, -8.8795e-01, -4.9085e-02,  ...,  3.6603e-01,
            4.3858e-01, -6.1412e-01],
          [ 7.5744e-02, -4.6294e-01,  1.1798e+00,  ...,  1.4153e-02,
            1.2599e-01, -8.4334e-01],
          ...,
          [ 1.8532e-01,  6.7643e-02, -4.8084e-01,  ...,  3.4835e-01,
           -1.7105e-01, -2.2380e-01],
          [-2.3345e-01, -3.6575e-01,  3.9444e-02,  ..., -3.5445e-01,
           -2.8266e-01, -3.4278e-01],
          [-5.4808e-01, -3.0808e-01, -8.5808e-01,  ...,  2.2267e-01,
           -2.2272e-01,  1.7892e+00]],

         [[-3.3985e-02, -1.3486e-03,  8.3480e-02,  ..., -4.4081e-02,
           -3.3619e-02, -5.0814e-02],
          [ 1.6667e-01,  2.3407e-01,  3.0234e-01,  ...,  2.8978e-01,
            6.3130e-01,  3.1103e-01],
          [ 3.4845e-01, -1.3173e+00,  3.9288e-01,  ...,  3.7700e-01,
            5.9143e-01, -7.0978e-01],
          ...,
          [ 3.0984e-01,  8.3614e-02,  1.8268e-01,  ..., -5.9535e-02,
            1.1795e-01,  2.4208e-01],
          [ 1.1407e-01, -1.4961e+00, -4.7466e-01,  ...,  1.1233e+00,
            4.0220e-01,  2.3130e-01],
          [ 3.3782e-01, -7.3522e-01, -2.8408e-01,  ...,  1.5294e-01,
            1.7005e-01, -1.5641e-01]],

         [[ 3.5987e-02, -1.0232e-01, -4.4135e-02,  ..., -2.5785e-02,
            8.6479e-02, -1.4895e-01],
          [-4.2532e-01, -8.7885e-02,  4.6666e-01,  ..., -1.6055e-01,
            1.9573e-01, -3.9598e-01],
          [-2.4160e-01, -2.7997e-01,  6.6099e-02,  ..., -1.2111e-01,
            3.1895e-01, -8.4615e-02],
          ...,
          [-2.6233e-01, -4.7053e-01,  6.4113e-01,  ...,  3.2162e-01,
           -2.1453e-01, -1.4517e-01],
          [ 3.8219e-01,  4.5365e-02,  1.6591e-02,  ...,  2.6456e-01,
           -2.6929e-01,  2.3254e-01],
          [-1.4182e-01,  1.8406e-01, -9.6458e-01,  ..., -5.1126e-01,
            4.8666e-01, -2.5828e-02]],

         ...,

         [[-2.3814e-02,  1.2032e-01, -7.7499e-03,  ..., -2.3475e-02,
            5.9122e-02, -4.1239e-02],
          [ 5.3979e-01,  6.3833e-02, -8.8802e-01,  ...,  2.3857e-01,
           -3.4962e-01, -1.8191e-01],
          [ 3.6208e-01, -8.7884e-02,  2.2842e-01,  ...,  2.4545e-02,
           -2.7802e-01,  3.5545e-01],
          ...,
          [-2.0331e-02, -1.8373e-01, -2.5179e-01,  ..., -4.0274e-02,
            7.3222e-02, -3.7319e-04],
          [-1.8016e-01, -3.4777e-01, -3.3174e-01,  ...,  4.5794e-01,
           -1.3380e-01, -1.7557e-01],
          [ 1.7746e-01, -3.2240e-01,  2.8709e-01,  ..., -1.9587e-01,
            3.0646e-01, -3.8693e-01]],

         [[-1.5903e-01, -1.2174e-01, -6.9957e-02,  ..., -2.3985e-01,
           -1.6513e-02, -3.9757e-02],
          [-6.4740e-02, -2.3349e-02,  3.1592e-01,  ...,  1.0429e+00,
           -8.8360e-02,  8.3591e-01],
          [-2.1363e+00, -7.1707e-01, -5.2893e-01,  ...,  7.2560e-01,
           -1.5213e+00,  9.7697e-01],
          ...,
          [ 2.0967e-01, -6.3169e-01,  1.0254e+00,  ...,  3.4685e-01,
            1.1476e-01,  6.6754e-01],
          [ 3.9468e-01,  3.6171e-01,  1.4690e+00,  ...,  6.8382e-01,
            2.1922e-01,  6.0305e-01],
          [ 5.0124e-01,  4.3996e-01,  9.9185e-01,  ...,  2.1516e-01,
           -1.6881e-01,  3.6948e-01]],

         [[ 1.1859e-01, -8.5552e-02, -2.7463e-02,  ..., -1.7703e-02,
           -9.0103e-02, -9.4804e-02],
          [ 1.4253e-01,  4.5364e-01, -3.9422e-01,  ...,  2.6342e-01,
            6.0791e-01,  2.2922e-02],
          [ 2.5674e-01, -4.9790e-02,  2.4432e-01,  ...,  6.9650e-01,
            3.0487e-01, -2.1513e-02],
          ...,
          [-9.6273e-02, -1.2066e+00, -8.2821e-02,  ..., -2.2924e-01,
            4.0816e-01, -3.1922e-01],
          [-3.7043e-01,  3.5381e-01,  3.8201e-01,  ...,  1.6970e-01,
           -4.3569e-01, -1.4287e-01],
          [ 7.0933e-02, -1.3498e-01,  4.6551e-02,  ..., -5.1041e-01,
           -7.0377e-02,  1.3301e-02]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[-8.8230e-01, -1.4005e-01,  3.3679e-01,  ..., -9.7407e-01,
            1.9440e-02, -2.9572e+00],
          [ 1.0084e+00,  1.0470e+00, -4.7232e+00,  ..., -2.5201e+00,
           -1.9014e+00,  8.8033e+00],
          [ 1.2540e+00,  4.6455e-01, -3.4971e+00,  ..., -1.8074e+00,
           -1.3556e-01,  1.0052e+01],
          ...,
          [-1.1896e-01, -4.2050e-01, -2.9757e+00,  ..., -2.9992e+00,
           -2.6175e+00,  8.8424e+00],
          [-3.3484e-01,  4.7369e-01, -4.5834e+00,  ..., -1.6147e+00,
           -2.4933e+00,  1.0348e+01],
          [ 9.3118e-01, -2.6527e+00, -8.8920e-01,  ..., -8.0500e-01,
           -4.1080e+00,  9.5769e+00]],

         [[ 3.7102e-01, -7.1686e-02,  4.7841e-01,  ..., -1.3458e-01,
           -6.9245e-02, -2.2292e+00],
          [-2.2462e+00,  1.5379e+00,  3.5270e+00,  ..., -1.5324e+00,
           -3.7864e+00,  6.7854e+00],
          [-4.6717e-01,  8.8104e-01,  2.0514e+00,  ...,  5.7600e-01,
           -2.9924e+00,  7.1979e+00],
          ...,
          [-3.3413e+00, -3.0525e-01,  3.8961e+00,  ..., -1.8470e-01,
           -1.0056e+00,  6.0956e+00],
          [-3.8802e+00, -5.7609e-02,  1.5106e+00,  ..., -5.5892e-01,
           -2.4805e+00,  6.4896e+00],
          [-2.2596e+00,  5.5595e-01,  3.5477e+00,  ..., -7.9328e-01,
           -2.3755e+00,  5.4057e+00]],

         [[ 1.1640e-01, -6.5411e-01, -2.1746e-01,  ...,  1.4848e-01,
            2.7450e-01, -1.7022e-01],
          [ 6.3319e-01,  2.4561e+00, -3.0855e-01,  ...,  7.4445e-01,
            8.7801e-01, -1.2477e+00],
          [ 3.7108e-01,  3.1974e+00,  1.9218e-02,  ..., -1.9633e+00,
            2.6208e-01, -6.9583e-01],
          ...,
          [-8.3210e-01,  3.1895e+00, -1.0288e-01,  ..., -6.3625e-01,
            7.0729e-01,  5.2103e-01],
          [-5.0214e-02,  3.4132e+00,  1.9143e-01,  ..., -3.9255e-01,
            1.5079e+00, -1.2919e+00],
          [ 7.7568e-01,  3.1053e+00, -2.3148e-01,  ..., -7.4599e-01,
            4.5266e-01, -3.9879e-01]],

         ...,

         [[-3.8907e-01,  2.1348e-02,  5.1327e-03,  ...,  1.2555e+00,
            6.2963e-02,  1.7806e+00],
          [-4.0056e-01, -1.2783e+00,  6.7603e-01,  ..., -1.8325e+00,
           -7.7475e-01, -1.6207e+00],
          [-9.4218e-02, -2.5972e+00,  1.3495e+00,  ..., -1.7981e+00,
           -2.3495e+00, -1.4134e+00],
          ...,
          [ 6.4192e-01, -7.7718e-01,  5.8877e-01,  ..., -2.8845e+00,
           -7.1077e-01, -1.4360e+00],
          [ 5.6177e-01,  3.9009e-01,  5.3228e-01,  ..., -1.8859e+00,
           -2.0232e+00, -2.1431e-01],
          [ 4.9119e-01,  2.7792e-01,  9.4523e-01,  ..., -3.2013e+00,
           -7.1947e-01,  6.3064e-01]],

         [[-3.3280e-01, -1.1855e-01,  2.2112e-01,  ...,  2.5834e-01,
           -3.4137e-02,  1.9058e-02],
          [-2.2876e-01, -8.1290e-01, -2.3541e-01,  ...,  2.0128e-01,
           -7.7041e-02,  1.6235e-01],
          [ 3.5779e-01, -1.2911e+00, -2.4718e-01,  ...,  8.0769e-01,
            1.6120e-01, -4.3466e-01],
          ...,
          [-3.7774e-01, -7.7144e-02,  1.0574e-01,  ...,  1.1080e+00,
            2.6553e-01,  1.8789e+00],
          [-1.5887e+00, -1.8580e+00,  9.7814e-01,  ...,  6.1661e-01,
            2.0212e+00,  6.2472e-01],
          [ 8.5004e-01, -1.1226e+00,  8.0483e-01,  ...,  1.1392e+00,
            1.7727e+00,  5.1945e-01]],

         [[ 3.3855e+00,  2.1669e+00, -2.1247e+00,  ..., -2.8605e+00,
           -3.8921e+00, -1.1778e+00],
          [-9.7418e-01, -4.1155e-01,  2.6531e+00,  ..., -4.8420e+00,
            1.1042e+01,  2.1530e+00],
          [ 1.5355e+00,  5.0286e-01,  5.4398e+00,  ..., -2.1528e+00,
            9.8180e+00, -2.0278e+00],
          ...,
          [-5.7341e+00, -1.2663e+00,  3.5021e+00,  ..., -3.7820e+00,
            1.3191e+01,  6.6441e+00],
          [-2.8189e+00, -2.5276e+00,  3.8678e+00,  ..., -4.2633e+00,
            9.9248e+00,  3.7757e+00],
          [-3.2653e+00, -1.6548e+00,  5.4356e+00,  ..., -5.3807e+00,
            1.1478e+01,  3.6439e+00]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[-2.8015e-03, -4.2843e-02,  2.1946e-02,  ...,  5.9664e-02,
            2.8541e-02,  7.5017e-02],
          [ 4.7948e-02, -2.7571e-01,  5.6533e-01,  ..., -1.1318e+00,
            4.9030e-01,  3.7480e-01],
          [-1.7491e-01, -1.1429e-01,  3.9419e-01,  ...,  3.6174e-01,
            6.4979e-01, -2.4254e-02],
          ...,
          [ 5.3288e-02, -1.5184e-01, -6.9785e-04,  ...,  1.3779e-01,
           -9.9556e-02,  4.6005e-01],
          [ 2.4405e-01, -3.2254e-01, -4.3253e-01,  ..., -2.5381e-01,
            5.0629e-01, -2.1757e-01],
          [ 4.3920e-01, -2.6082e-01,  4.6502e-02,  ..., -2.6076e-01,
           -4.6805e-01,  8.1107e-02]],

         [[-6.7511e-02, -1.7221e-02, -1.4186e-01,  ..., -4.4599e-02,
            4.3422e-02, -1.4441e-02],
          [-1.5872e-01,  2.0812e-01,  2.3949e-01,  ...,  3.2546e-01,
            2.2967e-01,  3.7792e-01],
          [ 6.9327e-01,  9.3650e-02, -1.0073e+00,  ...,  7.9522e-01,
           -2.5250e-01, -2.6677e-01],
          ...,
          [ 1.6569e-01, -5.1011e-01,  6.5104e-02,  ..., -3.7376e-02,
            1.4629e-01,  3.6823e-01],
          [-3.1743e-01, -3.3498e-01,  3.5960e-01,  ..., -8.5223e-01,
           -3.3373e-01, -3.6818e-01],
          [ 2.5580e-01, -4.3256e-01, -4.2976e-01,  ..., -7.5222e-01,
           -5.0993e-02,  9.3812e-03]],

         [[ 7.0874e-02,  8.7366e-02,  8.0504e-02,  ...,  1.9967e-02,
           -8.7679e-02,  1.5890e-03],
          [-8.3180e-01,  2.3641e-01, -4.9481e-01,  ..., -1.0852e-01,
            6.0099e-01, -4.9572e-02],
          [ 3.7039e-01,  2.2074e+00,  7.2645e-01,  ..., -1.3161e-01,
            1.3500e+00,  8.9254e-01],
          ...,
          [-5.4997e-01,  1.6803e-02, -2.1176e-01,  ...,  5.1826e-01,
           -8.0877e-01, -5.2248e-01],
          [-1.0285e+00,  4.6817e-01,  7.8376e-01,  ..., -7.7642e-01,
            1.1818e+00, -3.8882e-01],
          [ 5.9641e-02,  1.2021e-01,  2.2528e-01,  ...,  1.6642e-03,
           -2.0485e-01,  7.6625e-01]],

         ...,

         [[-7.1236e-03,  7.8853e-02, -8.0076e-02,  ...,  4.3391e-02,
            4.2254e-02, -1.3848e-01],
          [ 7.1242e-01,  2.3731e-01,  2.5584e-01,  ..., -5.2713e-01,
           -5.9722e-01,  6.7882e-01],
          [ 2.6651e-01, -3.6351e-01, -3.3804e-01,  ...,  2.8267e-01,
           -6.4981e-01,  2.8016e-01],
          ...,
          [-1.4381e-01, -1.6843e-01, -3.1744e-01,  ...,  3.1899e-01,
            4.8887e-03, -5.6601e-02],
          [ 6.6585e-01, -5.8674e-01, -6.6238e-01,  ..., -5.9955e-01,
            9.0017e-02, -6.8978e-01],
          [ 5.5239e-01,  9.5414e-01, -5.1927e-02,  ..., -2.9619e-01,
            5.8290e-01, -2.2422e-01]],

         [[-1.2629e-01, -4.8935e-02,  1.1181e-01,  ..., -6.4385e-02,
            4.6905e-02, -1.4047e-02],
          [-7.7673e-01, -8.3885e-01, -9.9223e-01,  ...,  7.1531e-03,
           -6.5990e-01,  5.3350e-02],
          [-1.3458e+00, -1.0354e+00, -6.1002e-01,  ...,  9.9597e-01,
            3.9742e-01,  2.1499e-01],
          ...,
          [-1.5937e-01,  6.1105e-01,  1.3167e+00,  ...,  9.3129e-04,
            1.4881e-01,  8.5424e-01],
          [ 1.0006e+00,  3.9797e-01,  6.0946e-01,  ...,  2.0232e+00,
            1.9867e-01,  5.5082e-01],
          [ 1.2180e-01, -6.8567e-01, -8.9236e-01,  ...,  6.9602e-01,
           -1.2220e+00,  2.4029e-01]],

         [[-3.7457e-03, -9.2607e-03, -2.4621e-02,  ..., -2.7746e-02,
            5.3264e-03, -1.2048e-02],
          [-3.3242e-02,  5.1789e-01, -9.5424e-02,  ..., -1.9423e-02,
            9.1498e-02, -3.0327e-01],
          [-3.9111e-01,  2.6925e-01, -8.8339e-01,  ..., -3.5068e-01,
            2.4670e-02, -8.1843e-01],
          ...,
          [ 1.2169e-01, -1.4590e-01, -1.0231e-01,  ..., -2.1047e-01,
           -9.8497e-02,  3.3168e-01],
          [-1.3166e-02, -1.5918e-01,  4.4679e-02,  ..., -2.8859e-01,
            7.7163e-02, -8.9428e-02],
          [-1.5484e-01,  4.7036e-01,  4.8949e-01,  ..., -5.6709e-01,
           -4.7696e-01,  5.4917e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 2.4425e-02, -2.9901e-01,  2.2618e-01,  ...,  1.7002e+00,
           -2.1425e-01, -7.5452e-02],
          [-6.7344e-01,  1.9831e+00,  3.7440e-02,  ..., -3.6278e+00,
            2.9602e-01, -1.7517e+00],
          [ 8.0027e-01,  8.3515e-02, -1.5843e+00,  ..., -4.8216e+00,
           -1.1108e+00, -2.5017e+00],
          ...,
          [ 6.5923e-01,  5.1771e-01,  1.1880e-01,  ..., -3.2431e+00,
           -2.4071e+00, -9.8829e-01],
          [ 9.1825e-01, -4.9183e-01,  9.0978e-01,  ..., -3.6721e+00,
           -4.4284e-01, -7.8973e-01],
          [-1.8351e+00, -9.1667e-01,  4.5614e-01,  ..., -4.2885e+00,
           -2.8697e-02, -5.2182e-01]],

         [[ 1.6053e-01,  9.7512e-01, -1.4212e+00,  ..., -1.1826e-01,
            2.6721e-01,  9.2271e-01],
          [-6.4489e-01, -4.8048e+00,  1.7009e+00,  ...,  4.8935e-01,
            6.5187e-01, -2.0910e+00],
          [-4.3839e+00, -6.9272e+00,  1.9357e+00,  ..., -3.3397e+00,
           -2.2280e+00, -4.7667e+00],
          ...,
          [ 6.6274e-02, -4.7992e+00,  2.9413e+00,  ...,  1.5957e+00,
           -9.6139e-04, -3.1260e+00],
          [ 4.0182e-02, -4.1015e+00,  2.5708e+00,  ...,  9.0249e-02,
            1.7288e-01, -2.9824e+00],
          [ 1.5543e+00, -6.1758e+00,  7.6686e+00,  ...,  9.0697e-01,
           -1.1387e+00, -6.2601e+00]],

         [[-6.6854e-01,  2.4147e-01, -4.0604e-02,  ...,  1.7321e-01,
            3.9602e-02, -2.9668e-01],
          [ 2.4421e+00, -9.1462e-01, -8.2885e-01,  ..., -8.3236e-01,
           -1.9725e-01, -2.7875e-01],
          [ 1.0662e+00, -6.5159e-01, -8.8043e-02,  ..., -1.3173e+00,
            1.0929e+00, -3.4334e-02],
          ...,
          [ 1.3428e+00, -1.4381e-01,  1.3997e+00,  ..., -9.9791e-01,
           -5.2218e-01, -1.3676e+00],
          [ 2.1379e+00, -1.6316e+00, -7.0705e-03,  ...,  1.3214e+00,
            2.4926e+00, -1.0005e+00],
          [ 3.3410e+00, -1.9271e+00, -2.6213e-01,  ..., -8.0559e-01,
            1.2947e+00, -5.7613e-01]],

         ...,

         [[-2.9694e-02,  1.2897e-01,  1.5171e-01,  ..., -1.0306e-01,
            2.4162e-02,  1.6164e-01],
          [ 8.7051e-01,  8.9658e-01, -2.1684e-01,  ...,  1.1384e+00,
           -3.0451e-01,  7.1239e-01],
          [ 2.0829e-01, -5.3203e-01,  1.8881e-01,  ...,  1.1098e+00,
           -1.0654e+00,  1.7870e+00],
          ...,
          [ 5.8492e-01,  1.2304e+00,  8.5839e-01,  ...,  1.8900e+00,
           -2.5703e-02,  6.3271e-01],
          [ 8.2500e-01,  5.2305e-01, -2.0326e-01,  ...,  8.7598e-01,
            7.1087e-01, -8.2672e-01],
          [-8.7811e-01,  1.5896e+00, -1.4590e+00,  ...,  1.2713e+00,
            5.9443e-01,  1.0977e+00]],

         [[-3.0105e+00,  4.0288e-01, -3.0686e-02,  ..., -4.7620e-01,
           -3.6181e-01,  1.2402e+00],
          [ 4.3091e+00, -3.6355e-02, -1.2733e+00,  ..., -1.3761e-01,
           -1.0568e+00,  3.7705e-01],
          [ 3.5537e+00, -1.2094e+00,  7.3616e-01,  ..., -1.2436e+00,
           -1.0722e+00,  1.5382e+00],
          ...,
          [ 5.0508e+00,  9.7999e-01, -9.1501e-01,  ..., -8.3978e-01,
           -7.1137e-01, -9.0957e-01],
          [ 5.3393e+00,  1.0491e+00,  1.3785e+00,  ...,  2.3665e-01,
            7.1385e-01, -8.4908e-01],
          [ 4.7278e+00, -7.3271e-01,  2.4610e-01,  ..., -8.1648e-01,
            8.1328e-01, -4.9938e-01]],

         [[-1.0533e-02, -2.4607e-01,  2.7443e-03,  ..., -1.7876e-01,
            3.2764e-01,  8.2148e-02],
          [ 8.6160e-01,  1.4717e-01,  1.1344e+00,  ..., -5.2077e-01,
            5.9950e-01, -3.6589e-01],
          [ 6.8364e-01, -1.4481e+00, -2.4772e-01,  ...,  7.1361e-01,
            8.1152e-01, -4.3211e-01],
          ...,
          [-6.3802e-01, -7.4226e-01,  2.0134e-02,  ..., -2.2923e-02,
           -3.0777e-01, -1.3769e+00],
          [ 7.0286e-01, -7.6536e-01, -3.0677e-01,  ...,  1.0340e+00,
            4.4649e-01, -8.3763e-01],
          [-9.1978e-01,  3.4823e-01,  7.0052e-01,  ..., -7.4284e-01,
            9.8658e-01, -5.9086e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[-0.0260, -0.0201,  0.0090,  ..., -0.0063, -0.0343,  0.3522],
          [ 1.0229,  0.3651, -0.2132,  ..., -1.4306,  0.3113, -0.5238],
          [ 1.5326, -0.5122,  0.1662,  ..., -0.8095, -0.3197,  0.2882],
          ...,
          [ 0.1370, -0.3330, -0.3261,  ..., -0.0797,  0.4292, -1.0034],
          [-0.1414,  0.3330,  0.3330,  ...,  0.0371,  0.1574, -0.5799],
          [-0.8609, -2.4307, -0.8873,  ..., -0.3572, -0.6777, -0.1879]],

         [[-0.0041, -0.0135,  0.0183,  ..., -0.0215,  0.0231,  0.0108],
          [ 1.3730,  0.6940,  0.7279,  ..., -0.9966,  0.9910, -0.2261],
          [ 0.3186,  0.4918,  1.7542,  ..., -0.2500, -0.0790, -0.1540],
          ...,
          [-0.5563, -0.5748,  0.8149,  ...,  0.5584,  1.4183,  0.3196],
          [ 0.8510, -0.7996, -0.0429,  ..., -0.2517,  0.5159, -0.4891],
          [ 0.3454, -1.2026, -0.2286,  ..., -0.5790,  0.3819,  0.1525]],

         [[-0.0566,  0.0091, -0.0393,  ..., -0.0407,  0.0051, -0.0818],
          [-0.6938, -1.2322, -1.1038,  ..., -0.1297,  1.1394, -0.9904],
          [-0.0594,  0.3513, -0.5823,  ..., -0.1540,  1.2982,  1.5220],
          ...,
          [-0.5360,  0.0138, -0.0441,  ..., -0.5549, -1.2158, -0.0624],
          [-0.2674, -0.5419, -0.1500,  ..., -0.7724,  0.6361, -0.8990],
          [ 0.1295,  0.5581,  0.0097,  ...,  0.6635, -1.1220, -0.1898]],

         ...,

         [[-0.3324, -0.1904, -0.0728,  ..., -0.4847,  0.2203,  0.1020],
          [ 1.2505, -0.1432, -0.2780,  ...,  1.0645, -0.9976,  0.3491],
          [-0.6492, -2.1424,  1.6660,  ..., -0.4581, -1.0233, -1.4598],
          ...,
          [-0.5372, -0.5731, -0.6616,  ...,  0.6041, -0.6421, -0.1693],
          [ 2.2314, -0.7433, -0.1765,  ...,  0.3478, -1.5081,  0.6716],
          [ 1.1512, -1.5820,  0.2295,  ..., -0.0358,  0.5782, -0.6780]],

         [[-0.0805, -0.1338, -0.0433,  ..., -0.1887, -0.1398,  0.1256],
          [-0.1725, -0.9856,  0.0283,  ...,  0.1987,  0.2887, -0.2125],
          [-0.2868, -1.2376, -0.8224,  ...,  1.6796,  0.9134,  0.3702],
          ...,
          [-0.2671,  0.1030, -0.0178,  ..., -0.5901, -0.1262,  0.1294],
          [-0.8149,  0.5106, -0.9661,  ...,  1.2941, -0.3053,  0.1000],
          [-0.2159,  0.3548, -0.1119,  ...,  0.1122,  0.0045,  0.7179]],

         [[-0.0283, -0.0370,  0.0912,  ...,  0.0785, -0.0412,  0.0123],
          [-0.3245,  0.4457, -0.2747,  ..., -0.8789, -0.0971,  0.6181],
          [-1.7313,  0.3132, -0.7923,  ..., -0.9594, -1.0080,  1.0855],
          ...,
          [-0.0484,  0.5334,  0.0299,  ...,  0.4920, -0.5462,  0.0646],
          [-1.3020,  1.6345,  0.7471,  ..., -1.0761, -0.1769,  0.0983],
          [-0.3621,  1.1414, -0.7577,  ..., -0.3690,  1.1279,  0.0493]]]],
       device='cuda:0', grad_fn=<PermuteBackward0>)), (tensor([[[[-3.3578e-01,  8.6494e-01, -1.5782e-01,  ...,  1.1243e+00,
           -1.7831e-01,  1.3638e-01],
          [ 6.2278e-01, -2.4989e+00,  1.1164e+00,  ..., -2.6288e+00,
           -1.5201e+00,  1.1154e+00],
          [ 2.5411e+00, -4.5835e+00,  1.0466e+00,  ..., -3.0950e+00,
           -1.3118e+00,  2.0691e+00],
          ...,
          [ 1.4910e-01, -5.1488e+00,  1.4900e+00,  ..., -3.5783e+00,
           -1.2938e+00,  1.6264e+00],
          [ 7.1818e-01, -4.5352e+00,  3.0051e+00,  ..., -3.4607e+00,
            5.2433e-01,  3.7942e-01],
          [-3.1768e-01, -4.0519e+00,  3.3414e+00,  ..., -4.1427e+00,
            3.7519e-01,  4.0434e-01]],

         [[ 6.0723e-02,  8.6226e-01, -6.4320e-01,  ..., -4.2258e-02,
            2.9290e-01,  8.5186e-03],
          [-8.5910e-01,  2.4301e+00, -2.0816e-03,  ..., -8.2199e-01,
           -1.9111e-01, -5.7361e-01],
          [-1.7353e-01, -1.0406e-02,  2.9922e+00,  ...,  1.7248e+00,
           -1.3871e+00, -9.8745e-01],
          ...,
          [ 1.4691e+00, -6.5255e-01,  1.0800e+00,  ...,  2.4355e+00,
            2.8071e-01, -9.0710e-01],
          [-4.0101e-01, -1.8545e-01,  1.2093e+00,  ...,  1.2111e-01,
            2.2229e-01, -8.8183e-01],
          [-1.4461e+00, -1.5290e+00, -7.7958e-01,  ..., -1.7141e-01,
            1.2369e+00,  1.6131e-01]],

         [[-3.1530e-01,  1.2470e-01, -9.8533e-01,  ..., -3.5174e-01,
           -6.0202e-02, -1.3882e-01],
          [-8.9693e-01,  2.8848e-01,  4.4991e+00,  ...,  5.4666e-01,
           -3.4291e-01,  2.7847e-01],
          [-5.8518e-01, -4.2858e-01,  3.9293e+00,  ..., -4.4146e-02,
           -1.5199e-01, -3.5618e-01],
          ...,
          [-8.3947e-01, -3.5598e-01,  2.7114e+00,  ...,  4.2037e-01,
            6.0501e-01,  1.4794e+00],
          [-1.1543e+00, -7.0097e-01,  3.8364e+00,  ...,  9.3257e-01,
            5.2762e-01, -4.1934e-02],
          [-2.1461e-01,  1.0343e-01,  2.4779e+00,  ..., -1.0749e+00,
           -1.1717e-01, -8.4882e-01]],

         ...,

         [[ 3.8411e-01,  8.0610e-02, -7.7487e-02,  ..., -3.1953e-02,
            2.1622e-01,  1.0605e-02],
          [-6.5354e-01, -3.1612e-02, -1.0638e+00,  ..., -5.8875e-01,
            1.0573e+00, -1.2601e-01],
          [ 2.7933e-01,  1.6202e+00,  5.9959e+00,  ..., -2.2502e+00,
           -4.2706e+00, -3.7783e-01],
          ...,
          [-1.3431e-01, -2.3244e+00, -1.0504e+00,  ..., -1.6168e+00,
           -3.2933e-01,  2.5949e-01],
          [ 7.4608e-01, -1.4878e+00, -6.3726e-01,  ..., -2.2942e+00,
           -9.9182e-02,  7.8056e-02],
          [-5.6625e-01, -1.2365e+00, -1.7069e+00,  ..., -3.6532e+00,
            5.9468e-01,  1.1633e+00]],

         [[ 2.0019e-01,  6.5883e-02,  3.2484e-01,  ...,  4.1552e-01,
            1.3110e-02,  2.2381e-01],
          [ 1.0125e+00,  2.4286e-01,  9.5984e-01,  ..., -4.9362e-01,
           -8.9649e-02,  4.7713e-01],
          [ 7.9506e-01,  3.0244e-01,  1.0567e+00,  ..., -8.4306e-01,
           -1.5668e+00,  3.5901e-01],
          ...,
          [ 5.9118e-01,  1.3707e-01,  1.1906e+00,  ..., -1.1724e+00,
           -1.5317e-01, -3.6017e-01],
          [ 1.5434e+00,  6.0157e-01,  9.3965e-01,  ..., -9.6186e-01,
           -1.3340e+00,  1.4163e+00],
          [ 1.2101e+00, -5.5004e-01,  1.9818e+00,  ..., -4.1132e-01,
            3.9219e-01,  1.3851e+00]],

         [[-3.0144e+00,  5.5559e-01,  5.5824e-01,  ..., -9.2879e-01,
            3.1234e-01,  2.1565e-01],
          [ 7.9342e+00, -1.3099e+00, -1.6237e+00,  ...,  2.3727e+00,
           -1.3910e-01,  2.4147e-01],
          [ 8.7548e+00,  3.1521e-01, -3.0962e+00,  ...,  2.1272e+00,
            2.3255e+00, -9.4848e-01],
          ...,
          [ 1.0275e+01, -4.5047e-01, -2.9359e+00,  ...,  1.8721e+00,
           -1.5067e+00, -5.9854e-01],
          [ 9.8252e+00, -1.5088e+00, -3.1082e+00,  ...,  1.3357e+00,
           -8.5859e-01, -1.5895e+00],
          [ 8.9484e+00, -3.9686e+00, -1.0389e+00,  ...,  2.1457e+00,
            1.6358e-01, -7.7658e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[ 5.1141e-02, -4.3904e-02,  1.0421e-02,  ..., -7.0718e-02,
            4.0156e-03, -8.7774e-02],
          [-3.6166e-02, -4.8795e-01,  9.8731e-03,  ...,  5.3657e-02,
           -3.1082e-01,  2.3718e-01],
          [ 2.2784e-01, -1.5567e-01, -3.7655e-01,  ..., -8.5816e-01,
           -2.3128e-01,  5.7785e-01],
          ...,
          [ 9.1042e-01,  1.2031e-01,  1.9260e-01,  ...,  4.4779e-01,
            9.7324e-01, -1.5580e+00],
          [-7.3678e-03, -7.9632e-03, -1.6174e+00,  ..., -2.6624e-01,
            1.3078e-01, -6.4462e-01],
          [-8.8761e-01,  4.9878e-01, -8.4630e-01,  ...,  4.2564e-01,
            6.2020e-01, -7.2440e-01]],

         [[ 6.9888e-02,  1.9044e-02, -2.6442e-02,  ..., -2.8091e-02,
            8.5949e-03, -8.9132e-03],
          [ 3.0326e-01,  9.8972e-01, -3.9176e-01,  ...,  7.6245e-02,
           -2.2104e-01,  4.6061e-01],
          [-8.6927e-01,  6.3877e-02,  1.8170e-01,  ..., -1.3091e-01,
           -3.3398e-02,  9.1734e-01],
          ...,
          [-4.1310e-01,  2.4178e-01, -1.0092e-02,  ...,  6.3579e-01,
           -8.6536e-01,  1.3382e+00],
          [-1.0261e+00,  2.0221e-01,  5.9535e-01,  ...,  5.4258e-02,
           -3.7277e-01,  6.7027e-01],
          [ 6.3411e-01,  2.9404e-01, -2.1462e-01,  ...,  5.9986e-01,
           -5.7667e-01,  2.7974e-01]],

         [[ 7.8181e-02,  1.2402e-02,  3.8624e-03,  ...,  2.5141e-02,
           -7.0518e-02, -6.1704e-02],
          [-1.8774e-02,  1.2058e-01, -3.6527e-01,  ..., -2.2870e-01,
            6.9500e-01,  1.5349e+00],
          [-1.6330e+00,  1.2746e+00,  9.4037e-01,  ..., -9.1634e-01,
           -3.8535e-01, -9.6021e-01],
          ...,
          [-1.0440e-01, -6.7721e-01, -9.5557e-01,  ..., -6.2216e-01,
            1.2931e-01,  5.3809e-01],
          [-1.6070e-01,  8.0107e-01, -6.2008e-01,  ..., -5.0174e-01,
            1.1805e+00,  9.2189e-01],
          [-1.6578e-01, -7.5677e-01, -4.0900e-01,  ..., -1.2848e+00,
            5.3773e-01, -8.9344e-01]],

         ...,

         [[ 4.7203e-03,  2.5398e-02,  1.7264e-02,  ..., -7.3335e-02,
           -1.9856e-02,  1.3880e-02],
          [-1.7961e+00, -1.7059e-01, -2.1769e-01,  ...,  5.8683e-01,
           -8.8736e-01, -1.7493e+00],
          [-5.4797e-01, -1.0059e+00, -6.7425e-01,  ...,  1.4185e-01,
            1.2526e+00, -4.5236e-01],
          ...,
          [-7.9344e-01, -6.5969e-01, -4.0526e-01,  ...,  6.3055e-01,
           -5.9347e-01, -1.1570e+00],
          [-2.6411e-01,  3.5041e-01,  9.5562e-01,  ...,  9.8356e-01,
            6.3222e-01, -1.1216e+00],
          [ 1.9708e-01, -8.2632e-01, -1.3375e-02,  ..., -6.5847e-01,
           -8.4580e-01, -1.7935e-03]],

         [[ 4.5729e-02, -1.4005e-02,  2.1403e-02,  ...,  3.1315e-02,
           -8.0189e-03,  2.6656e-03],
          [ 7.4114e-01,  1.2927e+00, -8.6742e-02,  ...,  4.5960e-01,
           -6.8074e-01,  1.6970e-01],
          [ 5.9978e-01,  1.2898e+00,  9.5936e-01,  ..., -5.8497e-01,
           -6.0349e-01,  6.4145e-01],
          ...,
          [ 3.8285e-01,  8.4960e-02,  4.7635e-01,  ..., -5.7194e-01,
           -3.3707e-01, -1.5005e-01],
          [-9.8847e-01, -8.6587e-01, -1.8648e-01,  ..., -1.0971e+00,
            2.6862e-01, -2.5423e-01],
          [-7.6857e-01, -3.6633e-02, -1.9877e+00,  ..., -4.6666e-02,
           -8.1418e-01, -2.1952e-01]],

         [[ 7.2741e-02, -1.9836e-01, -7.2501e-02,  ..., -3.1681e-02,
            1.9165e-01, -4.7206e-02],
          [-6.0899e-01, -1.0427e+00, -2.6415e-02,  ..., -6.7561e-01,
           -1.8160e-01,  1.0336e+00],
          [ 6.4349e-01,  5.6121e-02, -3.9025e-01,  ..., -9.6565e-01,
            1.0276e+00, -7.7055e-01],
          ...,
          [-1.0704e-01, -6.4438e-01, -1.1632e-01,  ...,  4.5279e-01,
           -3.9950e-01, -6.8103e-02],
          [ 8.0036e-02, -6.1467e-01,  1.7871e-01,  ..., -6.1554e-01,
            1.4653e-01, -1.2892e+00],
          [ 1.3268e+00, -1.8086e-01,  1.1401e+00,  ..., -1.1366e+00,
           -4.4702e-01, -4.2245e-02]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 1.0499e+00, -2.6094e-01, -1.4624e-01,  ...,  6.4066e-01,
            7.3481e-01, -3.1200e-01],
          [-4.3803e+00, -1.7161e+00,  1.2953e+00,  ...,  4.6987e-01,
           -5.0588e+00,  4.3037e-01],
          [-2.9824e+00, -1.4321e+00,  1.6383e+00,  ..., -1.0036e+00,
           -6.1600e+00, -6.0011e-01],
          ...,
          [-6.8617e+00, -4.6154e+00,  2.7333e-01,  ..., -8.7715e-01,
           -5.9364e+00, -1.4132e+00],
          [-6.7327e+00, -2.1619e+00,  8.2593e-02,  ...,  2.3139e-01,
           -4.8873e+00, -1.5720e+00],
          [-6.2051e+00, -3.0823e+00,  8.0622e-01,  ...,  4.3842e-01,
           -5.3239e+00, -3.8893e-01]],

         [[-1.4052e-01, -6.4056e-02,  1.6708e-01,  ..., -5.0385e-02,
           -8.8839e-01, -1.9052e-01],
          [ 2.1709e-01,  1.1619e+00,  5.7067e-01,  ..., -1.3953e+00,
            3.2450e-01,  1.2646e+00],
          [-1.5323e+00,  1.3859e+00,  1.5808e+00,  ...,  8.2068e-01,
            7.3862e-01,  1.8539e+00],
          ...,
          [ 6.0382e-01, -1.1719e+00,  2.8846e-01,  ...,  5.3128e-01,
           -1.8594e+00,  1.5157e-01],
          [-9.1678e-01, -4.0683e-01,  6.5201e-01,  ..., -3.7854e-01,
            3.6177e-01,  7.1991e-01],
          [-1.1788e+00, -8.7759e-01, -1.1967e+00,  ..., -3.4275e-01,
           -4.2240e-01,  4.9525e-01]],

         [[ 1.9705e-01,  3.1511e-01,  1.1296e+00,  ..., -4.5656e-01,
            4.3585e-01, -4.9803e-01],
          [-2.4683e-01, -1.8572e+00, -1.1624e+00,  ..., -3.6835e-01,
           -1.8631e+00,  2.6684e+00],
          [-8.0515e-01,  1.7749e-01, -2.0522e+00,  ..., -9.9005e-01,
           -1.4308e+00,  2.0506e+00],
          ...,
          [ 1.7602e+00, -9.8673e-01, -3.5310e+00,  ..., -4.7019e-01,
           -3.3827e+00,  2.8605e+00],
          [-7.9944e-01, -8.1130e-01, -2.5108e+00,  ..., -4.9938e-02,
           -1.6870e+00,  2.8972e+00],
          [-1.8172e-01, -1.3924e+00, -3.6942e+00,  ...,  1.2416e+00,
           -1.7045e+00,  3.0463e+00]],

         ...,

         [[-1.5123e-01,  7.1820e-02, -2.2811e-01,  ...,  4.1857e-04,
            1.5596e-01,  2.9428e-02],
          [-2.7472e+00,  8.9720e-01, -8.9011e-01,  ...,  2.0976e+00,
            4.3475e-01,  3.6547e-01],
          [-3.9650e+00, -1.7982e+00,  1.3452e+00,  ..., -4.0594e-03,
            1.0195e+00,  8.0354e-01],
          ...,
          [-9.3269e-01, -1.4746e-01, -1.1404e+00,  ...,  1.7352e+00,
           -3.7063e-01, -1.1835e+00],
          [-1.5874e+00, -4.5612e-01,  5.5681e-01,  ...,  1.8682e+00,
            4.1865e-01, -1.6453e+00],
          [-3.8839e-01,  1.3997e-01,  6.3983e-01,  ...,  4.3045e-01,
           -1.2466e+00, -2.4118e-01]],

         [[-3.5156e-01, -2.1933e+00,  1.1372e-01,  ..., -8.6831e-02,
           -4.2204e-02,  9.1871e-01],
          [ 1.6741e+00,  2.0308e+00,  2.3405e-02,  ...,  6.5852e-01,
           -1.0524e+00,  1.5907e-01],
          [ 1.9101e-01,  5.3070e+00, -5.7532e-01,  ..., -3.4579e+00,
           -2.8675e+00, -2.0603e+00],
          ...,
          [-6.3685e-01,  3.5661e+00, -5.8024e-01,  ...,  1.1246e-01,
           -3.5499e+00,  4.4231e-01],
          [ 1.2368e+00,  2.0089e+00, -5.1367e-01,  ..., -3.6685e-01,
           -1.5970e+00,  6.9527e-01],
          [ 1.7322e+00,  6.2705e+00, -2.0639e+00,  ..., -1.0150e-01,
           -2.8591e+00, -1.2347e+00]],

         [[ 3.6793e-01,  7.5982e-02, -1.3786e-01,  ...,  6.4646e-01,
            1.3555e-01,  2.5748e-01],
          [ 1.0353e-01,  3.1763e-02, -7.6792e-01,  ..., -4.7279e-01,
            6.0120e-01, -5.2789e-01],
          [-1.2707e+00, -1.6400e+00,  1.3822e+00,  ..., -1.4171e+00,
            1.8270e+00, -1.3366e+00],
          ...,
          [-1.9160e+00, -1.9547e-01,  1.0904e+00,  ..., -1.0474e+00,
           -4.5762e-01, -6.0643e-01],
          [-1.0998e+00, -9.2082e-01, -2.4341e-01,  ..., -2.0234e-01,
           -1.1176e-01, -1.7190e+00],
          [-2.0762e+00,  5.8654e-01,  2.3925e-01,  ..., -2.1948e+00,
           -7.5425e-01, -1.1309e+00]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[-3.6090e-02,  5.8404e-02, -4.7430e-02,  ..., -1.6448e-02,
            6.5290e-03,  2.0813e-02],
          [ 3.3981e-01, -2.0113e-01, -4.2505e-01,  ...,  2.1069e-01,
           -2.2217e-01,  1.4801e-02],
          [ 5.1794e-01, -2.6169e-01, -2.3763e-01,  ...,  7.0568e-01,
            5.7194e-01, -3.8482e-01],
          ...,
          [ 2.5554e-01,  1.4591e-01,  6.5668e-01,  ..., -9.7976e-01,
            7.0952e-01, -5.4556e-01],
          [ 4.0079e-01, -5.0342e-03, -7.9301e-01,  ...,  1.2865e-02,
           -1.8381e-01, -3.0065e-02],
          [-2.7750e-02, -1.1973e+00, -6.5840e-01,  ..., -2.8713e-01,
           -1.2795e-01, -5.5134e-01]],

         [[ 1.0002e-02, -2.6317e-02,  2.8788e-02,  ...,  2.0878e-02,
           -5.0190e-02,  1.7778e-02],
          [ 7.7976e-01,  3.3234e-01, -6.3060e-01,  ...,  9.9832e-02,
           -2.1819e-01,  6.5724e-01],
          [ 9.6898e-01, -9.7757e-02, -1.4808e-01,  ..., -8.0822e-03,
           -5.7036e-01,  2.0158e-01],
          ...,
          [ 1.8208e-01,  3.9556e-01, -9.8637e-01,  ...,  7.4350e-02,
            9.4404e-01,  1.4957e-01],
          [-4.6217e-01,  1.1374e+00, -1.9512e+00,  ..., -1.6531e-01,
            7.4679e-01, -1.0709e-01],
          [ 8.1311e-02,  9.7903e-01, -1.7733e+00,  ..., -1.1529e-03,
            2.6891e-01, -2.6389e-01]],

         [[ 2.8050e-02, -3.8394e-02,  4.2667e-02,  ...,  1.8323e-02,
           -7.2851e-03,  6.6780e-03],
          [ 8.1693e-03, -6.7934e-01, -1.6569e+00,  ...,  2.1917e+00,
           -8.0082e-01,  2.0747e+00],
          [-9.8902e-02, -9.6721e-01, -1.9402e+00,  ...,  2.0296e-01,
           -1.9917e+00, -5.7315e-01],
          ...,
          [-1.3532e+00, -4.1404e-01, -9.7093e-01,  ..., -2.2521e-01,
           -1.0802e+00,  1.7622e-01],
          [-8.9913e-01,  8.1342e-01, -4.2660e-01,  ..., -4.5484e-01,
           -4.4985e-01, -5.2983e-01],
          [ 1.4025e+00,  1.8062e-01, -1.2818e+00,  ...,  1.9920e-01,
           -5.2821e-01, -2.5628e-01]],

         ...,

         [[-1.8594e-01,  9.0682e-02,  4.9690e-02,  ...,  4.3440e-02,
            3.8206e-02, -1.2885e-01],
          [ 7.9794e-01,  6.0255e-01, -9.1034e-01,  ...,  1.2564e+00,
            2.1356e+00,  1.4350e+00],
          [ 1.6348e+00,  2.4815e+00, -7.9229e-01,  ...,  4.7704e-02,
            4.2612e-01,  4.7895e-01],
          ...,
          [-8.7949e-02,  5.8050e-01,  8.1942e-01,  ..., -6.4081e-01,
           -8.1574e-01,  3.7447e-01],
          [ 2.8842e-01,  4.7873e-01,  1.0562e-01,  ..., -3.5749e-01,
            4.3560e-01, -3.4558e-01],
          [ 5.2590e-01,  2.0273e-02, -2.1104e-01,  ...,  4.7948e-01,
           -4.2571e-01, -1.4939e+00]],

         [[-5.9095e-01, -5.1091e-03,  4.9711e-02,  ..., -6.8233e-03,
            1.0263e-02, -1.0583e-03],
          [-9.2941e-01, -5.9255e-01,  1.0062e+00,  ..., -3.8768e-01,
           -5.7188e-01,  8.9663e-01],
          [-3.1489e+00, -1.1034e-01, -4.2379e-01,  ...,  2.2262e+00,
            1.2680e+00,  2.7568e-01],
          ...,
          [-1.5769e+00,  1.9565e-01, -5.7385e-01,  ...,  1.3197e+00,
           -1.3651e-01, -3.4665e-01],
          [-7.1268e-01,  9.8680e-01, -7.7558e-01,  ...,  7.7028e-01,
            1.1927e-03, -6.3204e-01],
          [-1.0811e+00,  1.0461e+00,  1.1175e+00,  ...,  6.5761e-01,
           -1.2024e+00, -8.7618e-01]],

         [[ 4.1802e-03,  8.6289e-02, -4.7156e-02,  ...,  6.4784e-02,
            3.7624e-02, -3.6941e-02],
          [ 3.8048e-01, -1.1929e-01, -7.9124e-01,  ..., -4.4751e-01,
           -1.1959e+00, -4.3578e-01],
          [ 2.4210e-01, -7.7073e-01, -3.8349e-01,  ..., -7.1458e-01,
           -8.4916e-01,  5.2514e-01],
          ...,
          [ 1.4012e+00,  2.0892e-01, -2.9313e-01,  ..., -2.8728e-01,
           -7.7548e-01, -3.9739e-01],
          [ 5.1188e-01, -8.7199e-01,  8.7336e-01,  ..., -4.0334e-01,
           -3.4130e+00, -6.1112e-01],
          [-6.1907e-01, -8.8935e-01, -1.0951e-01,  ..., -7.4700e-01,
           -2.2108e+00, -3.7028e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[-2.6786e-02, -2.3484e+00,  1.7383e-01,  ..., -2.3387e-01,
           -1.8498e-01,  8.0406e-02],
          [-6.3816e-01,  4.5202e+00, -4.8769e-01,  ..., -6.2377e-01,
           -1.1104e+00,  2.6615e-01],
          [ 1.6524e-01,  5.3601e+00, -1.5751e+00,  ...,  5.1479e-01,
            1.4985e-01,  7.5970e-01],
          ...,
          [ 1.1661e+00,  4.1413e+00,  8.5835e-01,  ..., -1.3668e+00,
            1.1824e+00, -3.7819e-01],
          [ 8.8634e-02,  4.8891e+00, -6.3595e-01,  ..., -1.6221e-01,
           -6.4306e-01,  1.7413e-01],
          [-2.2818e-01,  4.9449e+00, -1.4196e-01,  ...,  2.3876e-01,
           -7.2030e-01, -4.8695e-01]],

         [[-8.1155e-01,  2.2757e-01,  4.6853e-01,  ..., -5.2139e-01,
            1.0647e+00,  1.1201e+00],
          [ 7.4454e-01, -6.9863e-02, -1.2254e+00,  ..., -2.5475e-01,
            1.8946e+00, -6.1198e-01],
          [-4.1235e-01,  1.7778e-01, -8.0618e-01,  ..., -1.3256e+00,
            7.5310e-01,  1.6034e-01],
          ...,
          [-9.3006e-01,  7.3088e-01,  9.9920e-01,  ..., -4.3810e-01,
           -6.7822e-01, -1.3033e+00],
          [-6.1311e-01, -1.2247e-01, -7.6474e-01,  ...,  9.3844e-02,
            1.2695e+00, -1.8101e+00],
          [-3.0981e-01, -3.6292e-01, -7.4754e-01,  ...,  4.7145e-01,
            1.9980e+00, -1.5107e+00]],

         [[-8.6583e-01,  4.7891e-01,  2.1436e-02,  ...,  4.8864e-01,
           -2.2923e-01,  1.1553e+00],
          [ 2.7614e+00,  6.3526e-01,  1.0754e+00,  ...,  2.0836e+00,
            4.9057e-01,  1.7316e+00],
          [ 1.6238e+00,  1.3626e+00,  2.4669e-01,  ...,  1.3097e+00,
            7.0308e-01,  7.0389e-01],
          ...,
          [ 5.1406e-01, -7.3369e-02, -5.0398e-01,  ..., -4.1342e-01,
            9.2301e-01,  2.7006e-01],
          [ 8.8824e-01, -1.9462e-01, -2.1208e-01,  ...,  4.7626e-01,
            1.5474e+00, -2.8075e-01],
          [ 1.4880e+00,  9.3271e-01,  1.3420e+00,  ...,  1.0434e+00,
            4.5726e-01,  5.6136e-01]],

         ...,

         [[-3.0037e-01, -1.3118e-01,  1.4130e-01,  ...,  1.9293e-01,
            1.7425e+00, -2.8649e+00],
          [-1.0344e+00,  2.0608e-01, -3.1066e-01,  ..., -6.3949e-01,
           -4.2152e+00,  4.7219e+00],
          [-1.0119e+00, -3.4674e-01,  6.4960e-01,  ..., -1.1140e+00,
           -5.8159e+00,  5.5243e+00],
          ...,
          [ 9.9599e-01,  7.7899e-01,  6.0444e-01,  ..., -2.3092e-01,
           -5.5457e+00,  5.7162e+00],
          [-1.0503e-01, -3.1278e-02,  2.6437e-01,  ..., -7.0391e-01,
           -5.0515e+00,  4.5615e+00],
          [-2.0179e+00, -6.2885e-01,  4.4828e-01,  ..., -8.9845e-01,
           -5.6423e+00,  5.0915e+00]],

         [[ 1.8530e-01,  3.6999e-01,  2.1407e-01,  ..., -2.1567e-01,
            3.0667e-02, -1.3970e-01],
          [-4.0770e-01, -1.6711e+00, -1.0513e+00,  ...,  1.9773e-01,
            1.9950e-01,  1.3481e+00],
          [-1.2989e+00, -2.9822e-01, -7.9525e-01,  ..., -2.1794e-02,
           -3.6022e-02,  1.3295e+00],
          ...,
          [ 3.4041e-02, -1.3368e+00,  3.7133e-01,  ..., -1.5534e-01,
            4.7558e-01,  1.7969e+00],
          [-2.4719e-01, -2.1765e+00,  9.6421e-01,  ..., -1.7622e-02,
            8.2257e-01,  2.6881e+00],
          [-1.2455e+00, -1.3100e+00,  1.0811e+00,  ..., -5.0272e-01,
            5.9302e-01,  2.6775e+00]],

         [[ 3.7035e-01,  1.1777e-01,  6.1432e-01,  ...,  5.2012e-01,
            5.8044e-01, -3.3140e-01],
          [ 9.5881e-02, -1.2821e+00, -6.4381e-01,  ..., -2.1402e+00,
           -2.8265e+00,  7.5561e-01],
          [ 5.9308e-01,  1.6217e+00, -1.8073e+00,  ..., -1.3962e+00,
           -5.3295e+00,  3.7561e-01],
          ...,
          [ 3.4402e-01,  7.1351e-04, -1.0847e+00,  ..., -1.5327e+00,
           -4.7644e+00,  1.2317e+00],
          [ 2.0750e-01,  4.2706e-01, -8.2447e-01,  ..., -2.4184e+00,
           -4.3488e+00,  5.6796e-01],
          [ 2.1829e-01,  1.2217e+00, -7.6039e-01,  ..., -1.6768e+00,
           -5.0688e+00,  1.1125e+00]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[ 6.5136e-02, -1.0029e-02, -2.2108e-02,  ...,  1.1689e-01,
           -7.1039e-02, -3.8524e-02],
          [ 6.3169e-01,  5.2662e-01, -3.0022e-01,  ..., -2.2283e-01,
            2.2716e-01,  8.7254e-01],
          [-1.1552e+00, -9.5870e-01,  1.1238e+00,  ...,  8.6759e-01,
            1.5205e+00,  3.0764e-01],
          ...,
          [ 8.4761e-01, -1.3798e+00, -1.2645e+00,  ...,  3.3378e-01,
            2.3198e-01,  3.6541e-02],
          [ 3.6719e-03,  5.9048e-01,  2.4346e+00,  ...,  5.0436e-01,
            5.6727e-02, -1.4954e-01],
          [-1.0332e-01, -1.7860e-01,  5.9478e-01,  ...,  1.8129e-01,
            1.1238e-01,  5.6843e-01]],

         [[ 1.1602e-02,  2.9395e-02,  5.1380e-02,  ..., -4.0433e-04,
            5.4506e-03, -5.0795e-03],
          [-8.1963e-01,  1.2248e+00,  2.0325e+00,  ..., -4.0260e-01,
            1.3450e+00, -6.8389e-02],
          [ 4.5896e-01, -1.3906e+00,  1.5941e+00,  ..., -1.3976e+00,
            3.3245e-01,  5.1057e-01],
          ...,
          [ 3.7079e-01, -1.0431e+00,  7.7085e-01,  ..., -7.9126e-01,
            1.4188e+00, -2.7265e-01],
          [-1.4974e+00, -1.0762e+00, -1.5428e-01,  ...,  5.5193e-01,
            6.3376e-01, -5.6096e-01],
          [ 1.4492e+00, -3.8667e-01, -3.8344e-01,  ...,  2.6724e-01,
            5.4262e-01, -4.2422e-01]],

         [[ 5.2304e-02, -3.7324e-02,  6.3114e-02,  ...,  5.8532e-02,
           -6.3338e-02, -5.8268e-02],
          [-7.0165e-01,  5.3844e-01, -5.8273e-01,  ..., -5.1593e-01,
           -6.2940e-01, -1.8096e-01],
          [-2.7908e+00,  1.0558e-01, -5.3334e-01,  ..., -1.5726e-01,
            2.0836e-01,  6.0709e-01],
          ...,
          [-5.1843e-02,  5.4624e-01,  2.0531e-01,  ...,  1.7076e+00,
            5.7584e-01,  6.2327e-01],
          [-3.6931e-02, -7.4176e-01,  3.6640e-01,  ...,  7.2651e-01,
            7.1034e-01, -5.0419e-01],
          [-1.2319e-01,  1.2123e-01, -4.3683e-01,  ..., -2.6635e-01,
            1.0085e-01, -5.3599e-01]],

         ...,

         [[-9.0498e-02, -4.4621e-02,  4.2299e-02,  ..., -9.2541e-02,
            3.0748e-02,  1.1883e-02],
          [-6.4144e-01,  1.2397e+00, -1.8991e-01,  ...,  6.5226e-01,
            1.0988e+00,  8.0597e-01],
          [ 1.4071e-02,  1.8777e+00, -5.3901e-01,  ...,  6.4912e-01,
            5.2703e-01, -1.2243e+00],
          ...,
          [ 2.1232e-01,  1.3418e+00, -1.7877e-01,  ..., -1.4944e+00,
           -6.7317e-02,  8.2103e-01],
          [-9.6293e-01, -1.5243e+00,  1.1825e-01,  ...,  8.9055e-01,
           -2.7738e-01,  1.1285e-01],
          [-6.8869e-01, -6.7001e-01, -1.1208e+00,  ...,  8.9605e-01,
            1.1449e+00,  1.4222e-01]],

         [[ 1.4738e-01, -6.2497e-02,  1.3729e-01,  ...,  6.0634e-02,
            3.2148e-02, -1.2945e-01],
          [ 7.4502e-01,  3.8916e-01,  8.7663e-02,  ...,  9.7815e-01,
           -1.7061e+00,  6.5918e-01],
          [ 4.5599e-01, -8.9115e-01, -1.1831e+00,  ...,  6.0562e-01,
           -1.3010e+00,  5.5591e-01],
          ...,
          [ 5.1827e-02,  9.4449e-01, -1.0665e+00,  ..., -5.5041e-01,
            1.6840e+00,  3.5078e-01],
          [-2.1572e-01,  1.8307e+00, -1.0656e+00,  ..., -2.8209e-01,
            1.1047e+00, -3.8316e-01],
          [ 3.0666e-01,  1.2266e-01, -4.6971e-01,  ...,  7.9809e-01,
            7.9670e-01,  9.7310e-01]],

         [[ 2.0906e-01, -4.6573e-02, -6.0253e-02,  ...,  3.5716e-02,
            4.1975e-02,  2.2183e-02],
          [ 1.1508e+00, -4.2915e-02,  1.7938e+00,  ..., -2.8495e+00,
           -1.4776e+00,  1.4759e-01],
          [ 1.6741e-01, -6.4878e-01,  1.5115e+00,  ..., -2.5187e+00,
           -2.1148e+00,  4.7344e-01],
          ...,
          [-6.9654e-01, -5.8817e-01,  1.1112e+00,  ..., -5.3462e-01,
            2.5145e-01,  1.2652e+00],
          [ 4.9990e-01, -6.9669e-01, -7.5295e-01,  ..., -7.6324e-01,
            1.9860e-01, -1.0063e+00],
          [ 8.3529e-02, -2.6978e-01,  5.3723e-01,  ..., -6.3546e-01,
           -1.3855e+00,  1.8679e+00]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[ 2.5910e-02, -2.6886e-01, -4.6291e-01,  ...,  2.9839e-01,
            3.2603e-01,  3.5748e-01],
          [ 3.1461e-01,  1.1503e+00, -1.2219e+00,  ..., -5.6552e-01,
            5.9567e-02,  8.4989e-02],
          [ 6.7203e-01,  2.6852e+00, -1.6149e+00,  ..., -4.6337e-01,
           -1.2918e-01, -9.1665e-01],
          ...,
          [-2.3362e-02, -1.1149e+00, -1.5101e-01,  ..., -8.8944e-01,
           -3.1759e-01, -1.8159e+00],
          [ 1.3122e+00, -1.0426e+00, -1.0986e+00,  ...,  3.0899e-01,
           -8.2710e-01,  4.9917e-01],
          [ 9.7887e-01, -1.6260e-01, -3.5367e-01,  ..., -3.7958e-01,
           -5.1196e-01,  9.1400e-02]],

         [[-2.8813e-01,  1.6128e-01,  1.1108e-01,  ...,  5.3097e-02,
           -1.1501e+00, -1.5568e-01],
          [ 2.0048e-01, -6.1896e-01, -9.0740e-01,  ...,  1.5258e+00,
            5.4791e-02,  1.0740e+00],
          [ 1.5828e+00,  2.4231e+00,  3.0855e-01,  ...,  7.5064e-01,
            1.4106e+00, -2.2212e-01],
          ...,
          [-1.0661e+00,  1.2544e+00,  2.0837e-03,  ..., -8.5617e-01,
           -8.1316e-01,  4.9132e-01],
          [-2.6743e-02,  5.4352e-01, -2.2571e+00,  ..., -1.3261e+00,
            3.2357e-01,  7.3165e-01],
          [-7.4855e-01,  1.7948e+00, -1.0619e+00,  ..., -1.2624e+00,
            2.3455e-01,  1.4526e-01]],

         [[-1.2375e+00, -1.2080e-01,  5.5275e-01,  ..., -6.6115e-01,
            4.6177e-01, -2.7181e-01],
          [ 2.1969e+00, -1.5818e-01, -5.5458e-01,  ...,  1.1694e-01,
            2.1649e-02,  5.6354e-01],
          [ 2.0254e+00,  6.7081e-01, -4.4748e-01,  ...,  1.4967e+00,
           -1.2851e+00,  6.7248e-01],
          ...,
          [ 1.5326e+00, -3.3540e-01, -1.4324e+00,  ...,  1.2954e+00,
           -1.1284e+00,  4.7197e-01],
          [ 8.7941e-01,  2.4923e+00, -5.3935e-01,  ...,  8.4928e-01,
            1.0220e+00,  3.7455e-01],
          [ 5.1570e-01,  1.6997e+00, -5.6758e-02,  ...,  1.1494e+00,
           -6.0428e-01, -1.0172e+00]],

         ...,

         [[ 7.9977e-01, -9.0907e-01, -3.9427e-01,  ..., -1.0356e+00,
           -4.3145e-01,  4.9180e-01],
          [-1.4421e-01, -8.1030e-01, -1.1014e+00,  ...,  1.0065e+00,
           -9.9178e-01, -2.7286e+00],
          [ 2.2822e+00, -1.0592e+00, -3.2463e+00,  ...,  3.9767e-01,
           -1.2019e-01, -1.5965e+00],
          ...,
          [ 1.5727e+00, -1.2017e+00, -1.8086e+00,  ...,  2.2871e+00,
           -9.3960e-01, -5.4853e-01],
          [ 1.1026e+00, -1.0645e+00, -1.2469e+00,  ...,  1.6859e+00,
            1.0864e+00, -1.8350e+00],
          [ 9.6080e-01, -1.1839e+00, -5.1685e-01,  ...,  1.9371e+00,
            5.2251e-01, -6.6630e-01]],

         [[-8.9759e-01,  2.5559e+00,  3.0324e-01,  ...,  3.4015e-01,
            1.9589e+00, -5.4084e-01],
          [ 8.5055e-01, -2.9616e+00,  1.0523e+00,  ...,  3.1944e-01,
           -4.3024e+00, -5.3262e-01],
          [-2.1991e-01, -2.6238e+00, -5.0961e-01,  ..., -1.0572e-02,
           -5.2822e+00,  5.7309e-01],
          ...,
          [ 1.7774e+00, -3.0262e+00, -6.2749e-02,  ..., -3.9953e-01,
           -3.5927e+00, -1.8122e+00],
          [ 3.4358e-01, -2.7537e+00,  5.5012e-01,  ...,  9.0664e-01,
           -5.0459e+00, -5.0702e-01],
          [ 8.0552e-01, -3.1555e+00, -2.6789e-01,  ..., -6.1371e-01,
           -4.8818e+00, -4.0512e-01]],

         [[-2.0290e+00, -3.6629e-01, -1.1161e+00,  ..., -4.0021e-01,
            6.4893e-02,  2.4993e-01],
          [ 1.8800e+00,  6.3801e-01,  1.8007e+00,  ...,  1.2802e+00,
           -7.8783e-01,  9.2138e-01],
          [ 1.2802e+00,  4.2201e+00,  1.5276e+00,  ...,  2.1099e-01,
            3.6180e-03, -4.3861e-01],
          ...,
          [ 2.9910e+00,  1.1778e+00,  5.4143e-01,  ..., -6.6982e-01,
            6.7589e-01, -8.9655e-02],
          [ 2.0554e+00,  1.6939e+00,  1.7272e+00,  ...,  1.1095e-01,
           -4.6233e-01,  9.0875e-02],
          [ 3.3871e+00,  8.2065e-02,  5.6282e-01,  ..., -1.0472e-01,
           -6.4104e-03,  8.9617e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>), tensor([[[[-4.9252e-02, -8.7662e-02,  1.9245e-02,  ...,  1.1583e-01,
           -2.2480e-02,  2.4090e-02],
          [ 7.3998e-01,  9.2088e-02, -1.2068e+00,  ..., -1.1689e+00,
            9.8164e-01, -3.2493e-01],
          [ 1.8654e+00,  1.8521e+00, -1.1428e+00,  ..., -7.8538e-01,
           -1.3162e-01, -9.1184e-01],
          ...,
          [ 5.7013e-01,  3.7659e-02,  3.2336e-01,  ..., -6.8518e-01,
            7.8144e-01,  3.1578e-01],
          [ 3.5704e-01, -1.2859e+00,  8.4548e-01,  ...,  7.8927e-01,
           -4.1751e-01, -4.3254e-01],
          [-7.8941e-02,  4.2914e-01,  2.8057e-01,  ...,  1.4775e+00,
           -4.4810e-01, -3.5821e-01]],

         [[ 1.7039e-02,  1.7426e-02, -3.8615e-02,  ...,  3.8622e-02,
            1.0783e-02,  3.5067e-02],
          [-1.8710e+00,  2.5931e-01,  1.9364e+00,  ...,  1.2818e+00,
           -2.1004e+00,  8.0968e-01],
          [-2.0571e+00,  2.1252e+00,  7.7616e-01,  ..., -2.1614e-01,
            3.0920e-01, -1.3792e-01],
          ...,
          [ 5.2768e-01,  7.6009e-01,  1.3738e-01,  ...,  7.8285e-02,
           -1.0129e+00, -1.0930e+00],
          [-1.6376e+00,  1.4974e-01, -3.4155e-01,  ..., -5.2083e-01,
           -3.6068e-01, -6.0488e-01],
          [ 5.6162e-01, -1.3109e-01, -1.0858e+00,  ..., -5.4658e-01,
            3.5207e-01,  2.5588e-01]],

         [[ 3.8827e-02,  3.3007e-02, -8.3370e-02,  ...,  2.5783e-03,
           -1.6268e-02,  9.1886e-04],
          [ 6.7396e-01, -7.4608e-01, -1.1437e+00,  ..., -4.9611e-01,
            1.8052e-02, -9.8371e-01],
          [ 1.3927e+00,  2.0229e-01,  1.1575e+00,  ...,  9.9906e-01,
            6.1328e-01,  1.2476e+00],
          ...,
          [ 4.2289e-01,  2.9166e-01, -1.7195e-01,  ..., -9.1719e-01,
           -2.4385e-01, -4.0700e-01],
          [ 7.7449e-01,  1.9776e+00,  7.9667e-01,  ..., -9.7234e-02,
           -1.0700e+00, -2.4079e-01],
          [ 6.5026e-01,  3.4379e-01,  1.9756e+00,  ..., -7.0553e-01,
           -8.3279e-01,  5.4796e-01]],

         ...,

         [[-2.6492e-02,  2.4856e-02, -1.0858e-02,  ..., -1.9138e-02,
           -1.6549e-02,  2.3261e-02],
          [ 1.8255e+00,  2.1934e+00,  6.5594e-01,  ...,  2.6818e+00,
            8.3834e-01, -6.3335e-01],
          [-4.0868e+00,  2.9853e+00,  1.5028e+00,  ...,  2.8890e-01,
            9.7505e-01, -3.2399e+00],
          ...,
          [-7.2338e-01, -1.4377e+00, -5.4115e-01,  ...,  6.3100e-01,
            1.2142e+00, -9.7212e-01],
          [ 1.6900e-02,  2.4184e-02, -7.9526e-01,  ...,  1.2693e+00,
            7.9554e-01, -9.4520e-01],
          [ 4.1660e-01,  6.7636e-01, -7.9182e-01,  ...,  4.9160e-01,
           -5.0815e-02, -6.8480e-02]],

         [[-6.4021e-02, -3.2721e-02,  2.4496e-02,  ...,  2.8996e-03,
           -5.4889e-02, -1.0845e-01],
          [-4.6883e-01,  1.5272e-01,  9.2087e-02,  ...,  8.4295e-01,
           -2.0841e+00, -3.0567e-02],
          [ 3.6565e-01, -1.4821e+00,  9.6754e-01,  ..., -9.3504e-02,
           -8.7716e-01, -4.0122e-01],
          ...,
          [-4.2841e-01,  1.2406e+00, -4.3889e-01,  ...,  4.2291e-01,
           -2.8603e-01, -8.7372e-01],
          [ 1.4675e-01,  1.5823e+00, -6.8927e-01,  ...,  9.1472e-02,
           -2.9227e-01,  1.0596e-01],
          [-9.5282e-01,  1.2839e+00,  2.9624e-01,  ..., -3.9504e-01,
            1.3045e+00, -3.3763e-01]],

         [[-3.9009e-04,  4.5721e-02, -4.4665e-02,  ..., -1.7350e-02,
            1.0824e-02,  2.1425e-03],
          [-1.1021e+00, -1.0736e+00, -1.3544e-01,  ..., -6.8339e-01,
            2.5608e-01, -1.5562e+00],
          [-1.9336e+00, -2.8961e-01, -1.8405e-01,  ...,  1.8993e+00,
            4.5613e-02, -1.1907e+00],
          ...,
          [-3.1445e-01, -6.2644e-02,  3.0309e-01,  ...,  1.1828e+00,
           -1.3786e-02, -7.7986e-01],
          [-1.2743e+00,  4.1847e-01,  1.0675e-01,  ...,  1.1067e+00,
            1.1204e+00, -5.8878e-01],
          [ 5.5352e-01, -1.4446e+00,  6.2596e-01,  ..., -2.2840e-01,
            3.1991e-01, -1.0959e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[-0.5113,  0.4863, -0.8478,  ..., -1.0067, -1.3222,  0.2112],
          [-0.2776,  0.8382,  1.8349,  ...,  2.3225,  0.6648,  0.2540],
          [ 1.3453, -0.5017,  1.2809,  ...,  4.1645, -1.7425, -0.8501],
          ...,
          [ 1.4667, -0.1365,  2.8200,  ...,  2.1293,  0.9827, -1.4491],
          [ 0.9350,  1.2324,  1.3160,  ...,  2.0101, -0.9095,  0.8805],
          [ 2.2349, -0.3502,  2.3561,  ...,  4.5559,  0.2889, -0.6012]],

         [[ 0.8578, -2.0592,  0.1552,  ...,  0.2365, -2.4501, -0.4252],
          [-0.5257,  2.9826, -0.0066,  ...,  1.6393, -0.1439, -0.3203],
          [ 0.1126,  6.1312, -1.8119,  ..., -1.3806,  1.5269,  0.0770],
          ...,
          [ 1.8046,  3.1371, -0.0435,  ..., -1.1984,  4.0211, -0.3047],
          [ 1.9017,  3.2856,  0.5921,  ..., -0.0992,  1.5096, -0.1924],
          [ 1.6355,  2.5339, -1.0278,  ...,  0.2948,  2.4720, -0.0789]],

         [[ 1.0126,  0.3729, -0.1703,  ..., -0.8086, -1.4068, -0.3650],
          [-0.7648, -1.0051,  0.2495,  ...,  0.6701, -0.5403, -0.2351],
          [-1.9147, -0.8003,  0.3430,  ...,  0.9768, -0.6854,  0.0126],
          ...,
          [-1.7370, -0.6033, -1.3323,  ...,  1.1623, -1.2578, -1.9024],
          [-1.4760, -0.7168, -0.2280,  ...,  1.5871, -0.8728, -0.5783],
          [-0.7148, -0.8849, -1.7579,  ...,  1.1032, -1.6794, -1.3470]],

         ...,

         [[ 0.2134, -0.5814,  0.4874,  ..., -0.6810,  1.0950,  0.3049],
          [-0.6622,  0.7595, -1.3103,  ..., -2.7334, -1.3875,  0.5355],
          [-0.9839,  0.5835, -2.7718,  ..., -0.7272, -0.6065,  2.6364],
          ...,
          [-0.9617,  1.3015, -1.3307,  ..., -2.1932, -0.2873,  0.0981],
          [-1.4241,  0.0142, -1.4190,  ..., -3.0027, -1.1089, -0.0661],
          [-1.7032,  0.8917, -1.3572,  ..., -3.3537,  1.2818,  0.3163]],

         [[ 0.2561,  0.5523,  0.5580,  ...,  0.7246,  0.0652,  0.8349],
          [-1.2620,  0.3286,  0.9976,  ...,  0.6765,  0.7315,  1.6506],
          [ 0.9963, -0.4575, -0.9883,  ...,  0.1085, -0.5890,  2.3868],
          ...,
          [-0.1060,  1.5261,  0.9926,  ..., -2.7033, -0.0864, -0.1516],
          [-3.7611, -0.0674,  1.6440,  ..., -0.0794,  0.9523,  1.6051],
          [ 0.3993,  0.0166,  0.7194,  ..., -0.1722,  0.4801,  0.2899]],

         [[-0.7296,  0.3057, -1.5837,  ..., -0.3948,  0.2267, -1.3335],
          [ 1.0017, -0.9907, -1.0231,  ...,  0.4469,  0.4642, -0.7319],
          [ 1.4362, -0.7307, -0.9360,  ..., -2.7721, -2.5525, -2.4089],
          ...,
          [-1.0814,  0.4487,  0.2135,  ...,  2.1611,  3.0587, -1.7137],
          [-0.4406, -0.4945, -0.1256,  ...,  1.1526,  1.1095,  0.3055],
          [-0.2776,  1.5215, -0.5332,  ...,  1.1829,  2.3322, -1.0183]]]],
       device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[-3.1461e-03,  4.0217e-02, -7.3887e-02,  ...,  5.2969e-02,
           -2.6784e-02, -8.3422e-02],
          [ 6.9965e-01,  8.4726e-01, -1.3827e-02,  ..., -1.4325e+00,
           -7.6671e-01,  1.1938e-01],
          [-6.3505e-01, -1.8155e-02, -1.1006e+00,  ..., -9.8568e-01,
           -1.9473e+00,  4.5972e-01],
          ...,
          [-1.8922e-02, -1.5553e+00, -9.6255e-01,  ...,  3.4609e-01,
            7.2390e-01, -7.2770e-01],
          [-5.7521e-01, -5.4535e-01, -1.5828e+00,  ...,  5.6929e-01,
           -6.4923e-01,  6.9475e-01],
          [ 1.0637e+00,  4.4950e-01,  5.3918e-02,  ..., -4.3740e-01,
           -3.2741e-02, -2.3343e-01]],

         [[ 4.5011e-02,  1.3549e-02,  2.6805e-02,  ..., -3.5686e-02,
           -3.5387e-02, -4.4578e-03],
          [ 8.8991e-01,  5.9314e-01, -7.2167e-02,  ..., -9.4840e-01,
            1.0844e+00,  1.2220e+00],
          [ 1.2285e+00,  1.8533e+00,  3.6534e-01,  ..., -8.7942e-01,
            3.1773e+00, -1.4437e-01],
          ...,
          [ 8.9366e-01,  5.2548e-03, -4.4113e-01,  ...,  2.4077e-01,
           -1.0397e+00,  3.4268e-01],
          [-8.1721e-02, -2.7034e-01, -7.1395e-01,  ..., -1.3899e-01,
            1.1629e+00, -3.0129e-01],
          [ 2.9050e-01,  1.2897e+00, -4.2347e-01,  ..., -3.6758e-01,
           -3.3340e-01,  5.5617e-01]],

         [[ 3.8629e-03,  2.0467e-02, -4.0017e-02,  ..., -1.7782e-03,
            3.0812e-02,  5.9422e-02],
          [-1.4746e+00, -7.6598e-01, -4.9963e-01,  ..., -1.1237e+00,
            4.0888e-01,  3.7366e-02],
          [ 3.1216e+00,  3.4719e+00, -1.5450e+00,  ...,  3.7155e-01,
            1.7554e+00,  1.0825e+00],
          ...,
          [-2.2221e+00,  1.7920e+00,  1.1178e+00,  ..., -3.7278e-01,
            2.8963e-01,  4.8757e-01],
          [ 1.8054e-01, -9.5324e-01, -4.6999e-01,  ..., -2.6873e-01,
           -3.4931e-01, -2.0291e-01],
          [-4.6330e-01,  9.4874e-01,  1.1230e+00,  ...,  2.4846e-01,
           -5.4433e-01,  1.9094e-01]],

         ...,

         [[-8.9160e-02,  4.4159e-02, -7.4543e-03,  ..., -4.1528e-02,
           -1.3703e-02,  1.0209e-02],
          [ 4.5365e-01, -9.8718e-02,  3.7648e-01,  ...,  5.5908e-01,
           -1.6181e+00,  2.0091e-01],
          [-1.0051e+00, -9.2160e-01, -1.8300e-01,  ...,  1.5032e-01,
           -3.5601e-01,  1.1393e+00],
          ...,
          [ 3.8664e-01,  8.8380e-01,  6.9114e-01,  ..., -7.1387e-01,
           -1.8584e+00, -1.8991e+00],
          [-2.6471e-01, -1.8073e-01,  1.5080e+00,  ...,  1.0224e+00,
           -5.7313e-01, -5.2180e-01],
          [-1.1091e+00, -1.0401e+00, -1.0134e+00,  ...,  6.1705e-01,
           -8.1685e-01, -7.5424e-02]],

         [[ 7.4462e-02,  1.1415e-02,  6.6753e-02,  ...,  2.6585e-02,
            5.9258e-02,  2.8430e-02],
          [ 8.3152e-01, -2.2886e+00,  2.2605e-01,  ...,  3.4404e+00,
           -1.3195e+00, -2.8798e+00],
          [ 1.4743e-02,  2.1297e-01, -2.4953e+00,  ..., -1.3823e+00,
            1.4250e+00, -9.5225e-01],
          ...,
          [ 9.1060e-01,  1.6273e+00,  1.8072e+00,  ..., -1.6699e+00,
            2.7229e+00, -9.9341e-01],
          [ 2.0697e+00, -1.0451e+00, -1.6918e+00,  ..., -7.6083e-01,
           -6.6113e-01, -1.1681e-01],
          [ 7.3505e-01,  9.9351e-01,  8.8695e-01,  ..., -2.8755e-01,
            1.9784e+00, -2.1858e+00]],

         [[-1.0551e-01,  3.7056e-02, -5.2910e-02,  ..., -8.2814e-02,
            8.1320e-02, -1.1580e-02],
          [ 7.9286e-01,  1.0883e+00, -2.8319e-01,  ..., -1.6656e-01,
            8.1304e-01,  2.0001e+00],
          [ 7.3818e-01,  2.7335e+00,  2.0108e+00,  ..., -8.0340e-01,
           -6.2005e-01,  3.4431e-01],
          ...,
          [-2.1225e-01,  1.5986e+00, -4.7704e-01,  ...,  1.7525e+00,
            2.5042e-01, -7.3810e-01],
          [-1.0866e+00,  2.4015e-01,  1.7822e-01,  ..., -5.7931e-01,
           -3.7555e-01, -4.3796e-01],
          [-6.3663e-01, -9.6326e-01,  6.6193e-01,  ..., -4.6179e-01,
           -1.1061e+00, -3.5319e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>)), (tensor([[[[-1.7171, -0.3168, -0.2978,  ...,  0.1637,  0.3394, -0.5041],
          [-0.5649, -0.0569, -0.1757,  ...,  0.7798, -0.9593,  0.1274],
          [ 0.6217,  0.6556, -0.2829,  ...,  1.7264, -0.9492, -0.3003],
          ...,
          [ 0.5127, -0.0550,  0.4555,  ..., -0.1798,  0.4189, -0.3237],
          [ 1.4306, -0.2665, -0.6924,  ...,  0.4072,  0.0356,  0.2801],
          [ 0.3647,  0.3681,  0.2000,  ..., -0.0545, -0.4125, -0.4987]],

         [[ 0.1125, -0.0720,  2.2769,  ...,  0.2411,  0.0791, -0.1868],
          [ 0.3754,  0.3120, -1.5438,  ...,  0.7578,  0.4480, -0.2125],
          [ 0.6421,  0.4066, -0.9517,  ...,  1.9527, -0.8776, -0.8329],
          ...,
          [ 1.4697,  0.6897, -2.0477,  ..., -0.3531, -0.8718,  0.1144],
          [ 0.4871, -0.1398, -1.8025,  ...,  0.0553,  0.1603, -0.1879],
          [ 1.1846,  0.8657, -1.2089,  ..., -0.3741, -0.3004,  0.4547]],

         [[-0.1971,  1.0731,  0.4710,  ..., -0.5335,  0.3173, -0.1099],
          [-0.6160, -0.6787, -0.1886,  ...,  1.2765, -0.2427,  0.0198],
          [-0.1753, -1.3161,  0.6825,  ...,  2.5880, -0.0398, -1.8585],
          ...,
          [-0.5711,  0.6618, -0.4905,  ...,  1.6065, -0.7117,  0.6734],
          [-0.4640,  0.6161,  0.2619,  ...,  2.0119,  1.0717, -0.6101],
          [-0.5517,  0.4712, -1.2532,  ...,  1.3872,  0.1236, -0.9117]],

         ...,

         [[ 0.5514,  0.9770, -0.8534,  ..., -0.7274,  0.6981,  0.8435],
          [-0.4633, -0.9677,  1.0696,  ..., -1.3604, -0.4768,  0.4836],
          [ 0.3272,  1.9423, -0.0427,  ..., -1.0563, -1.5149,  1.0777],
          ...,
          [-0.5303, -0.3247,  0.4965,  ..., -1.4606,  0.8117, -0.3339],
          [-0.8559,  0.3017,  0.1236,  ..., -1.0954,  1.3234, -0.2345],
          [-0.2782,  0.0201, -0.3758,  ..., -1.2297,  0.2962, -0.0464]],

         [[-0.4173,  0.3735,  0.3492,  ...,  0.7022,  0.0200, -0.0874],
          [-0.6046,  1.8319, -0.4594,  ..., -0.1303, -0.2883, -2.4248],
          [ 0.8539, -0.2973,  1.2250,  ...,  0.9134, -1.1105,  1.3329],
          ...,
          [-0.9636,  0.0477, -0.1330,  ...,  0.8147, -0.9786,  1.4489],
          [-0.2483, -0.0954, -1.6667,  ...,  1.4701,  0.9545,  0.7833],
          [-0.7331,  0.7880, -0.9207,  ...,  0.3971, -0.3445,  1.1242]],

         [[-0.7286, -0.0169,  0.4363,  ..., -0.1016,  0.0149, -0.0735],
          [-0.3077, -0.2044,  0.9706,  ...,  0.5311, -0.4540,  1.4855],
          [ 0.7991, -1.3786, -2.1901,  ..., -0.3034, -2.3266, -0.9113],
          ...,
          [ 1.4139, -1.7224,  0.0692,  ...,  0.2342, -2.1317,  0.4961],
          [-1.1990, -0.0822,  1.2112,  ..., -0.0050, -2.2928,  0.5032],
          [-0.4787,  0.2465, -0.0146,  ..., -0.7903, -1.8445,  0.4185]]]],
       device='cuda:0', grad_fn=<PermuteBackward0>), tensor([[[[ 8.6463e-02, -1.3589e-01, -2.0528e-01,  ..., -2.7072e-01,
            2.5288e-01, -1.4601e-01],
          [-1.3737e+00,  1.6262e+00, -3.9644e-01,  ...,  2.1092e+00,
           -1.2491e+00,  2.2709e+00],
          [ 9.3736e-02,  8.0827e-01,  4.1821e-01,  ...,  2.4003e+00,
           -6.1362e-01, -8.2722e-01],
          ...,
          [ 4.9494e-01, -1.4327e-01, -9.4550e-01,  ...,  2.5292e+00,
           -3.0418e+00,  1.7612e+00],
          [-9.5852e-01, -4.7755e-01, -8.8891e-02,  ...,  4.1432e+00,
           -4.8426e-01,  9.5120e-01],
          [ 1.7663e+00, -5.3790e-01,  2.6296e-01,  ...,  1.9204e+00,
           -6.1751e-01,  2.3158e+00]],

         [[ 8.4597e-02, -3.0458e-02,  4.1542e-02,  ..., -3.1874e-02,
           -9.8648e-02,  1.6476e-01],
          [-1.0960e+00,  1.6559e+00, -2.8393e-01,  ...,  1.0959e+00,
            7.6596e-01, -1.3732e+00],
          [-4.0602e-02,  4.3989e-01,  1.2150e-01,  ...,  1.6780e+00,
           -9.5549e-01, -8.1111e-01],
          ...,
          [ 1.4156e+00, -3.6635e-01,  3.3233e-01,  ...,  1.0604e+00,
           -1.4212e-01, -6.3028e-01],
          [ 1.3465e+00,  1.7579e+00,  5.1865e-01,  ...,  4.6014e-01,
            8.0379e-01,  1.0797e+00],
          [ 5.4950e-01,  1.1675e+00,  9.0108e-01,  ..., -4.9434e-01,
            5.4417e-01,  5.9731e-01]],

         [[-1.1752e-03,  2.7713e-02, -4.8585e-02,  ...,  4.0027e-02,
            2.4395e-02,  5.1549e-02],
          [ 6.0015e-01, -6.0283e-01, -3.2347e+00,  ..., -7.8959e-01,
            8.7755e-01, -2.7308e-01],
          [ 1.5926e+00, -3.3798e-02, -1.3016e+00,  ...,  1.7952e+00,
           -9.1188e-01,  5.9074e-01],
          ...,
          [-3.3754e-01,  4.6598e-01, -1.0240e+00,  ..., -1.0126e+00,
           -1.6857e+00, -1.9817e+00],
          [-7.5659e-01,  1.8637e-01, -1.3627e+00,  ...,  5.3148e-01,
           -1.6594e+00, -1.3408e+00],
          [ 4.6923e-01, -3.9860e-01, -1.3944e+00,  ...,  1.0094e+00,
            1.0449e-01, -8.4683e-01]],

         ...,

         [[ 5.7201e-04, -5.6053e-03,  1.0676e-01,  ...,  7.0169e-02,
           -1.8888e-02,  4.3763e-02],
          [-7.7385e-01,  2.5806e+00, -7.1687e-01,  ...,  1.3104e+00,
            1.0888e+00,  2.5205e-01],
          [ 2.4686e+00,  4.1175e-01,  2.8059e-01,  ..., -1.0749e+00,
            3.4196e+00, -2.0002e+00],
          ...,
          [-3.1486e-01, -1.0806e+00,  4.2761e-01,  ..., -8.3092e-02,
           -1.4934e+00, -2.3053e-01],
          [-1.4704e-01,  2.0300e+00,  4.0726e-01,  ..., -7.3362e-01,
            5.5142e-01, -1.8850e+00],
          [-2.5338e-01,  1.3151e+00, -4.0857e-01,  ..., -4.7356e-01,
           -1.0699e+00, -1.0685e+00]],

         [[-2.0099e-01, -6.8959e-02,  7.5206e-02,  ..., -5.4001e-02,
            4.2309e-02, -9.4743e-02],
          [ 1.9182e+00,  7.7161e-03, -1.2696e+00,  ..., -1.3677e+00,
            3.1515e+00, -3.5326e-01],
          [ 4.6426e-02, -1.0435e+00, -3.7551e-01,  ..., -2.0669e+00,
            8.7258e-01, -2.2646e-01],
          ...,
          [-1.0673e+00, -1.1056e+00, -2.9885e-01,  ..., -3.9416e-01,
           -2.2161e-01, -4.9800e-01],
          [ 5.8389e-01, -2.5311e-02, -1.2006e+00,  ...,  1.2141e+00,
           -1.7424e+00, -9.9682e-02],
          [-9.4210e-01,  2.2320e-01, -4.9680e-01,  ..., -3.3195e-02,
           -2.9205e-01, -5.4909e-02]],

         [[ 1.0339e-01, -1.4692e-01,  1.6582e-01,  ..., -1.4900e-01,
           -1.5749e-02, -1.8594e-01],
          [-1.5155e+00,  5.3209e-01,  6.6421e-01,  ..., -5.4169e-01,
            2.2736e+00, -2.5641e-01],
          [-9.2274e-02, -5.9834e-01, -1.2739e+00,  ..., -1.5369e+00,
           -5.6380e-01, -1.3742e+00],
          ...,
          [-9.9035e-01, -8.1416e-01, -1.5702e-01,  ...,  2.5352e-01,
           -2.9972e-01, -1.1698e-01],
          [-1.3327e+00, -1.1053e+00, -6.8016e-01,  ..., -8.5201e-01,
            1.1720e+00, -6.5619e-01],
          [ 2.2502e-02, -6.0447e-02,  4.3160e-02,  ...,  2.4946e-01,
            1.6794e-01, -6.6439e-01]]]], device='cuda:0',
       grad_fn=<PermuteBackward0>))), hidden_states=None, attentions=None, cross_attentions=None)

Lets see why token based indexing is necessary.

In this example, we call invokes on two inputs of different tokenized length. We incorrectly index into the hidden states using normal python indexing.

[3]:
from rich import print

with model.trace() as tracer:
    with tracer.invoke('The') as invoker:
        incorrect_a =  model.transformer.input[0][0][:,0].save()

    with tracer.invoke('The Eiffel Tower is in the city of''The Eiffel Tower is in the city of') as invoker:
        incorrect_b = model.transformer.input[0][0][:,0].save()

print(f"Shorter input: {incorrect_a.value}")
print(f"Longer input: {incorrect_b.value}")
Shorter input: tensor([50256], device='cuda:0')
Longer input: tensor([464], device='cuda:0')

Notice how we indexed into the first token for both strings but recieved a different result from each invoke. This is because if there are multiple invocations, padding is performed on the left side so these helper functions index from the back.

Let’s correctly index into the hidden states using token based indexing.

[4]:
with model.trace() as tracer:
    with tracer.invoke('The') as invoker:
        correct_a =  model.transformer.input[0][0].t[0].save()

    with tracer.invoke('The Eiffel Tower is in the city of') as invoker:
        correct_b = model.transformer.input[0][0].t[0].save()

print(f"Shorter input: {correct_a.value}")
print(f"Longer input: {correct_b.value}")
Shorter input: tensor([464], device='cuda:0')
Longer input: tensor([464], device='cuda:0')

Now we have the correct tokens!