This commit is contained in:
codl 2017-08-29 14:46:32 +02:00
parent 2c4d6b9f63
commit 007aec7529
No known key found for this signature in database
GPG Key ID: 6CD7C8891ED1233A
17 changed files with 472 additions and 230 deletions

19
app.py
View File

@ -6,7 +6,6 @@ from flask_migrate import Migrate
import version import version
from lib import cachebust from lib import cachebust
from flask_limiter import Limiter from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from lib import get_viewer from lib import get_viewer
import os import os
import mimetypes import mimetypes
@ -29,7 +28,7 @@ app.config.update(default_config)
app.config.from_pyfile('config.py', True) app.config.from_pyfile('config.py', True)
metadata = MetaData(naming_convention = { metadata = MetaData(naming_convention={
"ix": 'ix_%(column_0_label)s', "ix": 'ix_%(column_0_label)s',
"uq": "uq_%(table_name)s_%(column_0_name)s", "uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s",
@ -55,12 +54,14 @@ if 'SENTRY_DSN' in app.config:
url_for = cachebust(app) url_for = cachebust(app)
@app.context_processor @app.context_processor
def inject_static(): def inject_static():
def static(filename, **kwargs): def static(filename, **kwargs):
return url_for('static', filename=filename, **kwargs) return url_for('static', filename=filename, **kwargs)
return {'st': static} return {'st': static}
def rate_limit_key(): def rate_limit_key():
viewer = get_viewer() viewer = get_viewer()
if viewer: if viewer:
@ -71,16 +72,25 @@ def rate_limit_key():
return address return address
return request.remote_addr return request.remote_addr
limiter = Limiter(app, key_func=rate_limit_key) limiter = Limiter(app, key_func=rate_limit_key)
@app.after_request @app.after_request
def install_security_headers(resp): def install_security_headers(resp):
csp = "default-src 'none'; img-src 'self' https:; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self'; frame-ancestors 'none'" csp = ("default-src 'none';"
"img-src 'self' https:;"
"script-src 'self';"
"style-src 'self' 'unsafe-inline';"
"connect-src 'self';"
"frame-ancestors 'none';"
)
if 'CSP_REPORT_URI' in app.config: if 'CSP_REPORT_URI' in app.config:
csp += "; report-uri " + app.config.get('CSP_REPORT_URI') csp += "; report-uri " + app.config.get('CSP_REPORT_URI')
if app.config.get('HTTPS'): if app.config.get('HTTPS'):
resp.headers.set('strict-transport-security', 'max-age={}'.format(60*60*24*365)) resp.headers.set('strict-transport-security',
'max-age={}'.format(60*60*24*365))
csp += "; upgrade-insecure-requests" csp += "; upgrade-insecure-requests"
resp.headers.set('Content-Security-Policy', csp) resp.headers.set('Content-Security-Policy', csp)
@ -91,4 +101,5 @@ def install_security_headers(resp):
return resp return resp
mimetypes.add_type('image/webp', '.webp') mimetypes.add_type('image/webp', '.webp')

33
dodo.py
View File

@ -1,10 +1,12 @@
from doit import create_after from doit import create_after
def reltouch(source_filename, dest_filename): def reltouch(source_filename, dest_filename):
from os import stat, utime from os import stat, utime
stat_res = stat(source_filename) stat_res = stat(source_filename)
utime(dest_filename, ns=(stat_res.st_atime_ns, stat_res.st_mtime_ns)) utime(dest_filename, ns=(stat_res.st_atime_ns, stat_res.st_mtime_ns))
def resize_image(basename, width, format): def resize_image(basename, width, format):
from PIL import Image from PIL import Image
with Image.open('assets/{}.png'.format(basename)) as im: with Image.open('assets/{}.png'.format(basename)) as im:
@ -13,23 +15,25 @@ def resize_image(basename, width, format):
else: else:
im = im.convert('RGB') im = im.convert('RGB')
height = im.height * width // im.width height = im.height * width // im.width
new = im.resize((width,height), resample=Image.LANCZOS) new = im.resize((width, height), resample=Image.LANCZOS)
if format == 'jpeg': if format == 'jpeg':
kwargs = dict( kwargs = dict(
optimize = True, optimize=True,
progressive = True, progressive=True,
quality = 80, quality=80,
) )
elif format == 'webp': elif format == 'webp':
kwargs = dict( kwargs = dict(
quality = 79, quality=79,
) )
elif format == 'png': elif format == 'png':
kwargs = dict( kwargs = dict(
optimize = True, optimize=True,
) )
new.save('static/{}-{}.{}'.format(basename, width, format), **kwargs) new.save('static/{}-{}.{}'.format(basename, width, format), **kwargs)
reltouch('assets/{}.png'.format(basename), 'static/{}-{}.{}'.format(basename, width, format)) reltouch('assets/{}.png'.format(basename),
'static/{}-{}.{}'.format(basename, width, format))
def task_logotype(): def task_logotype():
"""resize and convert logotype""" """resize and convert logotype"""
@ -45,9 +49,10 @@ def task_logotype():
clean=True, clean=True,
) )
def task_service_icon(): def task_service_icon():
"""resize and convert service icons""" """resize and convert service icons"""
widths = (20,40,80) widths = (20, 40, 80)
formats = ('webp', 'png') formats = ('webp', 'png')
for width in widths: for width in widths:
for format in formats: for format in formats:
@ -55,11 +60,13 @@ def task_service_icon():
yield dict( yield dict(
name='{}-{}.{}'.format(basename, width, format), name='{}-{}.{}'.format(basename, width, format),
actions=[(resize_image, (basename, width, format))], actions=[(resize_image, (basename, width, format))],
targets=['static/{}-{}.{}'.format(basename,width,format)], targets=[
'static/{}-{}.{}'.format(basename, width, format)],
file_dep=['assets/{}.png'.format(basename)], file_dep=['assets/{}.png'.format(basename)],
clean=True, clean=True,
) )
def task_copy(): def task_copy():
"copy assets verbatim" "copy assets verbatim"
@ -81,6 +88,7 @@ def task_copy():
clean=True, clean=True,
) )
def task_minify_css(): def task_minify_css():
"""minify css file with csscompressor""" """minify css file with csscompressor"""
@ -99,12 +107,16 @@ def task_minify_css():
clean=True, clean=True,
) )
@create_after('logotype') @create_after('logotype')
@create_after('service_icon') @create_after('service_icon')
@create_after('copy') @create_after('copy')
@create_after('minify_css') @create_after('minify_css')
def task_compress(): def task_compress():
"make gzip and brotli compressed versions of each static file for the server to lazily serve" """
make gzip and brotli compressed versions of each
static file for the server to lazily serve
"""
from glob import glob from glob import glob
from itertools import chain from itertools import chain
@ -146,6 +158,7 @@ def task_compress():
clean=True, clean=True,
) )
if __name__ == '__main__': if __name__ == '__main__':
import doit import doit
doit.run(globals()) doit.run(globals())

View File

@ -1,2 +1,2 @@
from app import app from app import app # noqa: F401
import routes import routes # noqa: F401

View File

@ -1,6 +1,7 @@
from flask import g, redirect, jsonify, make_response, abort, request from flask import g, redirect, jsonify, make_response, abort, request
from functools import wraps from functools import wraps
def require_auth(fun): def require_auth(fun):
@wraps(fun) @wraps(fun)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -9,11 +10,14 @@ def require_auth(fun):
return fun(*args, **kwargs) return fun(*args, **kwargs)
return wrapper return wrapper
def require_auth_api(fun): def require_auth_api(fun):
@wraps(fun) @wraps(fun)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not g.viewer: if not g.viewer:
return make_response((jsonify(status='error', error='not logged in'), 403)) return make_response((
jsonify(status='error', error='not logged in'),
403))
return fun(*args, **kwargs) return fun(*args, **kwargs)
return wrapper return wrapper

View File

