Make DownloadSession use String identifier instead of AnyObject.

This commit is contained in:
Brent Simmons 2024-05-25 22:47:55 -07:00
parent 391408f00c
commit fb5a1b28d0

View File

@ -7,37 +7,35 @@
// //
import Foundation import Foundation
import os
// Create a DownloadSessionDelegate, then create a DownloadSession. // To download things: call `download` with a set of identifiers (String). Redirects are followed automatically.
// To download things: call downloadObjects, with a set of represented objects, to download things. DownloadSession will call the various delegate methods.
public protocol DownloadSessionDelegate { public protocol DownloadSessionDelegate: AnyObject {
@MainActor func downloadSession(_ downloadSession: DownloadSession, requestForRepresentedObject: AnyObject) -> URLRequest?
@MainActor func downloadSession(_ downloadSession: DownloadSession, downloadDidCompleteForRepresentedObject: AnyObject, response: URLResponse?, data: Data, error: NSError?, completion: @escaping () -> Void)
@MainActor func downloadSession(_ downloadSession: DownloadSession, shouldContinueAfterReceivingData: Data, representedObject: AnyObject) -> Bool
@MainActor func downloadSession(_ downloadSession: DownloadSession, didReceiveUnexpectedResponse: URLResponse, representedObject: AnyObject)
@MainActor func downloadSession(_ downloadSession: DownloadSession, didReceiveNotModifiedResponse: URLResponse, representedObject: AnyObject)
@MainActor func downloadSession(_ downloadSession: DownloadSession, didDiscardDuplicateRepresentedObject: AnyObject)
@MainActor func downloadSessionDidCompleteDownloadObjects(_ downloadSession: DownloadSession)
@MainActor func downloadSession(_ downloadSession: DownloadSession, requestForIdentifier: String) -> URLRequest?
@MainActor func downloadSession(_ downloadSession: DownloadSession, downloadDidCompleteForIdentifier: String, response: URLResponse?, data: Data?, error: Error?)
@MainActor func downloadSession(_ downloadSession: DownloadSession, shouldContinueAfterReceivingData: Data, identifier: String) -> Bool
@MainActor func downloadSession(_ downloadSession: DownloadSession, didReceiveUnexpectedResponse: URLResponse, identifier: String)
@MainActor func downloadSession(_ downloadSession: DownloadSession, didReceiveNotModifiedResponse: URLResponse, identifier: String)
@MainActor func downloadSession(_ downloadSession: DownloadSession, didDiscardDuplicateIdentifier: String)
@MainActor func downloadSessionDidComplete(_ downloadSession: DownloadSession)
} }
@MainActor @objc public final class DownloadSession: NSObject { @MainActor @objc public final class DownloadSession: NSObject {
public weak var delegate: DownloadSessionDelegate?
public var downloadProgress = DownloadProgress(numberOfTasks: 0)
private var urlSession: URLSession! private var urlSession: URLSession!
private var tasksInProgress = Set<URLSessionTask>() private var tasksInProgress = Set<URLSessionTask>()
private var tasksPending = Set<URLSessionTask>() private var tasksPending = Set<URLSessionTask>()
private var taskIdentifierToInfoDictionary = [Int: DownloadInfo]() private var taskIdentifierToInfoDictionary = [Int: DownloadInfo]()
private let representedObjects = NSMutableSet() private var allIdentifiers = Set<String>()
private let delegate: DownloadSessionDelegate
private var redirectCache = [String: String]() private var redirectCache = [String: String]()
private var queue = [AnyObject]() private var queue = [String]()
public init(delegate: DownloadSessionDelegate) {
self.delegate = delegate
override public init() {
super.init() super.init()
let sessionConfiguration = URLSessionConfiguration.default let sessionConfiguration = URLSessionConfiguration.default
@ -45,22 +43,27 @@ public protocol DownloadSessionDelegate {
sessionConfiguration.timeoutIntervalForRequest = 15.0 sessionConfiguration.timeoutIntervalForRequest = 15.0
sessionConfiguration.httpShouldSetCookies = false sessionConfiguration.httpShouldSetCookies = false
sessionConfiguration.httpCookieAcceptPolicy = .never sessionConfiguration.httpCookieAcceptPolicy = .never
sessionConfiguration.httpMaximumConnectionsPerHost = 2 sessionConfiguration.httpMaximumConnectionsPerHost = 1
sessionConfiguration.httpCookieStorage = nil sessionConfiguration.httpCookieStorage = nil
sessionConfiguration.urlCache = nil sessionConfiguration.urlCache = nil
sessionConfiguration.httpAdditionalHeaders = UserAgent.headers sessionConfiguration.httpAdditionalHeaders = UserAgent.headers
urlSession = URLSession(configuration: sessionConfiguration, delegate: self, delegateQueue: OperationQueue.main) self.urlSession = URLSession(configuration: sessionConfiguration, delegate: self, delegateQueue: OperationQueue.main)
} }
deinit { deinit {
urlSession.invalidateAndCancel() urlSession.invalidateAndCancel()
} }
// MARK: - API // MARK: - API
public func cancelAll() { public func cancelAll() async {
urlSession.getTasksWithCompletionHandler { dataTasks, uploadTasks, downloadTasks in
clearDownloadProgress()
let (dataTasks, uploadTasks, downloadTasks) = await urlSession.tasks
for dataTask in dataTasks { for dataTask in dataTasks {
dataTask.cancel() dataTask.cancel()
} }
@ -71,15 +74,15 @@ public protocol DownloadSessionDelegate {
downloadTask.cancel() downloadTask.cancel()
} }
} }
}
@MainActor public func downloadObjects(_ objects: NSSet) { public func download(_ identifiers: Set<String>) {
for oneObject in objects {
if !representedObjects.contains(oneObject) { for identifier in identifiers {
representedObjects.add(oneObject) if !allIdentifiers.contains(identifier) {
addDataTask(oneObject as AnyObject) allIdentifiers.insert(identifier)
addDataTask(identifier)
} else { } else {
delegate.downloadSession(self, didDiscardDuplicateRepresentedObject: oneObject as AnyObject) delegate?.downloadSession(self, didDiscardDuplicateIdentifier: identifier)
} }
} }
} }
@ -92,24 +95,20 @@ extension DownloadSession: URLSessionTaskDelegate {
nonisolated public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { nonisolated public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
MainActor.assumeIsolated { MainActor.assumeIsolated {
tasksInProgress.remove(task)
guard let info = infoForTask(task) else { guard let info = infoForTask(task) else {
assertionFailure("Missing info for task in DownloadSession didCompleteWithError")
return return
} }
info.error = error delegate?.downloadSession(self, downloadDidCompleteForIdentifier: info.identifier, response: info.urlResponse, data: info.data, error: error)
removeTask(task)
delegate.downloadSession(self, downloadDidCompleteForRepresentedObject: info.representedObject, response: info.urlResponse, data: info.data as Data, error: error as NSError?) {
self.removeTask(task)
}
} }
} }
nonisolated public func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { nonisolated public func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) {
MainActor.assumeIsolated { MainActor.assumeIsolated {
if response.statusCode == 301 || response.statusCode == 308 { if response.statusCode == HTTPResponseCode.redirectTemporary || response.statusCode == HTTPResponseCode.redirectVeryTemporary {
if let oldURLString = task.originalRequest?.url?.absoluteString, let newURLString = request.url?.absoluteString { if let oldURLString = task.originalRequest?.url?.absoluteString, let newURLString = request.url?.absoluteString {
cacheRedirect(oldURLString, newURLString) cacheRedirect(oldURLString, newURLString)
} }
@ -127,17 +126,18 @@ extension DownloadSession: URLSessionDataDelegate {
nonisolated public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) { nonisolated public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse, completionHandler: @escaping (URLSession.ResponseDisposition) -> Void) {
MainActor.assumeIsolated { MainActor.assumeIsolated {
tasksInProgress.insert(dataTask) tasksInProgress.insert(dataTask)
tasksPending.remove(dataTask) tasksPending.remove(dataTask)
if let info = infoForTask(dataTask) { let info = infoForTask(dataTask)
info.urlResponse = response let identifier = info?.identifier
} info?.urlResponse = response
if response.forcedStatusCode == 304 { if response.forcedStatusCode == HTTPResponseCode.notModified {
if let representedObject = infoForTask(dataTask)?.representedObject { if let identifier {
delegate.downloadSession(self, didReceiveNotModifiedResponse: response, representedObject: representedObject) delegate?.downloadSession(self, didReceiveNotModifiedResponse: response, identifier: identifier)
} }
completionHandler(.cancel) completionHandler(.cancel)
@ -148,8 +148,8 @@ extension DownloadSession: URLSessionDataDelegate {
if !response.statusIsOK { if !response.statusIsOK {
if let representedObject = infoForTask(dataTask)?.representedObject { if let identifier {
delegate.downloadSession(self, didReceiveUnexpectedResponse: response, representedObject: representedObject) delegate?.downloadSession(self, didReceiveUnexpectedResponse: response, identifier: identifier)
} }
completionHandler(.cancel) completionHandler(.cancel)
@ -167,33 +167,33 @@ extension DownloadSession: URLSessionDataDelegate {
nonisolated public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { nonisolated public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
MainActor.assumeIsolated { MainActor.assumeIsolated {
guard let info = infoForTask(dataTask) else {
guard let delegate, let info = infoForTask(dataTask) else {
return return
} }
info.addData(data) info.addData(data)
if !delegate.downloadSession(self, shouldContinueAfterReceivingData: info.data as Data, representedObject: info.representedObject) { if !delegate.downloadSession(self, shouldContinueAfterReceivingData: info.data!, identifier: info.identifier) {
info.canceled = true info.canceled = true
dataTask.cancel() dataTask.cancel()
removeTask(dataTask) removeTask(dataTask)
} }
} }
} }
} }
// MARK: - Private // MARK: - Private
private extension DownloadSession { private extension DownloadSession {
@MainActor func addDataTask(_ representedObject: AnyObject) { func addDataTask(_ identifier: String) {
guard tasksPending.count < 500 else { guard tasksPending.count < 500 else {
queue.insert(representedObject, at: 0) queue.insert(identifier, at: 0)
return return
} }
guard let request = delegate.downloadSession(self, requestForRepresentedObject: representedObject) else { guard let request = delegate?.downloadSession(self, requestForIdentifier: identifier) else {
return return
} }
@ -209,36 +209,58 @@ private extension DownloadSession {
let task = urlSession.dataTask(with: requestToUse) let task = urlSession.dataTask(with: requestToUse)
let info = DownloadInfo(representedObject, urlRequest: requestToUse) let info = DownloadInfo(identifier, urlRequest: requestToUse)
taskIdentifierToInfoDictionary[task.taskIdentifier] = info taskIdentifierToInfoDictionary[task.taskIdentifier] = info
tasksPending.insert(task) tasksPending.insert(task)
task.resume() task.resume()
updateDownloadProgress()
} }
func addDataTaskFromQueueIfNecessary() { func addDataTaskFromQueueIfNecessary() {
guard tasksPending.count < 500, let representedObject = queue.popLast() else { return }
addDataTask(representedObject) guard tasksPending.count < 500, let identifier = queue.popLast() else {
return
}
addDataTask(identifier)
} }
func infoForTask(_ task: URLSessionTask) -> DownloadInfo? { func infoForTask(_ task: URLSessionTask) -> DownloadInfo? {
return taskIdentifierToInfoDictionary[task.taskIdentifier] return taskIdentifierToInfoDictionary[task.taskIdentifier]
} }
@MainActor func removeTask(_ task: URLSessionTask) { func removeTask(_ task: URLSessionTask) {
tasksInProgress.remove(task) tasksInProgress.remove(task)
tasksPending.remove(task) tasksPending.remove(task)
taskIdentifierToInfoDictionary[task.taskIdentifier] = nil taskIdentifierToInfoDictionary[task.taskIdentifier] = nil
addDataTaskFromQueueIfNecessary() addDataTaskFromQueueIfNecessary()
downloadProgress.completeTask()
updateDownloadProgress()
if tasksInProgress.count + tasksPending.count < 1 { if tasksInProgress.count + tasksPending.count < 1 {
representedObjects.removeAllObjects() assert(allIdentifiers.isEmpty)
delegate.downloadSessionDidCompleteDownloadObjects(self) assert(queue.isEmpty)
delegate?.downloadSessionDidComplete(self)
clearDownloadProgress()
} }
} }
func urlStringIsBlackListedRedirect(_ urlString: String) -> Bool { func updateDownloadProgress() {
downloadProgress.numberRemaining = tasksInProgress.count + tasksPending.count + queue.count
}
func clearDownloadProgress() {
downloadProgress = DownloadProgress(numberOfTasks: 0)
}
func urlStringIsDisallowedRedirect(_ urlString: String) -> Bool {
// Hotels and similar often do permanent redirects. We can catch some of those. // Hotels and similar often do permanent redirects. We can catch some of those.
@ -255,7 +277,8 @@ private extension DownloadSession {
} }
func cacheRedirect(_ oldURLString: String, _ newURLString: String) { func cacheRedirect(_ oldURLString: String, _ newURLString: String) {
if urlStringIsBlackListedRedirect(newURLString) {
if urlStringIsDisallowedRedirect(newURLString) {
return return
} }
redirectCache[oldURLString] = newURLString redirectCache[oldURLString] = newURLString
@ -295,9 +318,9 @@ private extension DownloadSession {
private final class DownloadInfo { private final class DownloadInfo {
let representedObject: AnyObject let identifier: String
let urlRequest: URLRequest let urlRequest: URLRequest
let data = NSMutableData() var data: Data?
var error: Error? var error: Error?
var urlResponse: URLResponse? var urlResponse: URLResponse?
var canceled = false var canceled = false
@ -306,15 +329,17 @@ private final class DownloadInfo {
return urlResponse?.forcedStatusCode ?? 0 return urlResponse?.forcedStatusCode ?? 0
} }
init(_ representedObject: AnyObject, urlRequest: URLRequest) { init(_ identifier: String, urlRequest: URLRequest) {
self.representedObject = representedObject self.identifier = identifier
self.urlRequest = urlRequest self.urlRequest = urlRequest
} }
func addData(_ d: Data) { func addData(_ d: Data) {
data.append(d) if data == nil {
data = Data()
}
data!.append(d)
} }
} }