diff --git a/libavformat/crypto.c b/libavformat/crypto.c index 346c7061b4..9a48f2e6f5 100644 --- a/libavformat/crypto.c +++ b/libavformat/crypto.c @@ -54,10 +54,10 @@ typedef struct CryptoContext { int encrypt_ivlen; struct AVAES *aes_decrypt; struct AVAES *aes_encrypt; - + uint8_t *write_buf; + unsigned int write_buf_size; uint8_t pad[BLOCKSIZE]; int pad_len; - } CryptoContext; #define OFFSET(x) offsetof(CryptoContext, x) @@ -80,16 +80,16 @@ static const AVClass crypto_class = { .version = LIBAVUTIL_VERSION_INT, }; -static int set_aes_arg(CryptoContext *c, uint8_t **buf, int *buf_len, +static int set_aes_arg(URLContext *h, uint8_t **buf, int *buf_len, uint8_t *default_buf, int default_buf_len, const char *desc) { if (!*buf_len) { if (!default_buf_len) { - av_log(c, AV_LOG_ERROR, "%s not set\n", desc); + av_log(h, AV_LOG_ERROR, "%s not set\n", desc); return AVERROR(EINVAL); } else if (default_buf_len != BLOCKSIZE) { - av_log(c, AV_LOG_ERROR, + av_log(h, AV_LOG_ERROR, "invalid %s size (%d bytes, block size is %d)\n", desc, default_buf_len, BLOCKSIZE); return AVERROR(EINVAL); @@ -99,7 +99,7 @@ static int set_aes_arg(CryptoContext *c, uint8_t **buf, int *buf_len, return AVERROR(ENOMEM); *buf_len = default_buf_len; } else if (*buf_len != BLOCKSIZE) { - av_log(c, AV_LOG_ERROR, + av_log(h, AV_LOG_ERROR, "invalid %s size (%d bytes, block size is %d)\n", desc, *buf_len, BLOCKSIZE); return AVERROR(EINVAL); @@ -121,23 +121,21 @@ static int crypto_open2(URLContext *h, const char *uri, int flags, AVDictionary goto err; } - c->position = 0; - if (flags & AVIO_FLAG_READ) { - if ((ret = set_aes_arg(c, &c->decrypt_key, &c->decrypt_keylen, + if ((ret = set_aes_arg(h, &c->decrypt_key, &c->decrypt_keylen, c->key, c->keylen, "decryption key")) < 0) goto err; - if ((ret = set_aes_arg(c, &c->decrypt_iv, &c->decrypt_ivlen, + if ((ret = set_aes_arg(h, &c->decrypt_iv, &c->decrypt_ivlen, c->iv, c->ivlen, "decryption IV")) < 0) goto err; } if (flags & AVIO_FLAG_WRITE) { - if ((ret = set_aes_arg(c, &c->encrypt_key, &c->encrypt_keylen, + if ((ret = set_aes_arg(h, &c->encrypt_key, &c->encrypt_keylen, c->key, c->keylen, "encryption key")) < 0) if (ret < 0) goto err; - if ((ret = set_aes_arg(c, &c->encrypt_iv, &c->encrypt_ivlen, + if ((ret = set_aes_arg(h, &c->encrypt_iv, &c->encrypt_ivlen, c->iv, c->ivlen, "encryption IV")) < 0) goto err; } @@ -155,7 +153,7 @@ static int crypto_open2(URLContext *h, const char *uri, int flags, AVDictionary ret = AVERROR(ENOMEM); goto err; } - ret = av_aes_init(c->aes_decrypt, c->decrypt_key, BLOCKSIZE*8, 1); + ret = av_aes_init(c->aes_decrypt, c->decrypt_key, BLOCKSIZE * 8, 1); if (ret < 0) goto err; @@ -170,7 +168,7 @@ static int crypto_open2(URLContext *h, const char *uri, int flags, AVDictionary ret = AVERROR(ENOMEM); goto err; } - ret = av_aes_init(c->aes_encrypt, c->encrypt_key, BLOCKSIZE*8, 0); + ret = av_aes_init(c->aes_encrypt, c->encrypt_key, BLOCKSIZE * 8, 0); if (ret < 0) goto err; // for write, we must be streamed @@ -178,8 +176,6 @@ static int crypto_open2(URLContext *h, const char *uri, int flags, AVDictionary h->is_streamed = 1; } - c->pad_len = 0; - err: return ret; } @@ -338,7 +334,6 @@ static int crypto_write(URLContext *h, const unsigned char *buf, int size) { CryptoContext *c = h->priv_data; int total_size, blocks, pad_len, out_size; - uint8_t *out_buf; int ret = 0; total_size = size + c->pad_len; @@ -347,22 +342,23 @@ static int crypto_write(URLContext *h, const unsigned char *buf, int size) blocks = out_size / BLOCKSIZE; if (out_size) { - out_buf = av_malloc(out_size); - if (!out_buf) + av_fast_malloc(&c->write_buf, &c->write_buf_size, out_size); + + if (!c->write_buf) return AVERROR(ENOMEM); if (c->pad_len) { memcpy(&c->pad[c->pad_len], buf, BLOCKSIZE - c->pad_len); - av_aes_crypt(c->aes_encrypt, out_buf, c->pad, 1, c->encrypt_iv, 0); + av_aes_crypt(c->aes_encrypt, c->write_buf, c->pad, 1, c->encrypt_iv, 0); blocks--; } - av_aes_crypt(c->aes_encrypt, &out_buf[c->pad_len ? BLOCKSIZE : 0], - &buf[c->pad_len ? BLOCKSIZE - c->pad_len: 0], - blocks, c->encrypt_iv, 0); + av_aes_crypt(c->aes_encrypt, + &c->write_buf[c->pad_len ? BLOCKSIZE : 0], + &buf[c->pad_len ? BLOCKSIZE - c->pad_len : 0], + blocks, c->encrypt_iv, 0); - ret = ffurl_write(c->hd, out_buf, out_size); - av_free(out_buf); + ret = ffurl_write(c->hd, c->write_buf, out_size); if (ret < 0) return ret; @@ -378,22 +374,23 @@ static int crypto_write(URLContext *h, const unsigned char *buf, int size) static int crypto_close(URLContext *h) { CryptoContext *c = h->priv_data; - uint8_t out_buf[BLOCKSIZE]; - int ret, pad; + int ret = 0; if (c->aes_encrypt) { - pad = BLOCKSIZE - c->pad_len; + uint8_t out_buf[BLOCKSIZE]; + int pad = BLOCKSIZE - c->pad_len; + memset(&c->pad[c->pad_len], pad, pad); av_aes_crypt(c->aes_encrypt, out_buf, c->pad, 1, c->encrypt_iv, 0); - if ((ret = ffurl_write(c->hd, out_buf, BLOCKSIZE)) < 0) - return ret; + ret = ffurl_write(c->hd, out_buf, BLOCKSIZE); } if (c->hd) ffurl_close(c->hd); av_freep(&c->aes_decrypt); av_freep(&c->aes_encrypt); - return 0; + av_freep(&c->write_buf); + return ret; } const URLProtocol ff_crypto_protocol = {