i have written this rsa implementation in C, so far everything seems good, but there is something buggy about the pem loader, i mean the public key can't be loaded from the file, everything else seems correct except the loader, does anyone know what's the issue here?
C:
#pragma once
#ifndef RSA_ALGORITHM_C_H
#define RSA_ALGORITHM_C_H
#include <ctype.h>
#include <errno.h>
#include <gmp.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <fcntl.h>
#include <unistd.h> // For secure random bytes
#if defined(__GNUC__) || defined(__clang__)
#define __attr_nodiscard __attribute__((warn_unused_result))
#define __attr_malloc __attribute__((malloc))
#define __attr_hot __attribute__((hot))
#define __attr_cold __attribute__((cold))
#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)
#else
#define __attr_nodiscard
#define __attr_malloc
#define __attr_hot
#define __attr_cold
#define likely(x) (x)
#define unlikely(x) (x)
#endif
#ifdef __cplusplus
#define __restrict__ __restrict
#else
#define __restrict__ restrict
#endif
#ifdef __cplusplus
#define __noexcept noexcept
#define __const_noexcept const noexcept
#else
#define __noexcept
#define __const_noexcept
#endif
// ========== Secure Randomness ==========
// Fill `buf` with `len` cryptographically secure random bytes using /dev/urandom
__attr_hot static inline void get_secure_random_bytes(void *buf, size_t len) {
int fd = open("/dev/urandom", O_RDONLY);
if (fd < 0) {
perror("open /dev/urandom");
abort();
}
ssize_t r = read(fd, buf, len);
if (r != (ssize_t)len) {
perror("read /dev/urandom");
close(fd);
abort();
}
close(fd);
}
// Use cryptographically secure seed for GMP RNG
__attr_hot static inline void gmp_randseed_secure(gmp_randstate_t rng) {
unsigned char seed_buf[32];
get_secure_random_bytes(seed_buf, sizeof(seed_buf));
mpz_t seed;
mpz_init(seed);
mpz_import(seed, sizeof(seed_buf), 1, 1, 0, 0, seed_buf);
gmp_randseed(rng, seed);
mpz_clear(seed);
}
// ===== ENUMS =====
typedef enum
{
KEYSIZE_1024 = 1024,
KEYSIZE_2048 = 2048,
KEYSIZE_3072 = 3072,
KEYSIZE_4096 = 4096
} KeySize;
typedef enum
{
OUTPUT_BINARY,
OUTPUT_HEX,
OUTPUT_BASE64
} OutputFormat;
// ===== STRUCTS =====
typedef struct
{
mpz_t n, e;
} RSAPublicKey;
typedef struct
{
mpz_t n, e, d, p, q, dP, dQ, qInv;
} RSAPrivateKey;
typedef struct
{
RSAPublicKey public_key;
RSAPrivateKey private_key;
} RSAKeyPair;
// ========== Utility Memory ==========
#define SECURE_FREE(p, sz) \
do \
{ \
if (likely(p)) \
{ \
volatile unsigned char *vp = (volatile unsigned char *)(p); \
for (size_t _i = 0; _i < (sz); ++_i) \
vp[_i] = 0; \
free((void *)p); \
} \
} while (0)
#define SAFE_MALLOC(size) \
({ \
void *_p = malloc(size); \
if (unlikely(!_p)) \
{ \
fprintf(stderr, "Out of memory at %s:%d\n", __FILE__, __LINE__); \
abort(); \
} \
_p; \
})
// ========== ASN.1 DER ==========
__attr_hot static inline void encode_integer(const mpz_t x, uint8_t ** const out, size_t * const outlen) __noexcept
{
size_t count = (mpz_sizeinbase(x, 2) + 7) / 8;
if (unlikely(count == 0))
count = 1;
uint8_t * const bytes = (uint8_t *)SAFE_MALLOC(count + 2);
mpz_export(bytes, &count, 1, 1, 1, 0, x);
int prepend_zero = (bytes[0] & 0x80) ? 1 : 0;
if (likely(prepend_zero))
{
memmove(bytes + 1, bytes, count);
bytes[0] = 0x00;
count += 1;
}
*outlen = 2 + count;
*out = (uint8_t *)SAFE_MALLOC(*outlen);
(*out)[0] = 0x02;
(*out)[1] = (uint8_t)count;
memcpy(*out + 2, bytes, count);
SECURE_FREE(bytes, count + 1);
}
__attr_hot static inline void encode_sequence(uint8_t ** const fields, const size_t * const field_lens, const size_t num_fields, uint8_t ** const out, size_t * const outlen) __noexcept
{
size_t seqlen = 0;
for (size_t i = 0; i < num_fields; ++i)
seqlen += field_lens[i];
*outlen = 2 + seqlen;
*out = (uint8_t *)SAFE_MALLOC(*outlen);
(*out)[0] = 0x30;
(*out)[1] = (uint8_t)seqlen;
size_t offset = 2;
for (size_t i = 0; i < num_fields; ++i)
{
memcpy(*out + offset, fields[i], field_lens[i]);
offset += field_lens[i];
}
}
// ========== BASE64 ENCODE/DECODE ==========
static const char b64_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
__attr_malloc __attr_nodiscard __attr_hot static inline char *base64_encode(const uint8_t * __restrict__ in, const size_t inlen) __noexcept
{
const size_t outlen = 4 * ((inlen + 2) / 3);
char * const out = (char *)SAFE_MALLOC(outlen + 1);
size_t i = 0, j = 0;
for (; i + 2 < inlen; i += 3)
{
uint32_t triple = (in[i] << 16) | (in[i+1] << 8) | in[i+2];
out[j++] = b64_table[(triple >> 18) & 0x3F];
out[j++] = b64_table[(triple >> 12) & 0x3F];
out[j++] = b64_table[(triple >> 6) & 0x3F];
out[j++] = b64_table[triple & 0x3F];
}
if (i < inlen)
{
uint32_t triple = in[i] << 16;
out[j++] = b64_table[(triple >> 18) & 0x3F];
if (i + 1 < inlen)
{
triple |= in[i+1] << 8;
out[j++] = b64_table[(triple >> 12) & 0x3F];
out[j++] = b64_table[(triple >> 6) & 0x3F];
}
else
{
out[j++] = b64_table[(triple >> 12) & 0x3F];
out[j++] = '=';
}
out[j++] = '=';
}
out[outlen] = 0;
return out;
}
__attr_malloc __attr_nodiscard __attr_hot static inline uint8_t *base64_decode(const char *in, size_t * const outlen) __noexcept
{
static int dtable[256] = {0};
static int dtable_initialized = 0;
if (!dtable_initialized) {
for (size_t i = 0; i < 64; i++)
dtable[(uint8_t)b64_table[i]] = i + 1;
dtable_initialized = 1;
}
const size_t inlen = strlen(in);
size_t pad = 0, i = 0, j = 0;
if (inlen >= 2 && in[inlen - 1] == '=')
pad++;
if (inlen >= 2 && in[inlen - 2] == '=')
pad++;
*outlen = (inlen * 3) / 4 - pad;
uint8_t * const out = (uint8_t *)SAFE_MALLOC(*outlen);
for (i = 0, j = 0; i < inlen;) {
int val = 0;
for (int n = 0; n < 4 && i < inlen;) {
char c = in[i++];
if (likely(dtable[(uint8_t)c])) {
val = (val << 6) | (dtable[(uint8_t)c] - 1);
n++;
} else if (c == '=') {
val <<= 6;
n++;
}
}
if (likely(j < *outlen))
out[j++] = (val >> 16) & 0xFF;
if (likely(j < *outlen))
out[j++] = (val >> 8) & 0xFF;
if (likely(j < *outlen))
out[j++] = val & 0xFF;
}
return out;
}
// PEM WRAP/UNWRAP
__attr_malloc __attr_nodiscard __attr_hot static inline char *pem_wrap(const char *base64, const char *type) __noexcept
{
const size_t typelen = strlen(type);
const size_t b64len = strlen(base64);
const size_t outlen = 32 + b64len + b64len / 64 * 2 + typelen * 2;
char * const out = (char *)SAFE_MALLOC(outlen);
sprintf(out, "-----BEGIN %s-----\n", type);
size_t offset = strlen(out);
for (size_t i = 0; i < b64len; i += 64)
{
strncat(out + offset, base64 + i, 64);
offset += (b64len - i > 64 ? 64 : b64len - i);
out[offset++] = '\n';
out[offset] = 0;
}
sprintf(out + offset, "-----END %s-----\n", type);
return out;
}
__attr_malloc __attr_nodiscard __attr_hot static inline char *strip_pem(const char *pem) __noexcept
{
const char *begin = strstr(pem, "-----BEGIN");
if (unlikely(!begin))
{
fprintf(stderr, "Failed to find PEM BEGIN header\n");
return NULL;
}
begin = strchr(begin, '\n');
const char *end = strstr(pem, "-----END");
if (unlikely(!end))
{
fprintf(stderr, "Failed to find PEM END header\n");
return NULL;
}
const size_t b64len = end - begin - 1;
char * const b64 = (char *)SAFE_MALLOC(b64len + 1);
size_t k = 0;
for (const char *p = begin + 1; p < end; ++p)
if (isalnum((unsigned char)*p) || *p == '+' || *p == '/' || *p == '=')
b64[k++] = *p;
b64[k] = 0;
return b64;
}
// ========== HEX ENCODE/DECODE ==========
__attr_malloc __attr_nodiscard __attr_hot static inline char *bytes_to_hex(const uint8_t * __restrict__ data, const size_t len) __noexcept
{
char * const out = (char *)SAFE_MALLOC(len * 2 + 1);
for (size_t i = 0; i < len; ++i)
sprintf(out + 2 * i, "%02x", data[i]);
out[len * 2] = 0;
return out;
}
__attr_malloc __attr_nodiscard __attr_hot static inline uint8_t *hex_to_bytes(const char *hex, size_t * const outlen) __noexcept
{
const size_t len = strlen(hex) / 2;
uint8_t * const out = (uint8_t *)SAFE_MALLOC(len);
for (size_t i = 0; i < len; ++i)
{
if (unlikely(sscanf(hex + 2 * i, "%2hhx", &out[i]) != 1))
{
fprintf(stderr, "Invalid hex string at position %zu\n", i * 2);
SECURE_FREE(out, len);
*outlen = 0;
return NULL;
}
}
*outlen = len;
return out;
}
// ========== FILE IO ==========
__attr_hot static inline int file_exists(const char *filename) __noexcept
{
FILE * const f = fopen(filename, "rb");
if (likely(f))
{
fclose(f);
return 1;
}
return 0;
}
__attr_malloc __attr_nodiscard __attr_hot static inline uint8_t *read_file(const char *filename, size_t * const len) __noexcept
{
FILE * const f = fopen(filename, "rb");
if (unlikely(!f))
{
fprintf(stderr, "Failed to open file '%s' for reading: %s\n", filename, strerror(errno));
return NULL;
}
if (unlikely(fseek(f, 0, SEEK_END) != 0))
{
fclose(f);
return NULL;
}
const long flen = ftell(f);
if (unlikely(flen < 0))
{
fclose(f);
return NULL;
}
if (unlikely(fseek(f, 0, SEEK_SET) != 0))
{
fclose(f);
return NULL;
}
uint8_t * const buf = (uint8_t *)SAFE_MALLOC(flen);
const size_t r = fread(buf, 1, flen, f);
fclose(f);
if (unlikely(r != (size_t)flen))
{
SECURE_FREE(buf, flen);
return NULL;
}
*len = flen;
return buf;
}
__attr_hot static inline int write_file(const char *filename, const uint8_t * __restrict__ data, const size_t len) __noexcept
{
FILE * const f = fopen(filename, "wb");
if (unlikely(!f))
{
fprintf(stderr, "Failed to open '%s' for writing\n", filename);
return 0;
}
const size_t w = fwrite(data, 1, len, f);
fclose(f);
if (unlikely(w != len))
{
fprintf(stderr, "Failed to write all data to '%s'\n", filename);
return 0;
}
return 1;
}
__attr_hot static inline int write_text(const char *filename, const char *text) __noexcept
{
FILE * const f = fopen(filename, "w");
if (unlikely(!f))
{
fprintf(stderr, "Failed to open file '%s' for writing: %s\n", filename, strerror(errno));
return 0;
}
const size_t w = fwrite(text, 1, strlen(text), f);
fclose(f);
if (unlikely(w != strlen(text)))
{
fprintf(stderr, "Error writing text to file '%s'\n", filename);
return 0;
}
return 1;
}
// ========== RANDOM PRIME, MODINV, POWM ==========
__attr_hot static inline void random_prime(mpz_t out, const int bits, gmp_randstate_t rng) __noexcept
{
int tries = 0;
while (1)
{
mpz_urandomb(out, rng, bits);
mpz_setbit(out, bits - 1);
mpz_setbit(out, 0);
if (likely(mpz_probab_prime_p(out, 40)))
return;
if (unlikely(++tries > 100000))
{
fprintf(stderr, "Failed to generate a random prime after 100000 tries\n");
abort();
}
}
}
__attr_hot static inline void modinv(mpz_t out, const mpz_t a, const mpz_t m) __noexcept
{
if (unlikely(!mpz_invert(out, a, m)))
{
fprintf(stderr, "No modular inverse exists\n");
abort();
}
}
__attr_hot static inline void powm(mpz_t out, const mpz_t base, const mpz_t exp, const mpz_t mod) __noexcept
{
mpz_powm(out, base, exp, mod);
}
// ========== PKCS#1 v1.5 Padding ==========
__attr_nodiscard __attr_hot static inline int pkcs1_pad(const uint8_t * __restrict__ block, const size_t blocklen, uint8_t * __restrict__ padded, const size_t mod_bytes) __noexcept
{
if (unlikely(blocklen > mod_bytes - 11))
{
fprintf(stderr, "Block too large for PKCS#1 v1.5 padding\n");
return 0;
}
padded[0] = 0x00;
padded[1] = 0x02;
size_t pad_len = mod_bytes - 3 - blocklen;
uint8_t *pad_bytes = (uint8_t *)SAFE_MALLOC(pad_len);
get_secure_random_bytes(pad_bytes, pad_len);
// Ensure all padding bytes are nonzero
for (size_t i = 0; i < pad_len; ++i) {
while (pad_bytes[i] == 0) {
get_secure_random_bytes(&pad_bytes[i], 1);
}
padded[2 + i] = pad_bytes[i];
}
SECURE_FREE(pad_bytes, pad_len);
padded[2 + pad_len] = 0x00;
memcpy(padded + 3 + pad_len, block, blocklen);
return 1;
}
__attr_nodiscard __attr_hot static inline int pkcs1_unpad(const uint8_t *padded, const size_t paddedlen, uint8_t *out, size_t * const outlen) __noexcept
{
if (unlikely(paddedlen < 11 || padded[0] != 0x00 || padded[1] != 0x02))
{
fprintf(stderr, "Invalid PKCS#1 v1.5 padding detected\n");
return 0;
}
size_t idx = 2;
while (idx < paddedlen && padded[idx] != 0x00)
++idx;
if (unlikely(idx == paddedlen || idx < 10))
{
fprintf(stderr, "Padding format error in PKCS#1 v1.5\n");
return 0;
}
const size_t msglen = paddedlen - (idx + 1);
memcpy(out, padded + idx + 1, msglen);
*outlen = msglen;
return 1;
}
// ========== KEY INIT/CLEAR ==========
__attr_hot static inline void rsa_public_key_init(RSAPublicKey * const key) __noexcept
{
if (unlikely(!key))
{
fprintf(stderr, "Null RSAPublicKey pointer passed to init\n");
abort();
}
mpz_init(key->n);
mpz_init(key->e);
}
__attr_hot static inline void rsa_private_key_init(RSAPrivateKey * const key) __noexcept
{
if (unlikely(!key))
{
fprintf(stderr, "Null RSAPrivateKey pointer passed to init\n");
abort();
}
mpz_init(key->n);
mpz_init(key->e);
mpz_init(key->d);
mpz_init(key->p);
mpz_init(key->q);
mpz_init(key->dP);
mpz_init(key->dQ);
mpz_init(key->qInv);
}
__attr_cold static inline void rsa_public_key_clear(RSAPublicKey * const key) __noexcept
{
if (unlikely(!key))
return;
mpz_clear(key->n);
mpz_clear(key->e);
}
__attr_cold static inline void rsa_private_key_clear(RSAPrivateKey * const key) __noexcept
{
if (unlikely(!key))
return;
mpz_clear(key->n);
mpz_clear(key->e);
mpz_clear(key->d);
mpz_clear(key->p);
mpz_clear(key->q);
mpz_clear(key->dP);
mpz_clear(key->dQ);
mpz_clear(key->qInv);
}
// ========== KEY GENERATION ==========
__attr_hot static inline void rsa_generate_keypair(RSAKeyPair * const kp, const KeySize bits) __noexcept
{
if (unlikely(!kp))
{
fprintf(stderr, "Null RSAKeyPair pointer passed to key generation\n");
abort();
}
gmp_randstate_t rng;
gmp_randinit_mt(rng);
gmp_randseed_secure(rng); // Secure GMP RNG seed!
mpz_t p, q, phi, d, e, dP, dQ, qInv;
mpz_inits(p, q, phi, d, e, dP, dQ, qInv, NULL);
random_prime(p, bits / 2, rng);
int tries = 0;
do
{
random_prime(q, bits / 2, rng);
if (unlikely(++tries > 1000))
{
fprintf(stderr, "Failed to generate two unique primes for key generation\n");
abort();
}
} while (!mpz_cmp(p, q));
mpz_mul(kp->public_key.n, p, q);
mpz_set_ui(e, 65537);
mpz_set(kp->public_key.e, e);
mpz_sub_ui(phi, p, 1);
mpz_sub_ui(q, q, 1);
mpz_mul(phi, phi, q);
modinv(d, e, phi);
mpz_set(kp->private_key.n, kp->public_key.n);
mpz_set(kp->private_key.e, e);
mpz_set(kp->private_key.d, d);
mpz_set(kp->private_key.p, p);
mpz_set(kp->private_key.q, q);
mpz_mod(dP, d, p);
mpz_mod(dQ, d, q);
modinv(qInv, q, p);
mpz_set(kp->private_key.dP, dP);
mpz_set(kp->private_key.dQ, dQ);
mpz_set(kp->private_key.qInv, qInv);
mpz_clears(p, q, phi, d, e, dP, dQ, qInv, NULL);
gmp_randclear(rng);
}
// ========== ENCRYPTION/DECRYPTION ==========
__attr_nodiscard __attr_hot static inline int rsa_encrypt(const uint8_t * __restrict__ message, const size_t mlen, const RSAPublicKey * const pub, uint8_t * __restrict__ out, size_t * const outlen) __noexcept
{
if (unlikely(!message || !pub || !outlen))
{
fprintf(stderr, "NULL pointer argument to rsa_encrypt\n");
return 0;
}
const size_t mod_bytes = (mpz_sizeinbase(pub->n, 2) + 7) / 8;
if (!out) { *outlen = mod_bytes; return 1; } // Query mode
uint8_t * const block = (uint8_t *)SAFE_MALLOC(mod_bytes);
if (unlikely(!pkcs1_pad(message, mlen, block, mod_bytes)))
{
SECURE_FREE(block, mod_bytes);
return 0;
}
mpz_t m, c;
mpz_init(m);
mpz_init(c);
mpz_import(m, mod_bytes, 1, 1, 1, 0, block);
if (unlikely(mpz_cmp(m, pub->n) >= 0))
{
fprintf(stderr, "Message representative out of range in rsa_encrypt\n");
mpz_clear(m);
mpz_clear(c);
SECURE_FREE(block, mod_bytes);
return 0;
}
powm(c, m, pub->e, pub->n);
size_t clen = 0;
mpz_export(out, &clen, 1, 1, 1, 0, c);
*outlen = clen;
mpz_clear(m);
mpz_clear(c);
SECURE_FREE(block, mod_bytes);
return 1;
}
__attr_nodiscard __attr_hot static inline int rsa_decrypt(const uint8_t * __restrict__ enc, const size_t enclen, const RSAPrivateKey * const priv, uint8_t * __restrict__ out, size_t * const outlen) __noexcept
{
if (unlikely(!enc || !priv || !out || !outlen))
{
fprintf(stderr, "NULL pointer argument to rsa_decrypt\n");
return 0;
}
const size_t mod_bytes = (mpz_sizeinbase(priv->n, 2) + 7) / 8;
mpz_t c, m;
mpz_init(c);
mpz_init(m);
mpz_import(c, enclen, 1, 1, 1, 0, enc);
powm(m, c, priv->d, priv->n);
uint8_t * const padded = (uint8_t *)SAFE_MALLOC(mod_bytes);
size_t paddedlen = 0;
mpz_export(padded, &paddedlen, 1, 1, 1, 0, m);
if (paddedlen < mod_bytes)
{
memmove(padded + (mod_bytes - paddedlen), padded, paddedlen);
memset(padded, 0, mod_bytes - paddedlen);
paddedlen = mod_bytes;
}
const int res = pkcs1_unpad(padded, paddedlen, out, outlen);
SECURE_FREE(padded, mod_bytes);
mpz_clear(c);
mpz_clear(m);
return res;
}
// ========== PEM/DER SERIALIZATION (PUBLIC KEY) ==========
__attr_malloc __attr_nodiscard __attr_hot static inline char *rsa_public_key_to_pem(const RSAPublicKey * const pub) __noexcept
{
if (unlikely(!pub))
{
fprintf(stderr, "Null RSAPublicKey pointer passed to serialization\n");
return NULL;
}
uint8_t *en_n, *en_e, *seq;
size_t len_n, len_e, len_seq;
encode_integer(pub->n, &en_n, &len_n);
encode_integer(pub->e, &en_e, &len_e);
encode_sequence((uint8_t *[]){en_n, en_e}, (size_t[]){len_n, len_e}, 2, &seq, &len_seq);
char * const b64 = base64_encode(seq, len_seq);
char * const pem = pem_wrap(b64, "RSA PUBLIC KEY");
SECURE_FREE(en_n, len_n);
SECURE_FREE(en_e, len_e);
SECURE_FREE(seq, len_seq);
SECURE_FREE(b64, strlen(b64));
return pem;
}
__attr_malloc __attr_nodiscard __attr_hot static inline char *rsa_private_key_to_pem(const RSAPrivateKey * const priv) __noexcept
{
if (unlikely(!priv))
{
fprintf(stderr, "Null RSAPrivateKey pointer passed to serialization\n");
return NULL;
}
uint8_t *en_vals[9], *seq;
size_t lens[9], len_seq;
mpz_t zero;
mpz_init_set_ui(zero, 0);
encode_integer(zero, &en_vals[0], &lens[0]); // version
encode_integer(priv->n, &en_vals[1], &lens[1]);
encode_integer(priv->e, &en_vals[2], &lens[2]);
encode_integer(priv->d, &en_vals[3], &lens[3]);
encode_integer(priv->p, &en_vals[4], &lens[4]);
encode_integer(priv->q, &en_vals[5], &lens[5]);
encode_integer(priv->dP, &en_vals[6], &lens[6]);
encode_integer(priv->dQ, &en_vals[7], &lens[7]);
encode_integer(priv->qInv, &en_vals[8], &lens[8]);
encode_sequence(en_vals, lens, 9, &seq, &len_seq);
char * const b64 = base64_encode(seq, len_seq);
char * const pem = pem_wrap(b64, "RSA PRIVATE KEY");
for (int i = 0; i < 9; ++i)
SECURE_FREE(en_vals[i], lens[i]);
SECURE_FREE(seq, len_seq);
SECURE_FREE(b64, strlen(b64));
mpz_clear(zero);
return pem;
}
// ========== FILE SAVE/LOAD PEM ==========
__attr_nodiscard __attr_hot static inline int rsa_public_key_save_pem(const char *filename, const RSAPublicKey * const pub) __noexcept
{
char * const pem = rsa_public_key_to_pem(pub);
if (unlikely(!pem))
return 0;
const int r = write_text(filename, pem);
SECURE_FREE(pem, strlen(pem));
return r;
}
__attr_nodiscard __attr_hot static inline int rsa_private_key_save_pem(const char *filename, const RSAPrivateKey * const priv) __noexcept
{
char * const pem = rsa_private_key_to_pem(priv);
if (unlikely(!pem))
return 0;
const int r = write_text(filename, pem);
SECURE_FREE(pem, strlen(pem));
return r;
}
// ========== DER PARSING HELPERS & LOADERS ==========
__attr_hot static inline int der_read_tag_len(const uint8_t *buf, size_t buflen, size_t *offset, uint8_t *tag, size_t *len)
{
if (*offset + 2 > buflen) return 0;
*tag = buf[*offset];
(*offset)++;
uint8_t l = buf[*offset];
(*offset)++;
if (l & 0x80) {
size_t n = l & 0x7F;
if (n > sizeof(size_t) || *offset + n > buflen) return 0;
*len = 0;
for (size_t i = 0; i < n; ++i) {
*len = (*len << 8) | buf[*offset];
(*offset)++;
}
} else {
*len = l;
}
if (*offset + *len > buflen) return 0;
return 1;
}
__attr_hot static inline int der_read_integer(const uint8_t *buf, size_t buflen, size_t *offset, mpz_t out)
{
uint8_t tag = 0; size_t len = 0;
if (!der_read_tag_len(buf, buflen, offset, &tag, &len) || tag != 0x02) return 0;
mpz_import(out, len, 1, 1, 1, 0, buf + *offset);
*offset += len;
return 1;
}
__attr_hot static inline int der_expect_sequence(const uint8_t *buf, size_t buflen, size_t *offset, size_t *seqlen)
{
uint8_t tag = 0;
if (!der_read_tag_len(buf, buflen, offset, &tag, seqlen)) return 0;
if (tag != 0x30) return 0;
return 1;
}
__attr_hot static inline int rsa_public_key_load_pem(const char *filename, RSAPublicKey *pub)
{
size_t pemlen = 0;
char *pem = (char*)read_file(filename, &pemlen);
if (!pem) return 0;
char *b64 = strip_pem(pem);
SECURE_FREE(pem, pemlen);
if (!b64) return 0;
size_t derlen = 0;
uint8_t *der = base64_decode(b64, &derlen);
SECURE_FREE(b64, strlen(b64));
if (!der) return 0;
size_t off = 0, seqlen = 0;
rsa_public_key_init(pub);
int ok = 0;
if (der_expect_sequence(der, derlen, &off, &seqlen) &&
der_read_integer(der, derlen, &off, pub->n) &&
der_read_integer(der, derlen, &off, pub->e))
ok = 1;
SECURE_FREE(der, derlen);
if (!ok) rsa_public_key_clear(pub);
return ok;
}
__attr_hot static inline int rsa_private_key_load_pem(const char *filename, RSAPrivateKey *priv)
{
size_t pemlen = 0;
char *pem = (char*)read_file(filename, &pemlen);
if (!pem) return 0;
char *b64 = strip_pem(pem);
SECURE_FREE(pem, pemlen);
if (!b64) return 0;
size_t derlen = 0;
uint8_t *der = base64_decode(b64, &derlen);
SECURE_FREE(b64, strlen(b64));
if (!der) return 0;
size_t off = 0, seqlen = 0;
rsa_private_key_init(priv);
mpz_t version; mpz_init(version);
int ok = 0;
if (der_expect_sequence(der, derlen, &off, &seqlen) &&
der_read_integer(der, derlen, &off, version) && // version
der_read_integer(der, derlen, &off, priv->n) &&
der_read_integer(der, derlen, &off, priv->e) &&
der_read_integer(der, derlen, &off, priv->d) &&
der_read_integer(der, derlen, &off, priv->p) &&
der_read_integer(der, derlen, &off, priv->q) &&
der_read_integer(der, derlen, &off, priv->dP) &&
der_read_integer(der, derlen, &off, priv->dQ) &&
der_read_integer(der, derlen, &off, priv->qInv))
ok = 1;
mpz_clear(version);
SECURE_FREE(der, derlen);
if (!ok) rsa_private_key_clear(priv);
return ok;
}
#endif // RSA_ALGORITHM_C_H