@ -6,6 +6,7 @@ import redis
import os.path import os.path
import mimetypes import mimetypes
class BrotliCache(object): class BrotliCache(object):
def __init__(self, redis_kwargs={}, max_wait=0.020, expire=60*60*6): def __init__(self, redis_kwargs={}, max_wait=0.020, expire=60*60*6):
self.redis = redis.StrictRedis(**redis_kwargs) self.redis = redis.StrictRedis(**redis_kwargs)
@ -32,8 +33,13 @@ class BrotliCache(object):
response.headers.set('x-brotli-cache', 'MISS') response.headers.set('x-brotli-cache', 'MISS')
lock_key = 'brotlicache:lock:{}'.format(digest) lock_key = 'brotlicache:lock:{}'.format(digest)
if self.redis.set(lock_key, 1, nx=True, ex=10): if self.redis.set(lock_key, 1, nx=True, ex=10):
mode = brotli_.MODE_TEXT if response.content_type.startswith('text/') else brotli_.MODE_GENERIC mode = (
t = Thread(target=self.compress, args=(cache_key, lock_key, body, mode)) brotli_.MODE_TEXT
if response.content_type.startswith('text/')
else brotli_.MODE_GENERIC)
t = Thread(
target=self.compress,
args=(cache_key, lock_key, body, mode))
t.start() t.start()
if self.max_wait > 0: if self.max_wait > 0:
t.join(self.max_wait) t.join(self.max_wait)
@ -50,8 +56,10 @@ class BrotliCache(object):
return response return response
def brotli(app, static = True, dynamic = True):
def brotli(app, static=True, dynamic=True):
original_static = app.view_functions['static'] original_static = app.view_functions['static']
def static_maybe_gzip_brotli(filename=None): def static_maybe_gzip_brotli(filename=None):
path = os.path.join(app.static_folder, filename) path = os.path.join(app.static_folder, filename)
for encoding, extension in (('br', '.br'), ('gzip', '.gz')): for encoding, extension in (('br', '.br'), ('gzip', '.gz')):
@ -59,10 +67,12 @@ def brotli(app, static = True, dynamic = True):
continue continue
encpath = path + extension encpath = path + extension
if os.path.isfile(encpath): if os.path.isfile(encpath):
resp = make_response(original_static(filename=filename + extension)) resp = make_response(
original_static(filename=filename + extension))
resp.headers.set('content-encoding', encoding) resp.headers.set('content-encoding', encoding)
resp.headers.set('vary', 'accept-encoding') resp.headers.set('vary', 'accept-encoding')
mimetype = mimetypes.guess_type(filename)[0] or 'application/octet-stream' mimetype = (mimetypes.guess_type(filename)[0]
or 'application/octet-stream')
resp.headers.set('content-type', mimetype) resp.headers.set('content-type', mimetype)
return resp return resp
return original_static(filename=filename) return original_static(filename=filename)

View File

@ -1,5 +1,7 @@
from flask import url_for, abort from flask import url_for, abort
import os import os
def cachebust(app): def cachebust(app):
@app.route('/static-cb/<int:timestamp>/<path:filename>') @app.route('/static-cb/<int:timestamp>/<path:filename>')
def static_cachebust(timestamp, filename): def static_cachebust(timestamp, filename):
@ -12,14 +14,16 @@ def cachebust(app):
abort(404) abort(404)
else: else:
resp = app.view_functions['static'](filename=filename) resp = app.view_functions['static'](filename=filename)
resp.headers.set('cache-control', 'public, immutable, max-age=%s' % (60*60*24*365,)) resp.headers.set(
'cache-control',
'public, immutable, max-age={}'.format(60*60*24*365))
if 'expires' in resp.headers: if 'expires' in resp.headers:
resp.headers.remove('expires') resp.headers.remove('expires')
return resp return resp
@app.context_processor @app.context_processor
def replace_url_for(): def replace_url_for():
return dict(url_for = cachebust_url_for) return dict(url_for=cachebust_url_for)
def cachebust_url_for(endpoint, **kwargs): def cachebust_url_for(endpoint, **kwargs):
if endpoint == 'static': if endpoint == 'static':

View File

@ -1,30 +1,6 @@
from datetime import timedelta, datetime from datetime import timedelta, datetime
from statistics import mean from scales import SCALES
SCALES = [
('minutes', timedelta(minutes=1)),
('hours', timedelta(hours=1)),
('days', timedelta(days=1)),
('weeks', timedelta(days=7)),
('months', timedelta(days=
# you, a fool: a month is 30 days
# me, wise:
mean((31,
mean((29 if year % 400 == 0
or (year % 100 != 0 and year % 4 == 0)
else 28
for year in range(400)))
,31,30,31,30,31,31,30,31,30,31))
)),
('years', timedelta(days=
# you, a fool: ok. a year is 365.25 days. happy?
# me, wise: absolutely not
mean((366 if year % 400 == 0
or (year % 100 != 0 and year % 4 == 0)
else 365
for year in range(400)))
)),
]
def decompose_interval(attrname): def decompose_interval(attrname):
scales = [scale[1] for scale in SCALES] scales = [scale[1] for scale in SCALES]
@ -69,7 +45,6 @@ def decompose_interval(attrname):
raise ValueError("Incorrect time interval", e) raise ValueError("Incorrect time interval", e)
setattr(self, attrname, value * getattr(self, scl_name)) setattr(self, attrname, value * getattr(self, scl_name))
setattr(cls, scl_name, scale) setattr(cls, scl_name, scale)
setattr(cls, sig_name, significand) setattr(cls, sig_name, significand)
@ -77,6 +52,7 @@ def decompose_interval(attrname):
return decorator return decorator
def relative(interval): def relative(interval):
# special cases # special cases
if interval > timedelta(seconds=-15) and interval < timedelta(0): if interval > timedelta(seconds=-15) and interval < timedelta(0):
@ -99,5 +75,6 @@ def relative(interval):
else: else:
return '{} ago'.format(output) return '{} ago'.format(output)
def relnow(time): def relnow(time):
return relative(time - datetime.now()) return relative(time - datetime.now())

View File

