AESNI.c 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. /* ===================================================================
  2. *
  3. * Copyright (c) 2018, Helder Eijs <helderijs@gmail.com>
  4. * All rights reserved.
  5. *
  6. * Redistribution and use in source and binary forms, with or without
  7. * modification, are permitted provided that the following conditions
  8. * are met:
  9. *
  10. * 1. Redistributions of source code must retain the above copyright
  11. * notice, this list of conditions and the following disclaimer.
  12. * 2. Redistributions in binary form must reproduce the above copyright
  13. * notice, this list of conditions and the following disclaimer in
  14. * the documentation and/or other materials provided with the
  15. * distribution.
  16. *
  17. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  18. * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  19. * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
  20. * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
  21. * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  22. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
  23. * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  24. * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  25. * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  26. * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
  27. * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  28. * POSSIBILITY OF SUCH DAMAGE.
  29. * ===================================================================
  30. */
  31. #include <stdlib.h>
  32. #include <stdio.h>
  33. #include <wmmintrin.h>
  34. #include "common.h"
  35. #include "endianess.h"
  36. #include "block_base.h"
  37. FAKE_INIT(raw_aesni)
  38. #define MODULE_NAME AESNI
  39. #define BLOCK_SIZE 16
  40. struct block_state {
  41. __m128i *erk; /** Round keys for encryption (11, 13 or 15 elements) **/
  42. __m128i *drk; /** Round keys for decryption **/
  43. unsigned rounds;
  44. };
  45. typedef struct {
  46. BlockBase base_state;
  47. struct block_state algo_state;
  48. } AESNI_State;
  49. /*
  50. * See https://www.cosic.esat.kuleuven.be/ecrypt/AESday/slides/Use_of_the_AES_Instruction_Set.pdf
  51. */
  52. enum SubType { OnlySub, SubRotXor };
  53. static FUNC_SSE2 uint32_t sub_rot(uint32_t w, unsigned idx /** round/Nk **/, enum SubType subType)
  54. {
  55. __m128i x, y, z;
  56. assert((idx>=1) && (idx<=10));
  57. x = _mm_set1_epi32((int)w); // { w, w, w, w }
  58. y = _mm_set1_epi32(0);
  59. switch (idx) {
  60. case 1: y = _mm_aeskeygenassist_si128(x, 0x01); break;
  61. case 2: y = _mm_aeskeygenassist_si128(x, 0x02); break;
  62. case 3: y = _mm_aeskeygenassist_si128(x, 0x04); break;
  63. case 4: y = _mm_aeskeygenassist_si128(x, 0x08); break;
  64. case 5: y = _mm_aeskeygenassist_si128(x, 0x10); break;
  65. case 6: y = _mm_aeskeygenassist_si128(x, 0x20); break;
  66. case 7: y = _mm_aeskeygenassist_si128(x, 0x40); break;
  67. case 8: y = _mm_aeskeygenassist_si128(x, 0x80); break;
  68. case 9: y = _mm_aeskeygenassist_si128(x, 0x1b); break;
  69. case 10: y = _mm_aeskeygenassist_si128(x, 0x36); break;
  70. }
  71. /** Y0 contains SubWord(W) **/
  72. /** Y1 contains RotWord(SubWord(W)) xor RCON **/
  73. z = y;
  74. if (subType == SubRotXor) {
  75. z = _mm_srli_si128(y, 4);
  76. }
  77. return (uint32_t)_mm_cvtsi128_si32(z);
  78. }
  79. static FUNC_SSE2 int expand_key(__m128i *erk, __m128i *drk, const uint8_t *key, unsigned Nk, unsigned Nr)
  80. {
  81. uint32_t rk[4*(14+2)];
  82. unsigned tot_words;
  83. unsigned i;
  84. assert(
  85. ((Nk==4) && (Nr==10)) || /** AES-128 **/
  86. ((Nk==6) && (Nr==12)) || /** AES-192 **/
  87. ((Nk==8) && (Nr==14)) /** AES-256 **/
  88. );
  89. tot_words = 4*(Nr+1);
  90. for (i=0; i<Nk; i++) {
  91. rk[i] = LOAD_U32_LITTLE(key);
  92. key += 4;
  93. }
  94. for (i=Nk; i<tot_words; i++) {
  95. uint32_t tmp;
  96. tmp = rk[i-1];
  97. if (i % Nk == 0) {
  98. tmp = sub_rot(tmp, i/Nk, SubRotXor);
  99. } else {
  100. if ((i % Nk == 4) && (Nk == 8)) { /* AES-256 only */
  101. tmp = sub_rot(tmp, i/Nk, OnlySub);
  102. }
  103. }
  104. rk[i] = rk[i-Nk] ^ tmp;
  105. }
  106. for (i=0; i<tot_words; i+=4) {
  107. *erk++ = _mm_loadu_si128((__m128i*)&rk[i]);
  108. }
  109. erk--; /** Point to the last round **/
  110. *drk++ = *erk--;
  111. for (i=0; i<Nr-1; i++) {
  112. *drk++ = _mm_aesimc_si128(*erk--);
  113. }
  114. *drk = *erk;
  115. return 0;
  116. }
  117. static FUNC_SSE2 int AESNI_encrypt(const BlockBase *bb, const uint8_t *in, uint8_t *out, size_t data_len)
  118. {
  119. unsigned rounds;
  120. __m128i r[14+1];
  121. const struct block_state *state;
  122. unsigned k;
  123. if ((bb == NULL) || (in == NULL) || (out == NULL))
  124. return ERR_NULL;
  125. state = &((AESNI_State*)bb)->algo_state;
  126. rounds = state->rounds;
  127. if (rounds > 14)
  128. return ERR_NR_ROUNDS;
  129. for (k=0; k<=rounds; k++) {
  130. r[k] = state->erk[k];
  131. }
  132. /** Encrypt 8 blocks (128 bytes) in parallel, when possible **/
  133. for (; data_len >= 8*16; data_len -= 8*16) {
  134. __m128i pt[8], data[8];
  135. unsigned j;
  136. pt[0] = _mm_loadu_si128((__m128i*)in); in+=16;
  137. pt[1] = _mm_loadu_si128((__m128i*)in); in+=16;
  138. pt[2] = _mm_loadu_si128((__m128i*)in); in+=16;
  139. pt[3] = _mm_loadu_si128((__m128i*)in); in+=16;
  140. pt[4] = _mm_loadu_si128((__m128i*)in); in+=16;
  141. pt[5] = _mm_loadu_si128((__m128i*)in); in+=16;
  142. pt[6] = _mm_loadu_si128((__m128i*)in); in+=16;
  143. pt[7] = _mm_loadu_si128((__m128i*)in); in+=16;
  144. data[0] = _mm_xor_si128(pt[0], r[0]);
  145. data[1] = _mm_xor_si128(pt[1], r[0]);
  146. data[2] = _mm_xor_si128(pt[2], r[0]);
  147. data[3] = _mm_xor_si128(pt[3], r[0]);
  148. data[4] = _mm_xor_si128(pt[4], r[0]);
  149. data[5] = _mm_xor_si128(pt[5], r[0]);
  150. data[6] = _mm_xor_si128(pt[6], r[0]);
  151. data[7] = _mm_xor_si128(pt[7], r[0]);
  152. for (j=1; j<10; j++) {
  153. data[0] = _mm_aesenc_si128(data[0], r[j]);
  154. data[1] = _mm_aesenc_si128(data[1], r[j]);
  155. data[2] = _mm_aesenc_si128(data[2], r[j]);
  156. data[3] = _mm_aesenc_si128(data[3], r[j]);
  157. data[4] = _mm_aesenc_si128(data[4], r[j]);
  158. data[5] = _mm_aesenc_si128(data[5], r[j]);
  159. data[6] = _mm_aesenc_si128(data[6], r[j]);
  160. data[7] = _mm_aesenc_si128(data[7], r[j]);
  161. }
  162. for (; j<rounds; j++) {
  163. data[0] = _mm_aesenc_si128(data[0], r[j]);
  164. data[1] = _mm_aesenc_si128(data[1], r[j]);
  165. data[2] = _mm_aesenc_si128(data[2], r[j]);
  166. data[3] = _mm_aesenc_si128(data[3], r[j]);
  167. data[4] = _mm_aesenc_si128(data[4], r[j]);
  168. data[5] = _mm_aesenc_si128(data[5], r[j]);
  169. data[6] = _mm_aesenc_si128(data[6], r[j]);
  170. data[7] = _mm_aesenc_si128(data[7], r[j]);
  171. }
  172. data[0] = _mm_aesenclast_si128(data[0], r[rounds]);
  173. data[1] = _mm_aesenclast_si128(data[1], r[rounds]);
  174. data[2] = _mm_aesenclast_si128(data[2], r[rounds]);
  175. data[3] = _mm_aesenclast_si128(data[3], r[rounds]);
  176. data[4] = _mm_aesenclast_si128(data[4], r[rounds]);
  177. data[5] = _mm_aesenclast_si128(data[5], r[rounds]);
  178. data[6] = _mm_aesenclast_si128(data[6], r[rounds]);
  179. data[7] = _mm_aesenclast_si128(data[7], r[rounds]);
  180. _mm_storeu_si128((__m128i*)out, data[0]); out+=16;
  181. _mm_storeu_si128((__m128i*)out, data[1]); out+=16;
  182. _mm_storeu_si128((__m128i*)out, data[2]); out+=16;
  183. _mm_storeu_si128((__m128i*)out, data[3]); out+=16;
  184. _mm_storeu_si128((__m128i*)out, data[4]); out+=16;
  185. _mm_storeu_si128((__m128i*)out, data[5]); out+=16;
  186. _mm_storeu_si128((__m128i*)out, data[6]); out+=16;
  187. _mm_storeu_si128((__m128i*)out, data[7]); out+=16;
  188. }
  189. /** There are 7 blocks or fewer left **/
  190. for (;data_len>=BLOCK_SIZE; data_len-=BLOCK_SIZE, in+=BLOCK_SIZE, out+=BLOCK_SIZE) {
  191. __m128i pt, data;
  192. unsigned i;
  193. pt = _mm_loadu_si128((__m128i*)in);
  194. data = _mm_xor_si128(pt, r[0]);
  195. for (i=1; i<10; i++) {
  196. data = _mm_aesenc_si128(data, r[i]);
  197. }
  198. for (i=10; i<rounds; i+=2) {
  199. data = _mm_aesenc_si128(data, r[i]);
  200. data = _mm_aesenc_si128(data, r[i+1]);
  201. }
  202. data = _mm_aesenclast_si128(data, r[rounds]);
  203. _mm_storeu_si128((__m128i*)out, data);
  204. }
  205. if (data_len) {
  206. return ERR_NOT_ENOUGH_DATA;
  207. }
  208. return 0;
  209. }
  210. static FUNC_SSE2 int AESNI_decrypt(const BlockBase *bb, const uint8_t *in, uint8_t *out, size_t data_len)
  211. {
  212. unsigned rounds;
  213. __m128i r[14+1];
  214. const struct block_state *state;
  215. unsigned k;
  216. if ((bb == NULL) || (in == NULL) || (out == NULL))
  217. return ERR_NULL;
  218. state = &((AESNI_State*)bb)->algo_state;
  219. rounds = state->rounds;
  220. if (rounds > 14)
  221. return ERR_NR_ROUNDS;
  222. for (k=0; k<=rounds; k++) {
  223. r[k] = state->drk[k];
  224. }
  225. /** Decrypt 8 blocks (128 bytes) in parallel, when possible **/
  226. for (; data_len >= 8*16; data_len -= 8*16) {
  227. __m128i ct[8], data[8];
  228. unsigned j;
  229. ct[0] = _mm_loadu_si128((__m128i*)in); in+=16;
  230. ct[1] = _mm_loadu_si128((__m128i*)in); in+=16;
  231. ct[2] = _mm_loadu_si128((__m128i*)in); in+=16;
  232. ct[3] = _mm_loadu_si128((__m128i*)in); in+=16;
  233. ct[4] = _mm_loadu_si128((__m128i*)in); in+=16;
  234. ct[5] = _mm_loadu_si128((__m128i*)in); in+=16;
  235. ct[6] = _mm_loadu_si128((__m128i*)in); in+=16;
  236. ct[7] = _mm_loadu_si128((__m128i*)in); in+=16;
  237. data[0] = _mm_xor_si128(ct[0], r[0]);
  238. data[1] = _mm_xor_si128(ct[1], r[0]);
  239. data[2] = _mm_xor_si128(ct[2], r[0]);
  240. data[3] = _mm_xor_si128(ct[3], r[0]);
  241. data[4] = _mm_xor_si128(ct[4], r[0]);
  242. data[5] = _mm_xor_si128(ct[5], r[0]);
  243. data[6] = _mm_xor_si128(ct[6], r[0]);
  244. data[7] = _mm_xor_si128(ct[7], r[0]);
  245. for (j=1; j<10; j++) {
  246. data[0] = _mm_aesdec_si128(data[0], r[j]);
  247. data[1] = _mm_aesdec_si128(data[1], r[j]);
  248. data[2] = _mm_aesdec_si128(data[2], r[j]);
  249. data[3] = _mm_aesdec_si128(data[3], r[j]);
  250. data[4] = _mm_aesdec_si128(data[4], r[j]);
  251. data[5] = _mm_aesdec_si128(data[5], r[j]);
  252. data[6] = _mm_aesdec_si128(data[6], r[j]);
  253. data[7] = _mm_aesdec_si128(data[7], r[j]);
  254. }
  255. for (; j<rounds; j++) {
  256. data[0] = _mm_aesdec_si128(data[0], r[j]);
  257. data[1] = _mm_aesdec_si128(data[1], r[j]);
  258. data[2] = _mm_aesdec_si128(data[2], r[j]);
  259. data[3] = _mm_aesdec_si128(data[3], r[j]);
  260. data[4] = _mm_aesdec_si128(data[4], r[j]);
  261. data[5] = _mm_aesdec_si128(data[5], r[j]);
  262. data[6] = _mm_aesdec_si128(data[6], r[j]);
  263. data[7] = _mm_aesdec_si128(data[7], r[j]);
  264. }
  265. data[0] = _mm_aesdeclast_si128(data[0], r[rounds]);
  266. data[1] = _mm_aesdeclast_si128(data[1], r[rounds]);
  267. data[2] = _mm_aesdeclast_si128(data[2], r[rounds]);
  268. data[3] = _mm_aesdeclast_si128(data[3], r[rounds]);
  269. data[4] = _mm_aesdeclast_si128(data[4], r[rounds]);
  270. data[5] = _mm_aesdeclast_si128(data[5], r[rounds]);
  271. data[6] = _mm_aesdeclast_si128(data[6], r[rounds]);
  272. data[7] = _mm_aesdeclast_si128(data[7], r[rounds]);
  273. _mm_storeu_si128((__m128i*)out, data[0]); out+=16;
  274. _mm_storeu_si128((__m128i*)out, data[1]); out+=16;
  275. _mm_storeu_si128((__m128i*)out, data[2]); out+=16;
  276. _mm_storeu_si128((__m128i*)out, data[3]); out+=16;
  277. _mm_storeu_si128((__m128i*)out, data[4]); out+=16;
  278. _mm_storeu_si128((__m128i*)out, data[5]); out+=16;
  279. _mm_storeu_si128((__m128i*)out, data[6]); out+=16;
  280. _mm_storeu_si128((__m128i*)out, data[7]); out+=16;
  281. }
  282. /** There are 7 blocks or fewer left **/
  283. for (;data_len>=BLOCK_SIZE; data_len-=BLOCK_SIZE, in+=BLOCK_SIZE, out+=BLOCK_SIZE) {
  284. __m128i ct, data;
  285. unsigned i;
  286. ct = _mm_loadu_si128((__m128i*)in);
  287. data = _mm_xor_si128(ct, r[0]);
  288. for (i=1; i<10; i++) {
  289. data = _mm_aesdec_si128(data, r[i]);
  290. }
  291. for (i=10; i<rounds; i+=2) {
  292. data = _mm_aesdec_si128(data, r[i]);
  293. data = _mm_aesdec_si128(data, r[i+1]);
  294. }
  295. data = _mm_aesdeclast_si128(data, r[rounds]);
  296. _mm_storeu_si128((__m128i*)out, data);
  297. }
  298. if (data_len) {
  299. return ERR_NOT_ENOUGH_DATA;
  300. }
  301. return 0;
  302. }
  303. EXPORT_SYM int AESNI_stop_operation(BlockBase *bb)
  304. {
  305. AESNI_State *state;
  306. if (NULL == bb)
  307. return ERR_NULL;
  308. state = (AESNI_State*)bb;
  309. align_free(state->algo_state.erk);
  310. align_free(state->algo_state.drk);
  311. free(state);
  312. return 0;
  313. }
  314. EXPORT_SYM int AESNI_start_operation(const uint8_t key[], size_t key_len, AESNI_State **pResult)
  315. {
  316. unsigned Nr;
  317. const unsigned Nb = 4;
  318. int result;
  319. struct block_state *state;
  320. BlockBase *block_base;
  321. if ((NULL == key) || (NULL == pResult))
  322. return ERR_NULL;
  323. switch (key_len) {
  324. case 16: Nr = 10; break;
  325. case 24: Nr = 12; break;
  326. case 32: Nr = 14; break;
  327. default: return ERR_KEY_SIZE;
  328. }
  329. *pResult= calloc(1, sizeof(AESNI_State));
  330. if (NULL == *pResult)
  331. return ERR_MEMORY;
  332. block_base = &((*pResult)->base_state);
  333. block_base->encrypt = &AESNI_encrypt;
  334. block_base->decrypt = &AESNI_decrypt;
  335. block_base->destructor = &AESNI_stop_operation;
  336. block_base->block_len = BLOCK_SIZE;
  337. state = &((*pResult)->algo_state);
  338. state->rounds = Nr;
  339. state->erk = align_alloc(Nb*(Nr+1)*sizeof(uint32_t), 16);
  340. if (state->erk == NULL) {
  341. result = ERR_MEMORY;
  342. goto error;
  343. }
  344. state->drk = align_alloc(Nb*(Nr+1)*sizeof(uint32_t), 16);
  345. if (state->drk == NULL) {
  346. result = ERR_MEMORY;
  347. goto error;
  348. }
  349. result = expand_key(state->erk, state->drk, key, (unsigned)key_len/4, Nr);
  350. if (result) {
  351. goto error;
  352. }
  353. return 0;
  354. error:
  355. align_free(state->erk);
  356. align_free(state->drk);
  357. free(*pResult);
  358. return result;
  359. }
  360. #ifdef MAIN
  361. #include <stdio.h>
  362. int main(void)
  363. {
  364. void *c, *d;
  365. uint8_t key[16] = { 0 };
  366. struct block_state *s;
  367. int i;
  368. int q = 1000000*16;
  369. AESNI_start_operation(key, 16, &s);
  370. c = malloc(q);
  371. d = malloc(q);
  372. for (i=0; i<1000; i++)
  373. AESNI_encrypt((void*)s, c, d, q);
  374. printf("Done.\n");
  375. return 0;
  376. }
  377. #endif