SubwayTooter-Android-App/app/src/main/java/jp/juggler/subwaytooter/util/ProgressResponseBody.kt

268 lines
8.6 KiB
Kotlin

package jp.juggler.subwaytooter.util
import java.io.IOException
import java.io.InputStream
import java.nio.charset.Charset
import okhttp3.Interceptor
import okhttp3.MediaType
import okhttp3.Response
import okhttp3.ResponseBody
import okio.Buffer
import okio.BufferedSource
import okio.ByteString
import okio.Options
import okio.Sink
import okio.Source
import okio.Timeout
import java.nio.ByteBuffer
class ProgressResponseBody private constructor(private val originalBody : ResponseBody) :
ResponseBody() {
companion object {
internal val log = LogCategory("ProgressResponseBody")
// please append this for OkHttpClient.Builder#addInterceptor().
// ex) builder.addInterceptor( ProgressResponseBody.makeInterceptor() );
fun makeInterceptor() : Interceptor {
return Interceptor { chain ->
val originalResponse = chain.proceed(chain.request())
?: throw RuntimeException("makeInterceptor: chain.proceed() returns null.")
val originalBody = originalResponse.body()
?: throw RuntimeException("makeInterceptor: originalResponse.body() reruens null.")
originalResponse.newBuilder()
.body(ProgressResponseBody(originalBody))
.build()
}
}
@Throws(IOException::class)
fun bytes(response : Response, callback : ProgressResponseBodyCallback) : ByteArray {
val body = response.body() ?: throw RuntimeException("response.body() is null.")
return bytes(body, callback)
}
@Suppress("MemberVisibilityCanPrivate")
@Throws(IOException::class)
private fun bytes(
body : ResponseBody,
callback : ProgressResponseBodyCallback
) : ByteArray {
if(body is ProgressResponseBody) {
body.callback = callback
}
return body.bytes()
}
}
private var callback : ProgressResponseBodyCallback = { _, _ -> }
/*
RequestBody.bytes() is defined as final, We can't override it.
Make WrappedBufferedSource to capture BufferedSource.readByteArray().
*/
private var wrappedSource : BufferedSource? = null
/*
then you can read response body's bytes() with progress callback.
example:
byte[] data = ProgressResponseBody.bytes( response, new ProgressResponseBody.Callback() {
@Override public void progressBytes( long bytesRead, long bytesTotal ){
publishApiProgressRatio( (int) bytesRead, (int) bytesTotal );
}
} );
*/
override fun contentType() : MediaType? {
return originalBody.contentType()
}
override fun contentLength() : Long {
return originalBody.contentLength()
}
override fun source() : BufferedSource? {
if(wrappedSource == null) {
val originalSource = originalBody.source() ?: return null
try {
// if it is RealBufferedSource, I can access to source public field via reflection.
val field_source = originalSource.javaClass.getField("source")
// If there is the method, create the wrapper.
wrappedSource = object : ForwardingBufferedSource(originalSource) {
@Throws(IOException::class)
override fun readByteArray() : ByteArray {
/*
RealBufferedSource.readByteArray() does:
- buffer.writeAll(source);
- return buffer.readByteArray(buffer.size());
We do same things using Reflection, with progress.
*/
try {
val contentLength = originalBody.contentLength()
val buffer = originalSource.buffer()
val source = field_source.get(originalSource) as Source?
?: throw IllegalArgumentException("source == null")
// same thing of Buffer.writeAll(), with counting.
var nRead : Long = 0
callback(0, Math.max(contentLength, 1))
while(true) {
val delta = source.read(buffer, 8192)
if(delta == - 1L) break
nRead += delta
if(nRead > 0) {
callback(nRead, Math.max(contentLength, nRead))
}
}
// EOS時の進捗
callback(nRead, Math.max(contentLength, nRead))
return buffer.readByteArray()
} catch(ex : Throwable) {
log.trace(ex)
log.e("readByteArray() failed. ")
return originalSource.readByteArray()
}
}
}
} catch(ex : Throwable) {
log.e("can't access to RealBufferedSource#source field.")
wrappedSource = originalSource
}
}
return wrappedSource
}
// To avoid double buffering, We have to make ForwardingBufferedSource.
internal open class ForwardingBufferedSource(private val originalSource : BufferedSource) :
BufferedSource {
override fun read(dst : ByteBuffer?) = originalSource.read(dst)
override fun isOpen() = originalSource.isOpen
override fun buffer() : Buffer? = originalSource.buffer()
override fun exhausted() = originalSource.exhausted()
override fun require(byteCount : Long) = originalSource.require(byteCount)
override fun request(byteCount : Long) = originalSource.request(byteCount)
override fun readByte() = originalSource.readByte()
override fun readShort() = originalSource.readShort()
override fun readShortLe() = originalSource.readShortLe()
override fun readInt() = originalSource.readInt()
override fun readIntLe() = originalSource.readIntLe()
override fun readLong() = originalSource.readLong()
override fun readLongLe() = originalSource.readLongLe()
override fun readDecimalLong() = originalSource.readDecimalLong()
override fun readHexadecimalUnsignedLong() = originalSource.readHexadecimalUnsignedLong()
override fun skip(byteCount : Long) = originalSource.skip(byteCount)
override fun readByteString() : ByteString? = originalSource.readByteString()
override fun readByteString(byteCount : Long) : ByteString? =
originalSource.readByteString(byteCount)
override fun select(options : Options) = originalSource.select(options)
override fun readByteArray() : ByteArray? = originalSource.readByteArray()
override fun readByteArray(byteCount : Long) : ByteArray? =
originalSource.readByteArray(byteCount)
override fun read(sink : ByteArray) = originalSource.read(sink)
override fun readFully(sink : ByteArray) = originalSource.readFully(sink)
override fun read(sink : ByteArray, offset : Int, byteCount : Int) =
originalSource.read(sink, offset, byteCount)
override fun readFully(sink : Buffer, byteCount : Long) =
originalSource.readFully(sink, byteCount)
override fun readAll(sink : Sink) = originalSource.readAll(sink)
override fun readUtf8() : String? = originalSource.readUtf8()
override fun readUtf8(byteCount : Long) : String? = originalSource.readUtf8(byteCount)
override fun readUtf8Line() : String? = originalSource.readUtf8Line()
override fun readUtf8LineStrict() : String? = originalSource.readUtf8LineStrict()
override fun readUtf8LineStrict(limit : Long) : String? =
originalSource.readUtf8LineStrict(limit)
override fun readUtf8CodePoint() = originalSource.readUtf8CodePoint()
override fun readString(charset : Charset) : String? = originalSource.readString(charset)
override fun readString(byteCount : Long, charset : Charset) : String? =
originalSource.readString(byteCount, charset)
override fun indexOf(b : Byte) = originalSource.indexOf(b)
override fun indexOf(b : Byte, fromIndex : Long) = originalSource.indexOf(b, fromIndex)
override fun indexOf(b : Byte, fromIndex : Long, toIndex : Long) =
originalSource.indexOf(b, fromIndex, toIndex)
override fun indexOf(bytes : ByteString) = originalSource.indexOf(bytes)
override fun indexOf(bytes : ByteString, fromIndex : Long) =
originalSource.indexOf(bytes, fromIndex)
override fun indexOfElement(targetBytes : ByteString) =
originalSource.indexOfElement(targetBytes)
override fun indexOfElement(targetBytes : ByteString, fromIndex : Long) =
originalSource.indexOfElement(targetBytes, fromIndex)
override fun rangeEquals(offset : Long, bytes : ByteString) =
originalSource.rangeEquals(offset, bytes)
override fun rangeEquals(
offset : Long,
bytes : ByteString,
bytesOffset : Int,
byteCount : Int
) = originalSource.rangeEquals(offset, bytes, bytesOffset, byteCount)
override fun inputStream() : InputStream? = originalSource.inputStream()
override fun read(sink : Buffer, byteCount : Long) = originalSource.read(sink, byteCount)
override fun timeout() : Timeout? = originalSource.timeout()
override fun close() = originalSource.close()
}
}