@ -1,4 +1,3 @@
import mastodon
from mastodon import Mastodon from mastodon import Mastodon
from mastodon.Mastodon import MastodonAPIError from mastodon.Mastodon import MastodonAPIError
from model import MastodonApp, Account, OAuthToken, Post from model import MastodonApp, Account, OAuthToken, Post
@ -7,6 +6,7 @@ from app import db
from math import inf from math import inf
import iso8601 import iso8601
def get_or_create_app(instance_url, callback, website): def get_or_create_app(instance_url, callback, website):
instance_url = instance_url instance_url = instance_url
app = MastodonApp.query.get(instance_url) app = MastodonApp.query.get(instance_url)
@ -18,7 +18,8 @@ def get_or_create_app(instance_url, callback, website):
proto = 'http' proto = 'http'
if not app: if not app:
client_id, client_secret = Mastodon.create_app('forget', client_id, client_secret = Mastodon.create_app(
'forget',
scopes=('read', 'write'), scopes=('read', 'write'),
api_base_url='{}://{}'.format(proto, instance_url), api_base_url='{}://{}'.format(proto, instance_url),
redirect_uris=callback, redirect_uris=callback,
@ -31,18 +32,22 @@ def get_or_create_app(instance_url, callback, website):
app.protocol = proto app.protocol = proto
return app return app
def anonymous_api(app): def anonymous_api(app):
return Mastodon(app.client_id, return Mastodon(
client_secret = app.client_secret, app.client_id,
client_secret=app.client_secret,
api_base_url='{}://{}'.format(app.protocol, app.instance), api_base_url='{}://{}'.format(app.protocol, app.instance),
) )
def login_url(app, callback): def login_url(app, callback):
return anonymous_api(app).auth_request_url( return anonymous_api(app).auth_request_url(
redirect_uris=callback, redirect_uris=callback,
scopes=('read', 'write',) scopes=('read', 'write',)
) )
def receive_code(code, app, callback): def receive_code(code, app, callback):
api = anonymous_api(app) api = anonymous_api(app)
access_token = api.log_in( access_token = api.log_in(
@ -54,7 +59,7 @@ def receive_code(code, app, callback):
remote_acc = api.account_verify_credentials() remote_acc = api.account_verify_credentials()
acc = account_from_api_object(remote_acc, app.instance) acc = account_from_api_object(remote_acc, app.instance)
acc = db.session.merge(acc) acc = db.session.merge(acc)
token = OAuthToken(token = access_token) token = OAuthToken(token=access_token)
token = db.session.merge(token) token = db.session.merge(token)
token.account = acc token.account = acc
@ -64,12 +69,12 @@ def receive_code(code, app, callback):
def get_api_for_acc(account): def get_api_for_acc(account):
app = MastodonApp.query.get(account.mastodon_instance) app = MastodonApp.query.get(account.mastodon_instance)
for token in account.tokens: for token in account.tokens:
api = Mastodon(app.client_id, api = Mastodon(
client_secret = app.client_secret, app.client_id,
api_base_url = '{}://{}'.format(app.protocol, app.instance), client_secret=app.client_secret,
access_token = token.token, api_base_url='{}://{}'.format(app.protocol, app.instance),
ratelimit_method = 'throw', access_token=token.token,
#debug_requests = True, ratelimit_method='throw',
) )
# api.verify_credentials() # api.verify_credentials()
@ -91,15 +96,18 @@ def fetch_acc(acc, cursor=None):
print('no access, aborting') print('no access, aborting')
return None return None
newacc = account_from_api_object(api.account_verify_credentials(), acc.mastodon_instance) newacc = account_from_api_object(
api.account_verify_credentials(), acc.mastodon_instance)
acc = db.session.merge(newacc) acc = db.session.merge(newacc)
kwargs = dict(limit = 40) kwargs = dict(limit=40)
if cursor: if cursor:
kwargs.update(cursor) kwargs.update(cursor)
if 'max_id' not in kwargs: if 'max_id' not in kwargs:
most_recent_post = Post.query.with_parent(acc).order_by(db.desc(Post.created_at)).first() most_recent_post = (
Post.query.with_parent(acc)
.order_by(db.desc(Post.created_at)).first())
if most_recent_post: if most_recent_post:
kwargs['since_id'] = most_recent_post.mastodon_id kwargs['since_id'] = most_recent_post.mastodon_id
@ -120,27 +128,31 @@ def fetch_acc(acc, cursor=None):
return kwargs return kwargs
def post_from_api_object(obj, instance): def post_from_api_object(obj, instance):
return Post( return Post(
mastodon_instance = instance, mastodon_instance=instance,
mastodon_id = obj['id'], mastodon_id=obj['id'],
favourite = obj['favourited'], favourite=obj['favourited'],
has_media = 'media_attachments' in obj and bool(obj['media_attachments']), has_media=('media_attachments' in obj
created_at = iso8601.parse_date(obj['created_at']), and bool(obj['media_attachments'])),
author_id = account_from_api_object(obj['account'], instance).id, created_at=iso8601.parse_date(obj['created_at']),
direct = obj['visibility'] == 'direct', author_id=account_from_api_object(obj['account'], instance).id,
direct=obj['visibility'] == 'direct',
) )
def account_from_api_object(obj, instance): def account_from_api_object(obj, instance):
return Account( return Account(
mastodon_instance = instance, mastodon_instance=instance,
mastodon_id = obj['id'], mastodon_id=obj['id'],
screen_name = obj['username'], screen_name=obj['username'],
display_name = obj['display_name'], display_name=obj['display_name'],
avatar_url = obj['avatar'], avatar_url=obj['avatar'],
reported_post_count = obj['statuses_count'], reported_post_count=obj['statuses_count'],
) )
def refresh_posts(posts): def refresh_posts(posts):
acc = posts[0].author acc = posts[0].author
api = get_api_for_acc(acc) api = get_api_for_acc(acc)
@ -151,7 +163,8 @@ def refresh_posts(posts):
for post in posts: for post in posts:
try: try:
status = api.status(post.mastodon_id) status = api.status(post.mastodon_id)
new_post = db.session.merge(post_from_api_object(status, post.mastodon_instance)) new_post = db.session.merge(
post_from_api_object(status, post.mastodon_instance))
new_posts.append(new_post) new_posts.append(new_post)
except MastodonAPIError as e: except MastodonAPIError as e:
if str(e) == 'Endpoint not found.': if str(e) == 'Endpoint not found.':
@ -161,6 +174,7 @@ def refresh_posts(posts):
return new_posts return new_posts
def delete(post): def delete(post):
api = get_api_for_acc(post.author) api = get_api_for_acc(post.author)
api.status_delete(post.mastodon_id) api.status_delete(post.mastodon_id)

29
lib/scales.py Normal file
View File

@ -0,0 +1,29 @@
# flake8: noqa
from datetime import timedelta
from statistics import mean
SCALES = [
('minutes', timedelta(minutes=1)),
('hours', timedelta(hours=1)),
('days', timedelta(days=1)),
('weeks', timedelta(days=7)),
('months', timedelta(days=
# you, a fool: a month is 30 days
# me, wise:
mean((31,
mean((29 if year % 400 == 0
or (year % 100 != 0 and year % 4 == 0)
else 28
for year in range(400)))
,31,30,31,30,31,31,30,31,30,31))
)),
('years', timedelta(days=
# you, a fool: ok. a year is 365.25 days. happy?
# me, wise: absolutely not
mean((366 if year % 400 == 0
or (year % 100 != 0 and year % 4 == 0)
else 365
for year in range(400)))
)),
]

View File

@ -1,17 +1,21 @@
from flask import request from flask import request
def set_session_cookie(session, response, secure=True): def set_session_cookie(session, response, secure=True):
response.set_cookie('forget_sid', session.id, response.set_cookie(
'forget_sid', session.id,
max_age=60*60*48, max_age=60*60*48,
httponly=True, httponly=True,
secure=secure) secure=secure)
def get_viewer_session(): def get_viewer_session():
from model import Session from model import Session
sid = request.cookies.get('forget_sid', None) sid = request.cookies.get('forget_sid', None)
if sid: if sid:
return Session.query.get(sid) return Session.query.get(sid)
def get_viewer(): def get_viewer():
session = get_viewer_session() session = get_viewer_session()
if session: if session:

View File

@ -8,6 +8,7 @@ import locale
from zipfile import ZipFile from zipfile import ZipFile
from io import BytesIO from io import BytesIO
def get_login_url(callback='oob', consumer_key=None, consumer_secret=None): def get_login_url(callback='oob', consumer_key=None, consumer_secret=None):
twitter = Twitter( twitter = Twitter(
auth=OAuth('', '', consumer_key, consumer_secret), auth=OAuth('', '', consumer_key, consumer_secret),
@ -16,33 +17,42 @@ def get_login_url(callback='oob', consumer_key=None, consumer_secret=None):
oauth_token = resp['oauth_token'] oauth_token = resp['oauth_token']
oauth_token_secret = resp['oauth_token_secret'] oauth_token_secret = resp['oauth_token_secret']
token = OAuthToken(token = oauth_token, token_secret = oauth_token_secret) token = OAuthToken(token=oauth_token, token_secret=oauth_token_secret)
db.session.merge(token) db.session.merge(token)
db.session.commit() db.session.commit()
return "https://api.twitter.com/oauth/authenticate?oauth_token=%s" % (oauth_token,) return (
"https://api.twitter.com/oauth/authenticate?oauth_token=%s"
% (oauth_token,))
def account_from_api_user_object(obj): def account_from_api_user_object(obj):
return Account( return Account(
twitter_id = obj['id_str'], twitter_id=obj['id_str'],
display_name = obj['name'], display_name=obj['name'],
screen_name = obj['screen_name'], screen_name=obj['screen_name'],
avatar_url = obj['profile_image_url_https'], avatar_url=obj['profile_image_url_https'],
reported_post_count = obj['statuses_count']) reported_post_count=obj['statuses_count'])
def receive_verifier(oauth_token, oauth_verifier, consumer_key=None, consumer_secret=None):
def receive_verifier(oauth_token, oauth_verifier,
consumer_key=None, consumer_secret=None):
temp_token = OAuthToken.query.get(oauth_token) temp_token = OAuthToken.query.get(oauth_token)
if not temp_token: if not temp_token:
raise Exception("OAuth token has expired") raise Exception("OAuth token has expired")
twitter = Twitter( twitter = Twitter(
auth=OAuth(temp_token.token, temp_token.token_secret, consumer_key, consumer_secret), auth=OAuth(temp_token.token, temp_token.token_secret,
consumer_key, consumer_secret),
format='', api_version=None) format='', api_version=None)
resp = url_decode(twitter.oauth.access_token(oauth_verifier = oauth_verifier)) resp = url_decode(
twitter.oauth.access_token(oauth_verifier=oauth_verifier))
db.session.delete(temp_token) db.session.delete(temp_token)
new_token = OAuthToken(token = resp['oauth_token'], token_secret = resp['oauth_token_secret']) new_token = OAuthToken(token=resp['oauth_token'],
token_secret=resp['oauth_token_secret'])
new_token = db.session.merge(new_token) new_token = db.session.merge(new_token)
new_twitter = Twitter( new_twitter = Twitter(
auth=OAuth(new_token.token, new_token.token_secret, consumer_key, consumer_secret)) auth=OAuth(new_token.token, new_token.token_secret,
consumer_key, consumer_secret))
remote_acct = new_twitter.account.verify_credentials() remote_acct = new_twitter.account.verify_credentials()
acct = account_from_api_user_object(remote_acct) acct = account_from_api_user_object(remote_acct)
acct = db.session.merge(acct) acct = db.session.merge(acct)
@ -52,15 +62,17 @@ def receive_verifier(oauth_token, oauth_verifier, consumer_key=None, consumer_se
return new_token return new_token
def get_twitter_for_acc(account):
def get_twitter_for_acc(account):
consumer_key = app.config['TWITTER_CONSUMER_KEY'] consumer_key = app.config['TWITTER_CONSUMER_KEY']
consumer_secret = app.config['TWITTER_CONSUMER_SECRET'] consumer_secret = app.config['TWITTER_CONSUMER_SECRET']
tokens = OAuthToken.query.with_parent(account).order_by(db.desc(OAuthToken.created_at)).all() tokens = (OAuthToken.query.with_parent(account)
.order_by(db.desc(OAuthToken.created_at)).all())
for token in tokens: for token in tokens:
t = Twitter( t = Twitter(
auth=OAuth(token.token, token.token_secret, consumer_key, consumer_secret)) auth=OAuth(token.token, token.token_secret,
consumer_key, consumer_secret))
try: try:
t.account.verify_credentials() t.account.verify_credentials()
return t return t
@ -79,24 +91,30 @@ def get_twitter_for_acc(account):
account.force_log_out() account.force_log_out()
return None return None
locale.setlocale(locale.LC_TIME, 'C') locale.setlocale(locale.LC_TIME, 'C')
def post_from_api_tweet_object(tweet, post=None): def post_from_api_tweet_object(tweet, post=None):
if not post: if not post:
post = Post() post = Post()
post.twitter_id = tweet['id_str'] post.twitter_id = tweet['id_str']
try: try:
post.created_at = datetime.strptime(tweet['created_at'], '%a %b %d %H:%M:%S %z %Y') post.created_at = datetime.strptime(
tweet['created_at'], '%a %b %d %H:%M:%S %z %Y')
except ValueError: except ValueError:
post.created_at = datetime.strptime(tweet['created_at'], '%Y-%m-%d %H:%M:%S %z') post.created_at = datetime.strptime(
#whyyy tweet['created_at'], '%Y-%m-%d %H:%M:%S %z')
# whyyy
post.author_id = 'twitter:{}'.format(tweet['user']['id_str']) post.author_id = 'twitter:{}'.format(tweet['user']['id_str'])
if 'favorited' in tweet: if 'favorited' in tweet:
post.favourite = tweet['favorited'] post.favourite = tweet['favorited']
if 'entities' in tweet: if 'entities' in tweet:
post.has_media = bool('media' in tweet['entities'] and tweet['entities']['media']) post.has_media = bool(
'media' in tweet['entities'] and tweet['entities']['media'])
return post return post
def fetch_acc(account, cursor): def fetch_acc(account, cursor):
t = get_twitter_for_acc(account) t = get_twitter_for_acc(account)
if not t: if not t:
@ -106,12 +124,19 @@ def fetch_acc(account, cursor):
user = t.account.verify_credentials() user = t.account.verify_credentials()
db.session.merge(account_from_api_user_object(user)) db.session.merge(account_from_api_user_object(user))
kwargs = { 'user_id': account.twitter_id, 'count': 200, 'trim_user': True, 'tweet_mode': 'extended' } kwargs = {
'user_id': account.twitter_id,
'count': 200,
'trim_user': True,
'tweet_mode': 'extended',
}
if cursor: if cursor:
kwargs.update(cursor) kwargs.update(cursor)
if 'max_id' not in kwargs: if 'max_id' not in kwargs:
most_recent_post = Post.query.order_by(db.desc(Post.created_at)).filter(Post.author_id == account.id).first() most_recent_post = (
Post.query.order_by(db.desc(Post.created_at))
.filter(Post.author_id == account.id).first())
if most_recent_post: if most_recent_post:
kwargs['since_id'] = most_recent_post.twitter_id kwargs['since_id'] = most_recent_post.twitter_id
@ -142,11 +167,14 @@ def refresh_posts(posts):
t = get_twitter_for_acc(posts[0].author) t = get_twitter_for_acc(posts[0].author)
if not t: if not t:
raise Exception('shit idk. twitter says no') raise Exception('shit idk. twitter says no')
tweets = t.statuses.lookup(_id=",".join((post.twitter_id for post in posts)), tweets = t.statuses.lookup(
trim_user = True, tweet_mode = 'extended') _id=",".join((post.twitter_id for post in posts)),
trim_user=True, tweet_mode='extended')
refreshed_posts = list() refreshed_posts = list()
for post in posts: for post in posts:
tweet = next((tweet for tweet in tweets if tweet['id_str'] == post.twitter_id), None) tweet = next(
(tweet for tweet in tweets if tweet['id_str'] == post.twitter_id),
None)
if not tweet: if not tweet:
db.session.delete(post) db.session.delete(post)
else: else:
@ -166,7 +194,9 @@ def chunk_twitter_archive(archive_id):
ta = TwitterArchive.query.get(archive_id) ta = TwitterArchive.query.get(archive_id)
with ZipFile(BytesIO(ta.body), 'r') as zipfile: with ZipFile(BytesIO(ta.body), 'r') as zipfile:
files = [filename for filename in zipfile.namelist() if filename.startswith('data/js/tweets/') and filename.endswith('.js')] files = [filename for filename in zipfile.namelist()
if filename.startswith('data/js/tweets/')
and filename.endswith('.js')]
files.sort() files.sort()

View File

@ -3,9 +3,9 @@ import re
version_re = re.compile('(?P<tag>.+)-(?P<commits>[0-9]+)-g(?P<hash>[0-9a-f]+)') version_re = re.compile('(?P<tag>.+)-(?P<commits>[0-9]+)-g(?P<hash>[0-9a-f]+)')
def url_for_version(ver): def url_for_version(ver):
match = version_re.match(ver) match = version_re.match(ver)
if not match: if not match:
return app.config['REPO_URL'] return app.config['REPO_URL']
return app.config['COMMIT_URL'].format(**match.groupdict()) return app.config['COMMIT_URL'].format(**match.groupdict())

128
model.py
View File

@ -4,12 +4,16 @@ from app import db
import secrets import secrets
from lib import decompose_interval from lib import decompose_interval
class TimestampMixin(object): class TimestampMixin(object):
created_at = db.Column(db.DateTime, server_default=db.func.now(), nullable=False) created_at = db.Column(db.DateTime, server_default=db.func.now(),
updated_at = db.Column(db.DateTime, server_default=db.func.now(), onupdate=db.func.now(), nullable=False) nullable=False)
updated_at = db.Column(db.DateTime, server_default=db.func.now(),
onupdate=db.func.now(), nullable=False)
def touch(self): def touch(self):
self.updated_at=db.func.now() self.updated_at = db.func.now()
class RemoteIDMixin(object): class RemoteIDMixin(object):
@property @property
@ -23,7 +27,9 @@ class RemoteIDMixin(object):
if not self.id: if not self.id:
return None return None
if self.service != "twitter": if self.service != "twitter":
raise Exception("tried to get twitter id for a {} {}".format(self.service, type(self))) raise Exception(
"tried to get twitter id for a {} {}"
.format(self.service, type(self)))
return self.id.split(":")[1] return self.id.split(":")[1]
@twitter_id.setter @twitter_id.setter
@ -35,7 +41,9 @@ class RemoteIDMixin(object):
if not self.id: if not self.id:
return None return None
if self.service != "mastodon": if self.service != "mastodon":
raise Exception("tried to get mastodon instance for a {} {}".format(self.service, type(self))) raise Exception(
"tried to get mastodon instance for a {} {}"
.format(self.service, type(self)))
return self.id.split(":", 1)[1].split('@')[1] return self.id.split(":", 1)[1].split('@')[1]
@mastodon_instance.setter @mastodon_instance.setter
@ -47,7 +55,9 @@ class RemoteIDMixin(object):
if not self.id: if not self.id:
return None return None
if self.service != "mastodon": if self.service != "mastodon":
raise Exception("tried to get mastodon id for a {} {}".format(self.service, type(self))) raise Exception(
"tried to get mastodon id for a {} {}"
.format(self.service, type(self)))
return self.id.split(":", 1)[1].split('@')[0] return self.id.split(":", 1)[1].split('@')[0]
@mastodon_id.setter @mastodon_id.setter
@ -61,13 +71,20 @@ class Account(TimestampMixin, RemoteIDMixin):
__tablename__ = 'accounts' __tablename__ = 'accounts'
id = db.Column(db.String, primary_key=True) id = db.Column(db.String, primary_key=True)
policy_enabled = db.Column(db.Boolean, server_default='FALSE', nullable=False) policy_enabled = db.Column(db.Boolean, server_default='FALSE',
policy_keep_latest = db.Column(db.Integer, server_default='100', nullable=False) nullable=False)
policy_keep_favourites = db.Column(db.Boolean, server_default='TRUE', nullable=False) policy_keep_latest = db.Column(db.Integer, server_default='100',
policy_keep_media = db.Column(db.Boolean, server_default='FALSE', nullable=False) nullable=False)
policy_delete_every = db.Column(db.Interval, server_default='30 minutes', nullable=False) policy_keep_favourites = db.Column(db.Boolean, server_default='TRUE',
policy_keep_younger = db.Column(db.Interval, server_default='365 days', nullable=False) nullable=False)
policy_keep_direct = db.Column(db.Boolean, server_default='TRUE', nullable=False) policy_keep_media = db.Column(db.Boolean, server_default='FALSE',
nullable=False)
policy_delete_every = db.Column(db.Interval, server_default='30 minutes',
nullable=False)
policy_keep_younger = db.Column(db.Interval, server_default='365 days',
nullable=False)
policy_keep_direct = db.Column(db.Boolean, server_default='TRUE',
nullable=False)
display_name = db.Column(db.String) display_name = db.Column(db.String)
screen_name = db.Column(db.String) screen_name = db.Column(db.String)
@ -96,7 +113,8 @@ class Account(TimestampMixin, RemoteIDMixin):
def validate_intervals(self, key, value): def validate_intervals(self, key, value):
if not (value == timedelta(0) or value >= timedelta(minutes=1)): if not (value == timedelta(0) or value >= timedelta(minutes=1)):
value = timedelta(minutes=1) value = timedelta(minutes=1)
if key == 'policy_delete_every' and datetime.now() + value < self.next_delete: if key == 'policy_delete_every' and \
datetime.now() + value < self.next_delete:
# make sure that next delete is not in the far future # make sure that next delete is not in the far future
self.next_delete = datetime.now() + value self.next_delete = datetime.now() + value
return value return value
@ -107,7 +125,6 @@ class Account(TimestampMixin, RemoteIDMixin):
return 0 return 0
return value return value
# backref: tokens # backref: tokens
# backref: twitter_archives # backref: twitter_archives
# backref: posts # backref: posts
@ -121,20 +138,24 @@ class Account(TimestampMixin, RemoteIDMixin):
def estimate_eligible_for_delete(self): def estimate_eligible_for_delete(self):
""" """
this is an estimation because we do not know if favourite status has changed since last time a post was refreshed this is an estimation because we do not know if favourite status has
and it is unfeasible to refresh every single post every time we need to know how many posts are eligible to delete changed since last time a post was refreshed and it is unfeasible to
refresh every single post every time we need to know how many posts are
eligible to delete
""" """
latest_n_posts = Post.query.with_parent(self).order_by(db.desc(Post.created_at)).limit(self.policy_keep_latest) latest_n_posts = (Post.query.with_parent(self)
query = Post.query.with_parent(self).\ .order_by(db.desc(Post.created_at))
filter(Post.created_at <= db.func.now() - self.policy_keep_younger).\ .limit(self.policy_keep_latest))
except_(latest_n_posts) query = (Post.query.with_parent(self)
.filter(Post.created_at <=
db.func.now() - self.policy_keep_younger)
.except_(latest_n_posts))
if(self.policy_keep_favourites): if(self.policy_keep_favourites):
query = query.filter_by(favourite = False) query = query.filter_by(favourite=False)
if(self.policy_keep_media): if(self.policy_keep_media):
query = query.filter_by(has_media = False) query = query.filter_by(has_media=False)
return query.count() return query.count()
def force_log_out(self): def force_log_out(self):
Session.query.with_parent(self).delete() Session.query.with_parent(self).delete()
db.session.commit() db.session.commit()
@ -150,22 +171,36 @@ class OAuthToken(db.Model, TimestampMixin):
token = db.Column(db.String, primary_key=True) token = db.Column(db.String, primary_key=True)
token_secret = db.Column(db.String, nullable=True) token_secret = db.Column(db.String, nullable=True)
account_id = db.Column(db.String, db.ForeignKey('accounts.id', ondelete='CASCADE', onupdate='CASCADE'), nullable=True, index=True) account_id = db.Column(db.String,
account = db.relationship(Account, backref=db.backref('tokens', order_by=lambda: db.desc(OAuthToken.created_at))) db.ForeignKey('accounts.id', ondelete='CASCADE',
onupdate='CASCADE'),
nullable=True, index=True)
account = db.relationship(
Account,
backref=db.backref('tokens',
order_by=lambda: db.desc(OAuthToken.created_at))
)
# note: account_id is nullable here because we don't know what account a token is for # note: account_id is nullable here because we don't know what account a
# until we call /account/verify_credentials with it # token is for until we call /account/verify_credentials with it
class Session(db.Model, TimestampMixin): class Session(db.Model, TimestampMixin):
__tablename__ = 'sessions' __tablename__ = 'sessions'
id = db.Column(db.String, primary_key=True, default=lambda: secrets.token_urlsafe()) id = db.Column(db.String, primary_key=True,
default=lambda: secrets.token_urlsafe())
account_id = db.Column(db.String, db.ForeignKey('accounts.id', ondelete='CASCADE', onupdate='CASCADE'), nullable=False, index=True) account_id = db.Column(
db.String,
db.ForeignKey('accounts.id',
ondelete='CASCADE', onupdate='CASCADE'),
nullable=False, index=True)
account = db.relationship(Account, lazy='joined', backref='sessions') account = db.relationship(Account, lazy='joined', backref='sessions')
csrf_token = db.Column(db.String, default=lambda: secrets.token_urlsafe(), nullable=False) csrf_token = db.Column(db.String,
default=lambda: secrets.token_urlsafe(),
nullable=False)
class Post(db.Model, TimestampMixin, RemoteIDMixin): class Post(db.Model, TimestampMixin, RemoteIDMixin):
@ -173,9 +208,15 @@ class Post(db.Model, TimestampMixin, RemoteIDMixin):
id = db.Column(db.String, primary_key=True) id = db.Column(db.String, primary_key=True)
author_id = db.Column(db.String, db.ForeignKey('accounts.id', ondelete='CASCADE', onupdate='CASCADE'), nullable=False) author_id = db.Column(
author = db.relationship(Account, db.String,
backref=db.backref('posts', order_by=lambda: db.desc(Post.created_at))) db.ForeignKey('accounts.id',
ondelete='CASCADE', onupdate='CASCADE'),
nullable=False)
author = db.relationship(
Account,
backref=db.backref('posts',
order_by=lambda: db.desc(Post.created_at)))
favourite = db.Column(db.Boolean, server_default='FALSE', nullable=False) favourite = db.Column(db.Boolean, server_default='FALSE', nullable=False)
has_media = db.Column(db.Boolean, server_default='FALSE', nullable=False) has_media = db.Column(db.Boolean, server_default='FALSE', nullable=False)
@ -184,17 +225,27 @@ class Post(db.Model, TimestampMixin, RemoteIDMixin):
def __repr__(self): def __repr__(self):
return '<Post ({}, Author: {})>'.format(self.id, self.author_id) return '<Post ({}, Author: {})>'.format(self.id, self.author_id)
db.Index('ix_posts_author_id_created_at', Post.author_id, Post.created_at) db.Index('ix_posts_author_id_created_at', Post.author_id, Post.created_at)
class TwitterArchive(db.Model, TimestampMixin): class TwitterArchive(db.Model, TimestampMixin):
__tablename__ = 'twitter_archives' __tablename__ = 'twitter_archives'
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
account_id = db.Column(db.String, db.ForeignKey('accounts.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) account_id = db.Column(
account = db.relationship(Account, backref=db.backref('twitter_archives', order_by=lambda: db.desc(TwitterArchive.id))) db.String,
db.ForeignKey('accounts.id',
onupdate='CASCADE', ondelete='CASCADE'),
nullable=False)
account = db.relationship(
Account,
backref=db.backref('twitter_archives',
order_by=lambda: db.desc(TwitterArchive.id)))
body = db.deferred(db.Column(db.LargeBinary, nullable=False)) body = db.deferred(db.Column(db.LargeBinary, nullable=False))
chunks = db.Column(db.Integer) chunks = db.Column(db.Integer)
chunks_successful = db.Column(db.Integer, server_default='0', nullable=False) chunks_successful = db.Column(db.Integer,
server_default='0', nullable=False)
chunks_failed = db.Column(db.Integer, server_default='0', nullable=False) chunks_failed = db.Column(db.Integer, server_default='0', nullable=False)
def status(self): def status(self):
@ -204,8 +255,10 @@ class TwitterArchive(db.Model, TimestampMixin):
return 'successful' return 'successful'
return 'pending' return 'pending'
ProtoEnum = db.Enum('http', 'https', name='enum_protocol') ProtoEnum = db.Enum('http', 'https', name='enum_protocol')
class MastodonApp(db.Model, TimestampMixin): class MastodonApp(db.Model, TimestampMixin):
__tablename__ = 'mastodon_apps' __tablename__ = 'mastodon_apps'
@ -214,6 +267,7 @@ class MastodonApp(db.Model, TimestampMixin):
client_secret = db.Column(db.String, nullable=False) client_secret = db.Column(db.String, nullable=False)
protocol = db.Column(ProtoEnum, nullable=False) protocol = db.Column(ProtoEnum, nullable=False)
class MastodonInstance(db.Model): class MastodonInstance(db.Model):
""" """
this is for the autocomplete in the mastodon login form this is for the autocomplete in the mastodon login form

138
routes.py
View File

@ -1,4 +1,5 @@
from flask import render_template, url_for, redirect, request, g, Response, jsonify from flask import render_template, url_for, redirect, request, g, Response,\
jsonify
from datetime import datetime, timedelta from datetime import datetime, timedelta
import lib.twitter import lib.twitter
import lib.mastodon import lib.mastodon
@ -6,7 +7,7 @@ import lib
from lib.auth import require_auth, require_auth_api, csrf from lib.auth import require_auth, require_auth_api, csrf
from lib import set_session_cookie from lib import set_session_cookie
from lib import get_viewer_session, get_viewer from lib import get_viewer_session, get_viewer
from model import Account, Session, Post, TwitterArchive, MastodonApp, MastodonInstance from model import Session, TwitterArchive, MastodonApp, MastodonInstance
from app import app, db, sentry, limiter from app import app, db, sentry, limiter
import tasks import tasks
from zipfile import BadZipFile from zipfile import BadZipFile
@ -15,6 +16,7 @@ from urllib.error import URLError
import version import version
import lib.version import lib.version
@app.before_request @app.before_request
def load_viewer(): def load_viewer():
g.viewer = get_viewer_session() g.viewer = get_viewer_session()
@ -25,6 +27,7 @@ def load_viewer():
'service': g.viewer.account.service 'service': g.viewer.account.service
}) })
@app.context_processor @app.context_processor
def inject_version(): def inject_version():
return dict( return dict(
@ -32,6 +35,7 @@ def inject_version():
repo_url=lib.version.url_for_version(version.version), repo_url=lib.version.url_for_version(version.version),
) )
@app.context_processor @app.context_processor
def inject_sentry(): def inject_sentry():
if sentry: if sentry:
@ -41,6 +45,7 @@ def inject_sentry():
return dict(sentry_dsn=client_dsn) return dict(sentry_dsn=client_dsn)
return dict() return dict()
@app.after_request @app.after_request
def touch_viewer(resp): def touch_viewer(resp):
if 'viewer' in g and g.viewer: if 'viewer' in g and g.viewer:
@ -52,29 +57,40 @@ def touch_viewer(resp):
lib.brotli.brotli(app) lib.brotli.brotli(app)
@app.route('/') @app.route('/')
def index(): def index():
if g.viewer: if g.viewer:
return render_template('logged_in.html', scales=lib.interval.SCALES, return render_template(
tweet_archive_failed = 'tweet_archive_failed' in request.args, 'logged_in.html',
settings_error = 'settings_error' in request.args scales=lib.interval.SCALES,
) tweet_archive_failed='tweet_archive_failed' in request.args,
settings_error='settings_error' in request.args)
else: else:
instances = MastodonInstance.query.filter(MastodonInstance.popularity > 13).order_by(db.desc(MastodonInstance.popularity), MastodonInstance.instance).limit(5) instances = (
return render_template('index.html', MastodonInstance.query
mastodon_instances = instances, .filter(MastodonInstance.popularity > 13)
twitter_login_error = 'twitter_login_error' in request.args) .order_by(db.desc(MastodonInstance.popularity),
MastodonInstance.instance)
.limit(5))
return render_template(
'index.html',
mastodon_instances=instances,
twitter_login_error='twitter_login_error' in request.args)
@app.route('/login/twitter') @app.route('/login/twitter')
@limiter.limit('3/minute') @limiter.limit('3/minute')
def twitter_login_step1(): def twitter_login_step1():
try: try:
return redirect(lib.twitter.get_login_url( return redirect(lib.twitter.get_login_url(
callback = url_for('twitter_login_step2', _external=True), callback=url_for('twitter_login_step2', _external=True),
**app.config.get_namespace("TWITTER_") **app.config.get_namespace("TWITTER_")
)) ))
except (TwitterError, URLError): except (TwitterError, URLError):
return redirect(url_for('index', twitter_login_error='', _anchor='log_in')) return redirect(
url_for('index', twitter_login_error='', _anchor='log_in'))
@app.route('/login/twitter/callback') @app.route('/login/twitter/callback')
@limiter.limit('3/minute') @limiter.limit('3/minute')
@ -82,9 +98,11 @@ def twitter_login_step2():
try: try:
oauth_token = request.args['oauth_token'] oauth_token = request.args['oauth_token']
oauth_verifier = request.args['oauth_verifier'] oauth_verifier = request.args['oauth_verifier']
token = lib.twitter.receive_verifier(oauth_token, oauth_verifier, **app.config.get_namespace("TWITTER_")) token = lib.twitter.receive_verifier(
oauth_token, oauth_verifier,
**app.config.get_namespace("TWITTER_"))
session = Session(account_id = token.account_id) session = Session(account_id=token.account_id)
db.session.add(session) db.session.add(session)
db.session.commit() db.session.commit()
@ -94,17 +112,21 @@ def twitter_login_step2():
set_session_cookie(session, resp, app.config.get('HTTPS')) set_session_cookie(session, resp, app.config.get('HTTPS'))
return resp return resp
except (TwitterError, URLError): except (TwitterError, URLError):
return redirect(url_for('index', twitter_login_error='', _anchor='log_in')) return redirect(
url_for('index', twitter_login_error='', _anchor='log_in'))
class TweetArchiveEmptyException(Exception): class TweetArchiveEmptyException(Exception):
pass pass
@app.route('/upload_tweet_archive', methods=('POST',)) @app.route('/upload_tweet_archive', methods=('POST',))
@limiter.limit('10/10 minutes') @limiter.limit('10/10 minutes')
@require_auth @require_auth
def upload_tweet_archive(): def upload_tweet_archive():
ta = TwitterArchive(account = g.viewer.account, ta = TwitterArchive(
body = request.files['file'].read()) account=g.viewer.account,
body=request.files['file'].read())
db.session.add(ta) db.session.add(ta)
db.session.commit() db.session.commit()
@ -120,10 +142,12 @@ def upload_tweet_archive():
for filename in files: for filename in files:
tasks.import_twitter_archive_month.s(ta.id, filename).apply_async() tasks.import_twitter_archive_month.s(ta.id, filename).apply_async()
return redirect(url_for('index', _anchor='recent_archives')) return redirect(url_for('index', _anchor='recent_archives'))
except (BadZipFile, TweetArchiveEmptyException): except (BadZipFile, TweetArchiveEmptyException):
return redirect(url_for('index', tweet_archive_failed='', _anchor='tweet_archive_import')) return redirect(
url_for('index', tweet_archive_failed='',
_anchor='tweet_archive_import'))
@app.route('/settings', methods=('POST',)) @app.route('/settings', methods=('POST',))
@csrf @csrf
@ -138,9 +162,9 @@ def settings():
except ValueError: except ValueError:
return 400 return 400
return redirect(url_for('index', settings_saved='')) return redirect(url_for('index', settings_saved=''))
@app.route('/disable', methods=('POST',)) @app.route('/disable', methods=('POST',))
@csrf @csrf
@require_auth @require_auth
@ -150,24 +174,37 @@ def disable():
return redirect(url_for('index')) return redirect(url_for('index'))
@app.route('/enable', methods=('POST',)) @app.route('/enable', methods=('POST',))
@csrf @csrf
@require_auth @require_auth
def enable(): def enable():
if 'confirm' not in request.form and not g.viewer.account.policy_enabled:
risky = False
if not 'confirm' in request.form and not g.viewer.account.policy_enabled:
if g.viewer.account.policy_delete_every == timedelta(0): if g.viewer.account.policy_delete_every == timedelta(0):
approx = g.viewer.account.estimate_eligible_for_delete() approx = g.viewer.account.estimate_eligible_for_delete()
return render_template('warn.html', message=f"""You've set the time between deleting posts to 0. Every post that matches your expiration rules will be deleted within minutes. return render_template(
{ ("That's about " + str(approx) + " posts.") if approx > 0 else "" } 'warn.html',
Go ahead?""") message=f"""
You've set the time between deleting posts to 0. Every post
that matches your expiration rules will be deleted within
minutes.
{ ("That's about " + str(approx) + " posts.") if approx > 0
else "" }
Go ahead?
""")
if g.viewer.account.next_delete < datetime.now() - timedelta(days=365): if g.viewer.account.next_delete < datetime.now() - timedelta(days=365):
return render_template('warn.html', message="""Once you enable Forget, posts that match your expiration rules will be deleted <b>permanently</b>. We can't bring them back. Make sure that you won't miss them.""") return render_template(
'warn.html',
message="""
Once you enable Forget, posts that match your
expiration rules will be deleted <b>permanently</b>.
We can't bring them back. Make sure that you won't
miss them.
""")
if not g.viewer.account.policy_enabled: if not g.viewer.account.policy_enabled:
g.viewer.account.next_delete = datetime.now() + g.viewer.account.policy_delete_every g.viewer.account.next_delete = (
datetime.now() + g.viewer.account.policy_delete_every)
g.viewer.account.policy_enabled = True g.viewer.account.policy_enabled = True
db.session.commit() db.session.commit()
@ -184,6 +221,7 @@ def logout():
g.viewer = None g.viewer = None
return redirect(url_for('index')) return redirect(url_for('index'))
@app.route('/api/settings', methods=('PUT',)) @app.route('/api/settings', methods=('PUT',))
@require_auth_api @require_auth_api
def api_settings_put(): def api_settings_put():
@ -197,6 +235,7 @@ def api_settings_put():
db.session.commit() db.session.commit()
return jsonify(status='success', updated=updated) return jsonify(status='success', updated=updated)
@app.route('/api/viewer') @app.route('/api/viewer')
@require_auth_api @require_auth_api
def api_viewer(): def api_viewer():
@ -211,6 +250,7 @@ def api_viewer():
service=viewer.service, service=viewer.service,
) )
@app.route('/api/viewer/timers') @app.route('/api/viewer/timers')
@require_auth_api @require_auth_api
def api_viewer_timers(): def api_viewer_timers():
@ -224,23 +264,33 @@ def api_viewer_timers():
next_delete_rel=lib.interval.relnow(viewer.next_delete), next_delete_rel=lib.interval.relnow(viewer.next_delete),
) )
@app.route('/login/mastodon', methods=('GET', 'POST')) @app.route('/login/mastodon', methods=('GET', 'POST'))
def mastodon_login_step1(instance=None): def mastodon_login_step1(instance=None):
instances = MastodonInstance.query.filter(MastodonInstance.popularity > 1).order_by(db.desc(MastodonInstance.popularity), MastodonInstance.instance).limit(30) instances = (
MastodonInstance
.query.filter(MastodonInstance.popularity > 1)
.order_by(db.desc(MastodonInstance.popularity),
MastodonInstance.instance)
.limit(30))
instance_url = request.args.get('instance_url', None) or request.form.get('instance_url', None) instance_url = (request.args.get('instance_url', None)
or request.form.get('instance_url', None))
if not instance_url: if not instance_url:
return render_template('mastodon_login.html', instances=instances, return render_template(
address_error = request.method == 'POST', 'mastodon_login.html', instances=instances,
generic_error = 'error' in request.args address_error=request.method == 'POST',
generic_error='error' in request.args
) )
instance_url = instance_url.split("@")[-1].lower() instance_url = instance_url.split("@")[-1].lower()
callback = url_for('mastodon_login_step2', instance=instance_url, _external=True) callback = url_for('mastodon_login_step2',
instance=instance_url, _external=True)
app = lib.mastodon.get_or_create_app(instance_url, app = lib.mastodon.get_or_create_app(
instance_url,
callback, callback,
url_for('index', _external=True)) url_for('index', _external=True))
db.session.merge(app) db.session.merge(app)
@ -249,24 +299,26 @@ def mastodon_login_step1(instance=None):
return redirect(lib.mastodon.login_url(app, callback)) return redirect(lib.mastodon.login_url(app, callback))
@app.route('/login/mastodon/callback/<instance>') @app.route('/login/mastodon/callback/<instance>')
def mastodon_login_step2(instance): def mastodon_login_step2(instance_url):
code = request.args.get('code', None) code = request.args.get('code', None)
app = MastodonApp.query.get(instance) app = MastodonApp.query.get(instance_url)
if not code or not app: if not code or not app:
return redirect('mastodon_login_step1', error=True) return redirect('mastodon_login_step1', error=True)
callback = url_for('mastodon_login_step2', instance=instance, _external=True) callback = url_for('mastodon_login_step2',
instance=instance_url, _external=True)
token = lib.mastodon.receive_code(code, app, callback) token = lib.mastodon.receive_code(code, app, callback)
account = token.account account = token.account
sess = Session(account = account) sess = Session(account=account)
db.session.add(sess) db.session.add(sess)
i=MastodonInstance(instance=instance) instance = MastodonInstance(instance=instance_url)
i=db.session.merge(i) instance = db.session.merge(instance)
i.bump() instance.bump()
db.session.commit() db.session.commit()

