Merge pull request #59 from hughrun/sanitise-init

Sanitise init
This commit is contained in:
Hugh Rundle 2020-09-12 14:24:47 +10:00 committed by GitHub
commit 44f764ace0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 200 additions and 4 deletions

View File

@ -2,6 +2,7 @@
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
import json import json
import os import os
import re
import subprocess import subprocess
import sys import sys
import time import time
@ -29,7 +30,14 @@ def compulsory_input(tags, name, example):
else: else:
value = input(tags[0] + name + tags[2]) value = input(tags[0] + name + tags[2])
return value sanitised = sanitise_input(value, name, tags)
if len(value) > 0 and (sanitised == "ok" or sanitised == None):
return value
else:
if len(value) > 0 and sanitised != None:
print(sanitised)
value = ""
def digit_input(tags, name, example): def digit_input(tags, name, example):
@ -53,8 +61,108 @@ def yes_no_input(tags, name):
def optional_input(tags, name, example): def optional_input(tags, name, example):
value = input(tags[0] + name + tags[1] + example + tags[2])
return value incomplete = True
while incomplete:
value = input(tags[0] + name + tags[1] + example + tags[2])
sanitised = sanitise_input(value, name, tags)
if len(value) > 0 and (sanitised == "ok" or sanitised == None):
incomplete = False
return value
elif len(value) > 0 and sanitised != None:
print(sanitised)
else:
return ""
def sanitise_input(value, input_type, tags):
"""
Check that data entered when running --init complies with requirements
"""
if input_type == "Username":
return (
"Do not include '@' in username, please try again"
if value.startswith("@")
else "ok"
)
if input_type == "Base URL":
error = value.startswith("http") or value.find(".") == -1
return (
"Provide full domain without protocol prefix (e.g. "
+ tags[1]
+ "example.social"
+ tags[2]
+ ", not "
+ tags[1]
+ "http://example.social"
+ tags[2]
+ ")"
if error
else "ok"
)
if input_type == "Toots to keep":
l = value.split(",")
def check(s):
d = s.strip()
if not d.isdigit():
return False
allnum = map(check, l)
return (
"Toot IDs must be numeric and separated with commas"
if False in list(allnum)
else "ok"
)
if input_type == "Hashtags to keep":
l = value.split(",")
def check(s):
d = s.strip()
if d.isdigit():
return False
if not re.fullmatch(r"[\w]+", d, flags=re.IGNORECASE):
return False
complies = map(check, l)
return_string = (
"Hashtags must not include '#' and must match rules at "
+ tags[0]
+ "https://docs.joinmastodon.org/user/posting/#hashtags"
+ tags[2]
)
return return_string if False in list(complies) else "ok"
if input_type == "Visibility to keep":
l = value.split(",")
viz_options = set(["public", "unlisted", "private", "direct"])
def check(s):
d = [s.strip().lower()]
intersects = viz_options.intersection(d)
if len(intersects) == 0:
return False
complies = map(check, l)
return_string = "Valid values are one or more of 'public', 'unlisted', 'private' or 'direct'"
return return_string if False in list(complies) else "ok"
if input_type == "Archive path":
path = (
os.path.expanduser(value)
if len(str(value)) > 0 and str(value)[0] == "~"
else value
)
response = (
"ok"
if os.path.exists(path)
else "That directory does not exist, please try again"
)
return response
def init(): def init():
@ -151,7 +259,7 @@ def version(vnum):
print("You are using release: \033[92mv", vnum, "\033[0m", sep="") print("You are using release: \033[92mv", vnum, "\033[0m", sep="")
print("The latest release is: \033[92m" + latest_version + "\033[0m") print("The latest release is: \033[92m" + latest_version + "\033[0m")
print( print(
"To upgrade to the most recent version run \033[92mpip3 install --update ephemetoot\033[0m" "To upgrade to the most recent version run \033[92mpip install --upgrade ephemetoot\033[0m"
) )
except Exception as e: except Exception as e:

View File

@ -291,6 +291,94 @@ def test_init(monkeypatch, tmpdir):
assert os.path.exists(os.path.join(current_dir, "config.yaml")) assert os.path.exists(os.path.join(current_dir, "config.yaml"))
def test_init_archive_path(tmpdir):
good_path = tmpdir.mkdir("archive_dir") # temporary directory for testing
wrong = ephemetoot.sanitise_input(
os.path.join(good_path, "/bad/path/"), "Archive path", None
)
ok = ephemetoot.sanitise_input(good_path, "Archive path", None)
also_ok = ephemetoot.sanitise_input("~/Desktop", "Archive path", None)
assert ok == "ok"
assert wrong == "That directory does not exist, please try again"
def test_init_sanitise_id_list():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input(
"987654321, toot_id_number", "Toots to keep", tags
)
also_wrong = ephemetoot.sanitise_input("toot_id_number", "Toots to keep", tags)
ok = ephemetoot.sanitise_input("1234598745, 999933335555", "Toots to keep", tags)
also_ok = ephemetoot.sanitise_input("1234598745", "Toots to keep", tags)
assert wrong == "Toot IDs must be numeric and separated with commas"
assert also_wrong == "Toot IDs must be numeric and separated with commas"
assert ok == "ok"
assert also_ok == "ok"
def test_init_sanitise_tag_list():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("#tag, another_tag", "Hashtags to keep", tags)
also_wrong = ephemetoot.sanitise_input("tag, another tag", "Hashtags to keep", tags)
still_wrong = ephemetoot.sanitise_input("tag, 12345", "Hashtags to keep", tags)
ok = ephemetoot.sanitise_input("tag123, another_TAG", "Hashtags to keep", tags)
also_ok = ephemetoot.sanitise_input("single_tag", "Hashtags to keep", tags)
error = (
"Hashtags must not include '#' and must match rules at "
+ tags[0]
+ "https://docs.joinmastodon.org/user/posting/#hashtags"
+ tags[2]
)
assert ok == "ok"
assert also_ok == "ok"
assert wrong == error
assert also_wrong == error
assert still_wrong == error
def test_init_sanitise_url():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("http://example.social", "Base URL", tags)
also_wrong = ephemetoot.sanitise_input("http://example.social", "Base URL", tags)
ok = ephemetoot.sanitise_input("example.social", "Base URL", tags)
assert (
wrong
== "Provide full domain without protocol prefix (e.g. \033[2mexample.social\033[0m, not \033[2mhttp://example.social\033[0m)"
)
assert ok == "ok"
def test_init_sanitise_username():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("@alice", "Username", tags)
ok = ephemetoot.sanitise_input("alice", "Username", tags)
assert wrong == "Do not include '@' in username, please try again"
assert ok == "ok"
def test_init_sanitise_visibility_list():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("nonexistent", "Visibility to keep", tags)
also_wrong = ephemetoot.sanitise_input("direct public", "Visibility to keep", tags)
ok = ephemetoot.sanitise_input("direct", "Visibility to keep", tags)
also_ok = ephemetoot.sanitise_input("direct, public", "Visibility to keep", tags)
error = (
"Valid values are one or more of 'public', 'unlisted', 'private' or 'direct'"
)
assert ok == "ok"
assert also_ok == "ok"
assert wrong == error
assert also_wrong == error
def test_jsondefault(): def test_jsondefault():
d = ephemetoot.jsondefault(toot.created_at) d = ephemetoot.jsondefault(toot.created_at)
assert d == "2020-05-09T02:17:18.598000+00:00" assert d == "2020-05-09T02:17:18.598000+00:00"