126
tasks.py
View File

@ -1,14 +1,14 @@
from celery import Celery, Task from celery import Celery, Task
from app import app as flaskapp from app import app as flaskapp
from app import db from app import db
from model import Session, Account, TwitterArchive, Post, OAuthToken, MastodonInstance from model import Session, Account, TwitterArchive, Post, OAuthToken,\
MastodonInstance
import lib.twitter import lib.twitter
import lib.mastodon import lib.mastodon
from mastodon.Mastodon import MastodonRatelimitError from mastodon.Mastodon import MastodonRatelimitError
from twitter import TwitterError from twitter import TwitterError
from urllib.error import URLError from urllib.error import URLError
from datetime import timedelta, datetime from datetime import timedelta
from zipfile import ZipFile from zipfile import ZipFile
from io import BytesIO, TextIOWrapper from io import BytesIO, TextIOWrapper
import json import json
@ -16,7 +16,9 @@ from kombu import Queue
import random import random
import version import version
app = Celery('tasks', broker=flaskapp.config['CELERY_BROKER'], task_serializer='pickle')
app = Celery('tasks', broker=flaskapp.config['CELERY_BROKER'],
task_serializer='pickle')
app.conf.task_queues = ( app.conf.task_queues = (
Queue('default', routing_key='celery'), Queue('default', routing_key='celery'),
Queue('high_prio', routing_key='high'), Queue('high_prio', routing_key='high'),
@ -41,14 +43,20 @@ class DBTask(Task):
finally: finally:
db.session.close() db.session.close()
app.Task = DBTask app.Task = DBTask
def noop(*args, **kwargs):
pass
@app.task(autoretry_for=(TwitterError, URLError, MastodonRatelimitError)) @app.task(autoretry_for=(TwitterError, URLError, MastodonRatelimitError))
def fetch_acc(id, cursor=None): def fetch_acc(id, cursor=None):
acc = Account.query.get(id) acc = Account.query.get(id)
print(f'fetching {acc}') print(f'fetching {acc}')
try: try:
action = lambda acc, cursor: None action = noop
if(acc.service == 'twitter'): if(acc.service == 'twitter'):
action = lib.twitter.fetch_acc action = lib.twitter.fetch_acc
elif(acc.service == 'mastodon'): elif(acc.service == 'mastodon'):
@ -61,8 +69,10 @@ def fetch_acc(id, cursor=None):
acc.touch_fetch() acc.touch_fetch()
db.session.commit() db.session.commit()
@app.task @app.task
def queue_fetch_for_most_stale_accounts(min_staleness=timedelta(minutes=5), limit=20): def queue_fetch_for_most_stale_accounts(
min_staleness=timedelta(minutes=5), limit=20):
accs = Account.query\ accs = Account.query\
.join(Account.tokens).group_by(Account)\ .join(Account.tokens).group_by(Account)\
.filter(Account.last_fetch < db.func.now() - min_staleness)\ .filter(Account.last_fetch < db.func.now() - min_staleness)\
@ -70,7 +80,6 @@ def queue_fetch_for_most_stale_accounts(min_staleness=timedelta(minutes=5), limi
.limit(limit) .limit(limit)
for acc in accs: for acc in accs:
fetch_acc.s(acc.id).delay() fetch_acc.s(acc.id).delay()
#acc.touch_fetch()
db.session.commit() db.session.commit()
@ -92,8 +101,8 @@ def import_twitter_archive_month(archive_id, month_path):
post = lib.twitter.post_from_api_tweet_object(tweet) post = lib.twitter.post_from_api_tweet_object(tweet)
existing_post = db.session.query(Post).get(post.id) existing_post = db.session.query(Post).get(post.id)
if post.author_id != ta.account_id \ if post.author_id != ta.account_id or\
or existing_post and existing_post.author_id != ta.account_id: existing_post and existing_post.author_id != ta.account_id:
raise Exception("Shenanigans!") raise Exception("Shenanigans!")
post = db.session.merge(post) post = db.session.merge(post)
@ -111,81 +120,104 @@ def import_twitter_archive_month(archive_id, month_path):
@app.task @app.task
def periodic_cleanup(): def periodic_cleanup():
# delete sessions after 48 hours # delete sessions after 48 hours
Session.query.filter(Session.updated_at < (db.func.now() - timedelta(hours=48))).\ (Session.query
delete(synchronize_session=False) .filter(Session.updated_at < (db.func.now() - timedelta(hours=48)))
.delete(synchronize_session=False))
# delete twitter archives after 3 days # delete twitter archives after 3 days
TwitterArchive.query.filter(TwitterArchive.updated_at < (db.func.now() - timedelta(days=3))).\ (TwitterArchive.query
delete(synchronize_session=False) .filter(TwitterArchive.updated_at < (db.func.now() - timedelta(days=3)))
.delete(synchronize_session=False))
# delete anonymous oauth tokens after 1 day # delete anonymous oauth tokens after 1 day
OAuthToken.query.filter(OAuthToken.updated_at < (db.func.now() - timedelta(days=1)))\ (OAuthToken.query
.filter(OAuthToken.account_id == None)\ .filter(OAuthToken.updated_at < (db.func.now() - timedelta(days=1)))
.delete(synchronize_session=False) .filter(OAuthToken.account_id == None) # noqa: E711
.delete(synchronize_session=False))
# disable users with no tokens # disable users with no tokens
unreachable = Account.query.outerjoin(Account.tokens).group_by(Account).having(db.func.count(OAuthToken.token) == 0).filter(Account.policy_enabled == True) unreachable = (
Account.query
.outerjoin(Account.tokens)
.group_by(Account).having(db.func.count(OAuthToken.token) == 0)
.filter(Account.policy_enabled == True)) # noqa: E712
for account in unreachable: for account in unreachable:
account.policy_enabled = False account.policy_enabled = False
# normalise mastodon instance popularity scores # normalise mastodon instance popularity scores
biggest_instance = MastodonInstance.query.order_by(db.desc(MastodonInstance.popularity)).first() biggest_instance = (
MastodonInstance.query
.order_by(db.desc(MastodonInstance.popularity)).first())
if biggest_instance.popularity > 40: if biggest_instance.popularity > 40:
MastodonInstance.query.update({MastodonInstance.popularity: MastodonInstance.popularity * 40 / biggest_instance.popularity}) MastodonInstance.query.update({
MastodonInstance.popularity:
MastodonInstance.popularity * 40 / biggest_instance.popularity
})
db.session.commit() db.session.commit()
@app.task @app.task
def queue_deletes(): def queue_deletes():
eligible_accounts = Account.query.filter(Account.policy_enabled == True).\ eligible_accounts = (
filter(Account.next_delete < db.func.now()) Account.query.filter(Account.policy_enabled == True) # noqa: E712
.filter(Account.next_delete < db.func.now()))
for account in eligible_accounts: for account in eligible_accounts:
delete_from_account.s(account.id).apply_async() delete_from_account.s(account.id).apply_async()
@app.task(autoretry_for=(TwitterError, URLError, MastodonRatelimitError)) @app.task(autoretry_for=(TwitterError, URLError, MastodonRatelimitError))
def delete_from_account(account_id): def delete_from_account(account_id):
account = Account.query.get(account_id) account = Account.query.get(account_id)
latest_n_posts = Post.query.with_parent(account).order_by(db.desc(Post.created_at)).limit(account.policy_keep_latest) latest_n_posts = (Post.query.with_parent(account)
posts = Post.query.with_parent(account).\ .order_by(db.desc(Post.created_at))
filter(Post.created_at + account.policy_keep_younger <= db.func.now()).\ .limit(account.policy_keep_latest))
except_(latest_n_posts).\ posts = (
order_by(db.func.random()).limit(100).all() Post.query.with_parent(account)
.filter(
Post.created_at + account.policy_keep_younger <= db.func.now())
.except_(latest_n_posts)
.order_by(db.func.random())
.limit(100).all())
eligible = None eligible = None
action = lambda post: None action = noop
if account.service == 'twitter': if account.service == 'twitter':
action = lib.twitter.delete action = lib.twitter.delete
posts = refresh_posts(posts) posts = refresh_posts(posts)
eligible = list((post for post in posts if eligible = list(
(not account.policy_keep_favourites or not post.favourite) (post for post in posts if
and (not account.policy_keep_media or not post.has_media) (not account.policy_keep_favourites or not post.favourite)
)) and (not account.policy_keep_media or not post.has_media)
))
elif account.service == 'mastodon': elif account.service == 'mastodon':
action = lib.mastodon.delete action = lib.mastodon.delete
for post in posts: for post in posts:
refreshed = refresh_posts((post,)) refreshed = refresh_posts((post,))
if refreshed and \ if refreshed and \
(not account.policy_keep_favourites or not post.favourite) \ (not account.policy_keep_favourites or not post.favourite) \
and (not account.policy_keep_media or not post.has_media)\ and (not account.policy_keep_media or not post.has_media)\
and (not account.policy_keep_direct or not post.direct): and (not account.policy_keep_direct or not post.direct):
eligible = refreshed eligible = refreshed
break break
if eligible: if eligible:
if account.policy_delete_every == timedelta(0) and len(eligible) > 1: if account.policy_delete_every == timedelta(0) and len(eligible) > 1:
print("deleting all {} eligible posts for {}".format(len(eligible), account)) print("deleting all {} eligible posts for {}"
.format(len(eligible), account))
for post in eligible: for post in eligible:
account.touch_delete() account.touch_delete()
action(post) action(post)
else: else:
post = random.choice(eligible) # nosec post = random.choice(eligible) # nosec
print("deleting {}".format(post)) print("deleting {}".format(post))
account.touch_delete() account.touch_delete()
action(post) action(post)
db.session.commit() db.session.commit()
def refresh_posts(posts): def refresh_posts(posts):
posts = list(posts) posts = list(posts)
if len(posts) == 0: if len(posts) == 0:
@ -196,27 +228,36 @@ def refresh_posts(posts):
elif posts[0].service == 'mastodon': elif posts[0].service == 'mastodon':
return lib.mastodon.refresh_posts(posts) return lib.mastodon.refresh_posts(posts)
@app.task(autoretry_for=(TwitterError, URLError), throws=(MastodonRatelimitError))
@app.task(autoretry_for=(TwitterError, URLError),
throws=(MastodonRatelimitError))
def refresh_account(account_id): def refresh_account(account_id):
account = Account.query.get(account_id) account = Account.query.get(account_id)
limit = 100 limit = 100
if account.service == 'mastodon': if account.service == 'mastodon':
limit = 5 limit = 5
posts = Post.query.with_parent(account).order_by(db.asc(Post.updated_at)).limit(limit).all() posts = (Post.query.with_parent(account)
.order_by(db.asc(Post.updated_at)).limit(limit).all())
posts = refresh_posts(posts) posts = refresh_posts(posts)
account.touch_refresh() account.touch_refresh()
db.session.commit() db.session.commit()
@app.task(autoretry_for=(TwitterError, URLError), throws=(MastodonRatelimitError))
@app.task(autoretry_for=(TwitterError, URLError),
throws=(MastodonRatelimitError))
def refresh_account_with_oldest_post(): def refresh_account_with_oldest_post():
post = Post.query.outerjoin(Post.author).join(Account.tokens).group_by(Post).order_by(db.asc(Post.updated_at)).first() post = (Post.query.outerjoin(Post.author).join(Account.tokens)
.group_by(Post).order_by(db.asc(Post.updated_at)).first())
refresh_account(post.author_id) refresh_account(post.author_id)
@app.task(autoretry_for=(TwitterError, URLError), throws=(MastodonRatelimitError))
@app.task(autoretry_for=(TwitterError, URLError),
throws=(MastodonRatelimitError))
def refresh_account_with_longest_time_since_refresh(): def refresh_account_with_longest_time_since_refresh():
acc = Account.query.join(Account.tokens).group_by(Account).order_by(db.asc(Account.last_refresh)).first() acc = (Account.query.join(Account.tokens).group_by(Account)
.order_by(db.asc(Account.last_refresh)).first())
refresh_account(acc.id) refresh_account(acc.id)
@ -228,4 +269,3 @@ app.add_periodic_task(90, refresh_account_with_longest_time_since_refresh)
if __name__ == '__main__': if __name__ == '__main__':
app.worker_main() app.worker_main()

View File

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
cd $(dirname $0)/.. cd $(dirname $0)/..
git describe --tags --long --always | python -c 'from jinja2 import Template; print(Template("version=\"{{input}}\"").render(input=input()))' > version.py git describe --tags --long --always | python -c 'from jinja2 import Template; print(Template("version = \"{{input}}\"").render(input=input()))' > version.py

View File

@ -1 +1 @@
version='v0.0.8' version = 'v0.0.8'